Fix tiktoken cache env var and support encoding names
- Set cache env var before import - Support raw encoding names - Add cl100k_base to default list - Improve cache path resolution
This commit is contained in:
@@ -10,16 +10,45 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Known tiktoken encoding names (not model names)
|
||||
# These need to be loaded with tiktoken.get_encoding() instead of tiktoken.encoding_for_model()
|
||||
TIKTOKEN_ENCODING_NAMES = {"cl100k_base", "p50k_base", "r50k_base", "o200k_base"}
|
||||
|
||||
|
||||
def download_tiktoken_cache(cache_dir: str = None, models: list = None):
|
||||
"""Download tiktoken models to local cache
|
||||
|
||||
Args:
|
||||
cache_dir: Directory to store the cache files. If None, uses default location.
|
||||
models: List of model names to download. If None, downloads common models.
|
||||
cache_dir: Directory to store the cache files. If None, uses tiktoken's default location.
|
||||
models: List of model names or encoding names to download. If None, downloads common ones.
|
||||
|
||||
Returns:
|
||||
Tuple of (success_count, failed_models)
|
||||
Tuple of (success_count, failed_models, actual_cache_dir)
|
||||
"""
|
||||
# If user specified a cache directory, set it BEFORE importing tiktoken
|
||||
# tiktoken reads TIKTOKEN_CACHE_DIR at import time
|
||||
user_specified_cache = cache_dir is not None
|
||||
|
||||
if user_specified_cache:
|
||||
cache_dir = os.path.abspath(cache_dir)
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir
|
||||
cache_path = Path(cache_dir)
|
||||
cache_path.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Using specified cache directory: {cache_dir}")
|
||||
else:
|
||||
# Check if TIKTOKEN_CACHE_DIR is already set in environment
|
||||
env_cache_dir = os.environ.get("TIKTOKEN_CACHE_DIR")
|
||||
if env_cache_dir:
|
||||
cache_dir = env_cache_dir
|
||||
print(f"Using TIKTOKEN_CACHE_DIR from environment: {cache_dir}")
|
||||
else:
|
||||
# Use tiktoken's default location (tempdir/data-gym-cache)
|
||||
import tempfile
|
||||
|
||||
cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache")
|
||||
print(f"Using tiktoken default cache directory: {cache_dir}")
|
||||
|
||||
# Now import tiktoken (it will use the cache directory we determined)
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
@@ -27,19 +56,6 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
|
||||
print("Install with: pip install tiktoken")
|
||||
sys.exit(1)
|
||||
|
||||
# Set cache directory if provided
|
||||
if cache_dir:
|
||||
cache_dir = os.path.abspath(cache_dir)
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir
|
||||
cache_path = Path(cache_dir)
|
||||
cache_path.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Using cache directory: {cache_dir}")
|
||||
else:
|
||||
cache_dir = os.environ.get(
|
||||
"TIKTOKEN_CACHE_DIR", str(Path.home() / ".tiktoken_cache")
|
||||
)
|
||||
print(f"Using default cache directory: {cache_dir}")
|
||||
|
||||
# Common models used by LightRAG and OpenAI
|
||||
if models is None:
|
||||
models = [
|
||||
@@ -50,6 +66,7 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
|
||||
"text-embedding-ada-002", # Legacy embedding model
|
||||
"text-embedding-3-small", # Small embedding model
|
||||
"text-embedding-3-large", # Large embedding model
|
||||
"cl100k_base", # Default encoding for LightRAG
|
||||
]
|
||||
|
||||
print(f"\nDownloading {len(models)} tiktoken models...")
|
||||
@@ -61,13 +78,17 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
|
||||
for i, model in enumerate(models, 1):
|
||||
try:
|
||||
print(f"[{i}/{len(models)}] Downloading {model}...", end=" ", flush=True)
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
# Use get_encoding for encoding names, encoding_for_model for model names
|
||||
if model in TIKTOKEN_ENCODING_NAMES:
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
else:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
# Trigger download by encoding a test string
|
||||
encoding.encode("test")
|
||||
print("✓ Done")
|
||||
success_count += 1
|
||||
except KeyError as e:
|
||||
print(f"✗ Failed: Unknown model '{model}'")
|
||||
print(f"✗ Failed: Unknown model or encoding '{model}'")
|
||||
failed_models.append((model, str(e)))
|
||||
except Exception as e:
|
||||
print(f"✗ Failed: {e}")
|
||||
|
||||
Reference in New Issue
Block a user