Merge pull request #2523 from danielaskdd/embedding-max-token

feat: Add Automatic Text Truncation Support for Embedding Functions
This commit is contained in:
Daniel.y
2025-12-22 20:26:46 +08:00
committed by GitHub
8 changed files with 152 additions and 27 deletions

View File

@@ -75,7 +75,7 @@ LightRAG provides flexible dependency groups for different use cases:
### Available Dependency Groups
| Group | Description | Use Case |
|-------|-------------|----------|
| ----- | ----------- | -------- |
| `api` | API server + document processing | FastAPI server with PDF, DOCX, PPTX, XLSX support |
| `offline-storage` | Storage backends | Redis, Neo4j, MongoDB, PostgreSQL, etc. |
| `offline-llm` | LLM providers | OpenAI, Anthropic, Ollama, etc. |
@@ -120,7 +120,7 @@ Tiktoken downloads BPE encoding models on first use. In offline environments, yo
After installing LightRAG, use the built-in command:
```bash
# Download to default location (~/.tiktoken_cache)
# Download to default location (see output for exact path)
lightrag-download-cache
# Download to specific directory

View File

@@ -966,7 +966,9 @@ def create_app(args):
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
)
else:
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
logger.info(
"Embedding max_token_size: None (Embedding token limit is disabled)."
)
# Configure rerank function based on args.rerank_bindingparameter
rerank_model_func = None

View File

@@ -476,6 +476,7 @@ async def gemini_embed(
base_url: str | None = None,
api_key: str | None = None,
embedding_dim: int | None = None,
max_token_size: int | None = None,
task_type: str = "RETRIEVAL_DOCUMENT",
timeout: int | None = None,
token_tracker: Any | None = None,
@@ -497,6 +498,11 @@ async def gemini_embed(
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
or the EMBEDDING_DIM environment variable.
Supported range: 128-3072. Recommended values: 768, 1536, 3072.
max_token_size: Maximum tokens per text. This parameter is automatically
injected by the EmbeddingFunc wrapper when the underlying function
signature supports it (via inspect.signature check). Gemini API will
automatically truncate texts exceeding this limit (autoTruncate=True
by default), so no client-side truncation is needed.
task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
@@ -516,7 +522,11 @@ async def gemini_embed(
- For dimension 3072: Embeddings are already normalized by the API
- For dimensions < 3072: Embeddings are L2-normalized after retrieval
- Normalization ensures accurate semantic similarity via cosine distance
- Gemini API automatically truncates texts exceeding max_token_size (autoTruncate=True)
"""
# Note: max_token_size is received but not used for client-side truncation.
# Gemini API handles truncation automatically with autoTruncate=True (default).
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
loop = asyncio.get_running_loop()
key = _ensure_api_key(api_key)

View File

@@ -176,8 +176,33 @@ async def ollama_model_complete(
embedding_dim=1024, max_token_size=8192, model_name="bge-m3:latest"
)
async def ollama_embed(
texts: list[str], embed_model: str = "bge-m3:latest", **kwargs
texts: list[str],
embed_model: str = "bge-m3:latest",
max_token_size: int | None = None,
**kwargs,
) -> np.ndarray:
"""Generate embeddings using Ollama's API.
Args:
texts: List of texts to embed.
embed_model: The Ollama embedding model to use. Default is "bge-m3:latest".
max_token_size: Maximum tokens per text. This parameter is automatically
injected by the EmbeddingFunc wrapper when the underlying function
signature supports it (via inspect.signature check). Ollama will
automatically truncate texts exceeding the model's context length
(num_ctx), so no client-side truncation is needed.
**kwargs: Additional arguments passed to the Ollama client.
Returns:
A numpy array of embeddings, one per input text.
Note:
- Ollama API automatically truncates texts exceeding the model's context length
- The max_token_size parameter is received but not used for client-side truncation
"""
# Note: max_token_size is received but not used for client-side truncation.
# Ollama API handles truncation automatically based on the model's num_ctx setting.
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
api_key = kwargs.pop("api_key", None)
if not api_key:
api_key = os.getenv("OLLAMA_API_KEY")

View File

@@ -5,6 +5,7 @@ import logging
from collections.abc import AsyncIterator
import pipmaster as pm
import tiktoken
# install specific modules
if not pm.is_installed("openai"):
@@ -74,6 +75,30 @@ class InvalidResponseError(Exception):
pass
# Module-level cache for tiktoken encodings
_TIKTOKEN_ENCODING_CACHE: dict[str, Any] = {}
def _get_tiktoken_encoding_for_model(model: str) -> Any:
"""Get tiktoken encoding for the specified model with caching.
Args:
model: The model name to get encoding for.
Returns:
The tiktoken encoding for the model.
"""
if model not in _TIKTOKEN_ENCODING_CACHE:
try:
_TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug(
f"Encoding for model '{model}' not found, falling back to cl100k_base"
)
_TIKTOKEN_ENCODING_CACHE[model] = tiktoken.get_encoding("cl100k_base")
return _TIKTOKEN_ENCODING_CACHE[model]
def create_openai_async_client(
api_key: str | None = None,
base_url: str | None = None,
@@ -695,15 +720,17 @@ async def openai_embed(
base_url: str | None = None,
api_key: str | None = None,
embedding_dim: int | None = None,
max_token_size: int | None = None,
client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None,
use_azure: bool = False,
azure_deployment: str | None = None,
api_version: str | None = None,
) -> np.ndarray:
"""Generate embeddings for a list of texts using OpenAI's API.
"""Generate embeddings for a list of texts using OpenAI's API with automatic text truncation.
This function supports both standard OpenAI and Azure OpenAI services.
This function supports both standard OpenAI and Azure OpenAI services. It automatically
truncates texts that exceed the model's token limit to prevent API errors.
Args:
texts: List of texts to embed.
@@ -719,6 +746,10 @@ async def openai_embed(
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
Manually passing a different value will trigger a warning and be ignored.
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
max_token_size: Maximum tokens per text. Texts exceeding this limit will be truncated.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
when the underlying function signature supports it (via inspect.signature check).
The value is controlled by the @wrap_embedding_func_with_attrs decorator.
client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url). Supports proxy configuration,
@@ -740,6 +771,35 @@ async def openai_embed(
RateLimitError: If the OpenAI API rate limit is exceeded.
APITimeoutError: If the OpenAI API request times out.
"""
# Apply text truncation if max_token_size is provided
if max_token_size is not None and max_token_size > 0:
encoding = _get_tiktoken_encoding_for_model(model)
truncated_texts = []
truncation_count = 0
for text in texts:
if not text:
truncated_texts.append(text)
continue
tokens = encoding.encode(text)
if len(tokens) > max_token_size:
truncated_tokens = tokens[:max_token_size]
truncated_texts.append(encoding.decode(truncated_tokens))
truncation_count += 1
logger.debug(
f"Text truncated from {len(tokens)} to {max_token_size} tokens"
)
else:
truncated_texts.append(text)
if truncation_count > 0:
logger.info(
f"Truncated {truncation_count}/{len(texts)} texts to fit token limit ({max_token_size})"
)
texts = truncated_texts
# Create the OpenAI client (supports both OpenAI and Azure)
openai_async_client = create_openai_async_client(
api_key=api_key,

View File

@@ -365,12 +365,12 @@ async def _summarize_descriptions(
if embedding_token_limit is not None and summary:
tokenizer = global_config["tokenizer"]
summary_token_count = len(tokenizer.encode(summary))
threshold = int(embedding_token_limit * 0.9)
threshold = int(embedding_token_limit)
if summary_token_count > threshold:
logger.warning(
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
f"({embedding_token_limit}) for {description_type}: {description_name}"
f"Summary tokens({summary_token_count}) exceeds embedding_token_limit({embedding_token_limit}) "
f" for {description_type}: {description_name}"
)
return summary

View File

@@ -10,16 +10,45 @@ import sys
from pathlib import Path
# Known tiktoken encoding names (not model names)
# These need to be loaded with tiktoken.get_encoding() instead of tiktoken.encoding_for_model()
TIKTOKEN_ENCODING_NAMES = {"cl100k_base", "p50k_base", "r50k_base", "o200k_base"}
def download_tiktoken_cache(cache_dir: str = None, models: list = None):
"""Download tiktoken models to local cache
Args:
cache_dir: Directory to store the cache files. If None, uses default location.
models: List of model names to download. If None, downloads common models.
cache_dir: Directory to store the cache files. If None, uses tiktoken's default location.
models: List of model names or encoding names to download. If None, downloads common ones.
Returns:
Tuple of (success_count, failed_models)
Tuple of (success_count, failed_models, actual_cache_dir)
"""
# If user specified a cache directory, set it BEFORE importing tiktoken
# tiktoken reads TIKTOKEN_CACHE_DIR at import time
user_specified_cache = cache_dir is not None
if user_specified_cache:
cache_dir = os.path.abspath(cache_dir)
os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir
cache_path = Path(cache_dir)
cache_path.mkdir(parents=True, exist_ok=True)
print(f"Using specified cache directory: {cache_dir}")
else:
# Check if TIKTOKEN_CACHE_DIR is already set in environment
env_cache_dir = os.environ.get("TIKTOKEN_CACHE_DIR")
if env_cache_dir:
cache_dir = env_cache_dir
print(f"Using TIKTOKEN_CACHE_DIR from environment: {cache_dir}")
else:
# Use tiktoken's default location (tempdir/data-gym-cache)
import tempfile
cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache")
print(f"Using tiktoken default cache directory: {cache_dir}")
# Now import tiktoken (it will use the cache directory we determined)
try:
import tiktoken
except ImportError:
@@ -27,19 +56,6 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
print("Install with: pip install tiktoken")
sys.exit(1)
# Set cache directory if provided
if cache_dir:
cache_dir = os.path.abspath(cache_dir)
os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir
cache_path = Path(cache_dir)
cache_path.mkdir(parents=True, exist_ok=True)
print(f"Using cache directory: {cache_dir}")
else:
cache_dir = os.environ.get(
"TIKTOKEN_CACHE_DIR", str(Path.home() / ".tiktoken_cache")
)
print(f"Using default cache directory: {cache_dir}")
# Common models used by LightRAG and OpenAI
if models is None:
models = [
@@ -50,6 +66,7 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
"text-embedding-ada-002", # Legacy embedding model
"text-embedding-3-small", # Small embedding model
"text-embedding-3-large", # Large embedding model
"cl100k_base", # Default encoding for LightRAG
]
print(f"\nDownloading {len(models)} tiktoken models...")
@@ -61,13 +78,17 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
for i, model in enumerate(models, 1):
try:
print(f"[{i}/{len(models)}] Downloading {model}...", end=" ", flush=True)
encoding = tiktoken.encoding_for_model(model)
# Use get_encoding for encoding names, encoding_for_model for model names
if model in TIKTOKEN_ENCODING_NAMES:
encoding = tiktoken.get_encoding(model)
else:
encoding = tiktoken.encoding_for_model(model)
# Trigger download by encoding a test string
encoding.encode("test")
print("✓ Done")
success_count += 1
except KeyError as e:
print(f"✗ Failed: Unknown model '{model}'")
print(f"✗ Failed: Unknown model or encoding '{model}'")
failed_models.append((model, str(e)))
except Exception as e:
print(f"✗ Failed: {e}")

View File

@@ -6,6 +6,7 @@ import sys
import asyncio
import html
import csv
import inspect
import json
import logging
import logging.handlers
@@ -492,6 +493,12 @@ class EmbeddingFunc:
# Inject embedding_dim from decorator
kwargs["embedding_dim"] = self.embedding_dim
# Check if underlying function supports max_token_size and inject if not provided
if self.max_token_size is not None and "max_token_size" not in kwargs:
sig = inspect.signature(self.func)
if "max_token_size" in sig.parameters:
kwargs["max_token_size"] = self.max_token_size
# Call the actual embedding function
result = await self.func(*args, **kwargs)