resolve: merge conflicts between refine and main branches
- Resolved conflicts in src/biz_bud/graphs/graph.py: - Unified create_initial_state functions to use main branch implementation - Maintained optimized async/sync patterns from main - Preserved performance improvements and caching logic - Resolved conflicts in tests/integration_tests/graphs/test_main_graph_integration.py: - Fixed test assertions to use consistent state access patterns - Removed deprecated .get() calls in favor of direct dictionary access - Maintained test compatibility with resolved state structure 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1321,22 +1321,47 @@ async def create_initial_state(
|
||||
Returns:
|
||||
InputState properly formatted for the graph.
|
||||
"""
|
||||
from biz_bud.core.utils.graph_helpers import (
|
||||
extract_state_update_data,
|
||||
format_raw_input,
|
||||
process_state_query,
|
||||
)
|
||||
# If state_update is provided (from LangGraph API), extract data from it
|
||||
if state_update:
|
||||
# Extract messages if present
|
||||
if "messages" in state_update and not messages:
|
||||
messages = state_update["messages"]
|
||||
|
||||
# Extract data from state_update if provided
|
||||
messages, raw_input, thread_id = extract_state_update_data(
|
||||
state_update, messages, raw_input, thread_id
|
||||
)
|
||||
# Extract other fields if present
|
||||
if "raw_input" in state_update and not raw_input:
|
||||
raw_input = state_update["raw_input"]
|
||||
if "thread_id" in state_update and not thread_id:
|
||||
thread_id = state_update["thread_id"]
|
||||
|
||||
# Process query
|
||||
user_query = process_state_query(query, messages, state_update, DEFAULT_USER_QUERY)
|
||||
# Extract query from messages if not provided (optimized - early exit)
|
||||
if query is None and messages:
|
||||
# Look for the last human/user message
|
||||
for msg in reversed(messages):
|
||||
# Handle different message formats
|
||||
role = msg.get("role")
|
||||
msg_type = msg.get("type")
|
||||
if role in ("user", "human") or msg_type == "human":
|
||||
query = msg.get("content", "")
|
||||
break
|
||||
|
||||
# Format raw input
|
||||
raw_input_str, user_query = format_raw_input(raw_input, user_query)
|
||||
# Use provided query or fall back to default
|
||||
user_query = query or DEFAULT_USER_QUERY
|
||||
|
||||
# Optimized raw_input handling - avoid JSON serialization when possible
|
||||
if raw_input is None:
|
||||
raw_input_str = f'{{"query": "{user_query}"}}'
|
||||
elif isinstance(raw_input, dict):
|
||||
# If raw_input has a query field, use it
|
||||
if "query" in raw_input:
|
||||
user_query = raw_input["query"]
|
||||
# Avoid json.dumps for simple cases
|
||||
if len(raw_input) == 1 and "query" in raw_input:
|
||||
raw_input_str = f'{{"query": "{raw_input["query"]}"}}'
|
||||
else:
|
||||
import json
|
||||
raw_input_str = json.dumps(raw_input)
|
||||
else:
|
||||
raw_input_str = str(raw_input)
|
||||
|
||||
# Create messages if not provided
|
||||
if messages is None:
|
||||
@@ -1381,23 +1406,47 @@ def create_initial_state_sync(
|
||||
This is a synchronous version for backward compatibility.
|
||||
For best performance, use create_initial_state() in async contexts.
|
||||
"""
|
||||
from biz_bud.core.utils.graph_helpers import (
|
||||
create_initial_state_dict,
|
||||
extract_state_update_data,
|
||||
format_raw_input,
|
||||
process_state_query,
|
||||
)
|
||||
# If state_update is provided (from LangGraph API), extract data from it
|
||||
if state_update:
|
||||
# Extract messages if present
|
||||
if "messages" in state_update and not messages:
|
||||
messages = state_update["messages"]
|
||||
|
||||
# Extract data from state_update if provided
|
||||
messages, raw_input, thread_id = extract_state_update_data(
|
||||
state_update, messages, raw_input, thread_id
|
||||
)
|
||||
# Extract other fields if present
|
||||
if "raw_input" in state_update and not raw_input:
|
||||
raw_input = state_update["raw_input"]
|
||||
if "thread_id" in state_update and not thread_id:
|
||||
thread_id = state_update["thread_id"]
|
||||
|
||||
# Process query
|
||||
user_query = process_state_query(query, messages, state_update, DEFAULT_USER_QUERY)
|
||||
# Extract query from messages if not provided
|
||||
if query is None and messages:
|
||||
# Look for the last human/user message
|
||||
for msg in reversed(messages):
|
||||
# Handle different message formats
|
||||
role = msg.get("role")
|
||||
msg_type = msg.get("type")
|
||||
if role in ("user", "human") or msg_type == "human":
|
||||
query = msg.get("content", "")
|
||||
break
|
||||
|
||||
# Format raw input
|
||||
raw_input_str, user_query = format_raw_input(raw_input, user_query)
|
||||
# Use provided query or fall back to default
|
||||
user_query = query or DEFAULT_USER_QUERY
|
||||
|
||||
# Handle raw_input - optimized
|
||||
if raw_input is None:
|
||||
raw_input_str = f'{{"query": "{user_query}"}}'
|
||||
elif isinstance(raw_input, dict):
|
||||
# If raw_input has a query field, use it
|
||||
if "query" in raw_input:
|
||||
user_query = raw_input["query"]
|
||||
# Avoid json.dumps for simple cases
|
||||
if len(raw_input) == 1 and "query" in raw_input:
|
||||
raw_input_str = f'{{"query": "{raw_input["query"]}"}}'
|
||||
else:
|
||||
import json
|
||||
raw_input_str = json.dumps(raw_input)
|
||||
else:
|
||||
raw_input_str = str(raw_input)
|
||||
|
||||
# Create messages if not provided
|
||||
if messages is None:
|
||||
@@ -1407,16 +1456,30 @@ def create_initial_state_sync(
|
||||
if thread_id is None:
|
||||
thread_id = "initial_thread_mvp"
|
||||
|
||||
# Create state using helper
|
||||
return cast(
|
||||
"InputState",
|
||||
create_initial_state_dict(
|
||||
raw_input_str,
|
||||
user_query,
|
||||
messages,
|
||||
thread_id,
|
||||
get_config_dict_sync()
|
||||
)
|
||||
{
|
||||
"raw_input": raw_input_str,
|
||||
"parsed_input": {
|
||||
"raw_payload": {
|
||||
"query": user_query,
|
||||
},
|
||||
"user_query": user_query,
|
||||
},
|
||||
"input_metadata": {},
|
||||
"messages": messages,
|
||||
"initial_input": {
|
||||
"query": user_query,
|
||||
},
|
||||
"config": get_config_dict_sync(),
|
||||
"context": {},
|
||||
"status": "pending",
|
||||
"errors": [],
|
||||
"run_metadata": {},
|
||||
"thread_id": thread_id,
|
||||
"final_result": None,
|
||||
"is_last_step": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -326,11 +326,8 @@ async def test_optimized_state_creation() -> None:
|
||||
|
||||
# Verify state structure
|
||||
assert state["raw_input"] == f'{{"query": "{query}"}}'
|
||||
assert state["parsed_input"].get("user_query") == query
|
||||
# Messages are created as dictionaries in create_initial_state
|
||||
from typing import cast
|
||||
first_message = cast(dict[str, str], state["messages"][0])
|
||||
assert first_message["content"] == query
|
||||
assert state["parsed_input"]["user_query"] == query
|
||||
assert state["messages"][0]["content"] == query
|
||||
assert "config" in state
|
||||
assert "context" in state
|
||||
assert "errors" in state
|
||||
@@ -351,7 +348,7 @@ async def test_backward_compatibility() -> None:
|
||||
assert graph is not None
|
||||
|
||||
state_sync = create_initial_state_sync(query="sync test")
|
||||
assert state_sync["parsed_input"].get("user_query") == "sync test"
|
||||
assert state_sync["parsed_input"]["user_query"] == "sync test"
|
||||
|
||||
state_legacy = get_initial_state()
|
||||
assert "messages" in state_legacy
|
||||
|
||||
Reference in New Issue
Block a user