diff --git a/lightrag/llm/azure_openai.py b/lightrag/llm/azure_openai.py index 98437ca8..cb8d68df 100644 --- a/lightrag/llm/azure_openai.py +++ b/lightrag/llm/azure_openai.py @@ -26,6 +26,7 @@ from lightrag.utils import ( safe_unicode_decode, logger, ) +from lightrag.types import GPTKeywordExtractionFormat import numpy as np @@ -46,6 +47,7 @@ async def azure_openai_complete_if_cache( base_url: str | None = None, api_key: str | None = None, api_version: str | None = None, + keyword_extraction: bool = False, **kwargs, ): if enable_cot: @@ -66,9 +68,12 @@ async def azure_openai_complete_if_cache( ) kwargs.pop("hashing_kv", None) - kwargs.pop("keyword_extraction", None) timeout = kwargs.pop("timeout", None) + # Handle keyword extraction mode + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat + openai_async_client = AsyncAzureOpenAI( azure_endpoint=base_url, azure_deployment=deployment, @@ -85,7 +90,7 @@ async def azure_openai_complete_if_cache( messages.append({"role": "user", "content": prompt}) if "response_format" in kwargs: - response = await openai_async_client.beta.chat.completions.parse( + response = await openai_async_client.chat.completions.parse( model=model, messages=messages, **kwargs ) else: @@ -108,21 +113,32 @@ async def azure_openai_complete_if_cache( return inner() else: - content = response.choices[0].message.content - if r"\u" in content: - content = safe_unicode_decode(content.encode("utf-8")) + message = response.choices[0].message + + # Handle parsed responses (structured output via response_format) + # When using beta.chat.completions.parse(), the response is in message.parsed + if hasattr(message, "parsed") and message.parsed is not None: + # Serialize the parsed structured response to JSON + content = message.parsed.model_dump_json() + logger.debug("Using parsed structured response from API") + else: + # Handle regular content responses + content = message.content + if content and r"\u" in content: + content = safe_unicode_decode(content.encode("utf-8")) + return content async def azure_openai_complete( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: - kwargs.pop("keyword_extraction", None) result = await azure_openai_complete_if_cache( os.getenv("LLM_MODEL", "gpt-4o-mini"), prompt, system_prompt=system_prompt, history_messages=history_messages, + keyword_extraction=keyword_extraction, **kwargs, ) return result diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 8c984e51..6da79c2c 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -203,6 +203,10 @@ async def openai_complete_if_cache( # Extract client configuration options client_configs = kwargs.pop("openai_client_configs", {}) + # Handle keyword extraction mode + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat + # Create the OpenAI client openai_async_client = create_openai_async_client( api_key=api_key, @@ -237,7 +241,7 @@ async def openai_complete_if_cache( try: # Don't use async with context manager, use client directly if "response_format" in kwargs: - response = await openai_async_client.beta.chat.completions.parse( + response = await openai_async_client.chat.completions.parse( model=model, messages=messages, **kwargs ) else: @@ -449,46 +453,57 @@ async def openai_complete_if_cache( raise InvalidResponseError("Invalid response from OpenAI API") message = response.choices[0].message - content = getattr(message, "content", None) - reasoning_content = getattr(message, "reasoning_content", "") - # Handle COT logic for non-streaming responses (only if enabled) - final_content = "" + # Handle parsed responses (structured output via response_format) + # When using beta.chat.completions.parse(), the response is in message.parsed + if hasattr(message, "parsed") and message.parsed is not None: + # Serialize the parsed structured response to JSON + final_content = message.parsed.model_dump_json() + logger.debug("Using parsed structured response from API") + else: + # Handle regular content responses + content = getattr(message, "content", None) + reasoning_content = getattr(message, "reasoning_content", "") - if enable_cot: - # Check if we should include reasoning content - should_include_reasoning = False - if reasoning_content and reasoning_content.strip(): - if not content or content.strip() == "": - # Case 1: Only reasoning content, should include COT - should_include_reasoning = True - final_content = ( - content or "" - ) # Use empty string if content is None + # Handle COT logic for non-streaming responses (only if enabled) + final_content = "" + + if enable_cot: + # Check if we should include reasoning content + should_include_reasoning = False + if reasoning_content and reasoning_content.strip(): + if not content or content.strip() == "": + # Case 1: Only reasoning content, should include COT + should_include_reasoning = True + final_content = ( + content or "" + ) # Use empty string if content is None + else: + # Case 3: Both content and reasoning_content present, ignore reasoning + should_include_reasoning = False + final_content = content else: - # Case 3: Both content and reasoning_content present, ignore reasoning - should_include_reasoning = False - final_content = content + # No reasoning content, use regular content + final_content = content or "" + + # Apply COT wrapping if needed + if should_include_reasoning: + if r"\u" in reasoning_content: + reasoning_content = safe_unicode_decode( + reasoning_content.encode("utf-8") + ) + final_content = ( + f"{reasoning_content}{final_content}" + ) else: - # No reasoning content, use regular content + # COT disabled, only use regular content final_content = content or "" - # Apply COT wrapping if needed - if should_include_reasoning: - if r"\u" in reasoning_content: - reasoning_content = safe_unicode_decode( - reasoning_content.encode("utf-8") - ) - final_content = f"{reasoning_content}{final_content}" - else: - # COT disabled, only use regular content - final_content = content or "" - - # Validate final content - if not final_content or final_content.strip() == "": - logger.error("Received empty content from OpenAI API") - await openai_async_client.close() # Ensure client is closed - raise InvalidResponseError("Received empty content from OpenAI API") + # Validate final content + if not final_content or final_content.strip() == "": + logger.error("Received empty content from OpenAI API") + await openai_async_client.close() # Ensure client is closed + raise InvalidResponseError("Received empty content from OpenAI API") # Apply Unicode decoding to final content if needed if r"\u" in final_content: @@ -522,8 +537,6 @@ async def openai_complete( ) -> Union[str, AsyncIterator[str]]: if history_messages is None: history_messages = [] - if keyword_extraction: - kwargs["response_format"] = "json" model_name = kwargs["hashing_kv"].global_config["llm_model_name"] return await openai_complete_if_cache( model_name, @@ -545,8 +558,6 @@ async def gpt_4o_complete( ) -> str: if history_messages is None: history_messages = [] - if keyword_extraction: - kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( "gpt-4o", prompt, @@ -568,8 +579,6 @@ async def gpt_4o_mini_complete( ) -> str: if history_messages is None: history_messages = [] - if keyword_extraction: - kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( "gpt-4o-mini", prompt, diff --git a/pyproject.toml b/pyproject.toml index b76315d9..8d48b5df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ api = [ "nano-vectordb", "networkx", "numpy>=1.24.0,<2.0.0", - "openai>=1.0.0,<3.0.0", + "openai>=2.0.0,<3.0.0", "pandas>=2.0.0,<2.4.0", "pipmaster", "pydantic", @@ -115,7 +115,7 @@ offline-storage = [ offline-llm = [ # LLM provider dependencies - "openai>=1.0.0,<3.0.0", + "openai>=2.0.0,<3.0.0", "anthropic>=0.18.0,<1.0.0", "ollama>=0.1.0,<1.0.0", "zhipuai>=2.0.0,<3.0.0", diff --git a/requirements-offline-llm.txt b/requirements-offline-llm.txt index 1539552a..bcfb1451 100644 --- a/requirements-offline-llm.txt +++ b/requirements-offline-llm.txt @@ -14,6 +14,6 @@ google-api-core>=2.0.0,<3.0.0 google-genai>=1.0.0,<2.0.0 llama-index>=0.9.0,<1.0.0 ollama>=0.1.0,<1.0.0 -openai>=1.0.0,<3.0.0 +openai>=2.0.0,<3.0.0 voyageai>=0.2.0,<1.0.0 zhipuai>=2.0.0,<3.0.0 diff --git a/requirements-offline.txt b/requirements-offline.txt index 50848093..87ca7a6a 100644 --- a/requirements-offline.txt +++ b/requirements-offline.txt @@ -19,7 +19,7 @@ google-genai>=1.0.0,<2.0.0 llama-index>=0.9.0,<1.0.0 neo4j>=5.0.0,<7.0.0 ollama>=0.1.0,<1.0.0 -openai>=1.0.0,<3.0.0 +openai>=2.0.0,<3.0.0 openpyxl>=3.0.0,<4.0.0 pycryptodome>=3.0.0,<4.0.0 pymilvus>=2.6.2,<3.0.0 diff --git a/uv.lock b/uv.lock index 97703af0..a4f17ab4 100644 --- a/uv.lock +++ b/uv.lock @@ -2735,7 +2735,6 @@ requires-dist = [ { name = "json-repair", marker = "extra == 'api'" }, { name = "langfuse", marker = "extra == 'observability'", specifier = ">=3.8.1" }, { name = "lightrag-hku", extras = ["api", "offline-llm", "offline-storage"], marker = "extra == 'offline'" }, - { name = "lightrag-hku", extras = ["pytest"], marker = "extra == 'evaluation'" }, { name = "llama-index", marker = "extra == 'offline-llm'", specifier = ">=0.9.0,<1.0.0" }, { name = "nano-vectordb" }, { name = "nano-vectordb", marker = "extra == 'api'" }, @@ -2745,14 +2744,15 @@ requires-dist = [ { name = "numpy", specifier = ">=1.24.0,<2.0.0" }, { name = "numpy", marker = "extra == 'api'", specifier = ">=1.24.0,<2.0.0" }, { name = "ollama", marker = "extra == 'offline-llm'", specifier = ">=0.1.0,<1.0.0" }, - { name = "openai", marker = "extra == 'api'", specifier = ">=1.0.0,<3.0.0" }, - { name = "openai", marker = "extra == 'offline-llm'", specifier = ">=1.0.0,<3.0.0" }, + { name = "openai", marker = "extra == 'api'", specifier = ">=2.0.0,<3.0.0" }, + { name = "openai", marker = "extra == 'offline-llm'", specifier = ">=2.0.0,<3.0.0" }, { name = "openpyxl", marker = "extra == 'api'", specifier = ">=3.0.0,<4.0.0" }, { name = "pandas", specifier = ">=2.0.0,<2.4.0" }, { name = "pandas", marker = "extra == 'api'", specifier = ">=2.0.0,<2.4.0" }, { name = "passlib", extras = ["bcrypt"], marker = "extra == 'api'" }, { name = "pipmaster" }, { name = "pipmaster", marker = "extra == 'api'" }, + { name = "pre-commit", marker = "extra == 'evaluation'" }, { name = "pre-commit", marker = "extra == 'pytest'" }, { name = "psutil", marker = "extra == 'api'" }, { name = "pycryptodome", marker = "extra == 'api'", specifier = ">=3.0.0,<4.0.0" }, @@ -2764,7 +2764,9 @@ requires-dist = [ { name = "pypdf", marker = "extra == 'api'", specifier = ">=6.1.0" }, { name = "pypinyin" }, { name = "pypinyin", marker = "extra == 'api'" }, + { name = "pytest", marker = "extra == 'evaluation'", specifier = ">=8.4.2" }, { name = "pytest", marker = "extra == 'pytest'", specifier = ">=8.4.2" }, + { name = "pytest-asyncio", marker = "extra == 'evaluation'", specifier = ">=1.2.0" }, { name = "pytest-asyncio", marker = "extra == 'pytest'", specifier = ">=1.2.0" }, { name = "python-docx", marker = "extra == 'api'", specifier = ">=0.8.11,<2.0.0" }, { name = "python-dotenv" },