Fix tiktoken cache env var and support encoding names

- Set cache env var before import
- Support raw encoding names
- Add cl100k_base to default list
- Improve cache path resolution
This commit is contained in:
yangdx
2025-12-22 20:06:22 +08:00
parent 2678005448
commit 9c9dfcd488

View File

@@ -10,16 +10,45 @@ import sys
from pathlib import Path
# Known tiktoken encoding names (not model names)
# These need to be loaded with tiktoken.get_encoding() instead of tiktoken.encoding_for_model()
TIKTOKEN_ENCODING_NAMES = {"cl100k_base", "p50k_base", "r50k_base", "o200k_base"}
def download_tiktoken_cache(cache_dir: str = None, models: list = None):
"""Download tiktoken models to local cache
Args:
cache_dir: Directory to store the cache files. If None, uses default location.
models: List of model names to download. If None, downloads common models.
cache_dir: Directory to store the cache files. If None, uses tiktoken's default location.
models: List of model names or encoding names to download. If None, downloads common ones.
Returns:
Tuple of (success_count, failed_models)
Tuple of (success_count, failed_models, actual_cache_dir)
"""
# If user specified a cache directory, set it BEFORE importing tiktoken
# tiktoken reads TIKTOKEN_CACHE_DIR at import time
user_specified_cache = cache_dir is not None
if user_specified_cache:
cache_dir = os.path.abspath(cache_dir)
os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir
cache_path = Path(cache_dir)
cache_path.mkdir(parents=True, exist_ok=True)
print(f"Using specified cache directory: {cache_dir}")
else:
# Check if TIKTOKEN_CACHE_DIR is already set in environment
env_cache_dir = os.environ.get("TIKTOKEN_CACHE_DIR")
if env_cache_dir:
cache_dir = env_cache_dir
print(f"Using TIKTOKEN_CACHE_DIR from environment: {cache_dir}")
else:
# Use tiktoken's default location (tempdir/data-gym-cache)
import tempfile
cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache")
print(f"Using tiktoken default cache directory: {cache_dir}")
# Now import tiktoken (it will use the cache directory we determined)
try:
import tiktoken
except ImportError:
@@ -27,19 +56,6 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
print("Install with: pip install tiktoken")
sys.exit(1)
# Set cache directory if provided
if cache_dir:
cache_dir = os.path.abspath(cache_dir)
os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir
cache_path = Path(cache_dir)
cache_path.mkdir(parents=True, exist_ok=True)
print(f"Using cache directory: {cache_dir}")
else:
cache_dir = os.environ.get(
"TIKTOKEN_CACHE_DIR", str(Path.home() / ".tiktoken_cache")
)
print(f"Using default cache directory: {cache_dir}")
# Common models used by LightRAG and OpenAI
if models is None:
models = [
@@ -50,6 +66,7 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
"text-embedding-ada-002", # Legacy embedding model
"text-embedding-3-small", # Small embedding model
"text-embedding-3-large", # Large embedding model
"cl100k_base", # Default encoding for LightRAG
]
print(f"\nDownloading {len(models)} tiktoken models...")
@@ -61,13 +78,17 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
for i, model in enumerate(models, 1):
try:
print(f"[{i}/{len(models)}] Downloading {model}...", end=" ", flush=True)
encoding = tiktoken.encoding_for_model(model)
# Use get_encoding for encoding names, encoding_for_model for model names
if model in TIKTOKEN_ENCODING_NAMES:
encoding = tiktoken.get_encoding(model)
else:
encoding = tiktoken.encoding_for_model(model)
# Trigger download by encoding a test string
encoding.encode("test")
print("✓ Done")
success_count += 1
except KeyError as e:
print(f"✗ Failed: Unknown model '{model}'")
print(f"✗ Failed: Unknown model or encoding '{model}'")
failed_models.append((model, str(e)))
except Exception as e:
print(f"✗ Failed: {e}")