Improve EmbeddingFunc unwrapping safety and docs
- Limit nesting depth in EmbeddingFunc - Log warning when auto-unwrapping - Fix typo in EmbeddingFunc docstring - Document .func usage in examples
This commit is contained in:
@@ -93,11 +93,15 @@ async def initialize_rag():
|
||||
"options": {"num_ctx": 8192},
|
||||
"timeout": int(os.getenv("TIMEOUT", "300")),
|
||||
},
|
||||
# Note: ollama_embed is decorated with @wrap_embedding_func_with_attrs,
|
||||
# which wraps it in an EmbeddingFunc. Using .func accesses the original
|
||||
# unwrapped function to avoid double wrapping when we create our own
|
||||
# EmbeddingFunc with custom configuration (embedding_dim, max_token_size).
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
|
||||
max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")),
|
||||
func=partial(
|
||||
ollama_embed.func, # Use .func to access the unwrapped function
|
||||
ollama_embed.func, # Access the unwrapped function to avoid double EmbeddingFunc wrapping
|
||||
embed_model=os.getenv("EMBEDDING_MODEL", "bge-m3:latest"),
|
||||
host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"),
|
||||
),
|
||||
|
||||
@@ -109,11 +109,15 @@ async def initialize_rag():
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
# Note: ollama_embed is decorated with @wrap_embedding_func_with_attrs,
|
||||
# which wraps it in an EmbeddingFunc. Using .func accesses the original
|
||||
# unwrapped function to avoid double wrapping when we create our own
|
||||
# EmbeddingFunc with custom configuration (embedding_dim, max_token_size).
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
|
||||
max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")),
|
||||
func=partial(
|
||||
ollama_embed.func, # Use .func to access the unwrapped function
|
||||
ollama_embed.func, # Access the unwrapped function to avoid double EmbeddingFunc wrapping
|
||||
embed_model=os.getenv("EMBEDDING_MODEL", "bge-m3:latest"),
|
||||
host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"),
|
||||
),
|
||||
|
||||
@@ -26,6 +26,10 @@ EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
# Note: openai_embed is decorated with @wrap_embedding_func_with_attrs,
|
||||
# which wraps it in an EmbeddingFunc. Using .func accesses the original
|
||||
# unwrapped function to avoid double wrapping when we create our own
|
||||
# EmbeddingFunc with custom configuration in create_embedding_function_instance().
|
||||
return await openai_embed.func(
|
||||
texts,
|
||||
model=EMBEDDING_MODEL,
|
||||
|
||||
@@ -434,8 +434,7 @@ class EmbeddingFunc:
|
||||
func: The actual embedding function to wrap
|
||||
max_token_size: Enable embedding token limit checking for description summarization(Set embedding_token_limit in LightRAG)
|
||||
send_dimensions: Whether to inject embedding_dim argument to underlying function
|
||||
model_name: Model name for implementating workspace data isolation in vector DB
|
||||
)
|
||||
model_name: Model name for implementing workspace data isolation in vector DB
|
||||
"""
|
||||
|
||||
embedding_dim: int
|
||||
@@ -443,7 +442,7 @@ class EmbeddingFunc:
|
||||
max_token_size: int | None = None
|
||||
send_dimensions: bool = False
|
||||
model_name: str | None = (
|
||||
None # Model name for implementating workspace data isolation in vector DB
|
||||
None # Model name for implementing workspace data isolation in vector DB
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -455,10 +454,25 @@ class EmbeddingFunc:
|
||||
that only the outermost wrapper's configuration is applied.
|
||||
"""
|
||||
# Check if func is already an EmbeddingFunc instance and unwrap it
|
||||
max_unwrap_depth = 3 # Safety limit to prevent infinite loops
|
||||
unwrap_count = 0
|
||||
while isinstance(self.func, EmbeddingFunc):
|
||||
unwrap_count += 1
|
||||
if unwrap_count > max_unwrap_depth:
|
||||
raise ValueError(
|
||||
f"EmbeddingFunc unwrap depth exceeded {max_unwrap_depth}. "
|
||||
"Possible circular reference detected."
|
||||
)
|
||||
# Unwrap to get the original function
|
||||
self.func = self.func.func
|
||||
|
||||
if unwrap_count > 0:
|
||||
logger.warning(
|
||||
f"Detected nested EmbeddingFunc wrapping (depth: {unwrap_count}), "
|
||||
"auto-unwrapped to prevent configuration conflicts. "
|
||||
"Consider using .func to access the unwrapped function directly."
|
||||
)
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||
# Only inject embedding_dim when send_dimensions is True
|
||||
if self.send_dimensions:
|
||||
|
||||
Reference in New Issue
Block a user