diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 9151f02e..8b1560bb 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -966,7 +966,9 @@ def create_app(args): f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})" ) else: - logger.info("Embedding max_token_size: not set (90% token warning disabled)") + logger.info( + "Embedding max_token_size: None (Embedding token limit is disabled." + ) # Configure rerank function based on args.rerank_bindingparameter rerank_model_func = None diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index 5e438ceb..0692ce6a 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -476,6 +476,7 @@ async def gemini_embed( base_url: str | None = None, api_key: str | None = None, embedding_dim: int | None = None, + max_token_size: int | None = None, task_type: str = "RETRIEVAL_DOCUMENT", timeout: int | None = None, token_tracker: Any | None = None, @@ -497,6 +498,11 @@ async def gemini_embed( The dimension is controlled by the @wrap_embedding_func_with_attrs decorator or the EMBEDDING_DIM environment variable. Supported range: 128-3072. Recommended values: 768, 1536, 3072. + max_token_size: Maximum tokens per text. This parameter is automatically + injected by the EmbeddingFunc wrapper when the underlying function + signature supports it (via inspect.signature check). Gemini API will + automatically truncate texts exceeding this limit (autoTruncate=True + by default), so no client-side truncation is needed. task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT". Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY, @@ -516,7 +522,11 @@ async def gemini_embed( - For dimension 3072: Embeddings are already normalized by the API - For dimensions < 3072: Embeddings are L2-normalized after retrieval - Normalization ensures accurate semantic similarity via cosine distance + - Gemini API automatically truncates texts exceeding max_token_size (autoTruncate=True) """ + # Note: max_token_size is received but not used for client-side truncation. + # Gemini API handles truncation automatically with autoTruncate=True (default). + _ = max_token_size # Acknowledge parameter to avoid unused variable warning loop = asyncio.get_running_loop() key = _ensure_api_key(api_key) diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index 62269296..ab10a42b 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -176,8 +176,33 @@ async def ollama_model_complete( embedding_dim=1024, max_token_size=8192, model_name="bge-m3:latest" ) async def ollama_embed( - texts: list[str], embed_model: str = "bge-m3:latest", **kwargs + texts: list[str], + embed_model: str = "bge-m3:latest", + max_token_size: int | None = None, + **kwargs, ) -> np.ndarray: + """Generate embeddings using Ollama's API. + + Args: + texts: List of texts to embed. + embed_model: The Ollama embedding model to use. Default is "bge-m3:latest". + max_token_size: Maximum tokens per text. This parameter is automatically + injected by the EmbeddingFunc wrapper when the underlying function + signature supports it (via inspect.signature check). Ollama will + automatically truncate texts exceeding the model's context length + (num_ctx), so no client-side truncation is needed. + **kwargs: Additional arguments passed to the Ollama client. + + Returns: + A numpy array of embeddings, one per input text. + + Note: + - Ollama API automatically truncates texts exceeding the model's context length + - The max_token_size parameter is received but not used for client-side truncation + """ + # Note: max_token_size is received but not used for client-side truncation. + # Ollama API handles truncation automatically based on the model's num_ctx setting. + _ = max_token_size # Acknowledge parameter to avoid unused variable warning api_key = kwargs.pop("api_key", None) if not api_key: api_key = os.getenv("OLLAMA_API_KEY") diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index b49cac71..6b0f9d48 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -5,6 +5,7 @@ import logging from collections.abc import AsyncIterator import pipmaster as pm +import tiktoken # install specific modules if not pm.is_installed("openai"): @@ -74,6 +75,30 @@ class InvalidResponseError(Exception): pass +# Module-level cache for tiktoken encodings +_TIKTOKEN_ENCODING_CACHE: dict[str, Any] = {} + + +def _get_tiktoken_encoding_for_model(model: str) -> Any: + """Get tiktoken encoding for the specified model with caching. + + Args: + model: The model name to get encoding for. + + Returns: + The tiktoken encoding for the model. + """ + if model not in _TIKTOKEN_ENCODING_CACHE: + try: + _TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model) + except KeyError: + logger.debug( + f"Encoding for model '{model}' not found, falling back to cl100k_base" + ) + _TIKTOKEN_ENCODING_CACHE[model] = tiktoken.get_encoding("cl100k_base") + return _TIKTOKEN_ENCODING_CACHE[model] + + def create_openai_async_client( api_key: str | None = None, base_url: str | None = None, @@ -695,15 +720,17 @@ async def openai_embed( base_url: str | None = None, api_key: str | None = None, embedding_dim: int | None = None, + max_token_size: int | None = None, client_configs: dict[str, Any] | None = None, token_tracker: Any | None = None, use_azure: bool = False, azure_deployment: str | None = None, api_version: str | None = None, ) -> np.ndarray: - """Generate embeddings for a list of texts using OpenAI's API. + """Generate embeddings for a list of texts using OpenAI's API with automatic text truncation. - This function supports both standard OpenAI and Azure OpenAI services. + This function supports both standard OpenAI and Azure OpenAI services. It automatically + truncates texts that exceed the model's token limit to prevent API errors. Args: texts: List of texts to embed. @@ -719,6 +746,10 @@ async def openai_embed( The dimension is controlled by the @wrap_embedding_func_with_attrs decorator. Manually passing a different value will trigger a warning and be ignored. When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction. + max_token_size: Maximum tokens per text. Texts exceeding this limit will be truncated. + **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper + when the underlying function signature supports it (via inspect.signature check). + The value is controlled by the @wrap_embedding_func_with_attrs decorator. client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). Supports proxy configuration, @@ -740,6 +771,35 @@ async def openai_embed( RateLimitError: If the OpenAI API rate limit is exceeded. APITimeoutError: If the OpenAI API request times out. """ + # Apply text truncation if max_token_size is provided + if max_token_size is not None and max_token_size > 0: + encoding = _get_tiktoken_encoding_for_model(model) + truncated_texts = [] + truncation_count = 0 + + for text in texts: + if not text: + truncated_texts.append(text) + continue + + tokens = encoding.encode(text) + if len(tokens) > max_token_size: + truncated_tokens = tokens[:max_token_size] + truncated_texts.append(encoding.decode(truncated_tokens)) + truncation_count += 1 + logger.debug( + f"Text truncated from {len(tokens)} to {max_token_size} tokens" + ) + else: + truncated_texts.append(text) + + if truncation_count > 0: + logger.info( + f"Truncated {truncation_count}/{len(texts)} texts to fit token limit ({max_token_size})" + ) + + texts = truncated_texts + # Create the OpenAI client (supports both OpenAI and Azure) openai_async_client = create_openai_async_client( api_key=api_key, diff --git a/lightrag/operate.py b/lightrag/operate.py index faab7f26..207f0614 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -365,12 +365,12 @@ async def _summarize_descriptions( if embedding_token_limit is not None and summary: tokenizer = global_config["tokenizer"] summary_token_count = len(tokenizer.encode(summary)) - threshold = int(embedding_token_limit * 0.9) + threshold = int(embedding_token_limit) if summary_token_count > threshold: logger.warning( - f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit " - f"({embedding_token_limit}) for {description_type}: {description_name}" + f"Summary tokens({summary_token_count}) exceeds embedding_token_limit({embedding_token_limit}) " + f" for {description_type}: {description_name}" ) return summary diff --git a/lightrag/utils.py b/lightrag/utils.py index cd63f6e4..cd3f26d1 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -6,6 +6,7 @@ import sys import asyncio import html import csv +import inspect import json import logging import logging.handlers @@ -492,6 +493,12 @@ class EmbeddingFunc: # Inject embedding_dim from decorator kwargs["embedding_dim"] = self.embedding_dim + # Check if underlying function supports max_token_size and inject if not provided + if self.max_token_size is not None and "max_token_size" not in kwargs: + sig = inspect.signature(self.func) + if "max_token_size" in sig.parameters: + kwargs["max_token_size"] = self.max_token_size + # Call the actual embedding function result = await self.func(*args, **kwargs)