Files
biz-bud/tests/helpers/assertions/custom_assertions.py
Travis Vasceannie 8ad47a7640 Modernize research graph metadata for LangGraph v1 (#60)
* Modernize research graph metadata for LangGraph v1

* Update src/biz_bud/core/langgraph/graph_builder.py

Co-authored-by: qodo-merge-pro[bot] <151058649+qodo-merge-pro[bot]@users.noreply.github.com>

---------

Co-authored-by: qodo-merge-pro[bot] <151058649+qodo-merge-pro[bot]@users.noreply.github.com>
2025-09-19 03:01:18 -04:00

181 lines
6.7 KiB
Python

"""Custom assertion helpers used across the Business Buddy test-suite.
These helpers intentionally mirror the semantics of the original code base so
that higher level tests can focus on behaviour rather than hand-crafting
boilerplate assertions. The goal is not to be exhaustive, but to provide the
handful of lightweight checks that the modernised LangGraph tests rely on.
"""
from __future__ import annotations
from typing import Iterable, Mapping, Sequence
from langchain_core.messages import BaseMessage
def _normalise_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
"""Return a list copy of ``messages`` ensuring each entry is a message."""
normalised: list[BaseMessage] = []
for message in messages:
if not isinstance(message, BaseMessage):
raise AssertionError(f"Expected BaseMessage instance, got {type(message)!r}")
normalised.append(message)
return normalised
def assert_message_types(
messages: Sequence[BaseMessage], expected_types: Sequence[type[BaseMessage]]
) -> None:
"""Assert that messages are emitted in the expected order and types."""
normalised = _normalise_messages(messages)
if len(normalised) != len(expected_types):
raise AssertionError(
f"Expected {len(expected_types)} messages, received {len(normalised)}"
)
for index, (message, expected_type) in enumerate(zip(normalised, expected_types)):
if not isinstance(message, expected_type):
raise AssertionError(
f"Message at position {index} expected {expected_type.__name__}, "
f"received {type(message).__name__}"
)
def assert_state_has_messages(
state: Mapping[str, object], *, min_count: int = 1
) -> None:
"""Ensure a workflow state has at least ``min_count`` messages."""
messages = state.get("messages") if isinstance(state, Mapping) else None
if not isinstance(messages, Sequence):
raise AssertionError("State does not contain a messages sequence")
if len(messages) < min_count:
raise AssertionError(
f"Expected at least {min_count} messages, received {len(messages)}"
)
_normalise_messages(messages)
def assert_state_has_no_errors(state: Mapping[str, object]) -> None:
"""Assert that the workflow state does not contain any recorded errors."""
errors = state.get("errors") if isinstance(state, Mapping) else None
if errors in (None, []):
return
if isinstance(errors, Sequence) and len(errors) == 0:
return
raise AssertionError(f"Expected state to have no errors, found: {errors}")
def assert_state_has_errors(
state: Mapping[str, object], *, min_errors: int = 1, phases: Iterable[str] | None = None
) -> None:
"""Assert that errors exist on the workflow state and optionally check phases."""
errors = state.get("errors") if isinstance(state, Mapping) else None
if not isinstance(errors, Sequence) or len(errors) < min_errors:
raise AssertionError(
f"Expected at least {min_errors} errors, received {0 if errors is None else len(errors)}"
)
if phases:
phases = list(phases)
found = set()
for error in errors:
phase = None
if isinstance(error, Mapping):
phase = error.get("phase")
else:
phase = getattr(error, "phase", None)
if phase in phases:
found.add(phase)
missing = [phase for phase in phases if phase not in found]
if missing:
raise AssertionError(f"Expected errors for phases {missing!r} but they were not present")
def assert_metadata_contains(state: Mapping[str, object], keys: Iterable[str]) -> None:
"""Assert that metadata contains all required ``keys``."""
metadata = state.get("metadata") if isinstance(state, Mapping) else None
if not isinstance(metadata, Mapping):
raise AssertionError("State does not include metadata mapping")
missing = [key for key in keys if key not in metadata]
if missing:
raise AssertionError(f"Metadata missing required keys: {missing}")
def assert_search_results_valid(
results: Sequence[Mapping[str, object]], *, min_results: int = 1
) -> None:
"""Validate search results share a consistent minimal structure."""
if len(results) < min_results:
raise AssertionError(
f"Expected at least {min_results} search results, received {len(results)}"
)
required_keys = {"title", "url"}
for index, result in enumerate(results):
if not isinstance(result, Mapping):
raise AssertionError(f"Result at index {index} is not a mapping: {result!r}")
missing = required_keys.difference(result.keys())
if missing:
raise AssertionError(
f"Result at index {index} missing keys {sorted(missing)}: {result!r}"
)
def assert_synthesis_quality(
synthesis: str,
*,
min_length: int = 0,
max_length: int | None = None,
required_phrases: Iterable[str] | None = None,
) -> None:
"""Ensure synthesis text falls within expected bounds and mentions key phrases."""
if len(synthesis) < min_length:
raise AssertionError(
f"Synthesis too short; expected >= {min_length} characters, got {len(synthesis)}"
)
if max_length is not None and len(synthesis) > max_length:
raise AssertionError(
f"Synthesis too long; expected <= {max_length} characters, got {len(synthesis)}"
)
if required_phrases:
missing = [phrase for phrase in required_phrases if phrase not in synthesis]
if missing:
raise AssertionError(
f"Synthesis missing required phrases: {missing}. Synthesis: {synthesis!r}"
)
def assert_workflow_status(state: Mapping[str, object], expected_status: str) -> None:
"""Assert that the workflow status matches ``expected_status``."""
status = state.get("workflow_status") if isinstance(state, Mapping) else None
if status != expected_status:
raise AssertionError(
f"Expected workflow status '{expected_status}', received '{status}'"
)
def assert_valid_response(response: Mapping[str, object]) -> None:
"""Backward compatible helper kept for older tests."""
if not isinstance(response, Mapping):
raise AssertionError("Response should be a mapping")
if not {"status", "success"}.intersection(response.keys()):
raise AssertionError("Response missing status indicator")
def assert_contains_keys(data: Mapping[str, object], keys: Iterable[str]) -> None:
"""Assert that data contains all specified keys."""
missing = [key for key in keys if key not in data]
if missing:
raise AssertionError(f"Missing keys: {missing}")