From 9c9dfcd48865cd9c3809b479f417a117d6fec979 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 22 Dec 2025 20:06:22 +0800 Subject: [PATCH] Fix tiktoken cache env var and support encoding names - Set cache env var before import - Support raw encoding names - Add cl100k_base to default list - Improve cache path resolution --- lightrag/tools/download_cache.py | 57 ++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/lightrag/tools/download_cache.py b/lightrag/tools/download_cache.py index 43a21f0b..c16473c8 100644 --- a/lightrag/tools/download_cache.py +++ b/lightrag/tools/download_cache.py @@ -10,16 +10,45 @@ import sys from pathlib import Path +# Known tiktoken encoding names (not model names) +# These need to be loaded with tiktoken.get_encoding() instead of tiktoken.encoding_for_model() +TIKTOKEN_ENCODING_NAMES = {"cl100k_base", "p50k_base", "r50k_base", "o200k_base"} + + def download_tiktoken_cache(cache_dir: str = None, models: list = None): """Download tiktoken models to local cache Args: - cache_dir: Directory to store the cache files. If None, uses default location. - models: List of model names to download. If None, downloads common models. + cache_dir: Directory to store the cache files. If None, uses tiktoken's default location. + models: List of model names or encoding names to download. If None, downloads common ones. Returns: - Tuple of (success_count, failed_models) + Tuple of (success_count, failed_models, actual_cache_dir) """ + # If user specified a cache directory, set it BEFORE importing tiktoken + # tiktoken reads TIKTOKEN_CACHE_DIR at import time + user_specified_cache = cache_dir is not None + + if user_specified_cache: + cache_dir = os.path.abspath(cache_dir) + os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir + cache_path = Path(cache_dir) + cache_path.mkdir(parents=True, exist_ok=True) + print(f"Using specified cache directory: {cache_dir}") + else: + # Check if TIKTOKEN_CACHE_DIR is already set in environment + env_cache_dir = os.environ.get("TIKTOKEN_CACHE_DIR") + if env_cache_dir: + cache_dir = env_cache_dir + print(f"Using TIKTOKEN_CACHE_DIR from environment: {cache_dir}") + else: + # Use tiktoken's default location (tempdir/data-gym-cache) + import tempfile + + cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache") + print(f"Using tiktoken default cache directory: {cache_dir}") + + # Now import tiktoken (it will use the cache directory we determined) try: import tiktoken except ImportError: @@ -27,19 +56,6 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None): print("Install with: pip install tiktoken") sys.exit(1) - # Set cache directory if provided - if cache_dir: - cache_dir = os.path.abspath(cache_dir) - os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir - cache_path = Path(cache_dir) - cache_path.mkdir(parents=True, exist_ok=True) - print(f"Using cache directory: {cache_dir}") - else: - cache_dir = os.environ.get( - "TIKTOKEN_CACHE_DIR", str(Path.home() / ".tiktoken_cache") - ) - print(f"Using default cache directory: {cache_dir}") - # Common models used by LightRAG and OpenAI if models is None: models = [ @@ -50,6 +66,7 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None): "text-embedding-ada-002", # Legacy embedding model "text-embedding-3-small", # Small embedding model "text-embedding-3-large", # Large embedding model + "cl100k_base", # Default encoding for LightRAG ] print(f"\nDownloading {len(models)} tiktoken models...") @@ -61,13 +78,17 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None): for i, model in enumerate(models, 1): try: print(f"[{i}/{len(models)}] Downloading {model}...", end=" ", flush=True) - encoding = tiktoken.encoding_for_model(model) + # Use get_encoding for encoding names, encoding_for_model for model names + if model in TIKTOKEN_ENCODING_NAMES: + encoding = tiktoken.get_encoding(model) + else: + encoding = tiktoken.encoding_for_model(model) # Trigger download by encoding a test string encoding.encode("test") print("✓ Done") success_count += 1 except KeyError as e: - print(f"✗ Failed: Unknown model '{model}'") + print(f"✗ Failed: Unknown model or encoding '{model}'") failed_models.append((model, str(e))) except Exception as e: print(f"✗ Failed: {e}")