* Modernize research graph metadata for LangGraph v1 * Update src/biz_bud/core/langgraph/graph_builder.py Co-authored-by: qodo-merge-pro[bot] <151058649+qodo-merge-pro[bot]@users.noreply.github.com> --------- Co-authored-by: qodo-merge-pro[bot] <151058649+qodo-merge-pro[bot]@users.noreply.github.com>
181 lines
6.7 KiB
Python
181 lines
6.7 KiB
Python
"""Custom assertion helpers used across the Business Buddy test-suite.
|
|
|
|
These helpers intentionally mirror the semantics of the original code base so
|
|
that higher level tests can focus on behaviour rather than hand-crafting
|
|
boilerplate assertions. The goal is not to be exhaustive, but to provide the
|
|
handful of lightweight checks that the modernised LangGraph tests rely on.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Iterable, Mapping, Sequence
|
|
|
|
from langchain_core.messages import BaseMessage
|
|
|
|
|
|
def _normalise_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
|
|
"""Return a list copy of ``messages`` ensuring each entry is a message."""
|
|
|
|
normalised: list[BaseMessage] = []
|
|
for message in messages:
|
|
if not isinstance(message, BaseMessage):
|
|
raise AssertionError(f"Expected BaseMessage instance, got {type(message)!r}")
|
|
normalised.append(message)
|
|
return normalised
|
|
|
|
|
|
def assert_message_types(
|
|
messages: Sequence[BaseMessage], expected_types: Sequence[type[BaseMessage]]
|
|
) -> None:
|
|
"""Assert that messages are emitted in the expected order and types."""
|
|
|
|
normalised = _normalise_messages(messages)
|
|
if len(normalised) != len(expected_types):
|
|
raise AssertionError(
|
|
f"Expected {len(expected_types)} messages, received {len(normalised)}"
|
|
)
|
|
|
|
for index, (message, expected_type) in enumerate(zip(normalised, expected_types)):
|
|
if not isinstance(message, expected_type):
|
|
raise AssertionError(
|
|
f"Message at position {index} expected {expected_type.__name__}, "
|
|
f"received {type(message).__name__}"
|
|
)
|
|
|
|
|
|
def assert_state_has_messages(
|
|
state: Mapping[str, object], *, min_count: int = 1
|
|
) -> None:
|
|
"""Ensure a workflow state has at least ``min_count`` messages."""
|
|
|
|
messages = state.get("messages") if isinstance(state, Mapping) else None
|
|
if not isinstance(messages, Sequence):
|
|
raise AssertionError("State does not contain a messages sequence")
|
|
if len(messages) < min_count:
|
|
raise AssertionError(
|
|
f"Expected at least {min_count} messages, received {len(messages)}"
|
|
)
|
|
_normalise_messages(messages)
|
|
|
|
|
|
def assert_state_has_no_errors(state: Mapping[str, object]) -> None:
|
|
"""Assert that the workflow state does not contain any recorded errors."""
|
|
|
|
errors = state.get("errors") if isinstance(state, Mapping) else None
|
|
if errors in (None, []):
|
|
return
|
|
if isinstance(errors, Sequence) and len(errors) == 0:
|
|
return
|
|
raise AssertionError(f"Expected state to have no errors, found: {errors}")
|
|
|
|
|
|
def assert_state_has_errors(
|
|
state: Mapping[str, object], *, min_errors: int = 1, phases: Iterable[str] | None = None
|
|
) -> None:
|
|
"""Assert that errors exist on the workflow state and optionally check phases."""
|
|
|
|
errors = state.get("errors") if isinstance(state, Mapping) else None
|
|
if not isinstance(errors, Sequence) or len(errors) < min_errors:
|
|
raise AssertionError(
|
|
f"Expected at least {min_errors} errors, received {0 if errors is None else len(errors)}"
|
|
)
|
|
if phases:
|
|
phases = list(phases)
|
|
found = set()
|
|
for error in errors:
|
|
phase = None
|
|
if isinstance(error, Mapping):
|
|
phase = error.get("phase")
|
|
else:
|
|
phase = getattr(error, "phase", None)
|
|
if phase in phases:
|
|
found.add(phase)
|
|
missing = [phase for phase in phases if phase not in found]
|
|
if missing:
|
|
raise AssertionError(f"Expected errors for phases {missing!r} but they were not present")
|
|
|
|
|
|
def assert_metadata_contains(state: Mapping[str, object], keys: Iterable[str]) -> None:
|
|
"""Assert that metadata contains all required ``keys``."""
|
|
|
|
metadata = state.get("metadata") if isinstance(state, Mapping) else None
|
|
if not isinstance(metadata, Mapping):
|
|
raise AssertionError("State does not include metadata mapping")
|
|
|
|
missing = [key for key in keys if key not in metadata]
|
|
if missing:
|
|
raise AssertionError(f"Metadata missing required keys: {missing}")
|
|
|
|
|
|
def assert_search_results_valid(
|
|
results: Sequence[Mapping[str, object]], *, min_results: int = 1
|
|
) -> None:
|
|
"""Validate search results share a consistent minimal structure."""
|
|
|
|
if len(results) < min_results:
|
|
raise AssertionError(
|
|
f"Expected at least {min_results} search results, received {len(results)}"
|
|
)
|
|
|
|
required_keys = {"title", "url"}
|
|
for index, result in enumerate(results):
|
|
if not isinstance(result, Mapping):
|
|
raise AssertionError(f"Result at index {index} is not a mapping: {result!r}")
|
|
missing = required_keys.difference(result.keys())
|
|
if missing:
|
|
raise AssertionError(
|
|
f"Result at index {index} missing keys {sorted(missing)}: {result!r}"
|
|
)
|
|
|
|
|
|
def assert_synthesis_quality(
|
|
synthesis: str,
|
|
*,
|
|
min_length: int = 0,
|
|
max_length: int | None = None,
|
|
required_phrases: Iterable[str] | None = None,
|
|
) -> None:
|
|
"""Ensure synthesis text falls within expected bounds and mentions key phrases."""
|
|
|
|
if len(synthesis) < min_length:
|
|
raise AssertionError(
|
|
f"Synthesis too short; expected >= {min_length} characters, got {len(synthesis)}"
|
|
)
|
|
if max_length is not None and len(synthesis) > max_length:
|
|
raise AssertionError(
|
|
f"Synthesis too long; expected <= {max_length} characters, got {len(synthesis)}"
|
|
)
|
|
if required_phrases:
|
|
missing = [phrase for phrase in required_phrases if phrase not in synthesis]
|
|
if missing:
|
|
raise AssertionError(
|
|
f"Synthesis missing required phrases: {missing}. Synthesis: {synthesis!r}"
|
|
)
|
|
|
|
|
|
def assert_workflow_status(state: Mapping[str, object], expected_status: str) -> None:
|
|
"""Assert that the workflow status matches ``expected_status``."""
|
|
|
|
status = state.get("workflow_status") if isinstance(state, Mapping) else None
|
|
if status != expected_status:
|
|
raise AssertionError(
|
|
f"Expected workflow status '{expected_status}', received '{status}'"
|
|
)
|
|
|
|
|
|
def assert_valid_response(response: Mapping[str, object]) -> None:
|
|
"""Backward compatible helper kept for older tests."""
|
|
|
|
if not isinstance(response, Mapping):
|
|
raise AssertionError("Response should be a mapping")
|
|
if not {"status", "success"}.intersection(response.keys()):
|
|
raise AssertionError("Response missing status indicator")
|
|
|
|
|
|
def assert_contains_keys(data: Mapping[str, object], keys: Iterable[str]) -> None:
|
|
"""Assert that data contains all specified keys."""
|
|
|
|
missing = [key for key in keys if key not in data]
|
|
if missing:
|
|
raise AssertionError(f"Missing keys: {missing}")
|