Merge pull request #2531 from danielaskdd/gemini-aio

refact(gemini): Migrate Gemini LLM to native async Google GenAI client
This commit is contained in:
Daniel.y
2025-12-23 21:13:33 +08:00
committed by GitHub
3 changed files with 57 additions and 70 deletions

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.4.9.10"
__version__ = "1.4.9.11"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -1 +1 @@
__api_version__ = "0263"
__api_version__ = "0264"

View File

@@ -9,7 +9,6 @@ implementation mirrors the OpenAI helpers while relying on the official
from __future__ import annotations
import asyncio
import os
from collections.abc import AsyncIterator
from functools import lru_cache
@@ -267,8 +266,6 @@ async def gemini_complete_if_cache(
RuntimeError: If the response from Gemini is empty.
ValueError: If API key is not provided or configured.
"""
loop = asyncio.get_running_loop()
key = _ensure_api_key(api_key)
# Convert timeout from seconds to milliseconds for Gemini API
timeout_ms = timeout * 1000 if timeout else None
@@ -294,26 +291,25 @@ async def gemini_complete_if_cache(
if config_obj is not None:
request_kwargs["config"] = config_obj
def _call_model():
return client.models.generate_content(**request_kwargs)
if stream:
queue: asyncio.Queue[Any] = asyncio.Queue()
usage_container: dict[str, Any] = {}
def _stream_model() -> None:
async def _async_stream() -> AsyncIterator[str]:
# COT state tracking for streaming
cot_active = False
cot_started = False
initial_content_seen = False
usage_metadata = None
try:
stream_kwargs = dict(request_kwargs)
stream_iterator = client.models.generate_content_stream(**stream_kwargs)
for chunk in stream_iterator:
# Use native async streaming from genai SDK
# Note: generate_content_stream returns Awaitable[AsyncIterator], need to await first
stream = await client.aio.models.generate_content_stream(
**request_kwargs
)
async for chunk in stream:
usage = getattr(chunk, "usage_metadata", None)
if usage is not None:
usage_container["usage"] = usage
usage_metadata = usage
# Extract both regular and thought content
regular_text, thought_text = _extract_response_text(
@@ -328,78 +324,74 @@ async def gemini_complete_if_cache(
# Close COT section if it was active
if cot_active:
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
yield "</think>"
cot_active = False
# Send regular content
loop.call_soon_threadsafe(queue.put_nowait, regular_text)
# Process and yield regular content
if "\\u" in regular_text:
regular_text = safe_unicode_decode(
regular_text.encode("utf-8")
)
yield regular_text
# Process thought content
if thought_text:
if not initial_content_seen and not cot_started:
# Start COT section
loop.call_soon_threadsafe(queue.put_nowait, "<think>")
yield "<think>"
cot_active = True
cot_started = True
# Send thought content if COT is active
# Yield thought content if COT is active
if cot_active:
loop.call_soon_threadsafe(
queue.put_nowait, thought_text
)
if "\\u" in thought_text:
thought_text = safe_unicode_decode(
thought_text.encode("utf-8")
)
yield thought_text
else:
# COT disabled - only send regular content
# COT disabled - only yield regular content
if regular_text:
loop.call_soon_threadsafe(queue.put_nowait, regular_text)
if "\\u" in regular_text:
regular_text = safe_unicode_decode(
regular_text.encode("utf-8")
)
yield regular_text
# Ensure COT is properly closed if still active
if cot_active:
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
yield "</think>"
cot_active = False
loop.call_soon_threadsafe(queue.put_nowait, None)
except Exception as exc: # pragma: no cover - surface runtime issues
# Try to close COT tag before reporting error
except Exception as exc:
# Try to close COT tag before re-raising
if cot_active:
try:
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
yield "</think>"
except Exception:
pass
loop.call_soon_threadsafe(queue.put_nowait, exc)
loop.run_in_executor(None, _stream_model)
async def _async_stream() -> AsyncIterator[str]:
try:
while True:
item = await queue.get()
if item is None:
break
if isinstance(item, Exception):
raise item
chunk_text = str(item)
if "\\u" in chunk_text:
chunk_text = safe_unicode_decode(chunk_text.encode("utf-8"))
# Yield the chunk directly without filtering
# COT filtering is already handled in _stream_model()
yield chunk_text
raise exc
finally:
usage = usage_container.get("usage")
if token_tracker and usage:
# Track token usage after streaming completes
if token_tracker and usage_metadata:
token_tracker.add_usage(
{
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
"completion_tokens": getattr(
usage, "candidates_token_count", 0
"prompt_tokens": getattr(
usage_metadata, "prompt_token_count", 0
),
"completion_tokens": getattr(
usage_metadata, "candidates_token_count", 0
),
"total_tokens": getattr(
usage_metadata, "total_token_count", 0
),
"total_tokens": getattr(usage, "total_token_count", 0),
}
)
return _async_stream()
response = await asyncio.to_thread(_call_model)
# Non-streaming: use native async client
response = await client.aio.models.generate_content(**request_kwargs)
# Extract both regular text and thought text
regular_text, thought_text = _extract_response_text(response, extract_thoughts=True)
@@ -543,7 +535,6 @@ async def gemini_embed(
# Note: max_token_size is received but not used for client-side truncation.
# Gemini API handles truncation automatically with autoTruncate=True (default).
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
loop = asyncio.get_running_loop()
key = _ensure_api_key(api_key)
# Convert timeout from seconds to milliseconds for Gemini API
@@ -564,19 +555,15 @@ async def gemini_embed(
# Create config object if we have parameters
config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None
def _call_embed() -> Any:
"""Call Gemini embedding API in executor thread."""
request_kwargs: dict[str, Any] = {
"model": model,
"contents": texts,
}
if config_obj is not None:
request_kwargs["config"] = config_obj
request_kwargs: dict[str, Any] = {
"model": model,
"contents": texts,
}
if config_obj is not None:
request_kwargs["config"] = config_obj
return client.models.embed_content(**request_kwargs)
# Execute API call in thread pool
response = await loop.run_in_executor(None, _call_embed)
# Use native async client for embedding
response = await client.aio.models.embed_content(**request_kwargs)
# Extract embeddings from response
if not hasattr(response, "embeddings") or not response.embeddings: