289 lines
12 KiB
Python
289 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Final
|
|
|
|
import pytest
|
|
|
|
from noteflow.infrastructure.ai.nodes.web_search import (
|
|
DEFAULT_MAX_RESULTS,
|
|
DEFAULT_TIMEOUT_SECONDS,
|
|
DisabledWebSearchProvider,
|
|
WebSearchConfig,
|
|
WebSearchResponse,
|
|
WebSearchResult,
|
|
derive_search_query,
|
|
execute_web_search,
|
|
format_results_for_context,
|
|
merge_contexts,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
pass
|
|
|
|
SAMPLE_TITLE: Final[str] = "Example Result"
|
|
SAMPLE_URL: Final[str] = "https://example.com/page"
|
|
SAMPLE_SNIPPET: Final[str] = "This is a sample search result snippet."
|
|
SAMPLE_QUERY: Final[str] = "What is AI?"
|
|
SAMPLE_SCORE: Final[float] = 0.9
|
|
DEFAULT_SCORE: Final[float] = 1.0
|
|
EXPECTED_MAX_RESULTS: Final[int] = 5
|
|
EXPECTED_TIMEOUT: Final[float] = 10.0
|
|
ZERO_RESULTS: Final[int] = 0
|
|
ONE_RESULT: Final[int] = 1
|
|
ZERO_SEARCH_TIME: Final[float] = 0.0
|
|
SAMPLE_SEARCH_TIME: Final[float] = 100.0
|
|
MAX_QUERY_LENGTH: Final[int] = 256
|
|
TRANSCRIPT_CONTEXT: Final[str] = "Meeting transcript content here"
|
|
|
|
|
|
class TestWebSearchResult:
|
|
def test_result_stores_title(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
assert result.title == SAMPLE_TITLE, "should store title"
|
|
|
|
def test_result_stores_url(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
assert result.url == SAMPLE_URL, "should store url"
|
|
|
|
def test_result_stores_snippet(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
assert result.snippet == SAMPLE_SNIPPET, "should store snippet"
|
|
|
|
def test_result_has_default_score(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
assert result.score == DEFAULT_SCORE, "should have default score"
|
|
|
|
def test_result_accepts_custom_score(self) -> None:
|
|
result = WebSearchResult(
|
|
title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET, score=SAMPLE_SCORE
|
|
)
|
|
assert result.score == SAMPLE_SCORE, "should accept custom score"
|
|
|
|
def test_to_result_payload_has_title(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
payload = result.to_result_payload()
|
|
assert payload["title"] == SAMPLE_TITLE, "payload should have title"
|
|
|
|
def test_to_result_payload_has_url(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
payload = result.to_result_payload()
|
|
assert payload["url"] == SAMPLE_URL, "payload should have url"
|
|
|
|
def test_to_result_payload_has_snippet(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
payload = result.to_result_payload()
|
|
assert payload["snippet"] == SAMPLE_SNIPPET, "payload should have snippet"
|
|
|
|
def test_to_result_payload_has_score(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
payload = result.to_result_payload()
|
|
assert payload["score"] == DEFAULT_SCORE, "payload should have score"
|
|
|
|
def test_web_search_result_is_frozen(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
with pytest.raises(AttributeError, match="cannot assign"):
|
|
result.title = "New Title"
|
|
|
|
|
|
class TestWebSearchResponse:
|
|
def test_response_stores_query(self) -> None:
|
|
response = WebSearchResponse(
|
|
query=SAMPLE_QUERY,
|
|
results=(),
|
|
total_results=ZERO_RESULTS,
|
|
search_time_ms=ZERO_SEARCH_TIME,
|
|
)
|
|
assert response.query == SAMPLE_QUERY, "should store query"
|
|
|
|
def test_response_stores_results(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
response = WebSearchResponse(
|
|
query=SAMPLE_QUERY,
|
|
results=(result,),
|
|
total_results=ONE_RESULT,
|
|
search_time_ms=SAMPLE_SEARCH_TIME,
|
|
)
|
|
assert len(response.results) == ONE_RESULT, "should store results"
|
|
|
|
def test_response_stores_total_results(self) -> None:
|
|
response = WebSearchResponse(
|
|
query=SAMPLE_QUERY,
|
|
results=(),
|
|
total_results=ZERO_RESULTS,
|
|
search_time_ms=ZERO_SEARCH_TIME,
|
|
)
|
|
assert response.total_results == ZERO_RESULTS, "should store total_results"
|
|
|
|
def test_response_stores_search_time_ms(self) -> None:
|
|
response = WebSearchResponse(
|
|
query=SAMPLE_QUERY,
|
|
results=(),
|
|
total_results=ZERO_RESULTS,
|
|
search_time_ms=SAMPLE_SEARCH_TIME,
|
|
)
|
|
assert response.search_time_ms == SAMPLE_SEARCH_TIME, "should store search_time_ms"
|
|
|
|
def test_has_results_true_with_results(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
response = WebSearchResponse(
|
|
query=SAMPLE_QUERY,
|
|
results=(result,),
|
|
total_results=ONE_RESULT,
|
|
search_time_ms=SAMPLE_SEARCH_TIME,
|
|
)
|
|
assert response.has_results is True, "should have results"
|
|
|
|
def test_has_results_false_without_results(self) -> None:
|
|
response = WebSearchResponse(
|
|
query=SAMPLE_QUERY,
|
|
results=(),
|
|
total_results=ZERO_RESULTS,
|
|
search_time_ms=ZERO_SEARCH_TIME,
|
|
)
|
|
assert response.has_results is False, "should not have results"
|
|
|
|
|
|
class TestWebSearchConfig:
|
|
def test_default_enabled_is_false(self) -> None:
|
|
config = WebSearchConfig()
|
|
assert config.enabled is False, "default enabled should be False"
|
|
|
|
def test_default_max_results(self) -> None:
|
|
config = WebSearchConfig()
|
|
assert config.max_results == DEFAULT_MAX_RESULTS, "default max_results"
|
|
|
|
def test_default_timeout_seconds(self) -> None:
|
|
config = WebSearchConfig()
|
|
assert config.timeout_seconds == DEFAULT_TIMEOUT_SECONDS, "default timeout"
|
|
|
|
def test_default_require_approval(self) -> None:
|
|
config = WebSearchConfig()
|
|
assert config.require_approval is True, "default require_approval"
|
|
|
|
def test_custom_enabled(self) -> None:
|
|
config = WebSearchConfig(enabled=True)
|
|
assert config.enabled is True, "should accept enabled"
|
|
|
|
def test_custom_max_results(self) -> None:
|
|
config = WebSearchConfig(max_results=10)
|
|
assert config.max_results == 10, "should accept max_results"
|
|
|
|
|
|
class TestDisabledWebSearchProvider:
|
|
@pytest.mark.asyncio
|
|
async def test_search_returns_empty_response(self) -> None:
|
|
provider = DisabledWebSearchProvider()
|
|
response = await provider.search(SAMPLE_QUERY)
|
|
assert response.total_results == ZERO_RESULTS, "should return empty"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_returns_query_in_response(self) -> None:
|
|
provider = DisabledWebSearchProvider()
|
|
response = await provider.search(SAMPLE_QUERY)
|
|
assert response.query == SAMPLE_QUERY, "should include query"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_returns_empty_results_tuple(self) -> None:
|
|
provider = DisabledWebSearchProvider()
|
|
response = await provider.search(SAMPLE_QUERY)
|
|
assert response.results == (), "should have empty results"
|
|
|
|
|
|
class TestFormatResultsForContext:
|
|
def test_empty_results_returns_empty_string(self) -> None:
|
|
formatted = format_results_for_context(())
|
|
assert formatted == "", "empty results should return empty string"
|
|
|
|
def test_formats_single_result_with_header(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
formatted = format_results_for_context((result,))
|
|
assert "Web Search Results" in formatted, "should have header"
|
|
|
|
def test_formats_single_result_with_title(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
formatted = format_results_for_context((result,))
|
|
assert SAMPLE_TITLE in formatted, "should include title"
|
|
|
|
def test_formats_single_result_with_url(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
formatted = format_results_for_context((result,))
|
|
assert SAMPLE_URL in formatted, "should include url"
|
|
|
|
def test_formats_single_result_with_snippet(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
formatted = format_results_for_context((result,))
|
|
assert SAMPLE_SNIPPET in formatted, "should include snippet"
|
|
|
|
|
|
class TestExecuteWebSearch:
|
|
@pytest.mark.asyncio
|
|
async def test_returns_empty_when_disabled(self) -> None:
|
|
provider = DisabledWebSearchProvider()
|
|
config = WebSearchConfig(enabled=False)
|
|
response = await execute_web_search(SAMPLE_QUERY, provider, config)
|
|
assert response.total_results == ZERO_RESULTS, "should return empty when disabled"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_calls_provider_when_enabled(self) -> None:
|
|
provider = DisabledWebSearchProvider()
|
|
config = WebSearchConfig(enabled=True)
|
|
response = await execute_web_search(SAMPLE_QUERY, provider, config)
|
|
assert response.query == SAMPLE_QUERY, "should call provider"
|
|
|
|
|
|
class TestMergeContexts:
|
|
def test_returns_transcript_when_no_web_results(self) -> None:
|
|
empty_response = WebSearchResponse(
|
|
query=SAMPLE_QUERY,
|
|
results=(),
|
|
total_results=ZERO_RESULTS,
|
|
search_time_ms=ZERO_SEARCH_TIME,
|
|
)
|
|
merged = merge_contexts(TRANSCRIPT_CONTEXT, empty_response)
|
|
assert merged == TRANSCRIPT_CONTEXT, "should return transcript only"
|
|
|
|
def test_includes_transcript_with_web_results(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
response = WebSearchResponse(
|
|
query=SAMPLE_QUERY,
|
|
results=(result,),
|
|
total_results=ONE_RESULT,
|
|
search_time_ms=SAMPLE_SEARCH_TIME,
|
|
)
|
|
merged = merge_contexts(TRANSCRIPT_CONTEXT, response)
|
|
assert TRANSCRIPT_CONTEXT in merged, "should include transcript"
|
|
|
|
def test_includes_web_results_when_present(self) -> None:
|
|
result = WebSearchResult(title=SAMPLE_TITLE, url=SAMPLE_URL, snippet=SAMPLE_SNIPPET)
|
|
response = WebSearchResponse(
|
|
query=SAMPLE_QUERY,
|
|
results=(result,),
|
|
total_results=ONE_RESULT,
|
|
search_time_ms=SAMPLE_SEARCH_TIME,
|
|
)
|
|
merged = merge_contexts(TRANSCRIPT_CONTEXT, response)
|
|
assert SAMPLE_TITLE in merged, "should include web results"
|
|
|
|
|
|
class TestDeriveSearchQuery:
|
|
def test_returns_question_stripped(self) -> None:
|
|
query = derive_search_query(" What is AI? ")
|
|
assert query == "What is AI?", "should strip whitespace"
|
|
|
|
def test_appends_meeting_context(self) -> None:
|
|
query = derive_search_query("What is AI?", meeting_context="AI discussion")
|
|
assert "AI discussion" in query, "should append context"
|
|
|
|
def test_truncates_long_query(self) -> None:
|
|
long_question = "x" * 300
|
|
query = derive_search_query(long_question)
|
|
assert len(query) <= MAX_QUERY_LENGTH, "should truncate long query"
|
|
|
|
|
|
class TestWebSearchConstants:
|
|
def test_default_max_results_value(self) -> None:
|
|
assert DEFAULT_MAX_RESULTS == EXPECTED_MAX_RESULTS, "constant should match"
|
|
|
|
def test_default_timeout_seconds_value(self) -> None:
|
|
assert DEFAULT_TIMEOUT_SECONDS == EXPECTED_TIMEOUT, "constant should match"
|