From 2678005448ed1663d662c22c5c2481e1b797174b Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 22 Dec 2025 19:33:43 +0800 Subject: [PATCH 1/5] feat: inject max_token_size and add client-side truncation for OpenAI - Auto-inject max_token_size in wrapper - Implement OpenAI client-side truncation - Update Gemini/Ollama embed signatures - Relax summary token warning threshold - Update server startup logging --- lightrag/api/lightrag_server.py | 4 ++- lightrag/llm/gemini.py | 10 ++++++ lightrag/llm/ollama.py | 27 +++++++++++++- lightrag/llm/openai.py | 64 +++++++++++++++++++++++++++++++-- lightrag/operate.py | 6 ++-- lightrag/utils.py | 7 ++++ 6 files changed, 111 insertions(+), 7 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 9151f02e..8b1560bb 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -966,7 +966,9 @@ def create_app(args): f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})" ) else: - logger.info("Embedding max_token_size: not set (90% token warning disabled)") + logger.info( + "Embedding max_token_size: None (Embedding token limit is disabled." + ) # Configure rerank function based on args.rerank_bindingparameter rerank_model_func = None diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index 5e438ceb..0692ce6a 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -476,6 +476,7 @@ async def gemini_embed( base_url: str | None = None, api_key: str | None = None, embedding_dim: int | None = None, + max_token_size: int | None = None, task_type: str = "RETRIEVAL_DOCUMENT", timeout: int | None = None, token_tracker: Any | None = None, @@ -497,6 +498,11 @@ async def gemini_embed( The dimension is controlled by the @wrap_embedding_func_with_attrs decorator or the EMBEDDING_DIM environment variable. Supported range: 128-3072. Recommended values: 768, 1536, 3072. + max_token_size: Maximum tokens per text. This parameter is automatically + injected by the EmbeddingFunc wrapper when the underlying function + signature supports it (via inspect.signature check). Gemini API will + automatically truncate texts exceeding this limit (autoTruncate=True + by default), so no client-side truncation is needed. task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT". Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY, @@ -516,7 +522,11 @@ async def gemini_embed( - For dimension 3072: Embeddings are already normalized by the API - For dimensions < 3072: Embeddings are L2-normalized after retrieval - Normalization ensures accurate semantic similarity via cosine distance + - Gemini API automatically truncates texts exceeding max_token_size (autoTruncate=True) """ + # Note: max_token_size is received but not used for client-side truncation. + # Gemini API handles truncation automatically with autoTruncate=True (default). + _ = max_token_size # Acknowledge parameter to avoid unused variable warning loop = asyncio.get_running_loop() key = _ensure_api_key(api_key) diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index 62269296..ab10a42b 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -176,8 +176,33 @@ async def ollama_model_complete( embedding_dim=1024, max_token_size=8192, model_name="bge-m3:latest" ) async def ollama_embed( - texts: list[str], embed_model: str = "bge-m3:latest", **kwargs + texts: list[str], + embed_model: str = "bge-m3:latest", + max_token_size: int | None = None, + **kwargs, ) -> np.ndarray: + """Generate embeddings using Ollama's API. + + Args: + texts: List of texts to embed. + embed_model: The Ollama embedding model to use. Default is "bge-m3:latest". + max_token_size: Maximum tokens per text. This parameter is automatically + injected by the EmbeddingFunc wrapper when the underlying function + signature supports it (via inspect.signature check). Ollama will + automatically truncate texts exceeding the model's context length + (num_ctx), so no client-side truncation is needed. + **kwargs: Additional arguments passed to the Ollama client. + + Returns: + A numpy array of embeddings, one per input text. + + Note: + - Ollama API automatically truncates texts exceeding the model's context length + - The max_token_size parameter is received but not used for client-side truncation + """ + # Note: max_token_size is received but not used for client-side truncation. + # Ollama API handles truncation automatically based on the model's num_ctx setting. + _ = max_token_size # Acknowledge parameter to avoid unused variable warning api_key = kwargs.pop("api_key", None) if not api_key: api_key = os.getenv("OLLAMA_API_KEY") diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index b49cac71..6b0f9d48 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -5,6 +5,7 @@ import logging from collections.abc import AsyncIterator import pipmaster as pm +import tiktoken # install specific modules if not pm.is_installed("openai"): @@ -74,6 +75,30 @@ class InvalidResponseError(Exception): pass +# Module-level cache for tiktoken encodings +_TIKTOKEN_ENCODING_CACHE: dict[str, Any] = {} + + +def _get_tiktoken_encoding_for_model(model: str) -> Any: + """Get tiktoken encoding for the specified model with caching. + + Args: + model: The model name to get encoding for. + + Returns: + The tiktoken encoding for the model. + """ + if model not in _TIKTOKEN_ENCODING_CACHE: + try: + _TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model) + except KeyError: + logger.debug( + f"Encoding for model '{model}' not found, falling back to cl100k_base" + ) + _TIKTOKEN_ENCODING_CACHE[model] = tiktoken.get_encoding("cl100k_base") + return _TIKTOKEN_ENCODING_CACHE[model] + + def create_openai_async_client( api_key: str | None = None, base_url: str | None = None, @@ -695,15 +720,17 @@ async def openai_embed( base_url: str | None = None, api_key: str | None = None, embedding_dim: int | None = None, + max_token_size: int | None = None, client_configs: dict[str, Any] | None = None, token_tracker: Any | None = None, use_azure: bool = False, azure_deployment: str | None = None, api_version: str | None = None, ) -> np.ndarray: - """Generate embeddings for a list of texts using OpenAI's API. + """Generate embeddings for a list of texts using OpenAI's API with automatic text truncation. - This function supports both standard OpenAI and Azure OpenAI services. + This function supports both standard OpenAI and Azure OpenAI services. It automatically + truncates texts that exceed the model's token limit to prevent API errors. Args: texts: List of texts to embed. @@ -719,6 +746,10 @@ async def openai_embed( The dimension is controlled by the @wrap_embedding_func_with_attrs decorator. Manually passing a different value will trigger a warning and be ignored. When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction. + max_token_size: Maximum tokens per text. Texts exceeding this limit will be truncated. + **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper + when the underlying function signature supports it (via inspect.signature check). + The value is controlled by the @wrap_embedding_func_with_attrs decorator. client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). Supports proxy configuration, @@ -740,6 +771,35 @@ async def openai_embed( RateLimitError: If the OpenAI API rate limit is exceeded. APITimeoutError: If the OpenAI API request times out. """ + # Apply text truncation if max_token_size is provided + if max_token_size is not None and max_token_size > 0: + encoding = _get_tiktoken_encoding_for_model(model) + truncated_texts = [] + truncation_count = 0 + + for text in texts: + if not text: + truncated_texts.append(text) + continue + + tokens = encoding.encode(text) + if len(tokens) > max_token_size: + truncated_tokens = tokens[:max_token_size] + truncated_texts.append(encoding.decode(truncated_tokens)) + truncation_count += 1 + logger.debug( + f"Text truncated from {len(tokens)} to {max_token_size} tokens" + ) + else: + truncated_texts.append(text) + + if truncation_count > 0: + logger.info( + f"Truncated {truncation_count}/{len(texts)} texts to fit token limit ({max_token_size})" + ) + + texts = truncated_texts + # Create the OpenAI client (supports both OpenAI and Azure) openai_async_client = create_openai_async_client( api_key=api_key, diff --git a/lightrag/operate.py b/lightrag/operate.py index faab7f26..207f0614 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -365,12 +365,12 @@ async def _summarize_descriptions( if embedding_token_limit is not None and summary: tokenizer = global_config["tokenizer"] summary_token_count = len(tokenizer.encode(summary)) - threshold = int(embedding_token_limit * 0.9) + threshold = int(embedding_token_limit) if summary_token_count > threshold: logger.warning( - f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit " - f"({embedding_token_limit}) for {description_type}: {description_name}" + f"Summary tokens({summary_token_count}) exceeds embedding_token_limit({embedding_token_limit}) " + f" for {description_type}: {description_name}" ) return summary diff --git a/lightrag/utils.py b/lightrag/utils.py index cd63f6e4..cd3f26d1 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -6,6 +6,7 @@ import sys import asyncio import html import csv +import inspect import json import logging import logging.handlers @@ -492,6 +493,12 @@ class EmbeddingFunc: # Inject embedding_dim from decorator kwargs["embedding_dim"] = self.embedding_dim + # Check if underlying function supports max_token_size and inject if not provided + if self.max_token_size is not None and "max_token_size" not in kwargs: + sig = inspect.signature(self.func) + if "max_token_size" in sig.parameters: + kwargs["max_token_size"] = self.max_token_size + # Call the actual embedding function result = await self.func(*args, **kwargs) From 9c9dfcd48865cd9c3809b479f417a117d6fec979 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 22 Dec 2025 20:06:22 +0800 Subject: [PATCH 2/5] Fix tiktoken cache env var and support encoding names - Set cache env var before import - Support raw encoding names - Add cl100k_base to default list - Improve cache path resolution --- lightrag/tools/download_cache.py | 57 ++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/lightrag/tools/download_cache.py b/lightrag/tools/download_cache.py index 43a21f0b..c16473c8 100644 --- a/lightrag/tools/download_cache.py +++ b/lightrag/tools/download_cache.py @@ -10,16 +10,45 @@ import sys from pathlib import Path +# Known tiktoken encoding names (not model names) +# These need to be loaded with tiktoken.get_encoding() instead of tiktoken.encoding_for_model() +TIKTOKEN_ENCODING_NAMES = {"cl100k_base", "p50k_base", "r50k_base", "o200k_base"} + + def download_tiktoken_cache(cache_dir: str = None, models: list = None): """Download tiktoken models to local cache Args: - cache_dir: Directory to store the cache files. If None, uses default location. - models: List of model names to download. If None, downloads common models. + cache_dir: Directory to store the cache files. If None, uses tiktoken's default location. + models: List of model names or encoding names to download. If None, downloads common ones. Returns: - Tuple of (success_count, failed_models) + Tuple of (success_count, failed_models, actual_cache_dir) """ + # If user specified a cache directory, set it BEFORE importing tiktoken + # tiktoken reads TIKTOKEN_CACHE_DIR at import time + user_specified_cache = cache_dir is not None + + if user_specified_cache: + cache_dir = os.path.abspath(cache_dir) + os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir + cache_path = Path(cache_dir) + cache_path.mkdir(parents=True, exist_ok=True) + print(f"Using specified cache directory: {cache_dir}") + else: + # Check if TIKTOKEN_CACHE_DIR is already set in environment + env_cache_dir = os.environ.get("TIKTOKEN_CACHE_DIR") + if env_cache_dir: + cache_dir = env_cache_dir + print(f"Using TIKTOKEN_CACHE_DIR from environment: {cache_dir}") + else: + # Use tiktoken's default location (tempdir/data-gym-cache) + import tempfile + + cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache") + print(f"Using tiktoken default cache directory: {cache_dir}") + + # Now import tiktoken (it will use the cache directory we determined) try: import tiktoken except ImportError: @@ -27,19 +56,6 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None): print("Install with: pip install tiktoken") sys.exit(1) - # Set cache directory if provided - if cache_dir: - cache_dir = os.path.abspath(cache_dir) - os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir - cache_path = Path(cache_dir) - cache_path.mkdir(parents=True, exist_ok=True) - print(f"Using cache directory: {cache_dir}") - else: - cache_dir = os.environ.get( - "TIKTOKEN_CACHE_DIR", str(Path.home() / ".tiktoken_cache") - ) - print(f"Using default cache directory: {cache_dir}") - # Common models used by LightRAG and OpenAI if models is None: models = [ @@ -50,6 +66,7 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None): "text-embedding-ada-002", # Legacy embedding model "text-embedding-3-small", # Small embedding model "text-embedding-3-large", # Large embedding model + "cl100k_base", # Default encoding for LightRAG ] print(f"\nDownloading {len(models)} tiktoken models...") @@ -61,13 +78,17 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None): for i, model in enumerate(models, 1): try: print(f"[{i}/{len(models)}] Downloading {model}...", end=" ", flush=True) - encoding = tiktoken.encoding_for_model(model) + # Use get_encoding for encoding names, encoding_for_model for model names + if model in TIKTOKEN_ENCODING_NAMES: + encoding = tiktoken.get_encoding(model) + else: + encoding = tiktoken.encoding_for_model(model) # Trigger download by encoding a test string encoding.encode("test") print("✓ Done") success_count += 1 except KeyError as e: - print(f"✗ Failed: Unknown model '{model}'") + print(f"✗ Failed: Unknown model or encoding '{model}'") failed_models.append((model, str(e))) except Exception as e: print(f"✗ Failed: {e}") From 5a455985dfa5046754bef2f58844b46e4f4d3ee6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 22 Dec 2025 20:12:21 +0800 Subject: [PATCH 3/5] Update default cache path comment in docs --- docs/OfflineDeployment.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/OfflineDeployment.md b/docs/OfflineDeployment.md index e186dda0..3b857424 100644 --- a/docs/OfflineDeployment.md +++ b/docs/OfflineDeployment.md @@ -120,7 +120,7 @@ Tiktoken downloads BPE encoding models on first use. In offline environments, yo After installing LightRAG, use the built-in command: ```bash -# Download to default location (~/.tiktoken_cache) +# Download to default location (see output for exact path) lightrag-download-cache # Download to specific directory From 3527c68daebea174e3f383382770c07b8bf08670 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 22 Dec 2025 20:13:45 +0800 Subject: [PATCH 4/5] Fix table formatting in OfflineDeployment docs --- docs/OfflineDeployment.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/OfflineDeployment.md b/docs/OfflineDeployment.md index 3b857424..54eb7eb5 100644 --- a/docs/OfflineDeployment.md +++ b/docs/OfflineDeployment.md @@ -75,7 +75,7 @@ LightRAG provides flexible dependency groups for different use cases: ### Available Dependency Groups | Group | Description | Use Case | -|-------|-------------|----------| +| ----- | ----------- | -------- | | `api` | API server + document processing | FastAPI server with PDF, DOCX, PPTX, XLSX support | | `offline-storage` | Storage backends | Redis, Neo4j, MongoDB, PostgreSQL, etc. | | `offline-llm` | LLM providers | OpenAI, Anthropic, Ollama, etc. | From e2a95ab5a678b2e8e289fdb00f27eafe448d06bc Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 22 Dec 2025 20:15:22 +0800 Subject: [PATCH 5/5] Fix missing parenthesis in log message * Fix typo in log message * Add missing closing parenthesis --- lightrag/api/lightrag_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8b1560bb..f128c2a8 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -967,7 +967,7 @@ def create_app(args): ) else: logger.info( - "Embedding max_token_size: None (Embedding token limit is disabled." + "Embedding max_token_size: None (Embedding token limit is disabled)." ) # Configure rerank function based on args.rerank_bindingparameter