feat: inject max_token_size and add client-side truncation for OpenAI

- Auto-inject max_token_size in wrapper
- Implement OpenAI client-side truncation
- Update Gemini/Ollama embed signatures
- Relax summary token warning threshold
- Update server startup logging
This commit is contained in:
yangdx
2025-12-22 19:33:43 +08:00
parent b31b910e99
commit 2678005448
6 changed files with 111 additions and 7 deletions

View File

@@ -966,7 +966,9 @@ def create_app(args):
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
)
else:
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
logger.info(
"Embedding max_token_size: None (Embedding token limit is disabled."
)
# Configure rerank function based on args.rerank_bindingparameter
rerank_model_func = None

View File

@@ -476,6 +476,7 @@ async def gemini_embed(
base_url: str | None = None,
api_key: str | None = None,
embedding_dim: int | None = None,
max_token_size: int | None = None,
task_type: str = "RETRIEVAL_DOCUMENT",
timeout: int | None = None,
token_tracker: Any | None = None,
@@ -497,6 +498,11 @@ async def gemini_embed(
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
or the EMBEDDING_DIM environment variable.
Supported range: 128-3072. Recommended values: 768, 1536, 3072.
max_token_size: Maximum tokens per text. This parameter is automatically
injected by the EmbeddingFunc wrapper when the underlying function
signature supports it (via inspect.signature check). Gemini API will
automatically truncate texts exceeding this limit (autoTruncate=True
by default), so no client-side truncation is needed.
task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
@@ -516,7 +522,11 @@ async def gemini_embed(
- For dimension 3072: Embeddings are already normalized by the API
- For dimensions < 3072: Embeddings are L2-normalized after retrieval
- Normalization ensures accurate semantic similarity via cosine distance
- Gemini API automatically truncates texts exceeding max_token_size (autoTruncate=True)
"""
# Note: max_token_size is received but not used for client-side truncation.
# Gemini API handles truncation automatically with autoTruncate=True (default).
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
loop = asyncio.get_running_loop()
key = _ensure_api_key(api_key)

View File

@@ -176,8 +176,33 @@ async def ollama_model_complete(
embedding_dim=1024, max_token_size=8192, model_name="bge-m3:latest"
)
async def ollama_embed(
texts: list[str], embed_model: str = "bge-m3:latest", **kwargs
texts: list[str],
embed_model: str = "bge-m3:latest",
max_token_size: int | None = None,
**kwargs,
) -> np.ndarray:
"""Generate embeddings using Ollama's API.
Args:
texts: List of texts to embed.
embed_model: The Ollama embedding model to use. Default is "bge-m3:latest".
max_token_size: Maximum tokens per text. This parameter is automatically
injected by the EmbeddingFunc wrapper when the underlying function
signature supports it (via inspect.signature check). Ollama will
automatically truncate texts exceeding the model's context length
(num_ctx), so no client-side truncation is needed.
**kwargs: Additional arguments passed to the Ollama client.
Returns:
A numpy array of embeddings, one per input text.
Note:
- Ollama API automatically truncates texts exceeding the model's context length
- The max_token_size parameter is received but not used for client-side truncation
"""
# Note: max_token_size is received but not used for client-side truncation.
# Ollama API handles truncation automatically based on the model's num_ctx setting.
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
api_key = kwargs.pop("api_key", None)
if not api_key:
api_key = os.getenv("OLLAMA_API_KEY")

View File

@@ -5,6 +5,7 @@ import logging
from collections.abc import AsyncIterator
import pipmaster as pm
import tiktoken
# install specific modules
if not pm.is_installed("openai"):
@@ -74,6 +75,30 @@ class InvalidResponseError(Exception):
pass
# Module-level cache for tiktoken encodings
_TIKTOKEN_ENCODING_CACHE: dict[str, Any] = {}
def _get_tiktoken_encoding_for_model(model: str) -> Any:
"""Get tiktoken encoding for the specified model with caching.
Args:
model: The model name to get encoding for.
Returns:
The tiktoken encoding for the model.
"""
if model not in _TIKTOKEN_ENCODING_CACHE:
try:
_TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug(
f"Encoding for model '{model}' not found, falling back to cl100k_base"
)
_TIKTOKEN_ENCODING_CACHE[model] = tiktoken.get_encoding("cl100k_base")
return _TIKTOKEN_ENCODING_CACHE[model]
def create_openai_async_client(
api_key: str | None = None,
base_url: str | None = None,
@@ -695,15 +720,17 @@ async def openai_embed(
base_url: str | None = None,
api_key: str | None = None,
embedding_dim: int | None = None,
max_token_size: int | None = None,
client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None,
use_azure: bool = False,
azure_deployment: str | None = None,
api_version: str | None = None,
) -> np.ndarray:
"""Generate embeddings for a list of texts using OpenAI's API.
"""Generate embeddings for a list of texts using OpenAI's API with automatic text truncation.
This function supports both standard OpenAI and Azure OpenAI services.
This function supports both standard OpenAI and Azure OpenAI services. It automatically
truncates texts that exceed the model's token limit to prevent API errors.
Args:
texts: List of texts to embed.
@@ -719,6 +746,10 @@ async def openai_embed(
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
Manually passing a different value will trigger a warning and be ignored.
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
max_token_size: Maximum tokens per text. Texts exceeding this limit will be truncated.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
when the underlying function signature supports it (via inspect.signature check).
The value is controlled by the @wrap_embedding_func_with_attrs decorator.
client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url). Supports proxy configuration,
@@ -740,6 +771,35 @@ async def openai_embed(
RateLimitError: If the OpenAI API rate limit is exceeded.
APITimeoutError: If the OpenAI API request times out.
"""
# Apply text truncation if max_token_size is provided
if max_token_size is not None and max_token_size > 0:
encoding = _get_tiktoken_encoding_for_model(model)
truncated_texts = []
truncation_count = 0
for text in texts:
if not text:
truncated_texts.append(text)
continue
tokens = encoding.encode(text)
if len(tokens) > max_token_size:
truncated_tokens = tokens[:max_token_size]
truncated_texts.append(encoding.decode(truncated_tokens))
truncation_count += 1
logger.debug(
f"Text truncated from {len(tokens)} to {max_token_size} tokens"
)
else:
truncated_texts.append(text)
if truncation_count > 0:
logger.info(
f"Truncated {truncation_count}/{len(texts)} texts to fit token limit ({max_token_size})"
)
texts = truncated_texts
# Create the OpenAI client (supports both OpenAI and Azure)
openai_async_client = create_openai_async_client(
api_key=api_key,

View File

@@ -365,12 +365,12 @@ async def _summarize_descriptions(
if embedding_token_limit is not None and summary:
tokenizer = global_config["tokenizer"]
summary_token_count = len(tokenizer.encode(summary))
threshold = int(embedding_token_limit * 0.9)
threshold = int(embedding_token_limit)
if summary_token_count > threshold:
logger.warning(
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
f"({embedding_token_limit}) for {description_type}: {description_name}"
f"Summary tokens({summary_token_count}) exceeds embedding_token_limit({embedding_token_limit}) "
f" for {description_type}: {description_name}"
)
return summary

View File

@@ -6,6 +6,7 @@ import sys
import asyncio
import html
import csv
import inspect
import json
import logging
import logging.handlers
@@ -492,6 +493,12 @@ class EmbeddingFunc:
# Inject embedding_dim from decorator
kwargs["embedding_dim"] = self.embedding_dim
# Check if underlying function supports max_token_size and inject if not provided
if self.max_token_size is not None and "max_token_size" not in kwargs:
sig = inspect.signature(self.func)
if "max_token_size" in sig.parameters:
kwargs["max_token_size"] = self.max_token_size
# Call the actual embedding function
result = await self.func(*args, **kwargs)