Merge pull request #2531 from danielaskdd/gemini-aio
refact(gemini): Migrate Gemini LLM to native async Google GenAI client
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||
|
||||
__version__ = "1.4.9.10"
|
||||
__version__ = "1.4.9.11"
|
||||
__author__ = "Zirui Guo"
|
||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||
|
||||
@@ -1 +1 @@
|
||||
__api_version__ = "0263"
|
||||
__api_version__ = "0264"
|
||||
|
||||
@@ -9,7 +9,6 @@ implementation mirrors the OpenAI helpers while relying on the official
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
from functools import lru_cache
|
||||
@@ -267,8 +266,6 @@ async def gemini_complete_if_cache(
|
||||
RuntimeError: If the response from Gemini is empty.
|
||||
ValueError: If API key is not provided or configured.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
key = _ensure_api_key(api_key)
|
||||
# Convert timeout from seconds to milliseconds for Gemini API
|
||||
timeout_ms = timeout * 1000 if timeout else None
|
||||
@@ -294,26 +291,25 @@ async def gemini_complete_if_cache(
|
||||
if config_obj is not None:
|
||||
request_kwargs["config"] = config_obj
|
||||
|
||||
def _call_model():
|
||||
return client.models.generate_content(**request_kwargs)
|
||||
|
||||
if stream:
|
||||
queue: asyncio.Queue[Any] = asyncio.Queue()
|
||||
usage_container: dict[str, Any] = {}
|
||||
|
||||
def _stream_model() -> None:
|
||||
async def _async_stream() -> AsyncIterator[str]:
|
||||
# COT state tracking for streaming
|
||||
cot_active = False
|
||||
cot_started = False
|
||||
initial_content_seen = False
|
||||
usage_metadata = None
|
||||
|
||||
try:
|
||||
stream_kwargs = dict(request_kwargs)
|
||||
stream_iterator = client.models.generate_content_stream(**stream_kwargs)
|
||||
for chunk in stream_iterator:
|
||||
# Use native async streaming from genai SDK
|
||||
# Note: generate_content_stream returns Awaitable[AsyncIterator], need to await first
|
||||
stream = await client.aio.models.generate_content_stream(
|
||||
**request_kwargs
|
||||
)
|
||||
async for chunk in stream:
|
||||
usage = getattr(chunk, "usage_metadata", None)
|
||||
if usage is not None:
|
||||
usage_container["usage"] = usage
|
||||
usage_metadata = usage
|
||||
|
||||
# Extract both regular and thought content
|
||||
regular_text, thought_text = _extract_response_text(
|
||||
@@ -328,78 +324,74 @@ async def gemini_complete_if_cache(
|
||||
|
||||
# Close COT section if it was active
|
||||
if cot_active:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
|
||||
yield "</think>"
|
||||
cot_active = False
|
||||
|
||||
# Send regular content
|
||||
loop.call_soon_threadsafe(queue.put_nowait, regular_text)
|
||||
# Process and yield regular content
|
||||
if "\\u" in regular_text:
|
||||
regular_text = safe_unicode_decode(
|
||||
regular_text.encode("utf-8")
|
||||
)
|
||||
yield regular_text
|
||||
|
||||
# Process thought content
|
||||
if thought_text:
|
||||
if not initial_content_seen and not cot_started:
|
||||
# Start COT section
|
||||
loop.call_soon_threadsafe(queue.put_nowait, "<think>")
|
||||
yield "<think>"
|
||||
cot_active = True
|
||||
cot_started = True
|
||||
|
||||
# Send thought content if COT is active
|
||||
# Yield thought content if COT is active
|
||||
if cot_active:
|
||||
loop.call_soon_threadsafe(
|
||||
queue.put_nowait, thought_text
|
||||
)
|
||||
if "\\u" in thought_text:
|
||||
thought_text = safe_unicode_decode(
|
||||
thought_text.encode("utf-8")
|
||||
)
|
||||
yield thought_text
|
||||
else:
|
||||
# COT disabled - only send regular content
|
||||
# COT disabled - only yield regular content
|
||||
if regular_text:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, regular_text)
|
||||
if "\\u" in regular_text:
|
||||
regular_text = safe_unicode_decode(
|
||||
regular_text.encode("utf-8")
|
||||
)
|
||||
yield regular_text
|
||||
|
||||
# Ensure COT is properly closed if still active
|
||||
if cot_active:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
|
||||
yield "</think>"
|
||||
cot_active = False
|
||||
|
||||
loop.call_soon_threadsafe(queue.put_nowait, None)
|
||||
except Exception as exc: # pragma: no cover - surface runtime issues
|
||||
# Try to close COT tag before reporting error
|
||||
except Exception as exc:
|
||||
# Try to close COT tag before re-raising
|
||||
if cot_active:
|
||||
try:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, "</think>")
|
||||
yield "</think>"
|
||||
except Exception:
|
||||
pass
|
||||
loop.call_soon_threadsafe(queue.put_nowait, exc)
|
||||
|
||||
loop.run_in_executor(None, _stream_model)
|
||||
|
||||
async def _async_stream() -> AsyncIterator[str]:
|
||||
try:
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
|
||||
chunk_text = str(item)
|
||||
if "\\u" in chunk_text:
|
||||
chunk_text = safe_unicode_decode(chunk_text.encode("utf-8"))
|
||||
|
||||
# Yield the chunk directly without filtering
|
||||
# COT filtering is already handled in _stream_model()
|
||||
yield chunk_text
|
||||
raise exc
|
||||
finally:
|
||||
usage = usage_container.get("usage")
|
||||
if token_tracker and usage:
|
||||
# Track token usage after streaming completes
|
||||
if token_tracker and usage_metadata:
|
||||
token_tracker.add_usage(
|
||||
{
|
||||
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
|
||||
"completion_tokens": getattr(
|
||||
usage, "candidates_token_count", 0
|
||||
"prompt_tokens": getattr(
|
||||
usage_metadata, "prompt_token_count", 0
|
||||
),
|
||||
"completion_tokens": getattr(
|
||||
usage_metadata, "candidates_token_count", 0
|
||||
),
|
||||
"total_tokens": getattr(
|
||||
usage_metadata, "total_token_count", 0
|
||||
),
|
||||
"total_tokens": getattr(usage, "total_token_count", 0),
|
||||
}
|
||||
)
|
||||
|
||||
return _async_stream()
|
||||
|
||||
response = await asyncio.to_thread(_call_model)
|
||||
# Non-streaming: use native async client
|
||||
response = await client.aio.models.generate_content(**request_kwargs)
|
||||
|
||||
# Extract both regular text and thought text
|
||||
regular_text, thought_text = _extract_response_text(response, extract_thoughts=True)
|
||||
@@ -543,7 +535,6 @@ async def gemini_embed(
|
||||
# Note: max_token_size is received but not used for client-side truncation.
|
||||
# Gemini API handles truncation automatically with autoTruncate=True (default).
|
||||
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
key = _ensure_api_key(api_key)
|
||||
# Convert timeout from seconds to milliseconds for Gemini API
|
||||
@@ -564,19 +555,15 @@ async def gemini_embed(
|
||||
# Create config object if we have parameters
|
||||
config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None
|
||||
|
||||
def _call_embed() -> Any:
|
||||
"""Call Gemini embedding API in executor thread."""
|
||||
request_kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"contents": texts,
|
||||
}
|
||||
if config_obj is not None:
|
||||
request_kwargs["config"] = config_obj
|
||||
request_kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"contents": texts,
|
||||
}
|
||||
if config_obj is not None:
|
||||
request_kwargs["config"] = config_obj
|
||||
|
||||
return client.models.embed_content(**request_kwargs)
|
||||
|
||||
# Execute API call in thread pool
|
||||
response = await loop.run_in_executor(None, _call_embed)
|
||||
# Use native async client for embedding
|
||||
response = await client.aio.models.embed_content(**request_kwargs)
|
||||
|
||||
# Extract embeddings from response
|
||||
if not hasattr(response, "embeddings") or not response.embeddings:
|
||||
|
||||
Reference in New Issue
Block a user