Merge pull request #2523 from danielaskdd/embedding-max-token
feat: Add Automatic Text Truncation Support for Embedding Functions
This commit is contained in:
@@ -75,7 +75,7 @@ LightRAG provides flexible dependency groups for different use cases:
|
||||
### Available Dependency Groups
|
||||
|
||||
| Group | Description | Use Case |
|
||||
|-------|-------------|----------|
|
||||
| ----- | ----------- | -------- |
|
||||
| `api` | API server + document processing | FastAPI server with PDF, DOCX, PPTX, XLSX support |
|
||||
| `offline-storage` | Storage backends | Redis, Neo4j, MongoDB, PostgreSQL, etc. |
|
||||
| `offline-llm` | LLM providers | OpenAI, Anthropic, Ollama, etc. |
|
||||
@@ -120,7 +120,7 @@ Tiktoken downloads BPE encoding models on first use. In offline environments, yo
|
||||
After installing LightRAG, use the built-in command:
|
||||
|
||||
```bash
|
||||
# Download to default location (~/.tiktoken_cache)
|
||||
# Download to default location (see output for exact path)
|
||||
lightrag-download-cache
|
||||
|
||||
# Download to specific directory
|
||||
|
||||
@@ -966,7 +966,9 @@ def create_app(args):
|
||||
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
|
||||
)
|
||||
else:
|
||||
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
|
||||
logger.info(
|
||||
"Embedding max_token_size: None (Embedding token limit is disabled)."
|
||||
)
|
||||
|
||||
# Configure rerank function based on args.rerank_bindingparameter
|
||||
rerank_model_func = None
|
||||
|
||||
@@ -476,6 +476,7 @@ async def gemini_embed(
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
embedding_dim: int | None = None,
|
||||
max_token_size: int | None = None,
|
||||
task_type: str = "RETRIEVAL_DOCUMENT",
|
||||
timeout: int | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
@@ -497,6 +498,11 @@ async def gemini_embed(
|
||||
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
|
||||
or the EMBEDDING_DIM environment variable.
|
||||
Supported range: 128-3072. Recommended values: 768, 1536, 3072.
|
||||
max_token_size: Maximum tokens per text. This parameter is automatically
|
||||
injected by the EmbeddingFunc wrapper when the underlying function
|
||||
signature supports it (via inspect.signature check). Gemini API will
|
||||
automatically truncate texts exceeding this limit (autoTruncate=True
|
||||
by default), so no client-side truncation is needed.
|
||||
task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
|
||||
Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
|
||||
RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
|
||||
@@ -516,7 +522,11 @@ async def gemini_embed(
|
||||
- For dimension 3072: Embeddings are already normalized by the API
|
||||
- For dimensions < 3072: Embeddings are L2-normalized after retrieval
|
||||
- Normalization ensures accurate semantic similarity via cosine distance
|
||||
- Gemini API automatically truncates texts exceeding max_token_size (autoTruncate=True)
|
||||
"""
|
||||
# Note: max_token_size is received but not used for client-side truncation.
|
||||
# Gemini API handles truncation automatically with autoTruncate=True (default).
|
||||
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
key = _ensure_api_key(api_key)
|
||||
|
||||
@@ -176,8 +176,33 @@ async def ollama_model_complete(
|
||||
embedding_dim=1024, max_token_size=8192, model_name="bge-m3:latest"
|
||||
)
|
||||
async def ollama_embed(
|
||||
texts: list[str], embed_model: str = "bge-m3:latest", **kwargs
|
||||
texts: list[str],
|
||||
embed_model: str = "bge-m3:latest",
|
||||
max_token_size: int | None = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Generate embeddings using Ollama's API.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed.
|
||||
embed_model: The Ollama embedding model to use. Default is "bge-m3:latest".
|
||||
max_token_size: Maximum tokens per text. This parameter is automatically
|
||||
injected by the EmbeddingFunc wrapper when the underlying function
|
||||
signature supports it (via inspect.signature check). Ollama will
|
||||
automatically truncate texts exceeding the model's context length
|
||||
(num_ctx), so no client-side truncation is needed.
|
||||
**kwargs: Additional arguments passed to the Ollama client.
|
||||
|
||||
Returns:
|
||||
A numpy array of embeddings, one per input text.
|
||||
|
||||
Note:
|
||||
- Ollama API automatically truncates texts exceeding the model's context length
|
||||
- The max_token_size parameter is received but not used for client-side truncation
|
||||
"""
|
||||
# Note: max_token_size is received but not used for client-side truncation.
|
||||
# Ollama API handles truncation automatically based on the model's num_ctx setting.
|
||||
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
if not api_key:
|
||||
api_key = os.getenv("OLLAMA_API_KEY")
|
||||
|
||||
@@ -5,6 +5,7 @@ import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import pipmaster as pm
|
||||
import tiktoken
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("openai"):
|
||||
@@ -74,6 +75,30 @@ class InvalidResponseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Module-level cache for tiktoken encodings
|
||||
_TIKTOKEN_ENCODING_CACHE: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _get_tiktoken_encoding_for_model(model: str) -> Any:
|
||||
"""Get tiktoken encoding for the specified model with caching.
|
||||
|
||||
Args:
|
||||
model: The model name to get encoding for.
|
||||
|
||||
Returns:
|
||||
The tiktoken encoding for the model.
|
||||
"""
|
||||
if model not in _TIKTOKEN_ENCODING_CACHE:
|
||||
try:
|
||||
_TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
logger.debug(
|
||||
f"Encoding for model '{model}' not found, falling back to cl100k_base"
|
||||
)
|
||||
_TIKTOKEN_ENCODING_CACHE[model] = tiktoken.get_encoding("cl100k_base")
|
||||
return _TIKTOKEN_ENCODING_CACHE[model]
|
||||
|
||||
|
||||
def create_openai_async_client(
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
@@ -695,15 +720,17 @@ async def openai_embed(
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
embedding_dim: int | None = None,
|
||||
max_token_size: int | None = None,
|
||||
client_configs: dict[str, Any] | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
use_azure: bool = False,
|
||||
azure_deployment: str | None = None,
|
||||
api_version: str | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Generate embeddings for a list of texts using OpenAI's API.
|
||||
"""Generate embeddings for a list of texts using OpenAI's API with automatic text truncation.
|
||||
|
||||
This function supports both standard OpenAI and Azure OpenAI services.
|
||||
This function supports both standard OpenAI and Azure OpenAI services. It automatically
|
||||
truncates texts that exceed the model's token limit to prevent API errors.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed.
|
||||
@@ -719,6 +746,10 @@ async def openai_embed(
|
||||
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
|
||||
Manually passing a different value will trigger a warning and be ignored.
|
||||
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
|
||||
max_token_size: Maximum tokens per text. Texts exceeding this limit will be truncated.
|
||||
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
|
||||
when the underlying function signature supports it (via inspect.signature check).
|
||||
The value is controlled by the @wrap_embedding_func_with_attrs decorator.
|
||||
client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client.
|
||||
These will override any default configurations but will be overridden by
|
||||
explicit parameters (api_key, base_url). Supports proxy configuration,
|
||||
@@ -740,6 +771,35 @@ async def openai_embed(
|
||||
RateLimitError: If the OpenAI API rate limit is exceeded.
|
||||
APITimeoutError: If the OpenAI API request times out.
|
||||
"""
|
||||
# Apply text truncation if max_token_size is provided
|
||||
if max_token_size is not None and max_token_size > 0:
|
||||
encoding = _get_tiktoken_encoding_for_model(model)
|
||||
truncated_texts = []
|
||||
truncation_count = 0
|
||||
|
||||
for text in texts:
|
||||
if not text:
|
||||
truncated_texts.append(text)
|
||||
continue
|
||||
|
||||
tokens = encoding.encode(text)
|
||||
if len(tokens) > max_token_size:
|
||||
truncated_tokens = tokens[:max_token_size]
|
||||
truncated_texts.append(encoding.decode(truncated_tokens))
|
||||
truncation_count += 1
|
||||
logger.debug(
|
||||
f"Text truncated from {len(tokens)} to {max_token_size} tokens"
|
||||
)
|
||||
else:
|
||||
truncated_texts.append(text)
|
||||
|
||||
if truncation_count > 0:
|
||||
logger.info(
|
||||
f"Truncated {truncation_count}/{len(texts)} texts to fit token limit ({max_token_size})"
|
||||
)
|
||||
|
||||
texts = truncated_texts
|
||||
|
||||
# Create the OpenAI client (supports both OpenAI and Azure)
|
||||
openai_async_client = create_openai_async_client(
|
||||
api_key=api_key,
|
||||
|
||||
@@ -365,12 +365,12 @@ async def _summarize_descriptions(
|
||||
if embedding_token_limit is not None and summary:
|
||||
tokenizer = global_config["tokenizer"]
|
||||
summary_token_count = len(tokenizer.encode(summary))
|
||||
threshold = int(embedding_token_limit * 0.9)
|
||||
threshold = int(embedding_token_limit)
|
||||
|
||||
if summary_token_count > threshold:
|
||||
logger.warning(
|
||||
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
|
||||
f"({embedding_token_limit}) for {description_type}: {description_name}"
|
||||
f"Summary tokens({summary_token_count}) exceeds embedding_token_limit({embedding_token_limit}) "
|
||||
f" for {description_type}: {description_name}"
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
@@ -10,16 +10,45 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Known tiktoken encoding names (not model names)
|
||||
# These need to be loaded with tiktoken.get_encoding() instead of tiktoken.encoding_for_model()
|
||||
TIKTOKEN_ENCODING_NAMES = {"cl100k_base", "p50k_base", "r50k_base", "o200k_base"}
|
||||
|
||||
|
||||
def download_tiktoken_cache(cache_dir: str = None, models: list = None):
|
||||
"""Download tiktoken models to local cache
|
||||
|
||||
Args:
|
||||
cache_dir: Directory to store the cache files. If None, uses default location.
|
||||
models: List of model names to download. If None, downloads common models.
|
||||
cache_dir: Directory to store the cache files. If None, uses tiktoken's default location.
|
||||
models: List of model names or encoding names to download. If None, downloads common ones.
|
||||
|
||||
Returns:
|
||||
Tuple of (success_count, failed_models)
|
||||
Tuple of (success_count, failed_models, actual_cache_dir)
|
||||
"""
|
||||
# If user specified a cache directory, set it BEFORE importing tiktoken
|
||||
# tiktoken reads TIKTOKEN_CACHE_DIR at import time
|
||||
user_specified_cache = cache_dir is not None
|
||||
|
||||
if user_specified_cache:
|
||||
cache_dir = os.path.abspath(cache_dir)
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir
|
||||
cache_path = Path(cache_dir)
|
||||
cache_path.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Using specified cache directory: {cache_dir}")
|
||||
else:
|
||||
# Check if TIKTOKEN_CACHE_DIR is already set in environment
|
||||
env_cache_dir = os.environ.get("TIKTOKEN_CACHE_DIR")
|
||||
if env_cache_dir:
|
||||
cache_dir = env_cache_dir
|
||||
print(f"Using TIKTOKEN_CACHE_DIR from environment: {cache_dir}")
|
||||
else:
|
||||
# Use tiktoken's default location (tempdir/data-gym-cache)
|
||||
import tempfile
|
||||
|
||||
cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache")
|
||||
print(f"Using tiktoken default cache directory: {cache_dir}")
|
||||
|
||||
# Now import tiktoken (it will use the cache directory we determined)
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
@@ -27,19 +56,6 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
|
||||
print("Install with: pip install tiktoken")
|
||||
sys.exit(1)
|
||||
|
||||
# Set cache directory if provided
|
||||
if cache_dir:
|
||||
cache_dir = os.path.abspath(cache_dir)
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir
|
||||
cache_path = Path(cache_dir)
|
||||
cache_path.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Using cache directory: {cache_dir}")
|
||||
else:
|
||||
cache_dir = os.environ.get(
|
||||
"TIKTOKEN_CACHE_DIR", str(Path.home() / ".tiktoken_cache")
|
||||
)
|
||||
print(f"Using default cache directory: {cache_dir}")
|
||||
|
||||
# Common models used by LightRAG and OpenAI
|
||||
if models is None:
|
||||
models = [
|
||||
@@ -50,6 +66,7 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
|
||||
"text-embedding-ada-002", # Legacy embedding model
|
||||
"text-embedding-3-small", # Small embedding model
|
||||
"text-embedding-3-large", # Large embedding model
|
||||
"cl100k_base", # Default encoding for LightRAG
|
||||
]
|
||||
|
||||
print(f"\nDownloading {len(models)} tiktoken models...")
|
||||
@@ -61,13 +78,17 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
|
||||
for i, model in enumerate(models, 1):
|
||||
try:
|
||||
print(f"[{i}/{len(models)}] Downloading {model}...", end=" ", flush=True)
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
# Use get_encoding for encoding names, encoding_for_model for model names
|
||||
if model in TIKTOKEN_ENCODING_NAMES:
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
else:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
# Trigger download by encoding a test string
|
||||
encoding.encode("test")
|
||||
print("✓ Done")
|
||||
success_count += 1
|
||||
except KeyError as e:
|
||||
print(f"✗ Failed: Unknown model '{model}'")
|
||||
print(f"✗ Failed: Unknown model or encoding '{model}'")
|
||||
failed_models.append((model, str(e)))
|
||||
except Exception as e:
|
||||
print(f"✗ Failed: {e}")
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import asyncio
|
||||
import html
|
||||
import csv
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import logging.handlers
|
||||
@@ -492,6 +493,12 @@ class EmbeddingFunc:
|
||||
# Inject embedding_dim from decorator
|
||||
kwargs["embedding_dim"] = self.embedding_dim
|
||||
|
||||
# Check if underlying function supports max_token_size and inject if not provided
|
||||
if self.max_token_size is not None and "max_token_size" not in kwargs:
|
||||
sig = inspect.signature(self.func)
|
||||
if "max_token_size" in sig.parameters:
|
||||
kwargs["max_token_size"] = self.max_token_size
|
||||
|
||||
# Call the actual embedding function
|
||||
result = await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user