From 0b3d31507e798e9b27aa30443b09ffc2c262275d Mon Sep 17 00:00:00 2001 From: Humphry Date: Mon, 20 Oct 2025 13:17:16 +0300 Subject: [PATCH] extended to use gemini, sswitched to use gemini-flash-latest --- README.md | 2 + env.example | 10 +- lightrag/api/config.py | 15 ++ lightrag/api/lightrag_server.py | 46 +++++ lightrag/llm/binding_options.py | 57 +++++- lightrag/llm/gemini.py | 297 ++++++++++++++++++++++++++++++++ lightrag/utils.py | 2 +- pyproject.toml | 3 + requirements-offline-llm.txt | 1 + requirements-offline.txt | 1 + 10 files changed, 429 insertions(+), 5 deletions(-) create mode 100644 lightrag/llm/gemini.py diff --git a/README.md b/README.md index 9ba35c4b..fa79b53b 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,8 @@ cp env.example .env docker compose up ``` +> Tip: When targeting Google Gemini, set `LLM_BINDING=gemini`, choose a model such as `LLM_MODEL=gemini-flash-latest`, and provide your Gemini key via `LLM_BINDING_API_KEY` (or `GEMINI_API_KEY`). The server now understands this binding out of the box. + > Historical versions of LightRAG docker images can be found here: [LightRAG Docker Images]( https://github.com/HKUDS/LightRAG/pkgs/container/lightrag) ### Install LightRAG Core diff --git a/env.example b/env.example index a7ce53a4..2c7faded 100644 --- a/env.example +++ b/env.example @@ -154,7 +154,7 @@ MAX_PARALLEL_INSERT=2 ########################################################### ### LLM Configuration -### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock +### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock, gemini ########################################################### ### LLM request timeout setting for all llm (0 means no timeout for Ollma) # LLM_TIMEOUT=180 @@ -174,6 +174,14 @@ LLM_BINDING_API_KEY=your_api_key # LLM_BINDING_API_KEY=your_api_key # LLM_BINDING=openai +### Gemini example +# LLM_BINDING=gemini +# LLM_MODEL=gemini-flash-latest +# LLM_BINDING_HOST=https://generativelanguage.googleapis.com +# LLM_BINDING_API_KEY=your_gemini_api_key +# GEMINI_LLM_MAX_OUTPUT_TOKENS=8192 +# GEMINI_LLM_TEMPERATURE=0.7 + ### OpenAI Compatible API Specific Parameters ### Increased temperature values may mitigate infinite inference loops in certain LLM, such as Qwen3-30B. # OPENAI_LLM_TEMPERATURE=0.9 diff --git a/lightrag/api/config.py b/lightrag/api/config.py index de569f47..ff5e65b1 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -8,6 +8,7 @@ import logging from dotenv import load_dotenv from lightrag.utils import get_env_value from lightrag.llm.binding_options import ( + GeminiLLMOptions, OllamaEmbeddingOptions, OllamaLLMOptions, OpenAILLMOptions, @@ -63,6 +64,9 @@ def get_default_host(binding_type: str) -> str: "lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"), "azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"), "openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"), + "gemini": os.getenv( + "LLM_BINDING_HOST", "https://generativelanguage.googleapis.com" + ), } return default_hosts.get( binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434") @@ -226,6 +230,7 @@ def parse_args() -> argparse.Namespace: "openai-ollama", "azure_openai", "aws_bedrock", + "gemini", ], help="LLM binding type (default: from env or ollama)", ) @@ -281,6 +286,16 @@ def parse_args() -> argparse.Namespace: elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]: OpenAILLMOptions.add_args(parser) + if "--llm-binding" in sys.argv: + try: + idx = sys.argv.index("--llm-binding") + if idx + 1 < len(sys.argv) and sys.argv[idx + 1] == "gemini": + GeminiLLMOptions.add_args(parser) + except IndexError: + pass + elif os.environ.get("LLM_BINDING") == "gemini": + GeminiLLMOptions.add_args(parser) + args = parser.parse_args() # convert relative path to absolute path diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 19bc549a..89feca32 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -104,6 +104,7 @@ class LLMConfigCache: # Initialize configurations based on binding conditions self.openai_llm_options = None + self.gemini_llm_options = None self.ollama_llm_options = None self.ollama_embedding_options = None @@ -114,6 +115,12 @@ class LLMConfigCache: self.openai_llm_options = OpenAILLMOptions.options_dict(args) logger.info(f"OpenAI LLM Options: {self.openai_llm_options}") + if args.llm_binding == "gemini": + from lightrag.llm.binding_options import GeminiLLMOptions + + self.gemini_llm_options = GeminiLLMOptions.options_dict(args) + logger.info(f"Gemini LLM Options: {self.gemini_llm_options}") + # Only initialize and log Ollama LLM options when using Ollama LLM binding if args.llm_binding == "ollama": try: @@ -282,6 +289,7 @@ def create_app(args): "openai", "azure_openai", "aws_bedrock", + "gemini", ]: raise Exception("llm binding not supported") @@ -500,6 +508,42 @@ def create_app(args): return optimized_azure_openai_model_complete + def create_optimized_gemini_llm_func( + config_cache: LLMConfigCache, args + ): + """Create optimized Gemini LLM function with cached configuration""" + + async def optimized_gemini_model_complete( + prompt, + system_prompt=None, + history_messages=None, + keyword_extraction=False, + **kwargs, + ) -> str: + from lightrag.llm.gemini import gemini_complete_if_cache + + if history_messages is None: + history_messages = [] + + if ( + config_cache.gemini_llm_options is not None + and "generation_config" not in kwargs + ): + kwargs["generation_config"] = dict(config_cache.gemini_llm_options) + + return await gemini_complete_if_cache( + args.llm_model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=args.llm_binding_api_key, + base_url=args.llm_binding_host, + keyword_extraction=keyword_extraction, + **kwargs, + ) + + return optimized_gemini_model_complete + def create_llm_model_func(binding: str): """ Create LLM model function based on binding type. @@ -521,6 +565,8 @@ def create_app(args): return create_optimized_azure_openai_llm_func( config_cache, args, llm_timeout ) + elif binding == "gemini": + return create_optimized_gemini_llm_func(config_cache, args) else: # openai and compatible # Use optimized function with pre-processed configuration return create_optimized_openai_llm_func(config_cache, args, llm_timeout) diff --git a/lightrag/llm/binding_options.py b/lightrag/llm/binding_options.py index f17ba0f8..e2f94649 100644 --- a/lightrag/llm/binding_options.py +++ b/lightrag/llm/binding_options.py @@ -9,12 +9,26 @@ from argparse import ArgumentParser, Namespace import argparse import json from dataclasses import asdict, dataclass, field -from typing import Any, ClassVar, List +from typing import Any, ClassVar, List, get_args, get_origin from lightrag.utils import get_env_value from lightrag.constants import DEFAULT_TEMPERATURE +def _resolve_optional_type(field_type: Any) -> Any: + """Return the concrete type for Optional/Union annotations.""" + origin = get_origin(field_type) + if origin in (list, dict, tuple): + return field_type + + args = get_args(field_type) + if args: + non_none_args = [arg for arg in args if arg is not type(None)] + if len(non_none_args) == 1: + return non_none_args[0] + return field_type + + # ============================================================================= # BindingOptions Base Class # ============================================================================= @@ -177,9 +191,13 @@ class BindingOptions: help=arg_item["help"], ) else: + resolved_type = arg_item["type"] + if resolved_type is not None: + resolved_type = _resolve_optional_type(resolved_type) + group.add_argument( f"--{arg_item['argname']}", - type=arg_item["type"], + type=resolved_type, default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS), help=arg_item["help"], ) @@ -210,7 +228,7 @@ class BindingOptions: argdef = { "argname": f"{args_prefix}-{field.name}", "env_name": f"{env_var_prefix}{field.name.upper()}", - "type": field.type, + "type": _resolve_optional_type(field.type), "default": default_value, "help": f"{cls._binding_name} -- " + help.get(field.name, ""), } @@ -454,6 +472,39 @@ class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions): _binding_name: ClassVar[str] = "ollama_llm" +@dataclass +class GeminiLLMOptions(BindingOptions): + """Options for Google Gemini models.""" + + _binding_name: ClassVar[str] = "gemini_llm" + + temperature: float = DEFAULT_TEMPERATURE + top_p: float = 0.95 + top_k: int = 40 + max_output_tokens: int | None = None + candidate_count: int = 1 + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + stop_sequences: List[str] = field(default_factory=list) + response_mime_type: str | None = None + safety_settings: dict | None = None + system_instruction: str | None = None + + _help: ClassVar[dict[str, str]] = { + "temperature": "Controls randomness (0.0-2.0, higher = more creative)", + "top_p": "Nucleus sampling parameter (0.0-1.0)", + "top_k": "Limits sampling to the top K tokens (1 disables the limit)", + "max_output_tokens": "Maximum tokens generated in the response", + "candidate_count": "Number of candidates returned per request", + "presence_penalty": "Penalty for token presence (-2.0 to 2.0)", + "frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)", + "stop_sequences": 'Stop sequences (JSON array of strings, e.g., \'["END"]\')', + "response_mime_type": "Desired MIME type for the response (e.g., application/json)", + "safety_settings": "JSON object with Gemini safety settings overrides", + "system_instruction": "Default system instruction applied to every request", + } + + # ============================================================================= # Binding Options for OpenAI # ============================================================================= diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py new file mode 100644 index 00000000..14a1b238 --- /dev/null +++ b/lightrag/llm/gemini.py @@ -0,0 +1,297 @@ +""" +Gemini LLM binding for LightRAG. + +This module provides asynchronous helpers that adapt Google's Gemini models +to the same interface used by the rest of the LightRAG LLM bindings. The +implementation mirrors the OpenAI helpers while relying on the official +``google-genai`` client under the hood. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from collections.abc import AsyncIterator +from functools import lru_cache +from typing import Any + +from lightrag.utils import logger, remove_think_tags, safe_unicode_decode + +import pipmaster as pm + +# Install the Google Gemini client on demand +if not pm.is_installed("google-genai"): + pm.install("google-genai") + +from google import genai # type: ignore +from google.genai import types # type: ignore + +DEFAULT_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com" + +LOG = logging.getLogger(__name__) + + +@lru_cache(maxsize=8) +def _get_gemini_client(api_key: str, base_url: str | None) -> genai.Client: + """ + Create (or fetch cached) Gemini client. + + Args: + api_key: Google Gemini API key. + base_url: Optional custom API endpoint. + + Returns: + genai.Client: Configured Gemini client instance. + """ + client_kwargs: dict[str, Any] = {"api_key": api_key} + + if base_url and base_url != DEFAULT_GEMINI_ENDPOINT: + try: + client_kwargs["http_options"] = types.HttpOptions(api_endpoint=base_url) + except Exception as exc: # pragma: no cover - defensive + LOG.warning("Failed to apply custom Gemini endpoint %s: %s", base_url, exc) + + try: + return genai.Client(**client_kwargs) + except TypeError: + # Older google-genai releases don't accept http_options; retry without it. + client_kwargs.pop("http_options", None) + return genai.Client(**client_kwargs) + + +def _ensure_api_key(api_key: str | None) -> str: + key = api_key or os.getenv("LLM_BINDING_API_KEY") or os.getenv("GEMINI_API_KEY") + if not key: + raise ValueError( + "Gemini API key not provided. " + "Set LLM_BINDING_API_KEY or GEMINI_API_KEY in the environment." + ) + return key + + +def _build_generation_config( + base_config: dict[str, Any] | None, + system_prompt: str | None, + keyword_extraction: bool, +) -> types.GenerateContentConfig | None: + config_data = dict(base_config or {}) + + if system_prompt: + if config_data.get("system_instruction"): + config_data["system_instruction"] = ( + f"{config_data['system_instruction']}\n{system_prompt}" + ) + else: + config_data["system_instruction"] = system_prompt + + if keyword_extraction and not config_data.get("response_mime_type"): + config_data["response_mime_type"] = "application/json" + + # Remove entries that are explicitly set to None to avoid type errors + sanitized = { + key: value + for key, value in config_data.items() + if value is not None and value != "" + } + + if not sanitized: + return None + + return types.GenerateContentConfig(**sanitized) + + +def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> str: + if not history_messages: + return "" + + history_lines: list[str] = [] + for message in history_messages: + role = message.get("role", "user") + content = message.get("content", "") + history_lines.append(f"[{role}] {content}") + + return "\n".join(history_lines) + + +def _extract_response_text(response: Any) -> str: + if getattr(response, "text", None): + return response.text + + candidates = getattr(response, "candidates", None) + if not candidates: + return "" + + parts: list[str] = [] + for candidate in candidates: + if not getattr(candidate, "content", None): + continue + for part in getattr(candidate.content, "parts", []): + text = getattr(part, "text", None) + if text: + parts.append(text) + + return "\n".join(parts) + + +async def gemini_complete_if_cache( + model: str, + prompt: str, + system_prompt: str | None = None, + history_messages: list[dict[str, Any]] | None = None, + *, + api_key: str | None = None, + base_url: str | None = None, + generation_config: dict[str, Any] | None = None, + keyword_extraction: bool = False, + token_tracker: Any | None = None, + hashing_kv: Any | None = None, # noqa: ARG001 - present for interface parity + stream: bool | None = None, + enable_cot: bool = False, # noqa: ARG001 - not supported by Gemini currently + timeout: float | None = None, # noqa: ARG001 - handled by caller if needed + **_: Any, +) -> str | AsyncIterator[str]: + loop = asyncio.get_running_loop() + + key = _ensure_api_key(api_key) + client = _get_gemini_client(key, base_url) + + history_block = _format_history_messages(history_messages) + prompt_sections = [] + if history_block: + prompt_sections.append(history_block) + prompt_sections.append(f"[user] {prompt}") + combined_prompt = "\n".join(prompt_sections) + + config_obj = _build_generation_config( + generation_config, + system_prompt=system_prompt, + keyword_extraction=keyword_extraction, + ) + + request_kwargs: dict[str, Any] = { + "model": model, + "contents": [combined_prompt], + } + if config_obj is not None: + request_kwargs["config"] = config_obj + + def _call_model(): + return client.models.generate_content(**request_kwargs) + + if stream: + queue: asyncio.Queue[Any] = asyncio.Queue() + usage_container: dict[str, Any] = {} + + def _stream_model() -> None: + try: + stream_kwargs = dict(request_kwargs) + stream_iterator = client.models.generate_content_stream(**stream_kwargs) + for chunk in stream_iterator: + usage = getattr(chunk, "usage_metadata", None) + if usage is not None: + usage_container["usage"] = usage + text_piece = getattr(chunk, "text", None) or _extract_response_text(chunk) + if text_piece: + loop.call_soon_threadsafe(queue.put_nowait, text_piece) + loop.call_soon_threadsafe(queue.put_nowait, None) + except Exception as exc: # pragma: no cover - surface runtime issues + loop.call_soon_threadsafe(queue.put_nowait, exc) + + loop.run_in_executor(None, _stream_model) + + async def _async_stream() -> AsyncIterator[str]: + accumulated = "" + emitted = "" + try: + while True: + item = await queue.get() + if item is None: + break + if isinstance(item, Exception): + raise item + + chunk_text = str(item) + if "\\u" in chunk_text: + chunk_text = safe_unicode_decode(chunk_text.encode("utf-8")) + + accumulated += chunk_text + sanitized = remove_think_tags(accumulated) + if sanitized.startswith(emitted): + delta = sanitized[len(emitted) :] + else: + delta = sanitized + emitted = sanitized + + if delta: + yield delta + finally: + usage = usage_container.get("usage") + if token_tracker and usage: + token_tracker.add_usage( + { + "prompt_tokens": getattr(usage, "prompt_token_count", 0), + "completion_tokens": getattr( + usage, "candidates_token_count", 0 + ), + "total_tokens": getattr(usage, "total_token_count", 0), + } + ) + + return _async_stream() + + response = await asyncio.to_thread(_call_model) + + text = _extract_response_text(response) + if not text: + raise RuntimeError("Gemini response did not contain any text content.") + + if "\\u" in text: + text = safe_unicode_decode(text.encode("utf-8")) + + text = remove_think_tags(text) + + usage = getattr(response, "usage_metadata", None) + if token_tracker and usage: + token_tracker.add_usage( + { + "prompt_tokens": getattr(usage, "prompt_token_count", 0), + "completion_tokens": getattr(usage, "candidates_token_count", 0), + "total_tokens": getattr(usage, "total_token_count", 0), + } + ) + + logger.debug("Gemini response length: %s", len(text)) + return text + + +async def gemini_model_complete( + prompt: str, + system_prompt: str | None = None, + history_messages: list[dict[str, Any]] | None = None, + keyword_extraction: bool = False, + **kwargs: Any, +) -> str | AsyncIterator[str]: + hashing_kv = kwargs.get("hashing_kv") + model_name = None + if hashing_kv is not None: + model_name = hashing_kv.global_config.get("llm_model_name") + if model_name is None: + model_name = kwargs.pop("model_name", None) + if model_name is None: + raise ValueError("Gemini model name not provided in configuration.") + + return await gemini_complete_if_cache( + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + keyword_extraction=keyword_extraction, + **kwargs, + ) + + +__all__ = [ + "gemini_complete_if_cache", + "gemini_model_complete", +] diff --git a/lightrag/utils.py b/lightrag/utils.py index 94f1ff27..b85edc8e 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1783,7 +1783,7 @@ def normalize_extracted_info(name: str, remove_inner_quotes=False) -> str: - Filter out short numeric-only text (length < 3 and only digits/dots) - remove_inner_quotes = True remove Chinese quotes - remove English queotes in and around chinese + remove English quotes in and around chinese Convert non-breaking spaces to regular spaces Convert narrow non-breaking spaces after non-digits to regular spaces diff --git a/pyproject.toml b/pyproject.toml index b35c09be..a4f16ddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "aiohttp", "configparser", "future", + "google-genai>=1.0.0,<2.0.0", "json_repair", "nano-vectordb", "networkx", @@ -59,6 +60,7 @@ api = [ "tenacity", "tiktoken", "xlsxwriter>=3.1.0", + "google-genai>=1.0.0,<2.0.0", # API-specific dependencies "aiofiles", "ascii_colors", @@ -105,6 +107,7 @@ offline-llm = [ "aioboto3>=12.0.0,<16.0.0", "voyageai>=0.2.0,<1.0.0", "llama-index>=0.9.0,<1.0.0", + "google-genai>=1.0.0,<2.0.0", ] offline = [ diff --git a/requirements-offline-llm.txt b/requirements-offline-llm.txt index fe3fc747..441abc6e 100644 --- a/requirements-offline-llm.txt +++ b/requirements-offline-llm.txt @@ -13,5 +13,6 @@ anthropic>=0.18.0,<1.0.0 llama-index>=0.9.0,<1.0.0 ollama>=0.1.0,<1.0.0 openai>=1.0.0,<2.0.0 +google-genai>=1.0.0,<2.0.0 voyageai>=0.2.0,<1.0.0 zhipuai>=2.0.0,<3.0.0 diff --git a/requirements-offline.txt b/requirements-offline.txt index f3616e6b..d6943b11 100644 --- a/requirements-offline.txt +++ b/requirements-offline.txt @@ -19,6 +19,7 @@ llama-index>=0.9.0,<1.0.0 neo4j>=5.0.0,<7.0.0 ollama>=0.1.0,<1.0.0 openai>=1.0.0,<2.0.0 +google-genai>=1.0.0,<2.0.0 openpyxl>=3.0.0,<4.0.0 pymilvus>=2.6.2,<3.0.0 pymongo>=4.0.0,<5.0.0