diff --git a/lightrag/llm/azure_openai.py b/lightrag/llm/azure_openai.py index 98437ca8..c183c3a9 100644 --- a/lightrag/llm/azure_openai.py +++ b/lightrag/llm/azure_openai.py @@ -26,6 +26,7 @@ from lightrag.utils import ( safe_unicode_decode, logger, ) +from lightrag.types import GPTKeywordExtractionFormat import numpy as np @@ -46,6 +47,7 @@ async def azure_openai_complete_if_cache( base_url: str | None = None, api_key: str | None = None, api_version: str | None = None, + keyword_extraction: bool = False, **kwargs, ): if enable_cot: @@ -66,9 +68,12 @@ async def azure_openai_complete_if_cache( ) kwargs.pop("hashing_kv", None) - kwargs.pop("keyword_extraction", None) timeout = kwargs.pop("timeout", None) + # Handle keyword extraction mode + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat + openai_async_client = AsyncAzureOpenAI( azure_endpoint=base_url, azure_deployment=deployment, @@ -117,12 +122,12 @@ async def azure_openai_complete_if_cache( async def azure_openai_complete( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: - kwargs.pop("keyword_extraction", None) result = await azure_openai_complete_if_cache( os.getenv("LLM_MODEL", "gpt-4o-mini"), prompt, system_prompt=system_prompt, history_messages=history_messages, + keyword_extraction=keyword_extraction, **kwargs, ) return result diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 8c984e51..948ae270 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -203,6 +203,10 @@ async def openai_complete_if_cache( # Extract client configuration options client_configs = kwargs.pop("openai_client_configs", {}) + # Handle keyword extraction mode + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat + # Create the OpenAI client openai_async_client = create_openai_async_client( api_key=api_key, @@ -522,8 +526,6 @@ async def openai_complete( ) -> Union[str, AsyncIterator[str]]: if history_messages is None: history_messages = [] - if keyword_extraction: - kwargs["response_format"] = "json" model_name = kwargs["hashing_kv"].global_config["llm_model_name"] return await openai_complete_if_cache( model_name, @@ -545,8 +547,6 @@ async def gpt_4o_complete( ) -> str: if history_messages is None: history_messages = [] - if keyword_extraction: - kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( "gpt-4o", prompt, @@ -568,8 +568,6 @@ async def gpt_4o_mini_complete( ) -> str: if history_messages is None: history_messages = [] - if keyword_extraction: - kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( "gpt-4o-mini", prompt,