Improve EmbeddingFunc unwrapping safety and docs

- Limit nesting depth in EmbeddingFunc
- Log warning when auto-unwrapping
- Fix typo in EmbeddingFunc docstring
- Document .func usage in examples
This commit is contained in:
yangdx
2025-12-22 14:49:59 +08:00
parent 705e8c6c8e
commit 8aeb234aaa
4 changed files with 31 additions and 5 deletions

View File

@@ -93,11 +93,15 @@ async def initialize_rag():
"options": {"num_ctx": 8192},
"timeout": int(os.getenv("TIMEOUT", "300")),
},
# Note: ollama_embed is decorated with @wrap_embedding_func_with_attrs,
# which wraps it in an EmbeddingFunc. Using .func accesses the original
# unwrapped function to avoid double wrapping when we create our own
# EmbeddingFunc with custom configuration (embedding_dim, max_token_size).
embedding_func=EmbeddingFunc(
embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")),
func=partial(
ollama_embed.func, # Use .func to access the unwrapped function
ollama_embed.func, # Access the unwrapped function to avoid double EmbeddingFunc wrapping
embed_model=os.getenv("EMBEDDING_MODEL", "bge-m3:latest"),
host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"),
),

View File

@@ -109,11 +109,15 @@ async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
# Note: ollama_embed is decorated with @wrap_embedding_func_with_attrs,
# which wraps it in an EmbeddingFunc. Using .func accesses the original
# unwrapped function to avoid double wrapping when we create our own
# EmbeddingFunc with custom configuration (embedding_dim, max_token_size).
embedding_func=EmbeddingFunc(
embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")),
func=partial(
ollama_embed.func, # Use .func to access the unwrapped function
ollama_embed.func, # Access the unwrapped function to avoid double EmbeddingFunc wrapping
embed_model=os.getenv("EMBEDDING_MODEL", "bge-m3:latest"),
host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"),
),

View File

@@ -26,6 +26,10 @@ EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
async def embedding_func(texts: list[str]) -> np.ndarray:
# Note: openai_embed is decorated with @wrap_embedding_func_with_attrs,
# which wraps it in an EmbeddingFunc. Using .func accesses the original
# unwrapped function to avoid double wrapping when we create our own
# EmbeddingFunc with custom configuration in create_embedding_function_instance().
return await openai_embed.func(
texts,
model=EMBEDDING_MODEL,

View File

@@ -434,8 +434,7 @@ class EmbeddingFunc:
func: The actual embedding function to wrap
max_token_size: Enable embedding token limit checking for description summarization(Set embedding_token_limit in LightRAG)
send_dimensions: Whether to inject embedding_dim argument to underlying function
model_name: Model name for implementating workspace data isolation in vector DB
)
model_name: Model name for implementing workspace data isolation in vector DB
"""
embedding_dim: int
@@ -443,7 +442,7 @@ class EmbeddingFunc:
max_token_size: int | None = None
send_dimensions: bool = False
model_name: str | None = (
None # Model name for implementating workspace data isolation in vector DB
None # Model name for implementing workspace data isolation in vector DB
)
def __post_init__(self):
@@ -455,10 +454,25 @@ class EmbeddingFunc:
that only the outermost wrapper's configuration is applied.
"""
# Check if func is already an EmbeddingFunc instance and unwrap it
max_unwrap_depth = 3 # Safety limit to prevent infinite loops
unwrap_count = 0
while isinstance(self.func, EmbeddingFunc):
unwrap_count += 1
if unwrap_count > max_unwrap_depth:
raise ValueError(
f"EmbeddingFunc unwrap depth exceeded {max_unwrap_depth}. "
"Possible circular reference detected."
)
# Unwrap to get the original function
self.func = self.func.func
if unwrap_count > 0:
logger.warning(
f"Detected nested EmbeddingFunc wrapping (depth: {unwrap_count}), "
"auto-unwrapped to prevent configuration conflicts. "
"Consider using .func to access the unwrapped function directly."
)
async def __call__(self, *args, **kwargs) -> np.ndarray:
# Only inject embedding_dim when send_dimensions is True
if self.send_dimensions: