recovery
This commit is contained in:
@@ -0,0 +1,590 @@
|
|||||||
|
# Sprint 25: LangGraph Foundation
|
||||||
|
|
||||||
|
> **Size**: L | **Owner**: Backend | **Phase**: 5 - Platform Evolution
|
||||||
|
> **Effort**: ~1 sprint | **Prerequisites**: Sprint 19 (Embeddings)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Objective
|
||||||
|
|
||||||
|
Establish LangGraph infrastructure and wrap existing summarization as proof of pattern.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Current State Analysis
|
||||||
|
|
||||||
|
### What Exists
|
||||||
|
|
||||||
|
| Component | Location | Status |
|
||||||
|
|-----------|----------|--------|
|
||||||
|
| Summarization service | `application/services/summarization/` | ✅ Working |
|
||||||
|
| Segment semantic search | `infrastructure/persistence/repositories/segment_repo.py` | ✅ Working |
|
||||||
|
| Cloud consent pattern | `application/services/summarization/_consent_manager.py` | ✅ Working |
|
||||||
|
| Usage event tracking | `application/observability/ports.py` | ✅ Working |
|
||||||
|
| gRPC mixin pattern | `grpc/_mixins/` | ✅ Working |
|
||||||
|
|
||||||
|
### What's Missing
|
||||||
|
|
||||||
|
| Component | Target Location | Sprint Task |
|
||||||
|
|-----------|-----------------|-------------|
|
||||||
|
| LangGraph dependencies | `pyproject.toml` | Task 1 |
|
||||||
|
| State schemas | `domain/ai/state.py` | Task 2 |
|
||||||
|
| Checkpointer factory | `infrastructure/ai/checkpointer.py` | Task 3 |
|
||||||
|
| Retrieval tools | `infrastructure/ai/tools/retrieval.py` | Task 4 |
|
||||||
|
| Synthesis tools | `infrastructure/ai/tools/synthesis.py` | Task 5 |
|
||||||
|
| Summarization graph | `infrastructure/ai/graphs/summarization.py` | Task 6 |
|
||||||
|
| AssistantService | `application/services/assistant/` | Task 7 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation Tasks
|
||||||
|
|
||||||
|
### Task 1: Add LangGraph Dependencies
|
||||||
|
|
||||||
|
**File**: `pyproject.toml`
|
||||||
|
|
||||||
|
Add new optional dependency group:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
langgraph = [
|
||||||
|
"langgraph>=0.2",
|
||||||
|
"langgraph-checkpoint-postgres>=2.0",
|
||||||
|
"langchain-core>=0.3",
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Also add to `optional` and `all` groups.
|
||||||
|
|
||||||
|
**Verification**: `uv pip install -e ".[langgraph]"` succeeds
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Create State Schemas
|
||||||
|
|
||||||
|
**Files to create**:
|
||||||
|
- `src/noteflow/domain/ai/__init__.py`
|
||||||
|
- `src/noteflow/domain/ai/state.py`
|
||||||
|
- `src/noteflow/domain/ai/citations.py`
|
||||||
|
- `src/noteflow/domain/ai/ports.py`
|
||||||
|
|
||||||
|
#### `state.py` - Input/Output/Internal Separation
|
||||||
|
|
||||||
|
```python
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated, TypedDict
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import operator
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantInputState(TypedDict):
|
||||||
|
"""Public API input - what clients send."""
|
||||||
|
question: str
|
||||||
|
meeting_id: UUID | None
|
||||||
|
thread_id: str | None
|
||||||
|
allow_web: bool
|
||||||
|
top_k: int
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantOutputState(TypedDict):
|
||||||
|
"""Public API output - what clients receive."""
|
||||||
|
answer: str
|
||||||
|
citations: list[dict]
|
||||||
|
suggested_annotations: list[dict]
|
||||||
|
thread_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantInternalState(AssistantInputState):
|
||||||
|
"""Internal graph state - can evolve without breaking API."""
|
||||||
|
# Retrieval
|
||||||
|
retrieved_segment_ids: Annotated[list[int], operator.add]
|
||||||
|
retrieved_segments: list[dict]
|
||||||
|
|
||||||
|
# Synthesis
|
||||||
|
draft_answer: str
|
||||||
|
verification_passed: bool
|
||||||
|
|
||||||
|
# Tracking
|
||||||
|
loop_count: int
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `citations.py` - Value Object
|
||||||
|
|
||||||
|
```python
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SegmentCitation:
|
||||||
|
"""Reference to transcript segment used as evidence."""
|
||||||
|
meeting_id: UUID
|
||||||
|
segment_id: int
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
text: str
|
||||||
|
score: float = 0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `ports.py` - Protocol Definition
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Protocol
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from noteflow.domain.ai.state import AssistantOutputState
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantPort(Protocol):
|
||||||
|
"""Protocol for AI assistant operations."""
|
||||||
|
|
||||||
|
async def ask(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
meeting_id: UUID | None = None,
|
||||||
|
thread_id: str | None = None,
|
||||||
|
allow_web: bool = False,
|
||||||
|
top_k: int = 8,
|
||||||
|
) -> AssistantOutputState:
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
**Verification**: `basedpyright src/noteflow/domain/ai/` passes
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: Create Checkpointer Factory
|
||||||
|
|
||||||
|
**File**: `src/noteflow/infrastructure/ai/checkpointer.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Final
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
CHECKPOINTER_POOL_SIZE: Final[int] = 5
|
||||||
|
|
||||||
|
|
||||||
|
async def create_checkpointer(
|
||||||
|
database_url: str,
|
||||||
|
pool_size: int = CHECKPOINTER_POOL_SIZE,
|
||||||
|
) -> AsyncPostgresSaver:
|
||||||
|
"""Create async Postgres checkpointer for LangGraph state persistence."""
|
||||||
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
pool = AsyncConnectionPool(
|
||||||
|
conninfo=database_url,
|
||||||
|
max_size=pool_size,
|
||||||
|
kwargs={"autocommit": True},
|
||||||
|
)
|
||||||
|
checkpointer = AsyncPostgresSaver(pool)
|
||||||
|
await checkpointer.setup()
|
||||||
|
return checkpointer
|
||||||
|
```
|
||||||
|
|
||||||
|
**Verification**: Unit test with mock pool
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: Create Retrieval Tools
|
||||||
|
|
||||||
|
**File**: `src/noteflow/infrastructure/ai/tools/retrieval.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from noteflow.domain.entities import Segment
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedderProtocol(Protocol):
|
||||||
|
"""Protocol for text embedding."""
|
||||||
|
async def embed(self, text: str) -> list[float]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentSearchProtocol(Protocol):
|
||||||
|
"""Protocol for semantic segment search."""
|
||||||
|
async def search_semantic(
|
||||||
|
self,
|
||||||
|
query_embedding: list[float],
|
||||||
|
meeting_id: UUID | None,
|
||||||
|
limit: int,
|
||||||
|
) -> list[tuple[Segment, float]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievalResult:
|
||||||
|
"""Result from segment retrieval."""
|
||||||
|
segment_id: int
|
||||||
|
meeting_id: UUID
|
||||||
|
text: str
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
async def retrieve_segments(
|
||||||
|
query: str,
|
||||||
|
embedder: EmbedderProtocol,
|
||||||
|
segment_repo: SegmentSearchProtocol,
|
||||||
|
meeting_id: UUID | None = None,
|
||||||
|
top_k: int = 8,
|
||||||
|
) -> list[RetrievalResult]:
|
||||||
|
"""Retrieve relevant transcript segments via semantic search."""
|
||||||
|
query_embedding = await embedder.embed(query)
|
||||||
|
results = await segment_repo.search_semantic(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
meeting_id=meeting_id,
|
||||||
|
limit=top_k,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
RetrievalResult(
|
||||||
|
segment_id=segment.segment_id,
|
||||||
|
meeting_id=segment.meeting_id,
|
||||||
|
text=segment.text,
|
||||||
|
start_time=segment.start_time,
|
||||||
|
end_time=segment.end_time,
|
||||||
|
score=score,
|
||||||
|
)
|
||||||
|
for segment, score in results
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Verification**: Unit test with mock embedder and repo
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: Create Synthesis Tools
|
||||||
|
|
||||||
|
**File**: `src/noteflow/infrastructure/ai/tools/synthesis.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from noteflow.infrastructure.ai.tools.retrieval import RetrievalResult
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProtocol(Protocol):
|
||||||
|
"""Protocol for LLM completion."""
|
||||||
|
async def complete(self, prompt: str) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SynthesisResult:
|
||||||
|
"""Result from answer synthesis."""
|
||||||
|
answer: str
|
||||||
|
cited_segment_ids: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
SYNTHESIS_PROMPT_TEMPLATE = '''Answer the question based on the following transcript segments.
|
||||||
|
Cite specific segments by their ID when making claims.
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
|
||||||
|
Segments:
|
||||||
|
{segments}
|
||||||
|
|
||||||
|
Answer (cite segment IDs in brackets like [1], [3]):'''
|
||||||
|
|
||||||
|
|
||||||
|
async def synthesize_answer(
|
||||||
|
question: str,
|
||||||
|
segments: list[RetrievalResult],
|
||||||
|
llm: LLMProtocol,
|
||||||
|
) -> SynthesisResult:
|
||||||
|
"""Generate answer with segment citations."""
|
||||||
|
segment_text = "\n".join(
|
||||||
|
f"[{s.segment_id}] ({s.start_time:.1f}s-{s.end_time:.1f}s): {s.text}"
|
||||||
|
for s in segments
|
||||||
|
)
|
||||||
|
prompt = SYNTHESIS_PROMPT_TEMPLATE.format(
|
||||||
|
question=question,
|
||||||
|
segments=segment_text,
|
||||||
|
)
|
||||||
|
answer = await llm.complete(prompt)
|
||||||
|
cited_ids = _extract_cited_ids(answer, [s.segment_id for s in segments])
|
||||||
|
return SynthesisResult(answer=answer, cited_segment_ids=cited_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_cited_ids(answer: str, valid_ids: list[int]) -> list[int]:
|
||||||
|
"""Extract segment IDs cited in the answer."""
|
||||||
|
import re
|
||||||
|
pattern = r'\[(\d+)\]'
|
||||||
|
matches = re.findall(pattern, answer)
|
||||||
|
cited = [int(m) for m in matches if int(m) in valid_ids]
|
||||||
|
return list(dict.fromkeys(cited)) # Dedupe preserving order
|
||||||
|
```
|
||||||
|
|
||||||
|
**Verification**: Unit test with mock LLM
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 6: Create Summarization Graph Wrapper
|
||||||
|
|
||||||
|
**File**: `src/noteflow/infrastructure/ai/graphs/summarization.py`
|
||||||
|
|
||||||
|
This wraps the existing SummarizationService in a LangGraph graph as proof of pattern.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
|
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from uuid import UUID
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from noteflow.domain.entities import Segment
|
||||||
|
from noteflow.application.services.summarization import SummarizationService
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizationState(TypedDict):
|
||||||
|
"""State for summarization graph."""
|
||||||
|
meeting_id: UUID
|
||||||
|
segments: Sequence[Segment]
|
||||||
|
summary_text: str
|
||||||
|
key_points: list[dict]
|
||||||
|
action_items: list[dict]
|
||||||
|
|
||||||
|
|
||||||
|
def build_summarization_graph(
|
||||||
|
summarization_service: SummarizationService,
|
||||||
|
) -> StateGraph:
|
||||||
|
"""Build LangGraph wrapper around existing summarization service."""
|
||||||
|
|
||||||
|
async def summarize_node(state: SummarizationState) -> dict:
|
||||||
|
result = await summarization_service.summarize(
|
||||||
|
meeting_id=state["meeting_id"],
|
||||||
|
segments=state["segments"],
|
||||||
|
)
|
||||||
|
summary = result.summary
|
||||||
|
return {
|
||||||
|
"summary_text": summary.executive_summary,
|
||||||
|
"key_points": [
|
||||||
|
{"text": kp.text, "segment_ids": kp.segment_ids}
|
||||||
|
for kp in summary.key_points
|
||||||
|
],
|
||||||
|
"action_items": [
|
||||||
|
{"text": ai.text, "segment_ids": ai.segment_ids, "assignee": ai.assignee}
|
||||||
|
for ai in summary.action_items
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
builder = StateGraph(SummarizationState)
|
||||||
|
builder.add_node("summarize", summarize_node)
|
||||||
|
builder.add_edge(START, "summarize")
|
||||||
|
builder.add_edge("summarize", END)
|
||||||
|
|
||||||
|
return builder.compile()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Verification**: Integration test with mock summarization service
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 7: Create AssistantService Shell
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/noteflow/application/services/assistant/__init__.py`
|
||||||
|
- `src/noteflow/application/services/assistant/assistant_service.py`
|
||||||
|
|
||||||
|
#### `__init__.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from noteflow.application.services.assistant.assistant_service import AssistantService
|
||||||
|
|
||||||
|
__all__ = ["AssistantService"]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `assistant_service.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Final
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from noteflow.application.observability.ports import NullUsageEventSink, UsageEventSink
|
||||||
|
from noteflow.domain.ai.state import AssistantOutputState
|
||||||
|
from noteflow.infrastructure.logging import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
|
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_TOP_K: Final[int] = 8
|
||||||
|
THREAD_ID_PREFIX: Final[str] = "meeting"
|
||||||
|
|
||||||
|
|
||||||
|
def build_thread_id(meeting_id: UUID | None, user_id: UUID, graph_name: str) -> str:
|
||||||
|
"""Build deterministic thread_id for checkpointing."""
|
||||||
|
meeting_part = str(meeting_id) if meeting_id else "workspace"
|
||||||
|
return f"{THREAD_ID_PREFIX}:{meeting_part}:user:{user_id}:graph:{graph_name}:v1"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssistantService:
|
||||||
|
"""Orchestrates AI assistant workflows via LangGraph."""
|
||||||
|
|
||||||
|
uow_factory: Callable[[], UnitOfWork]
|
||||||
|
usage_events: UsageEventSink = field(default_factory=NullUsageEventSink)
|
||||||
|
|
||||||
|
async def ask(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
user_id: UUID,
|
||||||
|
meeting_id: UUID | None = None,
|
||||||
|
thread_id: str | None = None,
|
||||||
|
allow_web: bool = False,
|
||||||
|
top_k: int = DEFAULT_TOP_K,
|
||||||
|
) -> AssistantOutputState:
|
||||||
|
"""Ask a question about meeting transcript(s)."""
|
||||||
|
effective_thread_id = thread_id or build_thread_id(
|
||||||
|
meeting_id, user_id, "meeting_qa"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"assistant_ask",
|
||||||
|
question_length=len(question),
|
||||||
|
meeting_id=str(meeting_id) if meeting_id else None,
|
||||||
|
thread_id=effective_thread_id,
|
||||||
|
)
|
||||||
|
# TODO: Implement in Sprint 26
|
||||||
|
return AssistantOutputState(
|
||||||
|
answer="Not implemented yet",
|
||||||
|
citations=[],
|
||||||
|
suggested_annotations=[],
|
||||||
|
thread_id=effective_thread_id,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Verification**: Unit test for thread_id generation
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Test Plan
|
||||||
|
|
||||||
|
### Unit Tests (`tests/domain/ai/`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# test_citations.py
|
||||||
|
def test_segment_citation_creation():
|
||||||
|
citation = SegmentCitation(
|
||||||
|
meeting_id=uuid4(),
|
||||||
|
segment_id=1,
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=5.0,
|
||||||
|
text="Test segment",
|
||||||
|
score=0.95,
|
||||||
|
)
|
||||||
|
assert citation.duration == 5.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_segment_citation_invalid_times():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
SegmentCitation(
|
||||||
|
meeting_id=uuid4(),
|
||||||
|
segment_id=1,
|
||||||
|
start_time=10.0,
|
||||||
|
end_time=5.0, # Invalid: end < start
|
||||||
|
text="Test",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Unit Tests (`tests/infrastructure/ai/`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# test_retrieval.py
|
||||||
|
async def test_retrieve_segments_success(mock_embedder, mock_segment_repo):
|
||||||
|
mock_embedder.embed.return_value = [0.1, 0.2, 0.3]
|
||||||
|
mock_segment_repo.search_semantic.return_value = [
|
||||||
|
(sample_segment, 0.95),
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await retrieve_segments(
|
||||||
|
query="test query",
|
||||||
|
embedder=mock_embedder,
|
||||||
|
segment_repo=mock_segment_repo,
|
||||||
|
top_k=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].score == 0.95
|
||||||
|
|
||||||
|
|
||||||
|
# test_synthesis.py
|
||||||
|
async def test_synthesize_answer_extracts_citations(mock_llm):
|
||||||
|
mock_llm.complete.return_value = "The answer is X [1] and Y [3]."
|
||||||
|
|
||||||
|
result = await synthesize_answer(
|
||||||
|
question="What happened?",
|
||||||
|
segments=[...],
|
||||||
|
llm=mock_llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.cited_segment_ids == [1, 3]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Acceptance Criteria
|
||||||
|
|
||||||
|
- [ ] `uv pip install -e ".[langgraph]"` succeeds
|
||||||
|
- [ ] `basedpyright src/noteflow/domain/ai/` passes with 0 errors
|
||||||
|
- [ ] `basedpyright src/noteflow/infrastructure/ai/` passes with 0 errors
|
||||||
|
- [ ] `pytest tests/domain/ai/` passes
|
||||||
|
- [ ] `pytest tests/infrastructure/ai/` passes
|
||||||
|
- [ ] Existing summarization behavior unchanged (`pytest tests/application/services/test_summarization*`)
|
||||||
|
- [ ] `make quality` passes
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Rollback Plan
|
||||||
|
|
||||||
|
If issues arise:
|
||||||
|
1. Remove langgraph from dependencies
|
||||||
|
2. Delete `domain/ai/` and `infrastructure/ai/` directories
|
||||||
|
3. AssistantService is not wired to gRPC yet, so no API impact
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Files Created/Modified
|
||||||
|
|
||||||
|
| Action | Path |
|
||||||
|
|--------|------|
|
||||||
|
| Modified | `pyproject.toml` |
|
||||||
|
| Created | `src/noteflow/domain/ai/__init__.py` |
|
||||||
|
| Created | `src/noteflow/domain/ai/state.py` |
|
||||||
|
| Created | `src/noteflow/domain/ai/citations.py` |
|
||||||
|
| Created | `src/noteflow/domain/ai/ports.py` |
|
||||||
|
| Created | `src/noteflow/infrastructure/ai/__init__.py` |
|
||||||
|
| Created | `src/noteflow/infrastructure/ai/checkpointer.py` |
|
||||||
|
| Created | `src/noteflow/infrastructure/ai/tools/__init__.py` |
|
||||||
|
| Created | `src/noteflow/infrastructure/ai/tools/retrieval.py` |
|
||||||
|
| Created | `src/noteflow/infrastructure/ai/tools/synthesis.py` |
|
||||||
|
| Created | `src/noteflow/infrastructure/ai/graphs/__init__.py` |
|
||||||
|
| Created | `src/noteflow/infrastructure/ai/graphs/summarization.py` |
|
||||||
|
| Created | `src/noteflow/application/services/assistant/__init__.py` |
|
||||||
|
| Created | `src/noteflow/application/services/assistant/assistant_service.py` |
|
||||||
|
| Created | `tests/domain/ai/test_citations.py` |
|
||||||
|
| Created | `tests/domain/ai/test_state.py` |
|
||||||
|
| Created | `tests/infrastructure/ai/test_retrieval.py` |
|
||||||
|
| Created | `tests/infrastructure/ai/test_synthesis.py` |
|
||||||
|
| Created | `tests/infrastructure/ai/test_checkpointer.py` |
|
||||||
@@ -0,0 +1,530 @@
|
|||||||
|
# Sprint 26: Meeting Q&A MVP
|
||||||
|
|
||||||
|
> **Size**: L | **Owner**: Backend + Client | **Phase**: 5 - Platform Evolution
|
||||||
|
> **Effort**: ~1 sprint | **Prerequisites**: Sprint 25 (Foundation)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Objective
|
||||||
|
|
||||||
|
Implement single-meeting Q&A with segment citations via gRPC API and React UI.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Current State (After Sprint 25)
|
||||||
|
|
||||||
|
| Component | Status |
|
||||||
|
|-----------|--------|
|
||||||
|
| LangGraph infrastructure | ✅ Ready |
|
||||||
|
| State schemas | ✅ Ready |
|
||||||
|
| Retrieval tools | ✅ Ready |
|
||||||
|
| Synthesis tools | ✅ Ready |
|
||||||
|
| AssistantService shell | ✅ Ready |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation Tasks
|
||||||
|
|
||||||
|
### Task 1: Define MeetingQA Graph
|
||||||
|
|
||||||
|
**File**: `src/noteflow/infrastructure/ai/graphs/meeting_qa.py`
|
||||||
|
|
||||||
|
Graph flow: `retrieve → verify → synthesize`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langgraph.graph import StateGraph, START, END
|
||||||
|
|
||||||
|
class MeetingQAState(TypedDict):
|
||||||
|
# Input
|
||||||
|
question: str
|
||||||
|
meeting_id: UUID
|
||||||
|
top_k: int
|
||||||
|
|
||||||
|
# Internal
|
||||||
|
retrieved_segments: list[RetrievalResult]
|
||||||
|
verification_passed: bool
|
||||||
|
|
||||||
|
# Output
|
||||||
|
answer: str
|
||||||
|
citations: list[SegmentCitation]
|
||||||
|
|
||||||
|
|
||||||
|
def build_meeting_qa_graph(
|
||||||
|
embedder: EmbedderProtocol,
|
||||||
|
segment_repo: SegmentSearchProtocol,
|
||||||
|
llm: LLMProtocol,
|
||||||
|
verifier: CitationVerifier,
|
||||||
|
) -> StateGraph:
|
||||||
|
|
||||||
|
async def retrieve_node(state: MeetingQAState) -> dict:
|
||||||
|
results = await retrieve_segments(
|
||||||
|
query=state["question"],
|
||||||
|
embedder=embedder,
|
||||||
|
segment_repo=segment_repo,
|
||||||
|
meeting_id=state["meeting_id"],
|
||||||
|
top_k=state["top_k"],
|
||||||
|
)
|
||||||
|
return {"retrieved_segments": results}
|
||||||
|
|
||||||
|
async def verify_node(state: MeetingQAState) -> dict:
|
||||||
|
# Verify segments exist and are relevant
|
||||||
|
valid = len(state["retrieved_segments"]) > 0
|
||||||
|
return {"verification_passed": valid}
|
||||||
|
|
||||||
|
async def synthesize_node(state: MeetingQAState) -> dict:
|
||||||
|
if not state["verification_passed"]:
|
||||||
|
return {
|
||||||
|
"answer": "I couldn't find relevant information in this meeting.",
|
||||||
|
"citations": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await synthesize_answer(
|
||||||
|
question=state["question"],
|
||||||
|
segments=state["retrieved_segments"],
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
citations = [
|
||||||
|
SegmentCitation(
|
||||||
|
meeting_id=state["meeting_id"],
|
||||||
|
segment_id=seg.segment_id,
|
||||||
|
start_time=seg.start_time,
|
||||||
|
end_time=seg.end_time,
|
||||||
|
text=seg.text,
|
||||||
|
score=seg.score,
|
||||||
|
)
|
||||||
|
for seg in state["retrieved_segments"]
|
||||||
|
if seg.segment_id in result.cited_segment_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"answer": result.answer, "citations": citations}
|
||||||
|
|
||||||
|
builder = StateGraph(MeetingQAState)
|
||||||
|
builder.add_node("retrieve", retrieve_node)
|
||||||
|
builder.add_node("verify", verify_node)
|
||||||
|
builder.add_node("synthesize", synthesize_node)
|
||||||
|
|
||||||
|
builder.add_edge(START, "retrieve")
|
||||||
|
builder.add_edge("retrieve", "verify")
|
||||||
|
builder.add_edge("verify", "synthesize")
|
||||||
|
builder.add_edge("synthesize", END)
|
||||||
|
|
||||||
|
return builder.compile()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Create Citation Verifier Node
|
||||||
|
|
||||||
|
**File**: `src/noteflow/infrastructure/ai/nodes/verification.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VerificationResult:
|
||||||
|
is_valid: bool
|
||||||
|
invalid_citation_indices: list[int]
|
||||||
|
reason: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def verify_citations(
|
||||||
|
answer: str,
|
||||||
|
cited_ids: list[int],
|
||||||
|
available_ids: set[int],
|
||||||
|
) -> VerificationResult:
|
||||||
|
"""Verify all cited segment IDs exist in available segments."""
|
||||||
|
invalid = [i for i, cid in enumerate(cited_ids) if cid not in available_ids]
|
||||||
|
return VerificationResult(
|
||||||
|
is_valid=len(invalid) == 0,
|
||||||
|
invalid_citation_indices=invalid,
|
||||||
|
reason=f"Invalid citations: {invalid}" if invalid else None,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: Add Proto Messages
|
||||||
|
|
||||||
|
**File**: `src/noteflow/grpc/proto/noteflow.proto`
|
||||||
|
|
||||||
|
```protobuf
|
||||||
|
// Add to existing proto
|
||||||
|
|
||||||
|
message SegmentCitation {
|
||||||
|
string meeting_id = 1;
|
||||||
|
int32 segment_id = 2;
|
||||||
|
float start_time = 3;
|
||||||
|
float end_time = 4;
|
||||||
|
string text = 5;
|
||||||
|
float score = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AskAssistantRequest {
|
||||||
|
string question = 1;
|
||||||
|
optional string meeting_id = 2;
|
||||||
|
optional string thread_id = 3;
|
||||||
|
bool allow_web = 4;
|
||||||
|
int32 top_k = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AskAssistantResponse {
|
||||||
|
string answer = 1;
|
||||||
|
repeated SegmentCitation citations = 2;
|
||||||
|
repeated SuggestedAnnotation suggested_annotations = 3;
|
||||||
|
string thread_id = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SuggestedAnnotation {
|
||||||
|
string text = 1;
|
||||||
|
AnnotationType type = 2;
|
||||||
|
repeated int32 segment_ids = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to NoteFlowService
|
||||||
|
rpc AskAssistant(AskAssistantRequest) returns (AskAssistantResponse);
|
||||||
|
```
|
||||||
|
|
||||||
|
After modifying proto:
|
||||||
|
```bash
|
||||||
|
python -m grpc_tools.protoc -I src/noteflow/grpc/proto \
|
||||||
|
--python_out=src/noteflow/grpc/proto \
|
||||||
|
--grpc_python_out=src/noteflow/grpc/proto \
|
||||||
|
src/noteflow/grpc/proto/noteflow.proto
|
||||||
|
|
||||||
|
python scripts/patch_grpc_stubs.py
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: Add gRPC Mixin
|
||||||
|
|
||||||
|
**File**: `src/noteflow/grpc/_mixins/assistant.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from noteflow.grpc.proto import noteflow_pb2 as pb
|
||||||
|
from noteflow.grpc._mixins.protocols import ServicerHost
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from grpc.aio import ServicerContext
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantMixin:
|
||||||
|
"""gRPC mixin for AI assistant operations."""
|
||||||
|
|
||||||
|
async def AskAssistant(
|
||||||
|
self: ServicerHost,
|
||||||
|
request: pb.AskAssistantRequest,
|
||||||
|
context: ServicerContext,
|
||||||
|
) -> pb.AskAssistantResponse:
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
meeting_id = UUID(request.meeting_id) if request.meeting_id else None
|
||||||
|
op_context = await self.get_operation_context(context)
|
||||||
|
|
||||||
|
result = await self.assistant_service.ask(
|
||||||
|
question=request.question,
|
||||||
|
user_id=op_context.user.id,
|
||||||
|
meeting_id=meeting_id,
|
||||||
|
thread_id=request.thread_id or None,
|
||||||
|
allow_web=request.allow_web,
|
||||||
|
top_k=request.top_k or 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pb.AskAssistantResponse(
|
||||||
|
answer=result["answer"],
|
||||||
|
citations=[
|
||||||
|
pb.SegmentCitation(
|
||||||
|
meeting_id=str(c["meeting_id"]),
|
||||||
|
segment_id=c["segment_id"],
|
||||||
|
start_time=c["start_time"],
|
||||||
|
end_time=c["end_time"],
|
||||||
|
text=c["text"],
|
||||||
|
score=c["score"],
|
||||||
|
)
|
||||||
|
for c in result["citations"]
|
||||||
|
],
|
||||||
|
thread_id=result["thread_id"],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: Add Rust Command
|
||||||
|
|
||||||
|
**File**: `client/src-tauri/src/commands/assistant.rs`
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use crate::grpc::client::GrpcClient;
|
||||||
|
use crate::grpc::types::assistant::{AskAssistantRequest, AskAssistantResponse};
|
||||||
|
use tauri::State;
|
||||||
|
|
||||||
|
#[tauri::command]
|
||||||
|
pub async fn ask_assistant(
|
||||||
|
client: State<'_, GrpcClient>,
|
||||||
|
question: String,
|
||||||
|
meeting_id: Option<String>,
|
||||||
|
thread_id: Option<String>,
|
||||||
|
allow_web: bool,
|
||||||
|
top_k: i32,
|
||||||
|
) -> Result<AskAssistantResponse, String> {
|
||||||
|
let request = AskAssistantRequest {
|
||||||
|
question,
|
||||||
|
meeting_id,
|
||||||
|
thread_id,
|
||||||
|
allow_web,
|
||||||
|
top_k,
|
||||||
|
};
|
||||||
|
|
||||||
|
client
|
||||||
|
.ask_assistant(request)
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.to_string())
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 6: Add TypeScript Adapter
|
||||||
|
|
||||||
|
**File**: `client/src/api/tauri-adapter.ts` (add method)
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
async askAssistant(params: AskAssistantParams): Promise<AskAssistantResponse> {
|
||||||
|
return invoke('ask_assistant', {
|
||||||
|
question: params.question,
|
||||||
|
meetingId: params.meetingId,
|
||||||
|
threadId: params.threadId,
|
||||||
|
allowWeb: params.allowWeb ?? false,
|
||||||
|
topK: params.topK ?? 8,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**File**: `client/src/api/types/assistant.ts`
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export interface SegmentCitation {
|
||||||
|
meetingId: string;
|
||||||
|
segmentId: number;
|
||||||
|
startTime: number;
|
||||||
|
endTime: number;
|
||||||
|
text: string;
|
||||||
|
score: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AskAssistantParams {
|
||||||
|
question: string;
|
||||||
|
meetingId?: string;
|
||||||
|
threadId?: string;
|
||||||
|
allowWeb?: boolean;
|
||||||
|
topK?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AskAssistantResponse {
|
||||||
|
answer: string;
|
||||||
|
citations: SegmentCitation[];
|
||||||
|
suggestedAnnotations: SuggestedAnnotation[];
|
||||||
|
threadId: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SuggestedAnnotation {
|
||||||
|
text: string;
|
||||||
|
type: AnnotationType;
|
||||||
|
segmentIds: number[];
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 7: Create Ask UI Component
|
||||||
|
|
||||||
|
**File**: `client/src/components/meeting/AskPanel.tsx`
|
||||||
|
|
||||||
|
```tsx
|
||||||
|
import { useState } from 'react';
|
||||||
|
import { useAssistant } from '@/hooks/use-assistant';
|
||||||
|
import { Button } from '@/components/ui/button';
|
||||||
|
import { Textarea } from '@/components/ui/textarea';
|
||||||
|
import { Card } from '@/components/ui/card';
|
||||||
|
|
||||||
|
interface AskPanelProps {
|
||||||
|
meetingId: string;
|
||||||
|
onCitationClick?: (segmentId: number) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function AskPanel({ meetingId, onCitationClick }: AskPanelProps) {
|
||||||
|
const [question, setQuestion] = useState('');
|
||||||
|
const { ask, isLoading, response, error } = useAssistant();
|
||||||
|
|
||||||
|
const handleAsk = async () => {
|
||||||
|
if (!question.trim()) return;
|
||||||
|
await ask({ question, meetingId });
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-4 p-4">
|
||||||
|
<Textarea
|
||||||
|
placeholder="Ask a question about this meeting..."
|
||||||
|
value={question}
|
||||||
|
onChange={(e) => setQuestion(e.target.value)}
|
||||||
|
disabled={isLoading}
|
||||||
|
/>
|
||||||
|
<Button onClick={handleAsk} disabled={isLoading || !question.trim()}>
|
||||||
|
{isLoading ? 'Thinking...' : 'Ask'}
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
{response && (
|
||||||
|
<Card className="p-4">
|
||||||
|
<p className="whitespace-pre-wrap">{response.answer}</p>
|
||||||
|
{response.citations.length > 0 && (
|
||||||
|
<div className="mt-4 border-t pt-2">
|
||||||
|
<p className="text-sm text-muted-foreground mb-2">Sources:</p>
|
||||||
|
{response.citations.map((citation) => (
|
||||||
|
<button
|
||||||
|
key={citation.segmentId}
|
||||||
|
onClick={() => onCitationClick?.(citation.segmentId)}
|
||||||
|
className="text-sm text-blue-600 hover:underline block"
|
||||||
|
>
|
||||||
|
[{citation.startTime.toFixed(1)}s] {citation.text.slice(0, 50)}...
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{error && (
|
||||||
|
<p className="text-sm text-red-600">{error}</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**File**: `client/src/hooks/use-assistant.ts`
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { useState, useCallback } from 'react';
|
||||||
|
import { api } from '@/api';
|
||||||
|
import type { AskAssistantParams, AskAssistantResponse } from '@/api/types/assistant';
|
||||||
|
|
||||||
|
export function useAssistant() {
|
||||||
|
const [isLoading, setIsLoading] = useState(false);
|
||||||
|
const [response, setResponse] = useState<AskAssistantResponse | null>(null);
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
|
||||||
|
const ask = useCallback(async (params: AskAssistantParams) => {
|
||||||
|
setIsLoading(true);
|
||||||
|
setError(null);
|
||||||
|
try {
|
||||||
|
const result = await api.askAssistant(params);
|
||||||
|
setResponse(result);
|
||||||
|
return result;
|
||||||
|
} catch (err) {
|
||||||
|
setError(err instanceof Error ? err.message : 'Failed to get answer');
|
||||||
|
throw err;
|
||||||
|
} finally {
|
||||||
|
setIsLoading(false);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const reset = useCallback(() => {
|
||||||
|
setResponse(null);
|
||||||
|
setError(null);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return { ask, isLoading, response, error, reset };
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 8: Implement AssistantService.ask()
|
||||||
|
|
||||||
|
Complete the implementation in `application/services/assistant/assistant_service.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def ask(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
user_id: UUID,
|
||||||
|
meeting_id: UUID | None = None,
|
||||||
|
thread_id: str | None = None,
|
||||||
|
allow_web: bool = False,
|
||||||
|
top_k: int = DEFAULT_TOP_K,
|
||||||
|
) -> AssistantOutputState:
|
||||||
|
"""Ask a question about meeting transcript(s)."""
|
||||||
|
effective_thread_id = thread_id or build_thread_id(
|
||||||
|
meeting_id, user_id, "meeting_qa"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self.uow_factory() as uow:
|
||||||
|
# Build and run graph
|
||||||
|
graph = build_meeting_qa_graph(
|
||||||
|
embedder=self._embedder,
|
||||||
|
segment_repo=uow.segments,
|
||||||
|
llm=self._llm,
|
||||||
|
verifier=self._verifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = {"configurable": {"thread_id": effective_thread_id}}
|
||||||
|
result = await graph.ainvoke(
|
||||||
|
{
|
||||||
|
"question": question,
|
||||||
|
"meeting_id": meeting_id,
|
||||||
|
"top_k": top_k,
|
||||||
|
},
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record usage
|
||||||
|
self.usage_events.record_simple(
|
||||||
|
"assistant.ask",
|
||||||
|
meeting_id=str(meeting_id) if meeting_id else None,
|
||||||
|
question_length=len(question),
|
||||||
|
citation_count=len(result.get("citations", [])),
|
||||||
|
)
|
||||||
|
|
||||||
|
return AssistantOutputState(
|
||||||
|
answer=result["answer"],
|
||||||
|
citations=[asdict(c) for c in result["citations"]],
|
||||||
|
suggested_annotations=[],
|
||||||
|
thread_id=effective_thread_id,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Acceptance Criteria
|
||||||
|
|
||||||
|
- [ ] Q&A returns answers with valid segment citations
|
||||||
|
- [ ] Citations link to correct timestamps in transcript
|
||||||
|
- [ ] Feature hidden when `rag_enabled=false` in project rules
|
||||||
|
- [ ] Thread ID persists conversation context
|
||||||
|
- [ ] `make quality` passes
|
||||||
|
- [ ] `pytest tests/grpc/test_assistant.py` passes
|
||||||
|
- [ ] UI component displays answer and clickable citations
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Files Created/Modified
|
||||||
|
|
||||||
|
| Action | Path |
|
||||||
|
|--------|------|
|
||||||
|
| Created | `src/noteflow/infrastructure/ai/graphs/meeting_qa.py` |
|
||||||
|
| Created | `src/noteflow/infrastructure/ai/nodes/verification.py` |
|
||||||
|
| Modified | `src/noteflow/grpc/proto/noteflow.proto` |
|
||||||
|
| Created | `src/noteflow/grpc/_mixins/assistant.py` |
|
||||||
|
| Modified | `src/noteflow/grpc/service.py` (add mixin) |
|
||||||
|
| Created | `client/src-tauri/src/commands/assistant.rs` |
|
||||||
|
| Modified | `client/src/api/tauri-adapter.ts` |
|
||||||
|
| Created | `client/src/api/types/assistant.ts` |
|
||||||
|
| Created | `client/src/components/meeting/AskPanel.tsx` |
|
||||||
|
| Created | `client/src/hooks/use-assistant.ts` |
|
||||||
|
| Modified | `src/noteflow/application/services/assistant/assistant_service.py` |
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
# Sprint 27: Cross-Meeting RAG
|
||||||
|
|
||||||
|
> **Size**: M | **Owner**: Backend + Client | **Phase**: 5 - Platform Evolution
|
||||||
|
> **Effort**: ~1 sprint | **Prerequisites**: Sprint 26 (Meeting Q&A MVP)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Objective
|
||||||
|
|
||||||
|
Enable workspace-scoped Q&A and annotation suggestions across multiple meetings.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Tasks
|
||||||
|
|
||||||
|
### Task 1: Workspace-Scoped Semantic Search
|
||||||
|
|
||||||
|
Extend `SegmentRepository` with workspace-scoped search:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def search_semantic_workspace(
|
||||||
|
self,
|
||||||
|
query_embedding: list[float],
|
||||||
|
workspace_id: UUID,
|
||||||
|
project_id: UUID | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> list[tuple[Segment, float]]:
|
||||||
|
"""Search segments across all meetings in workspace/project."""
|
||||||
|
```
|
||||||
|
|
||||||
|
### Task 2: WorkspaceQA Graph
|
||||||
|
|
||||||
|
Create `infrastructure/ai/graphs/workspace_qa.py`:
|
||||||
|
|
||||||
|
- Similar to MeetingQA but omits `meeting_id` filter
|
||||||
|
- Groups results by meeting for citation display
|
||||||
|
- Returns cross-meeting citations
|
||||||
|
|
||||||
|
### Task 3: Annotation Suggester
|
||||||
|
|
||||||
|
Add annotation suggestion output to graph:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SuggestedAnnotation:
|
||||||
|
text: str
|
||||||
|
annotation_type: AnnotationType
|
||||||
|
segment_ids: list[int]
|
||||||
|
confidence: float
|
||||||
|
```
|
||||||
|
|
||||||
|
### Task 4: Conversation History
|
||||||
|
|
||||||
|
Implement thread persistence with checkpointer:
|
||||||
|
|
||||||
|
- Store conversation turns in graph state
|
||||||
|
- Support follow-up questions
|
||||||
|
- Maintain context across requests
|
||||||
|
|
||||||
|
### Task 5: Apply Annotation Flow
|
||||||
|
|
||||||
|
UI flow to apply suggested annotations:
|
||||||
|
|
||||||
|
1. Display suggested annotations in AskPanel
|
||||||
|
2. User clicks "Apply" on suggestion
|
||||||
|
3. Call existing `AddAnnotation` RPC
|
||||||
|
4. Update UI to show applied status
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Acceptance Criteria
|
||||||
|
|
||||||
|
- [ ] Cross-meeting queries return results from multiple meetings
|
||||||
|
- [ ] Suggested annotations can be applied with one click
|
||||||
|
- [ ] Conversation history persists across requests
|
||||||
|
- [ ] Follow-up questions reference previous context
|
||||||
|
- [ ] `make quality` passes
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Files Created/Modified
|
||||||
|
|
||||||
|
| Action | Path |
|
||||||
|
|--------|------|
|
||||||
|
| Modified | `src/noteflow/infrastructure/persistence/repositories/segment_repo.py` |
|
||||||
|
| Created | `src/noteflow/infrastructure/ai/graphs/workspace_qa.py` |
|
||||||
|
| Modified | `src/noteflow/application/services/assistant/assistant_service.py` |
|
||||||
|
| Modified | `client/src/components/meeting/AskPanel.tsx` |
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
# Sprint 28: Advanced Capabilities
|
||||||
|
|
||||||
|
> **Size**: L | **Owner**: Backend + Client | **Phase**: 5 - Platform Evolution
|
||||||
|
> **Effort**: ~1 sprint | **Prerequisites**: Sprint 27 (Cross-Meeting RAG)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Objective
|
||||||
|
|
||||||
|
Production hardening with streaming responses, caching, guardrails, and optional web search.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Tasks
|
||||||
|
|
||||||
|
### Task 1: Streaming Responses
|
||||||
|
|
||||||
|
Implement `StreamAssistant` RPC for progressive answer generation:
|
||||||
|
|
||||||
|
```protobuf
|
||||||
|
rpc StreamAssistant(AskAssistantRequest) returns (stream AskAssistantResponse);
|
||||||
|
```
|
||||||
|
|
||||||
|
Use LangGraph's `astream` with `stream_mode="messages"`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async for chunk in graph.astream(state, config, stream_mode="messages"):
|
||||||
|
yield pb.AskAssistantResponse(answer=chunk.content, partial=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Task 2: Embedding Cache
|
||||||
|
|
||||||
|
Create `infrastructure/ai/cache.py`:
|
||||||
|
|
||||||
|
- LRU cache for query embeddings
|
||||||
|
- Reduce redundant embedding calls
|
||||||
|
- Configurable TTL and max size
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingCache:
|
||||||
|
max_size: int = 1000
|
||||||
|
ttl_seconds: int = 3600
|
||||||
|
|
||||||
|
async def get_or_compute(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
embedder: EmbedderProtocol,
|
||||||
|
) -> list[float]:
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Task 3: Content Guardrails
|
||||||
|
|
||||||
|
Create `infrastructure/ai/guardrails.py`:
|
||||||
|
|
||||||
|
- Input validation (length, content)
|
||||||
|
- Output filtering (PII, harmful content)
|
||||||
|
- Configurable rules per workspace
|
||||||
|
|
||||||
|
```python
|
||||||
|
class GuardrailResult:
|
||||||
|
allowed: bool
|
||||||
|
reason: str | None
|
||||||
|
filtered_content: str | None
|
||||||
|
|
||||||
|
|
||||||
|
async def check_input(text: str, rules: GuardrailRules) -> GuardrailResult:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def filter_output(text: str, rules: GuardrailRules) -> GuardrailResult:
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Task 4: Web Search Node
|
||||||
|
|
||||||
|
Create `infrastructure/ai/nodes/web_search.py`:
|
||||||
|
|
||||||
|
- Optional node triggered by `allow_web=true`
|
||||||
|
- Integrate with web search API
|
||||||
|
- Merge web results with transcript evidence
|
||||||
|
|
||||||
|
### Task 5: AGENT_PROGRESS Tauri Event
|
||||||
|
|
||||||
|
Emit progress events for UI feedback:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Emit during graph execution
|
||||||
|
app.emit_all("AGENT_PROGRESS", AgentProgressPayload {
|
||||||
|
stage: "retrieving",
|
||||||
|
progress: 0.3,
|
||||||
|
message: "Searching transcript...",
|
||||||
|
})?;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Task 6: Interrupts for Approval
|
||||||
|
|
||||||
|
Implement LangGraph interrupts for:
|
||||||
|
|
||||||
|
- Web search approval
|
||||||
|
- Annotation creation approval
|
||||||
|
- Sensitive action confirmation
|
||||||
|
|
||||||
|
### Task 7: Performance Optimization
|
||||||
|
|
||||||
|
- Batch embedding requests
|
||||||
|
- Parallel segment retrieval
|
||||||
|
- Connection pool tuning
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Acceptance Criteria
|
||||||
|
|
||||||
|
- [ ] Streaming shows progressive answer generation in UI
|
||||||
|
- [ ] Cache reduces embedding latency by >50% for repeated queries
|
||||||
|
- [ ] Guardrails block inappropriate content
|
||||||
|
- [ ] Web search gated by `allow_web` flag
|
||||||
|
- [ ] Progress events update UI during long operations
|
||||||
|
- [ ] `make quality` passes
|
||||||
|
- [ ] E2E tests pass (`client/e2e/assistant.spec.ts`)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
| Metric | Target |
|
||||||
|
| ------------------ | ------ |
|
||||||
|
| Q&A latency (p95) | < 3s |
|
||||||
|
| Citation accuracy | > 90% |
|
||||||
|
| Cache hit rate | > 60% |
|
||||||
|
| Hallucination rate | < 5% |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Files Created/Modified
|
||||||
|
|
||||||
|
| Action | Path |
|
||||||
|
| -------- | ---------------------------------------------------- |
|
||||||
|
| Modified | `src/noteflow/grpc/proto/noteflow.proto` |
|
||||||
|
| Created | `src/noteflow/grpc/_mixins/assistant_streaming.py` |
|
||||||
|
| Created | `src/noteflow/infrastructure/ai/cache.py` |
|
||||||
|
| Created | `src/noteflow/infrastructure/ai/guardrails.py` |
|
||||||
|
| Created | `src/noteflow/infrastructure/ai/nodes/web_search.py` |
|
||||||
|
| Modified | `client/src-tauri/src/commands/assistant.rs` |
|
||||||
|
| Modified | `client/src/components/meeting/AskPanel.tsx` |
|
||||||
|
| Created | `client/e2e/assistant.spec.ts` |
|
||||||
38
src/noteflow/domain/ai/__init__.py
Normal file
38
src/noteflow/domain/ai/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""AI domain types for LangGraph workflows.
|
||||||
|
|
||||||
|
State schemas, citations, interrupts, and protocols for AI assistant functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from noteflow.domain.ai.citations import SegmentCitation
|
||||||
|
from noteflow.domain.ai.interrupts import (
|
||||||
|
InterruptAction,
|
||||||
|
InterruptConfig,
|
||||||
|
InterruptRequest,
|
||||||
|
InterruptResponse,
|
||||||
|
InterruptType,
|
||||||
|
create_annotation_interrupt,
|
||||||
|
create_sensitive_action_interrupt,
|
||||||
|
create_web_search_interrupt,
|
||||||
|
)
|
||||||
|
from noteflow.domain.ai.ports import AssistantPort
|
||||||
|
from noteflow.domain.ai.state import (
|
||||||
|
AssistantInputState,
|
||||||
|
AssistantInternalState,
|
||||||
|
AssistantOutputState,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AssistantInputState",
|
||||||
|
"AssistantInternalState",
|
||||||
|
"AssistantOutputState",
|
||||||
|
"AssistantPort",
|
||||||
|
"InterruptAction",
|
||||||
|
"InterruptConfig",
|
||||||
|
"InterruptRequest",
|
||||||
|
"InterruptResponse",
|
||||||
|
"InterruptType",
|
||||||
|
"SegmentCitation",
|
||||||
|
"create_annotation_interrupt",
|
||||||
|
"create_sensitive_action_interrupt",
|
||||||
|
"create_web_search_interrupt",
|
||||||
|
]
|
||||||
43
src/noteflow/domain/ai/citations.py
Normal file
43
src/noteflow/domain/ai/citations.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""Citation value objects for AI-generated responses."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SegmentCitation:
|
||||||
|
"""Reference to a transcript segment used as evidence.
|
||||||
|
|
||||||
|
Links AI-generated claims to source transcript segments for verification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
meeting_id: UUID
|
||||||
|
segment_id: int
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
text: str
|
||||||
|
score: float = 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration(self) -> float:
|
||||||
|
"""Duration of the cited segment in seconds."""
|
||||||
|
return self.end_time - self.start_time
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.segment_id < 0:
|
||||||
|
msg = "segment_id must be non-negative"
|
||||||
|
raise ValueError(msg)
|
||||||
|
if self.start_time < 0:
|
||||||
|
msg = "start_time must be non-negative"
|
||||||
|
raise ValueError(msg)
|
||||||
|
if self.end_time < self.start_time:
|
||||||
|
msg = "end_time must be >= start_time"
|
||||||
|
raise ValueError(msg)
|
||||||
|
if self.score < 0 or self.score > 1:
|
||||||
|
msg = "score must be between 0 and 1"
|
||||||
|
raise ValueError(msg)
|
||||||
202
src/noteflow/domain/ai/interrupts.py
Normal file
202
src/noteflow/domain/ai/interrupts.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
"""Domain types for LangGraph human-in-the-loop interrupts.
|
||||||
|
|
||||||
|
Defines interrupt request/response types for approval workflows:
|
||||||
|
- Web search approval
|
||||||
|
- Annotation creation approval
|
||||||
|
- Sensitive action confirmation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptType(StrEnum):
|
||||||
|
"""Types of human-in-the-loop interrupts."""
|
||||||
|
|
||||||
|
WEB_SEARCH_APPROVAL = "web_search_approval"
|
||||||
|
ANNOTATION_APPROVAL = "annotation_approval"
|
||||||
|
SENSITIVE_ACTION = "sensitive_action"
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptAction(StrEnum):
|
||||||
|
"""Possible user actions for an interrupt."""
|
||||||
|
|
||||||
|
APPROVE = "approve"
|
||||||
|
REJECT = "reject"
|
||||||
|
MODIFY = "modify"
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_WEB_SEARCH_OPTIONS: Final[tuple[str, ...]] = ("approve", "reject")
|
||||||
|
DEFAULT_ANNOTATION_OPTIONS: Final[tuple[str, ...]] = ("approve", "reject", "modify")
|
||||||
|
DEFAULT_SENSITIVE_OPTIONS: Final[tuple[str, ...]] = ("approve", "reject")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class InterruptConfig:
|
||||||
|
"""Configuration for interrupt behavior."""
|
||||||
|
|
||||||
|
allow_ignore: bool = False
|
||||||
|
allow_modify: bool = False
|
||||||
|
timeout_seconds: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class InterruptRequest:
|
||||||
|
"""A request for human approval during graph execution.
|
||||||
|
|
||||||
|
Sent to the client when the graph hits an interrupt point.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
interrupt_type: Category of interrupt (web_search, annotation, etc.)
|
||||||
|
message: Human-readable description of what needs approval.
|
||||||
|
context: Additional context data for the decision (query, entities, etc.)
|
||||||
|
options: Available response options.
|
||||||
|
config: Interrupt configuration (allow_ignore, timeout, etc.)
|
||||||
|
request_id: Unique identifier for this interrupt request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
interrupt_type: InterruptType
|
||||||
|
message: str
|
||||||
|
context: dict[str, object] = field(default_factory=dict)
|
||||||
|
options: tuple[str, ...] = field(default_factory=lambda: ("approve", "reject"))
|
||||||
|
config: InterruptConfig = field(default_factory=InterruptConfig)
|
||||||
|
request_id: str = ""
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, object]:
|
||||||
|
"""Convert to dictionary for serialization."""
|
||||||
|
return {
|
||||||
|
"interrupt_type": self.interrupt_type,
|
||||||
|
"message": self.message,
|
||||||
|
"context": self.context,
|
||||||
|
"options": list(self.options),
|
||||||
|
"config": {
|
||||||
|
"allow_ignore": self.config.allow_ignore,
|
||||||
|
"allow_modify": self.config.allow_modify,
|
||||||
|
"timeout_seconds": self.config.timeout_seconds,
|
||||||
|
},
|
||||||
|
"request_id": self.request_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class InterruptResponse:
|
||||||
|
"""User's response to an interrupt request.
|
||||||
|
|
||||||
|
Returned from the client to resume graph execution.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
action: The action taken (approve, reject, modify).
|
||||||
|
request_id: ID of the interrupt request being responded to.
|
||||||
|
modified_value: If action is MODIFY, the modified value.
|
||||||
|
user_message: Optional message from the user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
action: InterruptAction
|
||||||
|
request_id: str = ""
|
||||||
|
modified_value: dict[str, object] | None = None
|
||||||
|
user_message: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_approved(self) -> bool:
|
||||||
|
"""Check if the action was approved."""
|
||||||
|
return self.action == InterruptAction.APPROVE
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_rejected(self) -> bool:
|
||||||
|
"""Check if the action was rejected."""
|
||||||
|
return self.action == InterruptAction.REJECT
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_modified(self) -> bool:
|
||||||
|
"""Check if the action was modified."""
|
||||||
|
return self.action == InterruptAction.MODIFY
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, object]:
|
||||||
|
"""Convert to dictionary for serialization."""
|
||||||
|
result: dict[str, object] = {
|
||||||
|
"action": self.action,
|
||||||
|
"request_id": self.request_id,
|
||||||
|
}
|
||||||
|
if self.modified_value is not None:
|
||||||
|
result["modified_value"] = self.modified_value
|
||||||
|
if self.user_message is not None:
|
||||||
|
result["user_message"] = self.user_message
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def create_web_search_interrupt(
|
||||||
|
query: str,
|
||||||
|
request_id: str,
|
||||||
|
*,
|
||||||
|
allow_modify: bool = False,
|
||||||
|
) -> InterruptRequest:
|
||||||
|
"""Create an interrupt request for web search approval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query to be executed.
|
||||||
|
request_id: Unique identifier for this request.
|
||||||
|
allow_modify: Whether the user can modify the query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InterruptRequest configured for web search approval.
|
||||||
|
"""
|
||||||
|
return InterruptRequest(
|
||||||
|
interrupt_type=InterruptType.WEB_SEARCH_APPROVAL,
|
||||||
|
message=f"Allow web search for additional context? Query: {query[:100]}",
|
||||||
|
context={"query": query},
|
||||||
|
options=("approve", "reject", "modify") if allow_modify else DEFAULT_WEB_SEARCH_OPTIONS,
|
||||||
|
config=InterruptConfig(allow_modify=allow_modify),
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_annotation_interrupt(
|
||||||
|
annotations: list[dict[str, object]],
|
||||||
|
request_id: str,
|
||||||
|
) -> InterruptRequest:
|
||||||
|
"""Create an interrupt request for annotation approval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
annotations: List of suggested annotations to approve.
|
||||||
|
request_id: Unique identifier for this request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InterruptRequest configured for annotation approval.
|
||||||
|
"""
|
||||||
|
count = len(annotations)
|
||||||
|
return InterruptRequest(
|
||||||
|
interrupt_type=InterruptType.ANNOTATION_APPROVAL,
|
||||||
|
message=f"Apply {count} suggested annotation(s)?",
|
||||||
|
context={"annotations": annotations, "count": count},
|
||||||
|
options=DEFAULT_ANNOTATION_OPTIONS,
|
||||||
|
config=InterruptConfig(allow_modify=True, allow_ignore=True),
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_sensitive_action_interrupt(
|
||||||
|
action_name: str,
|
||||||
|
action_description: str,
|
||||||
|
request_id: str,
|
||||||
|
) -> InterruptRequest:
|
||||||
|
"""Create an interrupt request for sensitive action confirmation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_name: Name of the sensitive action.
|
||||||
|
action_description: Description of what the action does.
|
||||||
|
request_id: Unique identifier for this request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InterruptRequest configured for sensitive action confirmation.
|
||||||
|
"""
|
||||||
|
return InterruptRequest(
|
||||||
|
interrupt_type=InterruptType.SENSITIVE_ACTION,
|
||||||
|
message=f"Confirm action: {action_name}",
|
||||||
|
context={"action_name": action_name, "description": action_description},
|
||||||
|
options=DEFAULT_SENSITIVE_OPTIONS,
|
||||||
|
config=InterruptConfig(allow_ignore=False),
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
26
src/noteflow/domain/ai/ports.py
Normal file
26
src/noteflow/domain/ai/ports.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""Protocol definitions for AI assistant operations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from noteflow.domain.ai.state import AssistantOutputState
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantPort(Protocol):
|
||||||
|
"""Protocol for AI assistant operations."""
|
||||||
|
|
||||||
|
async def ask(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
user_id: UUID,
|
||||||
|
meeting_id: UUID | None = None,
|
||||||
|
thread_id: str | None = None,
|
||||||
|
allow_web: bool = False,
|
||||||
|
top_k: int = 8,
|
||||||
|
) -> AssistantOutputState:
|
||||||
|
"""Ask a question about meeting transcript(s)."""
|
||||||
|
...
|
||||||
39
src/noteflow/domain/ai/state.py
Normal file
39
src/noteflow/domain/ai/state.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""State schemas for LangGraph AI workflows.
|
||||||
|
|
||||||
|
Separates Input/Output (public API) from Internal (can evolve freely).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import operator
|
||||||
|
from typing import Annotated, TypedDict
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantInputState(TypedDict):
|
||||||
|
"""Public API input - what clients send."""
|
||||||
|
|
||||||
|
question: str
|
||||||
|
meeting_id: UUID | None
|
||||||
|
thread_id: str | None
|
||||||
|
allow_web: bool
|
||||||
|
top_k: int
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantOutputState(TypedDict):
|
||||||
|
"""Public API output - what clients receive."""
|
||||||
|
|
||||||
|
answer: str
|
||||||
|
citations: list[dict[str, object]]
|
||||||
|
suggested_annotations: list[dict[str, object]]
|
||||||
|
thread_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantInternalState(AssistantInputState):
|
||||||
|
"""Internal graph state - can evolve without breaking API."""
|
||||||
|
|
||||||
|
retrieved_segment_ids: Annotated[list[int], operator.add]
|
||||||
|
retrieved_segments: list[dict[str, object]]
|
||||||
|
draft_answer: str
|
||||||
|
verification_passed: bool
|
||||||
|
loop_count: int
|
||||||
45
src/noteflow/infrastructure/ai/__init__.py
Normal file
45
src/noteflow/infrastructure/ai/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""AI infrastructure components for LangGraph workflows."""
|
||||||
|
|
||||||
|
from noteflow.infrastructure.ai.cache import (
|
||||||
|
CachedEmbedder,
|
||||||
|
EmbeddingCache,
|
||||||
|
EmbeddingCacheStats,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.ai.checkpointer import create_checkpointer
|
||||||
|
from noteflow.infrastructure.ai.guardrails import (
|
||||||
|
GuardrailResult,
|
||||||
|
GuardrailRules,
|
||||||
|
GuardrailViolation,
|
||||||
|
check_input,
|
||||||
|
create_default_rules,
|
||||||
|
create_strict_rules,
|
||||||
|
filter_output,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.ai.interrupts import (
|
||||||
|
InterruptHandler,
|
||||||
|
check_annotation_approval,
|
||||||
|
check_web_search_approval,
|
||||||
|
create_resume_command,
|
||||||
|
request_annotation_approval,
|
||||||
|
request_web_search_approval,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CachedEmbedder",
|
||||||
|
"EmbeddingCache",
|
||||||
|
"EmbeddingCacheStats",
|
||||||
|
"GuardrailResult",
|
||||||
|
"GuardrailRules",
|
||||||
|
"GuardrailViolation",
|
||||||
|
"InterruptHandler",
|
||||||
|
"check_annotation_approval",
|
||||||
|
"check_input",
|
||||||
|
"check_web_search_approval",
|
||||||
|
"create_checkpointer",
|
||||||
|
"create_default_rules",
|
||||||
|
"create_resume_command",
|
||||||
|
"create_strict_rules",
|
||||||
|
"filter_output",
|
||||||
|
"request_annotation_approval",
|
||||||
|
"request_web_search_approval",
|
||||||
|
]
|
||||||
212
src/noteflow/infrastructure/ai/cache.py
Normal file
212
src/noteflow/infrastructure/ai/cache.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""Embedding cache with LRU eviction and TTL expiration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Final
|
||||||
|
|
||||||
|
from noteflow.infrastructure.logging import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from noteflow.infrastructure.ai.tools.retrieval import EmbedderProtocol
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MAX_SIZE: Final[int] = 1000
|
||||||
|
DEFAULT_TTL_SECONDS: Final[int] = 3600
|
||||||
|
HASH_ALGORITHM: Final[str] = "sha256"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CacheEntry:
|
||||||
|
"""Cached embedding with creation timestamp."""
|
||||||
|
|
||||||
|
embedding: tuple[float, ...]
|
||||||
|
created_at: float
|
||||||
|
|
||||||
|
def is_expired(self, ttl_seconds: float, current_time: float) -> bool:
|
||||||
|
"""Check if entry has expired based on TTL."""
|
||||||
|
return (current_time - self.created_at) > ttl_seconds
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingCacheStats:
|
||||||
|
"""Statistics for cache performance monitoring."""
|
||||||
|
|
||||||
|
hits: int = 0
|
||||||
|
misses: int = 0
|
||||||
|
evictions: int = 0
|
||||||
|
expirations: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hit_rate(self) -> float:
|
||||||
|
"""Calculate cache hit rate."""
|
||||||
|
total = self.hits + self.misses
|
||||||
|
return self.hits / total if total > 0 else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingCache:
|
||||||
|
"""LRU cache for text embeddings with TTL expiration and deduplication."""
|
||||||
|
|
||||||
|
max_size: int = DEFAULT_MAX_SIZE
|
||||||
|
ttl_seconds: int = DEFAULT_TTL_SECONDS
|
||||||
|
_cache: OrderedDict[str, CacheEntry] = field(default_factory=OrderedDict)
|
||||||
|
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||||
|
_stats: EmbeddingCacheStats = field(default_factory=EmbeddingCacheStats)
|
||||||
|
_in_flight: dict[str, asyncio.Future[list[float]]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def _compute_key(self, text: str) -> str:
|
||||||
|
"""Compute cache key from text using hash."""
|
||||||
|
return hashlib.new(HASH_ALGORITHM, text.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
async def get_or_compute(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
embedder: EmbedderProtocol,
|
||||||
|
) -> list[float]:
|
||||||
|
key = self._compute_key(text)
|
||||||
|
current_time = time.monotonic()
|
||||||
|
|
||||||
|
existing_future: asyncio.Future[list[float]] | None = None
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
if key in self._cache:
|
||||||
|
entry = self._cache[key]
|
||||||
|
if not entry.is_expired(self.ttl_seconds, current_time):
|
||||||
|
self._cache.move_to_end(key)
|
||||||
|
self._stats.hits += 1
|
||||||
|
logger.debug("cache_hit", key=key[:16])
|
||||||
|
return list(entry.embedding)
|
||||||
|
del self._cache[key]
|
||||||
|
self._stats.expirations += 1
|
||||||
|
logger.debug("cache_expired", key=key[:16])
|
||||||
|
|
||||||
|
if key in self._in_flight:
|
||||||
|
logger.debug("cache_in_flight_join", key=key[:16])
|
||||||
|
existing_future = self._in_flight[key]
|
||||||
|
|
||||||
|
if existing_future is not None:
|
||||||
|
return list(await existing_future)
|
||||||
|
|
||||||
|
new_future: asyncio.Future[list[float]] = asyncio.get_running_loop().create_future()
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
if key in self._in_flight:
|
||||||
|
existing_future = self._in_flight[key]
|
||||||
|
else:
|
||||||
|
self._stats.misses += 1
|
||||||
|
self._in_flight[key] = new_future
|
||||||
|
|
||||||
|
if existing_future is not None:
|
||||||
|
return list(await existing_future)
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedding = await embedder.embed(text)
|
||||||
|
except Exception:
|
||||||
|
async with self._lock:
|
||||||
|
_ = self._in_flight.pop(key, None)
|
||||||
|
new_future.set_exception(asyncio.CancelledError())
|
||||||
|
raise
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
_ = self._in_flight.pop(key, None)
|
||||||
|
|
||||||
|
while len(self._cache) >= self.max_size:
|
||||||
|
evicted_key, _ = self._cache.popitem(last=False)
|
||||||
|
self._stats.evictions += 1
|
||||||
|
logger.debug("cache_eviction", evicted_key=evicted_key[:16])
|
||||||
|
|
||||||
|
self._cache[key] = CacheEntry(
|
||||||
|
embedding=tuple(embedding),
|
||||||
|
created_at=current_time,
|
||||||
|
)
|
||||||
|
logger.debug("cache_store", key=key[:16])
|
||||||
|
|
||||||
|
new_future.set_result(embedding)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
async def get(self, text: str) -> list[float] | None:
|
||||||
|
"""Get embedding from cache without computing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to look up.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding if cached and not expired, None otherwise.
|
||||||
|
"""
|
||||||
|
key = self._compute_key(text)
|
||||||
|
current_time = time.monotonic()
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
if key in self._cache:
|
||||||
|
entry = self._cache[key]
|
||||||
|
if not entry.is_expired(self.ttl_seconds, current_time):
|
||||||
|
self._cache.move_to_end(key)
|
||||||
|
return list(entry.embedding)
|
||||||
|
else:
|
||||||
|
del self._cache[key]
|
||||||
|
self._stats.expirations += 1
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def clear(self) -> int:
|
||||||
|
"""Clear all cached entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries cleared.
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
count = len(self._cache)
|
||||||
|
self._cache.clear()
|
||||||
|
logger.info("cache_cleared", entries_cleared=count)
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def size(self) -> int:
|
||||||
|
"""Get current number of cached entries."""
|
||||||
|
async with self._lock:
|
||||||
|
return len(self._cache)
|
||||||
|
|
||||||
|
def get_stats(self) -> EmbeddingCacheStats:
|
||||||
|
"""Get cache statistics (not async - reads are atomic)."""
|
||||||
|
return EmbeddingCacheStats(
|
||||||
|
hits=self._stats.hits,
|
||||||
|
misses=self._stats.misses,
|
||||||
|
evictions=self._stats.evictions,
|
||||||
|
expirations=self._stats.expirations,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CachedEmbedder:
|
||||||
|
"""Wrapper that adds caching to any EmbedderProtocol implementation.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
base_embedder = MyEmbedder()
|
||||||
|
cached = CachedEmbedder(base_embedder, max_size=500, ttl_seconds=1800)
|
||||||
|
embedding = await cached.embed("hello world")
|
||||||
|
"""
|
||||||
|
|
||||||
|
_embedder: EmbedderProtocol
|
||||||
|
_cache: EmbeddingCache
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedder: EmbedderProtocol,
|
||||||
|
max_size: int = DEFAULT_MAX_SIZE,
|
||||||
|
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||||
|
) -> None:
|
||||||
|
self._embedder = embedder
|
||||||
|
self._cache = EmbeddingCache(max_size=max_size, ttl_seconds=ttl_seconds)
|
||||||
|
|
||||||
|
async def embed(self, text: str) -> list[float]:
|
||||||
|
"""Embed text with caching."""
|
||||||
|
return await self._cache.get_or_compute(text, self._embedder)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache(self) -> EmbeddingCache:
|
||||||
|
"""Access underlying cache for stats/management."""
|
||||||
|
return self._cache
|
||||||
56
src/noteflow/infrastructure/ai/checkpointer.py
Normal file
56
src/noteflow/infrastructure/ai/checkpointer.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""PostgreSQL checkpointer factory for LangGraph state persistence."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Final
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
CHECKPOINTER_POOL_SIZE: Final[int] = 5
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CheckpointerResult:
|
||||||
|
"""Wraps checkpointer with pool lifecycle management to prevent connection leaks."""
|
||||||
|
|
||||||
|
checkpointer: AsyncPostgresSaver
|
||||||
|
_pool: AsyncConnectionPool
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
# psycopg_pool.AsyncConnectionPool.close() exists at runtime but type stubs
|
||||||
|
# are incomplete. Use getattr to satisfy the type checker.
|
||||||
|
close_fn = getattr(self._pool, "close", None)
|
||||||
|
if close_fn is not None:
|
||||||
|
await close_fn()
|
||||||
|
|
||||||
|
async def __aenter__(self) -> AsyncPostgresSaver:
|
||||||
|
return self.checkpointer
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: object,
|
||||||
|
) -> None:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def create_checkpointer(
|
||||||
|
database_url: str,
|
||||||
|
pool_size: int = CHECKPOINTER_POOL_SIZE,
|
||||||
|
) -> CheckpointerResult:
|
||||||
|
"""Create async Postgres checkpointer for LangGraph state persistence."""
|
||||||
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
pool = AsyncConnectionPool(
|
||||||
|
conninfo=database_url,
|
||||||
|
max_size=pool_size,
|
||||||
|
kwargs={"autocommit": True},
|
||||||
|
)
|
||||||
|
checkpointer = AsyncPostgresSaver(pool)
|
||||||
|
await checkpointer.setup()
|
||||||
|
return CheckpointerResult(checkpointer=checkpointer, _pool=pool)
|
||||||
51
src/noteflow/infrastructure/ai/graphs/__init__.py
Normal file
51
src/noteflow/infrastructure/ai/graphs/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""LangGraph workflow definitions."""
|
||||||
|
|
||||||
|
from noteflow.infrastructure.ai.graphs.meeting_qa import (
|
||||||
|
MEETING_QA_GRAPH_NAME,
|
||||||
|
MEETING_QA_GRAPH_VERSION,
|
||||||
|
MeetingQAConfig,
|
||||||
|
MeetingQAInputState,
|
||||||
|
MeetingQAInternalState,
|
||||||
|
MeetingQAOutputState,
|
||||||
|
build_meeting_qa_graph,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.ai.graphs.summarization import (
|
||||||
|
SUMMARIZATION_GRAPH_NAME,
|
||||||
|
SUMMARIZATION_GRAPH_VERSION,
|
||||||
|
SummarizationInputState,
|
||||||
|
SummarizationOutputState,
|
||||||
|
SummarizationState,
|
||||||
|
build_summarization_graph,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.ai.graphs.workspace_qa import (
|
||||||
|
WORKSPACE_QA_GRAPH_NAME,
|
||||||
|
WORKSPACE_QA_GRAPH_VERSION,
|
||||||
|
WorkspaceQAConfig,
|
||||||
|
WorkspaceQAInputState,
|
||||||
|
WorkspaceQAInternalState,
|
||||||
|
WorkspaceQAOutputState,
|
||||||
|
build_workspace_qa_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MEETING_QA_GRAPH_NAME",
|
||||||
|
"MEETING_QA_GRAPH_VERSION",
|
||||||
|
"MeetingQAConfig",
|
||||||
|
"MeetingQAInputState",
|
||||||
|
"MeetingQAInternalState",
|
||||||
|
"MeetingQAOutputState",
|
||||||
|
"SUMMARIZATION_GRAPH_NAME",
|
||||||
|
"SUMMARIZATION_GRAPH_VERSION",
|
||||||
|
"SummarizationInputState",
|
||||||
|
"SummarizationOutputState",
|
||||||
|
"SummarizationState",
|
||||||
|
"WORKSPACE_QA_GRAPH_NAME",
|
||||||
|
"WORKSPACE_QA_GRAPH_VERSION",
|
||||||
|
"WorkspaceQAConfig",
|
||||||
|
"WorkspaceQAInputState",
|
||||||
|
"WorkspaceQAInternalState",
|
||||||
|
"WorkspaceQAOutputState",
|
||||||
|
"build_meeting_qa_graph",
|
||||||
|
"build_summarization_graph",
|
||||||
|
"build_workspace_qa_graph",
|
||||||
|
]
|
||||||
193
src/noteflow/infrastructure/ai/graphs/meeting_qa.py
Normal file
193
src/noteflow/infrastructure/ai/graphs/meeting_qa.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""Meeting Q&A graph for single-meeting question answering with citations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Final, TypedDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langgraph.graph import CompiledStateGraph
|
||||||
|
|
||||||
|
from noteflow.domain.ai.citations import SegmentCitation
|
||||||
|
from noteflow.domain.value_objects import MeetingId
|
||||||
|
from noteflow.infrastructure.ai.nodes.annotation_suggester import SuggestedAnnotation
|
||||||
|
from noteflow.infrastructure.ai.nodes.web_search import WebSearchProvider
|
||||||
|
from noteflow.infrastructure.ai.tools.retrieval import (
|
||||||
|
EmbedderProtocol,
|
||||||
|
RetrievalResult,
|
||||||
|
SegmentSearchProtocol,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.ai.tools.synthesis import LLMProtocol
|
||||||
|
|
||||||
|
MEETING_QA_GRAPH_NAME: Final[str] = "meeting_qa"
|
||||||
|
MEETING_QA_GRAPH_VERSION: Final[int] = 2
|
||||||
|
NO_INFORMATION_ANSWER: Final[str] = "I couldn't find relevant information in this meeting."
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MeetingQAConfig:
|
||||||
|
enable_web_search: bool = False
|
||||||
|
require_web_approval: bool = True
|
||||||
|
require_annotation_approval: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MeetingQAInputState(TypedDict):
|
||||||
|
question: str
|
||||||
|
meeting_id: MeetingId
|
||||||
|
top_k: int
|
||||||
|
|
||||||
|
|
||||||
|
class MeetingQAOutputState(TypedDict):
|
||||||
|
answer: str
|
||||||
|
citations: list[SegmentCitation]
|
||||||
|
suggested_annotations: list[SuggestedAnnotation]
|
||||||
|
|
||||||
|
|
||||||
|
class MeetingQAInternalState(MeetingQAInputState, MeetingQAOutputState):
|
||||||
|
retrieved_segments: list[RetrievalResult]
|
||||||
|
verification_passed: bool
|
||||||
|
web_search_approved: bool
|
||||||
|
web_context: str
|
||||||
|
annotations_approved: bool
|
||||||
|
|
||||||
|
|
||||||
|
def build_meeting_qa_graph(
|
||||||
|
embedder: EmbedderProtocol,
|
||||||
|
segment_repo: SegmentSearchProtocol,
|
||||||
|
llm: LLMProtocol,
|
||||||
|
*,
|
||||||
|
web_search_provider: WebSearchProvider | None = None,
|
||||||
|
config: MeetingQAConfig | None = None,
|
||||||
|
checkpointer: object | None = None,
|
||||||
|
) -> CompiledStateGraph[MeetingQAInternalState]:
|
||||||
|
"""Build a Q&A graph for single-meeting questions with segment citations.
|
||||||
|
|
||||||
|
Graph flow (with web search): retrieve -> verify -> [web_search_approval] -> [web_search] -> synthesize
|
||||||
|
Graph flow (without): retrieve -> verify -> synthesize
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedder: Protocol for generating text embeddings.
|
||||||
|
segment_repo: Protocol for semantic segment search.
|
||||||
|
llm: Protocol for LLM text completion.
|
||||||
|
web_search_provider: Optional web search provider for augmentation.
|
||||||
|
config: Graph configuration for features/interrupts.
|
||||||
|
checkpointer: Optional checkpointer for interrupt support.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compiled graph that accepts question/meeting_id and returns answer/citations.
|
||||||
|
"""
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
|
||||||
|
from noteflow.domain.ai.citations import SegmentCitation
|
||||||
|
from noteflow.infrastructure.ai.interrupts import check_web_search_approval
|
||||||
|
from noteflow.infrastructure.ai.nodes.annotation_suggester import (
|
||||||
|
extract_annotations_from_answer,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.ai.nodes.web_search import (
|
||||||
|
WebSearchConfig,
|
||||||
|
derive_search_query,
|
||||||
|
execute_web_search,
|
||||||
|
format_results_for_context,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.ai.tools.retrieval import retrieve_segments
|
||||||
|
from noteflow.infrastructure.ai.tools.synthesis import synthesize_answer
|
||||||
|
|
||||||
|
effective_config = config or MeetingQAConfig()
|
||||||
|
|
||||||
|
async def retrieve_node(state: MeetingQAInternalState) -> dict[str, object]:
|
||||||
|
results = await retrieve_segments(
|
||||||
|
query=state["question"],
|
||||||
|
embedder=embedder,
|
||||||
|
segment_repo=segment_repo,
|
||||||
|
meeting_id=state["meeting_id"],
|
||||||
|
top_k=state["top_k"],
|
||||||
|
)
|
||||||
|
return {"retrieved_segments": results}
|
||||||
|
|
||||||
|
async def verify_node(state: MeetingQAInternalState) -> dict[str, object]:
|
||||||
|
has_segments = len(state["retrieved_segments"]) > 0
|
||||||
|
return {"verification_passed": has_segments}
|
||||||
|
|
||||||
|
def web_search_approval_node(state: MeetingQAInternalState) -> dict[str, object]:
|
||||||
|
if not effective_config.enable_web_search or web_search_provider is None:
|
||||||
|
return {"web_search_approved": False}
|
||||||
|
|
||||||
|
if not effective_config.require_web_approval:
|
||||||
|
return {"web_search_approved": True}
|
||||||
|
|
||||||
|
query = derive_search_query(state["question"])
|
||||||
|
approved = check_web_search_approval(query, require_approval=True)
|
||||||
|
return {"web_search_approved": approved}
|
||||||
|
|
||||||
|
async def web_search_node(state: MeetingQAInternalState) -> dict[str, object]:
|
||||||
|
if not state.get("web_search_approved", False) or web_search_provider is None:
|
||||||
|
return {"web_context": ""}
|
||||||
|
|
||||||
|
query = derive_search_query(state["question"])
|
||||||
|
search_config = WebSearchConfig(enabled=True, require_approval=False)
|
||||||
|
response = await execute_web_search(query, web_search_provider, search_config)
|
||||||
|
context = format_results_for_context(response.results)
|
||||||
|
return {"web_context": context}
|
||||||
|
|
||||||
|
async def synthesize_node(state: MeetingQAInternalState) -> dict[str, object]:
|
||||||
|
if not state["verification_passed"]:
|
||||||
|
return {
|
||||||
|
"answer": NO_INFORMATION_ANSWER,
|
||||||
|
"citations": [],
|
||||||
|
"suggested_annotations": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await synthesize_answer(
|
||||||
|
question=state["question"],
|
||||||
|
segments=state["retrieved_segments"],
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
citations = [
|
||||||
|
SegmentCitation(
|
||||||
|
meeting_id=state["meeting_id"],
|
||||||
|
segment_id=seg.segment_id,
|
||||||
|
start_time=seg.start_time,
|
||||||
|
end_time=seg.end_time,
|
||||||
|
text=seg.text,
|
||||||
|
score=seg.score,
|
||||||
|
)
|
||||||
|
for seg in state["retrieved_segments"]
|
||||||
|
if seg.segment_id in result.cited_segment_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
suggested_annotations = extract_annotations_from_answer(
|
||||||
|
answer=result.answer,
|
||||||
|
cited_segment_ids=tuple(result.cited_segment_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"answer": result.answer,
|
||||||
|
"citations": citations,
|
||||||
|
"suggested_annotations": suggested_annotations,
|
||||||
|
}
|
||||||
|
|
||||||
|
builder: StateGraph[MeetingQAInternalState] = StateGraph(MeetingQAInternalState)
|
||||||
|
builder.add_node("retrieve", retrieve_node)
|
||||||
|
builder.add_node("verify", verify_node)
|
||||||
|
builder.add_node("synthesize", synthesize_node)
|
||||||
|
|
||||||
|
if effective_config.enable_web_search and web_search_provider is not None:
|
||||||
|
builder.add_node("web_search_approval", web_search_approval_node)
|
||||||
|
builder.add_node("web_search", web_search_node)
|
||||||
|
|
||||||
|
builder.add_edge(START, "retrieve")
|
||||||
|
builder.add_edge("retrieve", "verify")
|
||||||
|
builder.add_edge("verify", "web_search_approval")
|
||||||
|
builder.add_edge("web_search_approval", "web_search")
|
||||||
|
builder.add_edge("web_search", "synthesize")
|
||||||
|
builder.add_edge("synthesize", END)
|
||||||
|
else:
|
||||||
|
builder.add_edge(START, "retrieve")
|
||||||
|
builder.add_edge("retrieve", "verify")
|
||||||
|
builder.add_edge("verify", "synthesize")
|
||||||
|
builder.add_edge("synthesize", END)
|
||||||
|
|
||||||
|
compile_method = getattr(builder, "compile")
|
||||||
|
compiled: CompiledStateGraph[MeetingQAInternalState] = compile_method(checkpointer=checkpointer)
|
||||||
|
return compiled
|
||||||
103
src/noteflow/infrastructure/ai/graphs/summarization.py
Normal file
103
src/noteflow/infrastructure/ai/graphs/summarization.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""LangGraph wrapper for existing SummarizationService.
|
||||||
|
|
||||||
|
Demonstrates the LangGraph integration pattern by wrapping the existing
|
||||||
|
summarization infrastructure in a StateGraph.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Final, TypedDict, cast
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from langgraph.graph import CompiledStateGraph
|
||||||
|
|
||||||
|
from noteflow.application.services.summarization import SummarizationService
|
||||||
|
from noteflow.domain.entities import Segment
|
||||||
|
from noteflow.domain.value_objects import MeetingId
|
||||||
|
|
||||||
|
SUMMARIZATION_GRAPH_NAME: Final[str] = "summarization"
|
||||||
|
SUMMARIZATION_GRAPH_VERSION: Final[int] = 1
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizationInputState(TypedDict):
|
||||||
|
"""Input state for summarization graph."""
|
||||||
|
|
||||||
|
meeting_id: UUID
|
||||||
|
segments: Sequence[Segment]
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizationOutputState(TypedDict):
|
||||||
|
"""Output state from summarization graph."""
|
||||||
|
|
||||||
|
summary_text: str
|
||||||
|
key_points: list[dict[str, object]]
|
||||||
|
action_items: list[dict[str, object]]
|
||||||
|
provider_used: str
|
||||||
|
tokens_used: int | None
|
||||||
|
latency_ms: float | None
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizationState(SummarizationInputState, SummarizationOutputState):
|
||||||
|
"""Full internal state for summarization graph."""
|
||||||
|
|
||||||
|
|
||||||
|
def build_summarization_graph(
|
||||||
|
summarization_service: SummarizationService,
|
||||||
|
) -> CompiledStateGraph[SummarizationState]:
|
||||||
|
"""Build LangGraph wrapper around existing SummarizationService.
|
||||||
|
|
||||||
|
This demonstrates the pattern of wrapping existing services in LangGraph
|
||||||
|
graphs, enabling future expansion (checkpointing, conditional branching,
|
||||||
|
human-in-the-loop, etc.) while maintaining backward compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summarization_service: The existing summarization service to wrap.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A compiled StateGraph that can be invoked with meeting_id and segments.
|
||||||
|
"""
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
|
||||||
|
async def summarize_node(state: SummarizationState) -> dict[str, object]:
|
||||||
|
# MeetingId is NewType of UUID, cast is safe
|
||||||
|
meeting_id = cast("MeetingId", state["meeting_id"])
|
||||||
|
result = await summarization_service.summarize(
|
||||||
|
meeting_id=meeting_id,
|
||||||
|
segments=state["segments"],
|
||||||
|
)
|
||||||
|
summary = result.summary
|
||||||
|
return {
|
||||||
|
"summary_text": summary.executive_summary,
|
||||||
|
"key_points": [
|
||||||
|
{
|
||||||
|
"text": kp.text,
|
||||||
|
"segment_ids": kp.segment_ids,
|
||||||
|
"start_time": kp.start_time,
|
||||||
|
"end_time": kp.end_time,
|
||||||
|
}
|
||||||
|
for kp in summary.key_points
|
||||||
|
],
|
||||||
|
"action_items": [
|
||||||
|
{
|
||||||
|
"text": ai.text,
|
||||||
|
"segment_ids": ai.segment_ids,
|
||||||
|
"assignee": ai.assignee,
|
||||||
|
"start_time": ai.start_time,
|
||||||
|
"end_time": ai.end_time,
|
||||||
|
}
|
||||||
|
for ai in summary.action_items
|
||||||
|
],
|
||||||
|
"provider_used": result.provider_used,
|
||||||
|
"tokens_used": summary.tokens_used,
|
||||||
|
"latency_ms": summary.latency_ms,
|
||||||
|
}
|
||||||
|
|
||||||
|
builder = StateGraph(SummarizationState)
|
||||||
|
builder.add_node("summarize", summarize_node)
|
||||||
|
builder.add_edge(START, "summarize")
|
||||||
|
builder.add_edge("summarize", END)
|
||||||
|
|
||||||
|
return builder.compile()
|
||||||
198
src/noteflow/infrastructure/ai/graphs/workspace_qa.py
Normal file
198
src/noteflow/infrastructure/ai/graphs/workspace_qa.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""Workspace Q&A graph for cross-meeting question answering with citations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Final, TypedDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||||
|
from langgraph.graph import CompiledStateGraph
|
||||||
|
|
||||||
|
from noteflow.domain.ai.citations import SegmentCitation
|
||||||
|
from noteflow.infrastructure.ai.nodes.annotation_suggester import SuggestedAnnotation
|
||||||
|
from noteflow.infrastructure.ai.nodes.web_search import WebSearchProvider
|
||||||
|
from noteflow.infrastructure.ai.tools.retrieval import (
|
||||||
|
EmbedderProtocol,
|
||||||
|
RetrievalResult,
|
||||||
|
WorkspaceSegmentSearchProtocol,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.ai.tools.synthesis import LLMProtocol
|
||||||
|
|
||||||
|
WORKSPACE_QA_GRAPH_NAME: Final[str] = "workspace_qa"
|
||||||
|
WORKSPACE_QA_GRAPH_VERSION: Final[int] = 2
|
||||||
|
NO_INFORMATION_ANSWER: Final[str] = "I couldn't find relevant information across your meetings."
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class WorkspaceQAConfig:
|
||||||
|
enable_web_search: bool = False
|
||||||
|
require_web_approval: bool = True
|
||||||
|
require_annotation_approval: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceQAInputState(TypedDict):
|
||||||
|
question: str
|
||||||
|
workspace_id: UUID
|
||||||
|
project_id: UUID | None
|
||||||
|
top_k: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceQAOutputState(TypedDict):
|
||||||
|
answer: str
|
||||||
|
citations: list[SegmentCitation]
|
||||||
|
suggested_annotations: list[SuggestedAnnotation]
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceQAInternalState(WorkspaceQAInputState, WorkspaceQAOutputState):
|
||||||
|
retrieved_segments: list[RetrievalResult]
|
||||||
|
verification_passed: bool
|
||||||
|
web_search_approved: bool
|
||||||
|
web_context: str
|
||||||
|
|
||||||
|
|
||||||
|
def build_workspace_qa_graph(
|
||||||
|
embedder: EmbedderProtocol,
|
||||||
|
segment_repo: WorkspaceSegmentSearchProtocol,
|
||||||
|
llm: LLMProtocol,
|
||||||
|
*,
|
||||||
|
web_search_provider: WebSearchProvider | None = None,
|
||||||
|
config: WorkspaceQAConfig | None = None,
|
||||||
|
checkpointer: BaseCheckpointSaver[str] | None = None,
|
||||||
|
) -> CompiledStateGraph[WorkspaceQAInternalState]:
|
||||||
|
"""Build Q&A graph for cross-meeting questions with segment citations.
|
||||||
|
|
||||||
|
Graph flow (with web search): retrieve -> verify -> [web_search_approval] -> [web_search] -> synthesize
|
||||||
|
Graph flow (without): retrieve -> verify -> synthesize
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedder: Protocol for generating text embeddings.
|
||||||
|
segment_repo: Protocol for workspace-scoped semantic segment search.
|
||||||
|
llm: Protocol for LLM text completion.
|
||||||
|
web_search_provider: Optional web search provider for augmentation.
|
||||||
|
config: Graph configuration for features/interrupts.
|
||||||
|
checkpointer: Optional checkpointer for interrupt support.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compiled graph that accepts question/workspace_id and returns answer/citations.
|
||||||
|
"""
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
|
||||||
|
from noteflow.domain.ai.citations import SegmentCitation
|
||||||
|
from noteflow.infrastructure.ai.interrupts import check_web_search_approval
|
||||||
|
from noteflow.infrastructure.ai.nodes.annotation_suggester import (
|
||||||
|
extract_annotations_from_answer,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.ai.nodes.web_search import (
|
||||||
|
WebSearchConfig,
|
||||||
|
derive_search_query,
|
||||||
|
execute_web_search,
|
||||||
|
format_results_for_context,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.ai.tools.retrieval import retrieve_segments_workspace
|
||||||
|
from noteflow.infrastructure.ai.tools.synthesis import synthesize_answer
|
||||||
|
|
||||||
|
effective_config = config or WorkspaceQAConfig()
|
||||||
|
|
||||||
|
async def retrieve_node(state: WorkspaceQAInternalState) -> dict[str, object]:
|
||||||
|
results = await retrieve_segments_workspace(
|
||||||
|
query=state["question"],
|
||||||
|
embedder=embedder,
|
||||||
|
segment_repo=segment_repo,
|
||||||
|
workspace_id=state["workspace_id"],
|
||||||
|
project_id=state["project_id"],
|
||||||
|
top_k=state["top_k"],
|
||||||
|
)
|
||||||
|
return {"retrieved_segments": results}
|
||||||
|
|
||||||
|
async def verify_node(state: WorkspaceQAInternalState) -> dict[str, object]:
|
||||||
|
has_segments = len(state["retrieved_segments"]) > 0
|
||||||
|
return {"verification_passed": has_segments}
|
||||||
|
|
||||||
|
def web_search_approval_node(state: WorkspaceQAInternalState) -> dict[str, object]:
|
||||||
|
if not effective_config.enable_web_search or web_search_provider is None:
|
||||||
|
return {"web_search_approved": False}
|
||||||
|
|
||||||
|
if not effective_config.require_web_approval:
|
||||||
|
return {"web_search_approved": True}
|
||||||
|
|
||||||
|
query = derive_search_query(state["question"])
|
||||||
|
approved = check_web_search_approval(query, require_approval=True)
|
||||||
|
return {"web_search_approved": approved}
|
||||||
|
|
||||||
|
async def web_search_node(state: WorkspaceQAInternalState) -> dict[str, object]:
|
||||||
|
if not state.get("web_search_approved", False) or web_search_provider is None:
|
||||||
|
return {"web_context": ""}
|
||||||
|
|
||||||
|
query = derive_search_query(state["question"])
|
||||||
|
search_config = WebSearchConfig(enabled=True, require_approval=False)
|
||||||
|
response = await execute_web_search(query, web_search_provider, search_config)
|
||||||
|
context = format_results_for_context(response.results)
|
||||||
|
return {"web_context": context}
|
||||||
|
|
||||||
|
async def synthesize_node(state: WorkspaceQAInternalState) -> dict[str, object]:
|
||||||
|
if not state["verification_passed"]:
|
||||||
|
return {
|
||||||
|
"answer": NO_INFORMATION_ANSWER,
|
||||||
|
"citations": [],
|
||||||
|
"suggested_annotations": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await synthesize_answer(
|
||||||
|
question=state["question"],
|
||||||
|
segments=state["retrieved_segments"],
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
citations = [
|
||||||
|
SegmentCitation(
|
||||||
|
meeting_id=seg.meeting_id,
|
||||||
|
segment_id=seg.segment_id,
|
||||||
|
start_time=seg.start_time,
|
||||||
|
end_time=seg.end_time,
|
||||||
|
text=seg.text,
|
||||||
|
score=seg.score,
|
||||||
|
)
|
||||||
|
for seg in state["retrieved_segments"]
|
||||||
|
if seg.segment_id in result.cited_segment_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
suggested_annotations = extract_annotations_from_answer(
|
||||||
|
answer=result.answer,
|
||||||
|
cited_segment_ids=tuple(result.cited_segment_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"answer": result.answer,
|
||||||
|
"citations": citations,
|
||||||
|
"suggested_annotations": suggested_annotations,
|
||||||
|
}
|
||||||
|
|
||||||
|
builder: StateGraph[WorkspaceQAInternalState] = StateGraph(WorkspaceQAInternalState)
|
||||||
|
builder.add_node("retrieve", retrieve_node)
|
||||||
|
builder.add_node("verify", verify_node)
|
||||||
|
builder.add_node("synthesize", synthesize_node)
|
||||||
|
|
||||||
|
if effective_config.enable_web_search and web_search_provider is not None:
|
||||||
|
builder.add_node("web_search_approval", web_search_approval_node)
|
||||||
|
builder.add_node("web_search", web_search_node)
|
||||||
|
|
||||||
|
builder.add_edge(START, "retrieve")
|
||||||
|
builder.add_edge("retrieve", "verify")
|
||||||
|
builder.add_edge("verify", "web_search_approval")
|
||||||
|
builder.add_edge("web_search_approval", "web_search")
|
||||||
|
builder.add_edge("web_search", "synthesize")
|
||||||
|
builder.add_edge("synthesize", END)
|
||||||
|
else:
|
||||||
|
builder.add_edge(START, "retrieve")
|
||||||
|
builder.add_edge("retrieve", "verify")
|
||||||
|
builder.add_edge("verify", "synthesize")
|
||||||
|
builder.add_edge("synthesize", END)
|
||||||
|
|
||||||
|
compile_method = getattr(builder, "compile")
|
||||||
|
compiled: CompiledStateGraph[WorkspaceQAInternalState] = compile_method(
|
||||||
|
checkpointer=checkpointer
|
||||||
|
)
|
||||||
|
return compiled
|
||||||
312
src/noteflow/infrastructure/ai/guardrails.py
Normal file
312
src/noteflow/infrastructure/ai/guardrails.py
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
"""Content guardrails for AI input validation and output filtering."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
from noteflow.infrastructure.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Input validation limits
|
||||||
|
DEFAULT_MIN_INPUT_LENGTH: Final[int] = 3
|
||||||
|
DEFAULT_MAX_INPUT_LENGTH: Final[int] = 4000
|
||||||
|
DEFAULT_MAX_OUTPUT_LENGTH: Final[int] = 10000
|
||||||
|
|
||||||
|
# PII patterns (simplified - production would use more comprehensive detection)
|
||||||
|
EMAIL_PATTERN: Final[re.Pattern[str]] = re.compile(
|
||||||
|
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
|
||||||
|
)
|
||||||
|
PHONE_PATTERN: Final[re.Pattern[str]] = re.compile(
|
||||||
|
r"\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b"
|
||||||
|
)
|
||||||
|
SSN_PATTERN: Final[re.Pattern[str]] = re.compile(r"\b\d{3}-\d{2}-\d{4}\b")
|
||||||
|
CREDIT_CARD_PATTERN: Final[re.Pattern[str]] = re.compile(r"\b(?:\d{4}[-\s]?){3}\d{4}\b")
|
||||||
|
|
||||||
|
PII_PATTERNS: Final[tuple[tuple[str, re.Pattern[str]], ...]] = (
|
||||||
|
("email", EMAIL_PATTERN),
|
||||||
|
("phone", PHONE_PATTERN),
|
||||||
|
("ssn", SSN_PATTERN),
|
||||||
|
("credit_card", CREDIT_CARD_PATTERN),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redaction placeholder
|
||||||
|
PII_REDACTION: Final[str] = "[REDACTED]"
|
||||||
|
|
||||||
|
|
||||||
|
class GuardrailViolation(str, Enum):
|
||||||
|
"""Types of guardrail violations."""
|
||||||
|
|
||||||
|
INPUT_TOO_SHORT = "input_too_short"
|
||||||
|
INPUT_TOO_LONG = "input_too_long"
|
||||||
|
OUTPUT_TOO_LONG = "output_too_long"
|
||||||
|
CONTAINS_PII = "contains_pii"
|
||||||
|
BLOCKED_CONTENT = "blocked_content"
|
||||||
|
INJECTION_ATTEMPT = "injection_attempt"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GuardrailResult:
|
||||||
|
"""Result of a guardrail check."""
|
||||||
|
|
||||||
|
allowed: bool
|
||||||
|
violation: GuardrailViolation | None = None
|
||||||
|
reason: str | None = None
|
||||||
|
filtered_content: str | None = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ok(content: str | None = None) -> GuardrailResult:
|
||||||
|
"""Create a passing result."""
|
||||||
|
return GuardrailResult(allowed=True, filtered_content=content)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def blocked(
|
||||||
|
violation: GuardrailViolation,
|
||||||
|
reason: str,
|
||||||
|
) -> GuardrailResult:
|
||||||
|
"""Create a blocking result."""
|
||||||
|
return GuardrailResult(allowed=False, violation=violation, reason=reason)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def filtered(
|
||||||
|
content: str,
|
||||||
|
violation: GuardrailViolation,
|
||||||
|
reason: str,
|
||||||
|
) -> GuardrailResult:
|
||||||
|
"""Create a result with filtered content."""
|
||||||
|
return GuardrailResult(
|
||||||
|
allowed=True,
|
||||||
|
violation=violation,
|
||||||
|
reason=reason,
|
||||||
|
filtered_content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GuardrailRules:
|
||||||
|
"""Configurable guardrail rules.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
min_input_length: Minimum allowed input length.
|
||||||
|
max_input_length: Maximum allowed input length.
|
||||||
|
max_output_length: Maximum allowed output length.
|
||||||
|
block_pii: Whether to block content containing PII.
|
||||||
|
redact_pii: Whether to redact PII instead of blocking.
|
||||||
|
blocked_phrases: Phrases that should block content entirely.
|
||||||
|
detect_injection: Whether to detect prompt injection attempts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
min_input_length: int = DEFAULT_MIN_INPUT_LENGTH
|
||||||
|
max_input_length: int = DEFAULT_MAX_INPUT_LENGTH
|
||||||
|
max_output_length: int = DEFAULT_MAX_OUTPUT_LENGTH
|
||||||
|
block_pii: bool = False
|
||||||
|
redact_pii: bool = True
|
||||||
|
blocked_phrases: frozenset[str] = field(default_factory=frozenset)
|
||||||
|
detect_injection: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
# Common injection patterns
|
||||||
|
INJECTION_PATTERNS: Final[tuple[re.Pattern[str], ...]] = (
|
||||||
|
re.compile(r"ignore\s+(?:all\s+)?(?:previous|above)\s+instructions", re.IGNORECASE),
|
||||||
|
re.compile(r"disregard\s+(?:all\s+)?(?:previous|prior)\s+", re.IGNORECASE),
|
||||||
|
re.compile(r"you\s+are\s+now\s+(?:a|an)\s+", re.IGNORECASE),
|
||||||
|
re.compile(r"forget\s+(?:everything|all)\s+(?:you|and)\s+", re.IGNORECASE),
|
||||||
|
re.compile(r"new\s+(?:system\s+)?instructions?:", re.IGNORECASE),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_length(
|
||||||
|
text: str,
|
||||||
|
min_length: int,
|
||||||
|
max_length: int,
|
||||||
|
is_input: bool,
|
||||||
|
) -> GuardrailResult | None:
|
||||||
|
"""Check text length constraints."""
|
||||||
|
if is_input and len(text) < min_length:
|
||||||
|
return GuardrailResult.blocked(
|
||||||
|
GuardrailViolation.INPUT_TOO_SHORT,
|
||||||
|
f"Input must be at least {min_length} characters",
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_input and len(text) > max_length:
|
||||||
|
return GuardrailResult.blocked(
|
||||||
|
GuardrailViolation.INPUT_TOO_LONG,
|
||||||
|
f"Input must be at most {max_length} characters",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_input and len(text) > max_length:
|
||||||
|
return GuardrailResult.blocked(
|
||||||
|
GuardrailViolation.OUTPUT_TOO_LONG,
|
||||||
|
f"Output exceeds {max_length} characters",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _check_blocked_phrases(
|
||||||
|
text: str,
|
||||||
|
blocked_phrases: frozenset[str],
|
||||||
|
) -> GuardrailResult | None:
|
||||||
|
"""Check for blocked phrases."""
|
||||||
|
text_lower = text.lower()
|
||||||
|
for phrase in blocked_phrases:
|
||||||
|
if phrase.lower() in text_lower:
|
||||||
|
logger.warning("blocked_phrase_detected", phrase=phrase[:20])
|
||||||
|
return GuardrailResult.blocked(
|
||||||
|
GuardrailViolation.BLOCKED_CONTENT,
|
||||||
|
"Content contains blocked phrase",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _check_injection(text: str) -> GuardrailResult | None:
|
||||||
|
"""Check for prompt injection attempts."""
|
||||||
|
for pattern in INJECTION_PATTERNS:
|
||||||
|
if pattern.search(text):
|
||||||
|
logger.warning("injection_attempt_detected")
|
||||||
|
return GuardrailResult.blocked(
|
||||||
|
GuardrailViolation.INJECTION_ATTEMPT,
|
||||||
|
"Potential prompt injection detected",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_pii(text: str) -> list[tuple[str, str]]:
|
||||||
|
"""Detect PII in text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (pii_type, matched_text) tuples.
|
||||||
|
"""
|
||||||
|
findings: list[tuple[str, str]] = []
|
||||||
|
for pii_type, pattern in PII_PATTERNS:
|
||||||
|
for match in pattern.finditer(text):
|
||||||
|
findings.append((pii_type, match.group()))
|
||||||
|
return findings
|
||||||
|
|
||||||
|
|
||||||
|
def _redact_pii(text: str) -> str:
|
||||||
|
"""Redact all PII in text."""
|
||||||
|
result = text
|
||||||
|
for _, pattern in PII_PATTERNS:
|
||||||
|
result = pattern.sub(PII_REDACTION, result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def check_input(text: str, rules: GuardrailRules) -> GuardrailResult:
|
||||||
|
"""Validate input text against guardrail rules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to validate.
|
||||||
|
rules: Guardrail rules to apply.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GuardrailResult indicating if input is allowed.
|
||||||
|
"""
|
||||||
|
# Length checks
|
||||||
|
length_result = _check_length(
|
||||||
|
text,
|
||||||
|
rules.min_input_length,
|
||||||
|
rules.max_input_length,
|
||||||
|
is_input=True,
|
||||||
|
)
|
||||||
|
if length_result is not None:
|
||||||
|
return length_result
|
||||||
|
|
||||||
|
# Blocked phrases
|
||||||
|
phrase_result = _check_blocked_phrases(text, rules.blocked_phrases)
|
||||||
|
if phrase_result is not None:
|
||||||
|
return phrase_result
|
||||||
|
|
||||||
|
# Injection detection
|
||||||
|
if rules.detect_injection:
|
||||||
|
injection_result = _check_injection(text)
|
||||||
|
if injection_result is not None:
|
||||||
|
return injection_result
|
||||||
|
|
||||||
|
# PII checks
|
||||||
|
if rules.block_pii or rules.redact_pii:
|
||||||
|
pii_findings = _detect_pii(text)
|
||||||
|
if pii_findings:
|
||||||
|
pii_types = [f[0] for f in pii_findings]
|
||||||
|
logger.info("pii_detected_in_input", pii_types=pii_types)
|
||||||
|
|
||||||
|
if rules.block_pii:
|
||||||
|
return GuardrailResult.blocked(
|
||||||
|
GuardrailViolation.CONTAINS_PII,
|
||||||
|
f"Input contains PII: {', '.join(pii_types)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redact instead of block
|
||||||
|
redacted = _redact_pii(text)
|
||||||
|
return GuardrailResult.filtered(
|
||||||
|
redacted,
|
||||||
|
GuardrailViolation.CONTAINS_PII,
|
||||||
|
f"PII redacted: {', '.join(pii_types)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return GuardrailResult.ok(text)
|
||||||
|
|
||||||
|
|
||||||
|
async def filter_output(text: str, rules: GuardrailRules) -> GuardrailResult:
|
||||||
|
"""Filter output text, redacting sensitive content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Output text to filter.
|
||||||
|
rules: Guardrail rules to apply.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GuardrailResult with potentially filtered content.
|
||||||
|
"""
|
||||||
|
# Length check
|
||||||
|
length_result = _check_length(
|
||||||
|
text,
|
||||||
|
min_length=0, # No minimum for output
|
||||||
|
max_length=rules.max_output_length,
|
||||||
|
is_input=False,
|
||||||
|
)
|
||||||
|
if length_result is not None:
|
||||||
|
# Truncate instead of blocking for output
|
||||||
|
truncated = text[: rules.max_output_length]
|
||||||
|
return GuardrailResult.filtered(
|
||||||
|
truncated,
|
||||||
|
GuardrailViolation.OUTPUT_TOO_LONG,
|
||||||
|
f"Output truncated to {rules.max_output_length} characters",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Blocked phrases in output
|
||||||
|
phrase_result = _check_blocked_phrases(text, rules.blocked_phrases)
|
||||||
|
if phrase_result is not None:
|
||||||
|
return phrase_result
|
||||||
|
|
||||||
|
# PII redaction in output (always redact, never block output)
|
||||||
|
if rules.redact_pii:
|
||||||
|
pii_findings = _detect_pii(text)
|
||||||
|
if pii_findings:
|
||||||
|
pii_types = [f[0] for f in pii_findings]
|
||||||
|
logger.info("pii_detected_in_output", pii_types=pii_types)
|
||||||
|
redacted = _redact_pii(text)
|
||||||
|
return GuardrailResult.filtered(
|
||||||
|
redacted,
|
||||||
|
GuardrailViolation.CONTAINS_PII,
|
||||||
|
f"PII redacted: {', '.join(pii_types)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return GuardrailResult.ok(text)
|
||||||
|
|
||||||
|
|
||||||
|
def create_default_rules() -> GuardrailRules:
|
||||||
|
"""Create default guardrail rules."""
|
||||||
|
return GuardrailRules()
|
||||||
|
|
||||||
|
|
||||||
|
def create_strict_rules() -> GuardrailRules:
|
||||||
|
"""Create strict guardrail rules with PII blocking."""
|
||||||
|
return GuardrailRules(
|
||||||
|
block_pii=True,
|
||||||
|
redact_pii=False,
|
||||||
|
detect_injection=True,
|
||||||
|
max_input_length=2000,
|
||||||
|
)
|
||||||
231
src/noteflow/infrastructure/ai/interrupts.py
Normal file
231
src/noteflow/infrastructure/ai/interrupts.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
"""Infrastructure utilities for LangGraph human-in-the-loop interrupts.
|
||||||
|
|
||||||
|
Wraps LangGraph's interrupt() and Command APIs for consistent usage across graphs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Final
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from langgraph.types import Command, interrupt
|
||||||
|
|
||||||
|
from noteflow.domain.ai.interrupts import (
|
||||||
|
InterruptAction,
|
||||||
|
InterruptResponse,
|
||||||
|
create_annotation_interrupt,
|
||||||
|
create_web_search_interrupt,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.logging import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from noteflow.infrastructure.ai.nodes.annotation_suggester import SuggestedAnnotation
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
INTERRUPT_RESPONSE_KEY: Final[str] = "response"
|
||||||
|
INTERRUPT_APPROVED_VALUE: Final[str] = "approved"
|
||||||
|
|
||||||
|
|
||||||
|
def request_web_search_approval(
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
allow_modify: bool = False,
|
||||||
|
) -> InterruptResponse:
|
||||||
|
"""Request user approval for web search via LangGraph interrupt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query to be executed.
|
||||||
|
allow_modify: Whether the user can modify the query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InterruptResponse with user's decision.
|
||||||
|
"""
|
||||||
|
request_id = str(uuid4())
|
||||||
|
interrupt_request = create_web_search_interrupt(
|
||||||
|
query=query,
|
||||||
|
request_id=request_id,
|
||||||
|
allow_modify=allow_modify,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"interrupt_web_search_requested",
|
||||||
|
request_id=request_id,
|
||||||
|
query_preview=query[:50],
|
||||||
|
)
|
||||||
|
|
||||||
|
response_data = interrupt(interrupt_request.to_dict())
|
||||||
|
|
||||||
|
return _parse_interrupt_response(response_data, request_id)
|
||||||
|
|
||||||
|
|
||||||
|
def request_annotation_approval(
|
||||||
|
annotations: list[SuggestedAnnotation],
|
||||||
|
) -> InterruptResponse:
|
||||||
|
"""Request user approval for suggested annotations via LangGraph interrupt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
annotations: List of suggested annotations to approve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InterruptResponse with user's decision.
|
||||||
|
"""
|
||||||
|
request_id = str(uuid4())
|
||||||
|
annotation_dicts = [ann.to_dict() for ann in annotations]
|
||||||
|
interrupt_request = create_annotation_interrupt(
|
||||||
|
annotations=annotation_dicts,
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"interrupt_annotation_requested",
|
||||||
|
request_id=request_id,
|
||||||
|
annotation_count=len(annotations),
|
||||||
|
)
|
||||||
|
|
||||||
|
response_data = interrupt(interrupt_request.to_dict())
|
||||||
|
|
||||||
|
return _parse_interrupt_response(response_data, request_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_interrupt_response(
|
||||||
|
response_data: object,
|
||||||
|
request_id: str,
|
||||||
|
) -> InterruptResponse:
|
||||||
|
"""Parse LangGraph interrupt response into domain type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response_data: Raw response from LangGraph interrupt.
|
||||||
|
request_id: ID of the original request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed InterruptResponse.
|
||||||
|
"""
|
||||||
|
if isinstance(response_data, str):
|
||||||
|
action = _string_to_action(response_data)
|
||||||
|
return InterruptResponse(action=action, request_id=request_id)
|
||||||
|
|
||||||
|
if isinstance(response_data, dict):
|
||||||
|
action_str = str(response_data.get("action", "reject"))
|
||||||
|
action = _string_to_action(action_str)
|
||||||
|
|
||||||
|
modified_value = response_data.get("modified_value")
|
||||||
|
if modified_value is not None and not isinstance(modified_value, dict):
|
||||||
|
modified_value = None
|
||||||
|
|
||||||
|
user_message = response_data.get("user_message")
|
||||||
|
if user_message is not None:
|
||||||
|
user_message = str(user_message)
|
||||||
|
|
||||||
|
return InterruptResponse(
|
||||||
|
action=action,
|
||||||
|
request_id=request_id,
|
||||||
|
modified_value=modified_value,
|
||||||
|
user_message=user_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"interrupt_response_unknown_format",
|
||||||
|
request_id=request_id,
|
||||||
|
response_type=type(response_data).__name__,
|
||||||
|
)
|
||||||
|
return InterruptResponse(action=InterruptAction.REJECT, request_id=request_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _string_to_action(value: str) -> InterruptAction:
|
||||||
|
"""Convert string response to InterruptAction."""
|
||||||
|
normalized = value.lower().strip()
|
||||||
|
if normalized in ("approve", "yes", "approved", "accept"):
|
||||||
|
return InterruptAction.APPROVE
|
||||||
|
if normalized in ("modify", "edit", "change"):
|
||||||
|
return InterruptAction.MODIFY
|
||||||
|
return InterruptAction.REJECT
|
||||||
|
|
||||||
|
|
||||||
|
def create_resume_command(response: InterruptResponse) -> Command[None]:
|
||||||
|
"""Create a LangGraph Command to resume execution with user response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: User's interrupt response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Command to resume graph execution.
|
||||||
|
"""
|
||||||
|
return Command(resume=response.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptHandler:
|
||||||
|
"""Handles interrupt requests and responses for a graph execution."""
|
||||||
|
|
||||||
|
_require_web_approval: bool
|
||||||
|
|
||||||
|
def __init__(self, require_web_approval: bool = True) -> None:
|
||||||
|
self._require_web_approval = require_web_approval
|
||||||
|
|
||||||
|
def should_interrupt_for_web_search(self) -> bool:
|
||||||
|
return self._require_web_approval
|
||||||
|
|
||||||
|
def request_web_search(self, query: str) -> InterruptResponse:
|
||||||
|
return request_web_search_approval(query)
|
||||||
|
|
||||||
|
def request_annotation_approval(
|
||||||
|
self,
|
||||||
|
annotations: list[SuggestedAnnotation],
|
||||||
|
) -> InterruptResponse:
|
||||||
|
return request_annotation_approval(annotations)
|
||||||
|
|
||||||
|
|
||||||
|
def check_web_search_approval(
|
||||||
|
query: str,
|
||||||
|
require_approval: bool,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if web search should proceed (with optional interrupt).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query to execute.
|
||||||
|
require_approval: Whether to interrupt for user approval.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if search should proceed, False if rejected.
|
||||||
|
"""
|
||||||
|
if not require_approval:
|
||||||
|
return True
|
||||||
|
|
||||||
|
response = request_web_search_approval(query)
|
||||||
|
return response.is_approved
|
||||||
|
|
||||||
|
|
||||||
|
def check_annotation_approval(
|
||||||
|
annotations: list[SuggestedAnnotation],
|
||||||
|
) -> tuple[bool, list[SuggestedAnnotation]]:
|
||||||
|
"""Check if annotations should be applied (with interrupt).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
annotations: Suggested annotations to approve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (should_apply, possibly_modified_annotations).
|
||||||
|
"""
|
||||||
|
if not annotations:
|
||||||
|
return False, []
|
||||||
|
|
||||||
|
response = request_annotation_approval(annotations)
|
||||||
|
|
||||||
|
if response.is_rejected:
|
||||||
|
return False, []
|
||||||
|
|
||||||
|
if response.is_modified and response.modified_value:
|
||||||
|
modified_list_raw = response.modified_value.get("annotations", [])
|
||||||
|
if isinstance(modified_list_raw, list):
|
||||||
|
from noteflow.infrastructure.ai.nodes.annotation_suggester import (
|
||||||
|
SuggestedAnnotation,
|
||||||
|
)
|
||||||
|
|
||||||
|
modified_annotations: list[SuggestedAnnotation] = []
|
||||||
|
for item in modified_list_raw:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
item_dict: dict[str, object] = {str(k): v for k, v in item.items()}
|
||||||
|
modified_annotations.append(SuggestedAnnotation.from_dict(item_dict))
|
||||||
|
return True, modified_annotations
|
||||||
|
|
||||||
|
return response.is_approved, annotations
|
||||||
39
src/noteflow/infrastructure/ai/nodes/__init__.py
Normal file
39
src/noteflow/infrastructure/ai/nodes/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""LangGraph node implementations."""
|
||||||
|
|
||||||
|
from noteflow.infrastructure.ai.nodes.annotation_suggester import (
|
||||||
|
SuggestedAnnotation,
|
||||||
|
SuggestedAnnotationType,
|
||||||
|
extract_annotations_from_answer,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.ai.nodes.verification import (
|
||||||
|
VerificationResult,
|
||||||
|
verify_citations,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.ai.nodes.web_search import (
|
||||||
|
DisabledWebSearchProvider,
|
||||||
|
WebSearchConfig,
|
||||||
|
WebSearchProvider,
|
||||||
|
WebSearchResponse,
|
||||||
|
WebSearchResult,
|
||||||
|
derive_search_query,
|
||||||
|
execute_web_search,
|
||||||
|
format_results_for_context,
|
||||||
|
merge_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DisabledWebSearchProvider",
|
||||||
|
"SuggestedAnnotation",
|
||||||
|
"SuggestedAnnotationType",
|
||||||
|
"VerificationResult",
|
||||||
|
"WebSearchConfig",
|
||||||
|
"WebSearchProvider",
|
||||||
|
"WebSearchResponse",
|
||||||
|
"WebSearchResult",
|
||||||
|
"derive_search_query",
|
||||||
|
"execute_web_search",
|
||||||
|
"extract_annotations_from_answer",
|
||||||
|
"format_results_for_context",
|
||||||
|
"merge_contexts",
|
||||||
|
"verify_citations",
|
||||||
|
]
|
||||||
121
src/noteflow/infrastructure/ai/nodes/annotation_suggester.py
Normal file
121
src/noteflow/infrastructure/ai/nodes/annotation_suggester.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""Annotation suggester for extracting action items and decisions from answers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestedAnnotationType(str, Enum):
|
||||||
|
ACTION_ITEM = "action_item"
|
||||||
|
DECISION = "decision"
|
||||||
|
NOTE = "note"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SuggestedAnnotation:
|
||||||
|
text: str
|
||||||
|
annotation_type: SuggestedAnnotationType
|
||||||
|
segment_ids: tuple[int, ...]
|
||||||
|
confidence: float = 0.8
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, object]:
|
||||||
|
return {
|
||||||
|
"text": self.text,
|
||||||
|
"type": self.annotation_type.value,
|
||||||
|
"segment_ids": list(self.segment_ids),
|
||||||
|
"confidence": self.confidence,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, object]) -> SuggestedAnnotation:
|
||||||
|
text = str(data.get("text", ""))
|
||||||
|
type_str = str(data.get("type", "note"))
|
||||||
|
segment_ids_raw = data.get("segment_ids", [])
|
||||||
|
if isinstance(segment_ids_raw, list):
|
||||||
|
segment_ids = tuple(
|
||||||
|
int(sid) for sid in segment_ids_raw if isinstance(sid, (int, float))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
segment_ids = ()
|
||||||
|
confidence_raw = data.get("confidence", 0.8)
|
||||||
|
confidence = float(confidence_raw) if isinstance(confidence_raw, (int, float)) else 0.8
|
||||||
|
|
||||||
|
try:
|
||||||
|
annotation_type = SuggestedAnnotationType(type_str)
|
||||||
|
except ValueError:
|
||||||
|
annotation_type = SuggestedAnnotationType.NOTE
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
text=text,
|
||||||
|
annotation_type=annotation_type,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
confidence=confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ACTION_ITEM_PATTERNS: Final[tuple[re.Pattern[str], ...]] = (
|
||||||
|
re.compile(r"(?:need to|should|must|will|going to|has to)\s+(.+?)(?:\.|$)", re.IGNORECASE),
|
||||||
|
re.compile(r"(?:action item|TODO|task):\s*(.+?)(?:\.|$)", re.IGNORECASE),
|
||||||
|
re.compile(r"(?:follow[- ]up|next step):\s*(.+?)(?:\.|$)", re.IGNORECASE),
|
||||||
|
)
|
||||||
|
|
||||||
|
DECISION_PATTERNS: Final[tuple[re.Pattern[str], ...]] = (
|
||||||
|
re.compile(r"(?:decided to|agreed to|will go with|chose to)\s+(.+?)(?:\.|$)", re.IGNORECASE),
|
||||||
|
re.compile(r"(?:decision|resolution):\s*(.+?)(?:\.|$)", re.IGNORECASE),
|
||||||
|
re.compile(r"(?:the team|we|they) (?:decided|agreed|chose)\s+(.+?)(?:\.|$)", re.IGNORECASE),
|
||||||
|
)
|
||||||
|
|
||||||
|
MIN_TEXT_LENGTH: Final[int] = 10
|
||||||
|
MAX_TEXT_LENGTH: Final[int] = 200
|
||||||
|
|
||||||
|
|
||||||
|
def extract_annotations_from_answer(
|
||||||
|
answer: str,
|
||||||
|
cited_segment_ids: tuple[int, ...],
|
||||||
|
) -> list[SuggestedAnnotation]:
|
||||||
|
"""Extract action items and decisions from synthesized answer."""
|
||||||
|
suggestions: list[SuggestedAnnotation] = []
|
||||||
|
|
||||||
|
for pattern in ACTION_ITEM_PATTERNS:
|
||||||
|
for match in pattern.finditer(answer):
|
||||||
|
text = match.group(1).strip()
|
||||||
|
if MIN_TEXT_LENGTH <= len(text) <= MAX_TEXT_LENGTH:
|
||||||
|
suggestions.append(
|
||||||
|
SuggestedAnnotation(
|
||||||
|
text=text,
|
||||||
|
annotation_type=SuggestedAnnotationType.ACTION_ITEM,
|
||||||
|
segment_ids=cited_segment_ids,
|
||||||
|
confidence=0.7,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for pattern in DECISION_PATTERNS:
|
||||||
|
for match in pattern.finditer(answer):
|
||||||
|
text = match.group(1).strip()
|
||||||
|
if MIN_TEXT_LENGTH <= len(text) <= MAX_TEXT_LENGTH:
|
||||||
|
suggestions.append(
|
||||||
|
SuggestedAnnotation(
|
||||||
|
text=text,
|
||||||
|
annotation_type=SuggestedAnnotationType.DECISION,
|
||||||
|
segment_ids=cited_segment_ids,
|
||||||
|
confidence=0.75,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return _dedupe_suggestions(suggestions)
|
||||||
|
|
||||||
|
|
||||||
|
def _dedupe_suggestions(suggestions: list[SuggestedAnnotation]) -> list[SuggestedAnnotation]:
|
||||||
|
seen_texts: set[str] = set()
|
||||||
|
deduped: list[SuggestedAnnotation] = []
|
||||||
|
|
||||||
|
for suggestion in suggestions:
|
||||||
|
normalized = suggestion.text.lower().strip()
|
||||||
|
if normalized not in seen_texts:
|
||||||
|
seen_texts.add(normalized)
|
||||||
|
deduped.append(suggestion)
|
||||||
|
|
||||||
|
return deduped
|
||||||
61
src/noteflow/infrastructure/ai/nodes/verification.py
Normal file
61
src/noteflow/infrastructure/ai/nodes/verification.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""Citation verification for AI-generated answers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class VerificationResult:
|
||||||
|
"""Result of citation verification.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
is_valid: True if all cited segment IDs exist in available segments.
|
||||||
|
invalid_citation_indices: Indices of citations that failed validation.
|
||||||
|
reason: Human-readable explanation if validation failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_valid: bool
|
||||||
|
invalid_citation_indices: tuple[int, ...]
|
||||||
|
reason: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
NO_SEGMENTS_REASON: Final[str] = "No segments retrieved for question"
|
||||||
|
INVALID_CITATIONS_PREFIX: Final[str] = "Invalid citation indices: "
|
||||||
|
|
||||||
|
|
||||||
|
def verify_citations(
|
||||||
|
cited_ids: list[int],
|
||||||
|
available_ids: set[int],
|
||||||
|
) -> VerificationResult:
|
||||||
|
"""Verify all cited segment IDs exist in available segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cited_ids: List of segment IDs cited in the answer.
|
||||||
|
available_ids: Set of valid segment IDs from retrieval.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VerificationResult with validation status and any invalid indices.
|
||||||
|
"""
|
||||||
|
if not available_ids:
|
||||||
|
return VerificationResult(
|
||||||
|
is_valid=False,
|
||||||
|
invalid_citation_indices=(),
|
||||||
|
reason=NO_SEGMENTS_REASON,
|
||||||
|
)
|
||||||
|
|
||||||
|
invalid_indices = tuple(i for i, cid in enumerate(cited_ids) if cid not in available_ids)
|
||||||
|
|
||||||
|
if invalid_indices:
|
||||||
|
return VerificationResult(
|
||||||
|
is_valid=False,
|
||||||
|
invalid_citation_indices=invalid_indices,
|
||||||
|
reason=f"{INVALID_CITATIONS_PREFIX}{invalid_indices}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return VerificationResult(
|
||||||
|
is_valid=True,
|
||||||
|
invalid_citation_indices=(),
|
||||||
|
reason=None,
|
||||||
|
)
|
||||||
226
src/noteflow/infrastructure/ai/nodes/web_search.py
Normal file
226
src/noteflow/infrastructure/ai/nodes/web_search.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
"""Web search node for augmenting RAG with external knowledge."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Final, Protocol
|
||||||
|
|
||||||
|
from noteflow.infrastructure.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MAX_RESULTS: Final[int] = 5
|
||||||
|
DEFAULT_TIMEOUT_SECONDS: Final[float] = 10.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class WebSearchResult:
|
||||||
|
"""A single web search result."""
|
||||||
|
|
||||||
|
title: str
|
||||||
|
url: str
|
||||||
|
snippet: str
|
||||||
|
score: float = 1.0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, object]:
|
||||||
|
"""Convert to dictionary for serialization."""
|
||||||
|
return {
|
||||||
|
"title": self.title,
|
||||||
|
"url": self.url,
|
||||||
|
"snippet": self.snippet,
|
||||||
|
"score": self.score,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class WebSearchResponse:
|
||||||
|
"""Response from a web search query."""
|
||||||
|
|
||||||
|
query: str
|
||||||
|
results: tuple[WebSearchResult, ...]
|
||||||
|
total_results: int
|
||||||
|
search_time_ms: float
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_results(self) -> bool:
|
||||||
|
"""Check if search returned any results."""
|
||||||
|
return len(self.results) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchProvider(Protocol):
|
||||||
|
"""Protocol for web search providers.
|
||||||
|
|
||||||
|
Implementations can integrate with:
|
||||||
|
- Exa AI
|
||||||
|
- SerpAPI
|
||||||
|
- Brave Search API
|
||||||
|
- Bing Web Search API
|
||||||
|
- Google Custom Search
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
max_results: int = DEFAULT_MAX_RESULTS,
|
||||||
|
) -> WebSearchResponse:
|
||||||
|
"""Execute a web search query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query string.
|
||||||
|
max_results: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WebSearchResponse with search results.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class DisabledWebSearchProvider:
|
||||||
|
"""Stub provider that returns empty results.
|
||||||
|
|
||||||
|
Used when web search is not configured or disabled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
_max_results: int = DEFAULT_MAX_RESULTS,
|
||||||
|
) -> WebSearchResponse:
|
||||||
|
"""Return empty results - web search disabled."""
|
||||||
|
logger.debug("web_search_disabled", query=query[:50])
|
||||||
|
return WebSearchResponse(
|
||||||
|
query=query,
|
||||||
|
results=(),
|
||||||
|
total_results=0,
|
||||||
|
search_time_ms=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WebSearchConfig:
|
||||||
|
"""Configuration for web search node."""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
max_results: int = DEFAULT_MAX_RESULTS
|
||||||
|
timeout_seconds: float = DEFAULT_TIMEOUT_SECONDS
|
||||||
|
require_approval: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
def format_results_for_context(results: tuple[WebSearchResult, ...]) -> str:
|
||||||
|
"""Format web search results for LLM context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Web search results to format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string suitable for LLM context.
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
formatted_parts: list[str] = ["## Web Search Results\n"]
|
||||||
|
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
formatted_parts.append(f"### [{i}] {result.title}")
|
||||||
|
formatted_parts.append(f"Source: {result.url}")
|
||||||
|
formatted_parts.append(f"{result.snippet}\n")
|
||||||
|
|
||||||
|
return "\n".join(formatted_parts)
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_web_search(
|
||||||
|
query: str,
|
||||||
|
provider: WebSearchProvider,
|
||||||
|
config: WebSearchConfig,
|
||||||
|
) -> WebSearchResponse:
|
||||||
|
"""Execute web search with configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query derived from user question.
|
||||||
|
provider: Web search provider implementation.
|
||||||
|
config: Search configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WebSearchResponse with results (empty if disabled).
|
||||||
|
"""
|
||||||
|
if not config.enabled:
|
||||||
|
logger.debug("web_search_skipped_disabled")
|
||||||
|
return WebSearchResponse(
|
||||||
|
query=query,
|
||||||
|
results=(),
|
||||||
|
total_results=0,
|
||||||
|
search_time_ms=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("web_search_executing", query=query[:100], max_results=config.max_results)
|
||||||
|
|
||||||
|
response = await provider.search(
|
||||||
|
query=query,
|
||||||
|
max_results=config.max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"web_search_completed",
|
||||||
|
query=query[:50],
|
||||||
|
result_count=len(response.results),
|
||||||
|
search_time_ms=response.search_time_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def merge_contexts(
|
||||||
|
transcript_context: str,
|
||||||
|
web_results: WebSearchResponse,
|
||||||
|
) -> str:
|
||||||
|
"""Merge transcript segments with web search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transcript_context: Context from transcript segments.
|
||||||
|
web_results: Web search response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined context for LLM synthesis.
|
||||||
|
"""
|
||||||
|
if not web_results.has_results:
|
||||||
|
return transcript_context
|
||||||
|
|
||||||
|
web_context = format_results_for_context(web_results.results)
|
||||||
|
|
||||||
|
return f"""## Meeting Transcript Context
|
||||||
|
{transcript_context}
|
||||||
|
|
||||||
|
{web_context}
|
||||||
|
|
||||||
|
Note: Web search results are provided as supplementary context.
|
||||||
|
Prioritize information from the meeting transcript when answering questions about the meeting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def derive_search_query(question: str, meeting_context: str | None = None) -> str:
|
||||||
|
"""Derive a web search query from user question and context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question: User's original question.
|
||||||
|
meeting_context: Optional context about the meeting topic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optimized search query.
|
||||||
|
"""
|
||||||
|
# Simple approach: use the question directly
|
||||||
|
# A more sophisticated approach would use an LLM to generate the query
|
||||||
|
query = question.strip()
|
||||||
|
|
||||||
|
# Add meeting context keywords if available
|
||||||
|
if meeting_context:
|
||||||
|
# Extract key terms from context (simplified)
|
||||||
|
context_terms = meeting_context[:100].strip()
|
||||||
|
if context_terms:
|
||||||
|
query = f"{query} {context_terms}"
|
||||||
|
|
||||||
|
# Limit query length
|
||||||
|
max_query_length = 256
|
||||||
|
if len(query) > max_query_length:
|
||||||
|
query = query[:max_query_length].rsplit(" ", 1)[0]
|
||||||
|
|
||||||
|
return query
|
||||||
27
src/noteflow/infrastructure/ai/tools/__init__.py
Normal file
27
src/noteflow/infrastructure/ai/tools/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""Tool adapters for LangGraph workflows."""
|
||||||
|
|
||||||
|
from noteflow.infrastructure.ai.tools.retrieval import (
|
||||||
|
BatchEmbedderProtocol,
|
||||||
|
EmbedderProtocol,
|
||||||
|
RetrievalResult,
|
||||||
|
retrieve_segments,
|
||||||
|
retrieve_segments_batch,
|
||||||
|
retrieve_segments_workspace,
|
||||||
|
retrieve_segments_workspace_batch,
|
||||||
|
)
|
||||||
|
from noteflow.infrastructure.ai.tools.synthesis import (
|
||||||
|
SynthesisResult,
|
||||||
|
synthesize_answer,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BatchEmbedderProtocol",
|
||||||
|
"EmbedderProtocol",
|
||||||
|
"RetrievalResult",
|
||||||
|
"SynthesisResult",
|
||||||
|
"retrieve_segments",
|
||||||
|
"retrieve_segments_batch",
|
||||||
|
"retrieve_segments_workspace",
|
||||||
|
"retrieve_segments_workspace_batch",
|
||||||
|
"synthesize_answer",
|
||||||
|
]
|
||||||
237
src/noteflow/infrastructure/ai/tools/retrieval.py
Normal file
237
src/noteflow/infrastructure/ai/tools/retrieval.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
"""Segment retrieval tools for LangGraph workflows."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Final, Protocol, runtime_checkable
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from noteflow.domain.value_objects import MeetingId
|
||||||
|
|
||||||
|
# Limit concurrent parallel operations to prevent resource exhaustion
|
||||||
|
MAX_CONCURRENT_OPERATIONS: Final[int] = 10
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class EmbedderProtocol(Protocol):
|
||||||
|
"""Protocol for text embedding providers."""
|
||||||
|
|
||||||
|
async def embed(self, text: str) -> list[float]:
|
||||||
|
"""Embed a single text string."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class BatchEmbedderProtocol(EmbedderProtocol, Protocol):
|
||||||
|
"""Extended protocol for embedders supporting batch operations."""
|
||||||
|
|
||||||
|
async def embed_batch(self, texts: Sequence[str]) -> list[list[float]]:
|
||||||
|
"""Embed multiple texts in a single batch operation.
|
||||||
|
|
||||||
|
More efficient than calling embed() multiple times for providers
|
||||||
|
that support batching (reduces API calls, leverages GPU batching).
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentLike(Protocol):
|
||||||
|
segment_id: int
|
||||||
|
meeting_id: MeetingId | None
|
||||||
|
text: str
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentSearchProtocol(Protocol):
|
||||||
|
async def search_semantic(
|
||||||
|
self,
|
||||||
|
query_embedding: list[float],
|
||||||
|
limit: int,
|
||||||
|
meeting_id: MeetingId | None,
|
||||||
|
) -> Sequence[tuple[SegmentLike, float]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceSegmentSearchProtocol(Protocol):
|
||||||
|
async def search_semantic_workspace(
|
||||||
|
self,
|
||||||
|
query_embedding: list[float],
|
||||||
|
workspace_id: UUID,
|
||||||
|
project_id: UUID | None,
|
||||||
|
limit: int,
|
||||||
|
) -> Sequence[tuple[SegmentLike, float]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RetrievalResult:
|
||||||
|
segment_id: int
|
||||||
|
meeting_id: UUID
|
||||||
|
text: str
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
def _meeting_id_to_uuid(mid: MeetingId | None) -> UUID:
|
||||||
|
if mid is None:
|
||||||
|
msg = "meeting_id is required for RetrievalResult"
|
||||||
|
raise ValueError(msg)
|
||||||
|
return UUID(str(mid))
|
||||||
|
|
||||||
|
|
||||||
|
async def retrieve_segments(
|
||||||
|
query: str,
|
||||||
|
embedder: EmbedderProtocol,
|
||||||
|
segment_repo: SegmentSearchProtocol,
|
||||||
|
meeting_id: MeetingId | None = None,
|
||||||
|
top_k: int = 8,
|
||||||
|
) -> list[RetrievalResult]:
|
||||||
|
"""Retrieve relevant transcript segments via semantic search."""
|
||||||
|
query_embedding = await embedder.embed(query)
|
||||||
|
results = await segment_repo.search_semantic(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
limit=top_k,
|
||||||
|
meeting_id=meeting_id,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
RetrievalResult(
|
||||||
|
segment_id=segment.segment_id,
|
||||||
|
meeting_id=_meeting_id_to_uuid(segment.meeting_id),
|
||||||
|
text=segment.text,
|
||||||
|
start_time=segment.start_time,
|
||||||
|
end_time=segment.end_time,
|
||||||
|
score=score,
|
||||||
|
)
|
||||||
|
for segment, score in results
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def retrieve_segments_workspace(
|
||||||
|
query: str,
|
||||||
|
embedder: EmbedderProtocol,
|
||||||
|
segment_repo: WorkspaceSegmentSearchProtocol,
|
||||||
|
workspace_id: UUID,
|
||||||
|
project_id: UUID | None = None,
|
||||||
|
top_k: int = 20,
|
||||||
|
) -> list[RetrievalResult]:
|
||||||
|
"""Retrieve relevant transcript segments across workspace/project via semantic search."""
|
||||||
|
query_embedding = await embedder.embed(query)
|
||||||
|
results = await segment_repo.search_semantic_workspace(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
project_id=project_id,
|
||||||
|
limit=top_k,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
RetrievalResult(
|
||||||
|
segment_id=segment.segment_id,
|
||||||
|
meeting_id=_meeting_id_to_uuid(segment.meeting_id),
|
||||||
|
text=segment.text,
|
||||||
|
start_time=segment.start_time,
|
||||||
|
end_time=segment.end_time,
|
||||||
|
score=score,
|
||||||
|
)
|
||||||
|
for segment, score in results
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def _embed_batch_fallback(
|
||||||
|
texts: Sequence[str],
|
||||||
|
embedder: EmbedderProtocol,
|
||||||
|
) -> list[list[float]]:
|
||||||
|
"""Embed multiple texts, using batch API if available or parallel fallback."""
|
||||||
|
if isinstance(embedder, BatchEmbedderProtocol):
|
||||||
|
return await embedder.embed_batch(texts)
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(MAX_CONCURRENT_OPERATIONS)
|
||||||
|
|
||||||
|
async def _bounded_embed(text: str) -> list[float]:
|
||||||
|
async with semaphore:
|
||||||
|
return await embedder.embed(text)
|
||||||
|
|
||||||
|
return list(await asyncio.gather(*(_bounded_embed(t) for t in texts)))
|
||||||
|
|
||||||
|
|
||||||
|
async def retrieve_segments_batch(
|
||||||
|
queries: Sequence[str],
|
||||||
|
embedder: EmbedderProtocol,
|
||||||
|
segment_repo: SegmentSearchProtocol,
|
||||||
|
meeting_id: MeetingId | None = None,
|
||||||
|
top_k: int = 8,
|
||||||
|
) -> list[list[RetrievalResult]]:
|
||||||
|
"""Retrieve segments for multiple queries in parallel.
|
||||||
|
|
||||||
|
Uses batch embedding when available, then parallel search execution.
|
||||||
|
Returns results in the same order as input queries.
|
||||||
|
"""
|
||||||
|
if not queries:
|
||||||
|
return []
|
||||||
|
embeddings = await _embed_batch_fallback(list(queries), embedder)
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(MAX_CONCURRENT_OPERATIONS)
|
||||||
|
|
||||||
|
async def _search(emb: list[float]) -> list[RetrievalResult]:
|
||||||
|
async with semaphore:
|
||||||
|
results = await segment_repo.search_semantic(
|
||||||
|
query_embedding=emb,
|
||||||
|
limit=top_k,
|
||||||
|
meeting_id=meeting_id,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
RetrievalResult(
|
||||||
|
segment_id=seg.segment_id,
|
||||||
|
meeting_id=_meeting_id_to_uuid(seg.meeting_id),
|
||||||
|
text=seg.text,
|
||||||
|
start_time=seg.start_time,
|
||||||
|
end_time=seg.end_time,
|
||||||
|
score=score,
|
||||||
|
)
|
||||||
|
for seg, score in results
|
||||||
|
]
|
||||||
|
|
||||||
|
search_results = await asyncio.gather(*(_search(emb) for emb in embeddings))
|
||||||
|
return list(search_results)
|
||||||
|
|
||||||
|
|
||||||
|
async def retrieve_segments_workspace_batch(
|
||||||
|
queries: Sequence[str],
|
||||||
|
embedder: EmbedderProtocol,
|
||||||
|
segment_repo: WorkspaceSegmentSearchProtocol,
|
||||||
|
workspace_id: UUID,
|
||||||
|
project_id: UUID | None = None,
|
||||||
|
top_k: int = 20,
|
||||||
|
) -> list[list[RetrievalResult]]:
|
||||||
|
"""Retrieve workspace segments for multiple queries in parallel.
|
||||||
|
|
||||||
|
Uses batch embedding when available, then parallel search execution.
|
||||||
|
Returns results in the same order as input queries.
|
||||||
|
"""
|
||||||
|
if not queries:
|
||||||
|
return []
|
||||||
|
embeddings = await _embed_batch_fallback(list(queries), embedder)
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(MAX_CONCURRENT_OPERATIONS)
|
||||||
|
|
||||||
|
async def _search(emb: list[float]) -> list[RetrievalResult]:
|
||||||
|
async with semaphore:
|
||||||
|
results = await segment_repo.search_semantic_workspace(
|
||||||
|
query_embedding=emb,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
project_id=project_id,
|
||||||
|
limit=top_k,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
RetrievalResult(
|
||||||
|
segment_id=seg.segment_id,
|
||||||
|
meeting_id=_meeting_id_to_uuid(seg.meeting_id),
|
||||||
|
text=seg.text,
|
||||||
|
start_time=seg.start_time,
|
||||||
|
end_time=seg.end_time,
|
||||||
|
score=score,
|
||||||
|
)
|
||||||
|
for seg, score in results
|
||||||
|
]
|
||||||
|
|
||||||
|
search_results = await asyncio.gather(*(_search(emb) for emb in embeddings))
|
||||||
|
return list(search_results)
|
||||||
60
src/noteflow/infrastructure/ai/tools/synthesis.py
Normal file
60
src/noteflow/infrastructure/ai/tools/synthesis.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""Answer synthesis tools for LangGraph workflows."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Final, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from noteflow.infrastructure.ai.tools.retrieval import RetrievalResult
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProtocol(Protocol):
|
||||||
|
async def complete(self, prompt: str) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SynthesisResult:
|
||||||
|
answer: str
|
||||||
|
cited_segment_ids: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
SYNTHESIS_PROMPT_TEMPLATE: Final[
|
||||||
|
str
|
||||||
|
] = """Answer the question based on the following transcript segments.
|
||||||
|
Cite specific segments by their ID when making claims.
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
|
||||||
|
Segments:
|
||||||
|
{segments}
|
||||||
|
|
||||||
|
Answer (cite segment IDs in brackets like [1], [3]):"""
|
||||||
|
|
||||||
|
CITATION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\[(\d+)\]")
|
||||||
|
|
||||||
|
|
||||||
|
async def synthesize_answer(
|
||||||
|
question: str,
|
||||||
|
segments: list[RetrievalResult],
|
||||||
|
llm: LLMProtocol,
|
||||||
|
) -> SynthesisResult:
|
||||||
|
"""Generate answer with segment citations using LLM."""
|
||||||
|
segment_text = "\n".join(
|
||||||
|
f"[{s.segment_id}] ({s.start_time:.1f}s-{s.end_time:.1f}s): {s.text}" for s in segments
|
||||||
|
)
|
||||||
|
prompt = SYNTHESIS_PROMPT_TEMPLATE.format(
|
||||||
|
question=question,
|
||||||
|
segments=segment_text,
|
||||||
|
)
|
||||||
|
answer = await llm.complete(prompt)
|
||||||
|
valid_ids = {s.segment_id for s in segments}
|
||||||
|
cited_ids = extract_cited_ids(answer, valid_ids)
|
||||||
|
return SynthesisResult(answer=answer, cited_segment_ids=cited_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_cited_ids(answer: str, valid_ids: set[int]) -> list[int]:
|
||||||
|
matches = CITATION_PATTERN.findall(answer)
|
||||||
|
cited = [int(m) for m in matches if int(m) in valid_ids]
|
||||||
|
return list(dict.fromkeys(cited))
|
||||||
1
tests/domain/ai/__init__.py
Normal file
1
tests/domain/ai/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for domain/ai/ module."""
|
||||||
118
tests/domain/ai/test_citations.py
Normal file
118
tests/domain/ai/test_citations.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from noteflow.domain.ai.citations import SegmentCitation
|
||||||
|
|
||||||
|
|
||||||
|
class TestSegmentCitation:
|
||||||
|
def test_creation_with_valid_values(self) -> None:
|
||||||
|
meeting_id = uuid4()
|
||||||
|
citation = SegmentCitation(
|
||||||
|
meeting_id=meeting_id,
|
||||||
|
segment_id=1,
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=5.0,
|
||||||
|
text="Test segment text",
|
||||||
|
score=0.95,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert citation.meeting_id == meeting_id
|
||||||
|
assert citation.segment_id == 1
|
||||||
|
assert citation.start_time == 0.0
|
||||||
|
assert citation.end_time == 5.0
|
||||||
|
assert citation.text == "Test segment text"
|
||||||
|
assert citation.score == 0.95
|
||||||
|
|
||||||
|
def test_duration_property(self) -> None:
|
||||||
|
citation = SegmentCitation(
|
||||||
|
meeting_id=uuid4(),
|
||||||
|
segment_id=1,
|
||||||
|
start_time=10.0,
|
||||||
|
end_time=25.0,
|
||||||
|
text="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert citation.duration == 15.0
|
||||||
|
|
||||||
|
def test_default_score_is_zero(self) -> None:
|
||||||
|
citation = SegmentCitation(
|
||||||
|
meeting_id=uuid4(),
|
||||||
|
segment_id=1,
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=1.0,
|
||||||
|
text="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert citation.score == 0.0
|
||||||
|
|
||||||
|
def test_rejects_negative_segment_id(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="segment_id must be non-negative"):
|
||||||
|
SegmentCitation(
|
||||||
|
meeting_id=uuid4(),
|
||||||
|
segment_id=-1,
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=5.0,
|
||||||
|
text="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_rejects_negative_start_time(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="start_time must be non-negative"):
|
||||||
|
SegmentCitation(
|
||||||
|
meeting_id=uuid4(),
|
||||||
|
segment_id=1,
|
||||||
|
start_time=-1.0,
|
||||||
|
end_time=5.0,
|
||||||
|
text="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_rejects_end_time_before_start_time(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="end_time must be >= start_time"):
|
||||||
|
SegmentCitation(
|
||||||
|
meeting_id=uuid4(),
|
||||||
|
segment_id=1,
|
||||||
|
start_time=10.0,
|
||||||
|
end_time=5.0,
|
||||||
|
text="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invalid_score",
|
||||||
|
[
|
||||||
|
pytest.param(-0.1, id="negative"),
|
||||||
|
pytest.param(1.1, id="above_one"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_rejects_invalid_score(self, invalid_score: float) -> None:
|
||||||
|
with pytest.raises(ValueError, match="score must be between 0 and 1"):
|
||||||
|
SegmentCitation(
|
||||||
|
meeting_id=uuid4(),
|
||||||
|
segment_id=1,
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=5.0,
|
||||||
|
text="Test",
|
||||||
|
score=invalid_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_accepts_zero_duration(self) -> None:
|
||||||
|
citation = SegmentCitation(
|
||||||
|
meeting_id=uuid4(),
|
||||||
|
segment_id=1,
|
||||||
|
start_time=5.0,
|
||||||
|
end_time=5.0,
|
||||||
|
text="Instant moment",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert citation.duration == 0.0
|
||||||
|
|
||||||
|
def test_is_frozen(self) -> None:
|
||||||
|
citation = SegmentCitation(
|
||||||
|
meeting_id=uuid4(),
|
||||||
|
segment_id=1,
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=5.0,
|
||||||
|
text="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
citation.text = "Modified" # type: ignore[misc]
|
||||||
1
tests/infrastructure/ai/__init__.py
Normal file
1
tests/infrastructure/ai/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for infrastructure/ai/ module."""
|
||||||
268
tests/infrastructure/ai/test_retrieval.py
Normal file
268
tests/infrastructure/ai/test_retrieval.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from noteflow.infrastructure.ai.tools.retrieval import (
|
||||||
|
BatchEmbedderProtocol,
|
||||||
|
RetrievalResult,
|
||||||
|
retrieve_segments,
|
||||||
|
retrieve_segments_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockSegment:
|
||||||
|
segment_id: int
|
||||||
|
meeting_id: object
|
||||||
|
text: str
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrieveSegments:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embedder(self) -> AsyncMock:
|
||||||
|
embedder = AsyncMock()
|
||||||
|
embedder.embed.return_value = [0.1, 0.2, 0.3]
|
||||||
|
return embedder
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_segment_repo(self) -> AsyncMock:
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_meeting_id(self) -> object:
|
||||||
|
return uuid4()
|
||||||
|
|
||||||
|
async def test_retrieve_segments_success(
|
||||||
|
self,
|
||||||
|
mock_embedder: AsyncMock,
|
||||||
|
mock_segment_repo: AsyncMock,
|
||||||
|
sample_meeting_id: object,
|
||||||
|
) -> None:
|
||||||
|
segment = MockSegment(
|
||||||
|
segment_id=1,
|
||||||
|
meeting_id=sample_meeting_id,
|
||||||
|
text="Test segment",
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=5.0,
|
||||||
|
)
|
||||||
|
mock_segment_repo.search_semantic.return_value = [(segment, 0.95)]
|
||||||
|
|
||||||
|
results = await retrieve_segments(
|
||||||
|
query="test query",
|
||||||
|
embedder=mock_embedder,
|
||||||
|
segment_repo=mock_segment_repo,
|
||||||
|
meeting_id=sample_meeting_id, # type: ignore[arg-type]
|
||||||
|
top_k=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].segment_id == 1
|
||||||
|
assert results[0].text == "Test segment"
|
||||||
|
assert results[0].score == 0.95
|
||||||
|
|
||||||
|
async def test_retrieve_segments_calls_embedder_with_query(
|
||||||
|
self,
|
||||||
|
mock_embedder: AsyncMock,
|
||||||
|
mock_segment_repo: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
mock_segment_repo.search_semantic.return_value = []
|
||||||
|
|
||||||
|
await retrieve_segments(
|
||||||
|
query="what happened in the meeting",
|
||||||
|
embedder=mock_embedder,
|
||||||
|
segment_repo=mock_segment_repo,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_embedder.embed.assert_called_once_with("what happened in the meeting")
|
||||||
|
|
||||||
|
async def test_retrieve_segments_passes_embedding_to_repo(
|
||||||
|
self,
|
||||||
|
mock_embedder: AsyncMock,
|
||||||
|
mock_segment_repo: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
mock_embedder.embed.return_value = [1.0, 2.0, 3.0]
|
||||||
|
mock_segment_repo.search_semantic.return_value = []
|
||||||
|
|
||||||
|
await retrieve_segments(
|
||||||
|
query="test",
|
||||||
|
embedder=mock_embedder,
|
||||||
|
segment_repo=mock_segment_repo,
|
||||||
|
top_k=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_segment_repo.search_semantic.assert_called_once_with(
|
||||||
|
query_embedding=[1.0, 2.0, 3.0],
|
||||||
|
meeting_id=None,
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_retrieve_segments_empty_result(
|
||||||
|
self,
|
||||||
|
mock_embedder: AsyncMock,
|
||||||
|
mock_segment_repo: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
mock_segment_repo.search_semantic.return_value = []
|
||||||
|
|
||||||
|
results = await retrieve_segments(
|
||||||
|
query="test",
|
||||||
|
embedder=mock_embedder,
|
||||||
|
segment_repo=mock_segment_repo,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
async def test_retrieval_result_is_frozen(self) -> None:
|
||||||
|
result = RetrievalResult(
|
||||||
|
segment_id=1,
|
||||||
|
meeting_id=uuid4(),
|
||||||
|
text="Test",
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=5.0,
|
||||||
|
score=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
result.text = "Modified" # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
class MockBatchEmbedder:
|
||||||
|
def __init__(self, embedding: list[float]) -> None:
|
||||||
|
self._embedding = embedding
|
||||||
|
self.embed_calls: list[str] = []
|
||||||
|
self.embed_batch_calls: list[Sequence[str]] = []
|
||||||
|
|
||||||
|
async def embed(self, text: str) -> list[float]:
|
||||||
|
self.embed_calls.append(text)
|
||||||
|
return self._embedding
|
||||||
|
|
||||||
|
async def embed_batch(self, texts: Sequence[str]) -> list[list[float]]:
|
||||||
|
self.embed_batch_calls.append(texts)
|
||||||
|
return [self._embedding for _ in texts]
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrieveSegmentsBatch:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embedder(self) -> AsyncMock:
|
||||||
|
embedder = AsyncMock()
|
||||||
|
embedder.embed.return_value = [0.1, 0.2, 0.3]
|
||||||
|
return embedder
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def batch_embedder(self) -> MockBatchEmbedder:
|
||||||
|
return MockBatchEmbedder([0.1, 0.2, 0.3])
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_segment_repo(self) -> AsyncMock:
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_meeting_id(self) -> object:
|
||||||
|
return uuid4()
|
||||||
|
|
||||||
|
async def test_batch_returns_empty_for_no_queries(
|
||||||
|
self,
|
||||||
|
mock_embedder: AsyncMock,
|
||||||
|
mock_segment_repo: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
results = await retrieve_segments_batch(
|
||||||
|
queries=[],
|
||||||
|
embedder=mock_embedder,
|
||||||
|
segment_repo=mock_segment_repo,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
mock_embedder.embed.assert_not_called()
|
||||||
|
|
||||||
|
async def test_batch_uses_embed_batch_when_available(
|
||||||
|
self,
|
||||||
|
batch_embedder: MockBatchEmbedder,
|
||||||
|
mock_segment_repo: AsyncMock,
|
||||||
|
sample_meeting_id: object,
|
||||||
|
) -> None:
|
||||||
|
segment = MockSegment(
|
||||||
|
segment_id=1,
|
||||||
|
meeting_id=sample_meeting_id,
|
||||||
|
text="Test",
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=5.0,
|
||||||
|
)
|
||||||
|
mock_segment_repo.search_semantic.return_value = [(segment, 0.9)]
|
||||||
|
|
||||||
|
assert isinstance(batch_embedder, BatchEmbedderProtocol)
|
||||||
|
|
||||||
|
results = await retrieve_segments_batch(
|
||||||
|
queries=["query1", "query2"],
|
||||||
|
embedder=batch_embedder,
|
||||||
|
segment_repo=mock_segment_repo,
|
||||||
|
meeting_id=sample_meeting_id, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert len(batch_embedder.embed_batch_calls) == 1
|
||||||
|
assert list(batch_embedder.embed_batch_calls[0]) == ["query1", "query2"]
|
||||||
|
assert batch_embedder.embed_calls == []
|
||||||
|
|
||||||
|
async def test_batch_falls_back_to_parallel_embed(
|
||||||
|
self,
|
||||||
|
mock_embedder: AsyncMock,
|
||||||
|
mock_segment_repo: AsyncMock,
|
||||||
|
sample_meeting_id: object,
|
||||||
|
) -> None:
|
||||||
|
segment = MockSegment(
|
||||||
|
segment_id=1,
|
||||||
|
meeting_id=sample_meeting_id,
|
||||||
|
text="Test",
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=5.0,
|
||||||
|
)
|
||||||
|
mock_segment_repo.search_semantic.return_value = [(segment, 0.9)]
|
||||||
|
|
||||||
|
results = await retrieve_segments_batch(
|
||||||
|
queries=["query1", "query2"],
|
||||||
|
embedder=mock_embedder,
|
||||||
|
segment_repo=mock_segment_repo,
|
||||||
|
meeting_id=sample_meeting_id, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert mock_embedder.embed.call_count == 2
|
||||||
|
|
||||||
|
async def test_batch_preserves_query_order(
|
||||||
|
self,
|
||||||
|
mock_segment_repo: AsyncMock,
|
||||||
|
sample_meeting_id: object,
|
||||||
|
) -> None:
|
||||||
|
segment1 = MockSegment(1, sample_meeting_id, "First", 0.0, 5.0)
|
||||||
|
segment2 = MockSegment(2, sample_meeting_id, "Second", 5.0, 10.0)
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def side_effect(
|
||||||
|
query_embedding: list[float],
|
||||||
|
limit: int,
|
||||||
|
meeting_id: object,
|
||||||
|
) -> list[tuple[MockSegment, float]]:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return [(segment1, 0.9)]
|
||||||
|
return [(segment2, 0.8)]
|
||||||
|
|
||||||
|
mock_segment_repo.search_semantic.side_effect = side_effect
|
||||||
|
|
||||||
|
embedder = MockBatchEmbedder([0.1, 0.2])
|
||||||
|
results = await retrieve_segments_batch(
|
||||||
|
queries=["first", "second"],
|
||||||
|
embedder=embedder,
|
||||||
|
segment_repo=mock_segment_repo,
|
||||||
|
meeting_id=sample_meeting_id, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0][0].text == "First"
|
||||||
|
assert results[1][0].text == "Second"
|
||||||
149
tests/infrastructure/ai/test_synthesis.py
Normal file
149
tests/infrastructure/ai/test_synthesis.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
from unittest.mock import AsyncMock
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from noteflow.infrastructure.ai.tools.retrieval import RetrievalResult
|
||||||
|
from noteflow.infrastructure.ai.tools.synthesis import (
|
||||||
|
SynthesisResult,
|
||||||
|
extract_cited_ids,
|
||||||
|
synthesize_answer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSynthesizeAnswer:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm(self) -> AsyncMock:
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_segments(self) -> list[RetrievalResult]:
|
||||||
|
meeting_id = uuid4()
|
||||||
|
return [
|
||||||
|
RetrievalResult(
|
||||||
|
segment_id=1,
|
||||||
|
meeting_id=meeting_id,
|
||||||
|
text="John discussed the project timeline",
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=5.0,
|
||||||
|
score=0.95,
|
||||||
|
),
|
||||||
|
RetrievalResult(
|
||||||
|
segment_id=3,
|
||||||
|
meeting_id=meeting_id,
|
||||||
|
text="The deadline is next Friday",
|
||||||
|
start_time=10.0,
|
||||||
|
end_time=15.0,
|
||||||
|
score=0.85,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def test_synthesize_answer_returns_result(
|
||||||
|
self,
|
||||||
|
mock_llm: AsyncMock,
|
||||||
|
sample_segments: list[RetrievalResult],
|
||||||
|
) -> None:
|
||||||
|
mock_llm.complete.return_value = "The project deadline is next Friday [3]."
|
||||||
|
|
||||||
|
result = await synthesize_answer(
|
||||||
|
question="What is the deadline?",
|
||||||
|
segments=sample_segments,
|
||||||
|
llm=mock_llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, SynthesisResult)
|
||||||
|
assert "deadline" in result.answer.lower()
|
||||||
|
|
||||||
|
async def test_synthesize_answer_extracts_citations(
|
||||||
|
self,
|
||||||
|
mock_llm: AsyncMock,
|
||||||
|
sample_segments: list[RetrievalResult],
|
||||||
|
) -> None:
|
||||||
|
mock_llm.complete.return_value = "John discussed timelines [1] and the deadline [3]."
|
||||||
|
|
||||||
|
result = await synthesize_answer(
|
||||||
|
question="What happened?",
|
||||||
|
segments=sample_segments,
|
||||||
|
llm=mock_llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.cited_segment_ids == [1, 3]
|
||||||
|
|
||||||
|
async def test_synthesize_answer_filters_invalid_citations(
|
||||||
|
self,
|
||||||
|
mock_llm: AsyncMock,
|
||||||
|
sample_segments: list[RetrievalResult],
|
||||||
|
) -> None:
|
||||||
|
mock_llm.complete.return_value = "Found [1], [99], and [3]."
|
||||||
|
|
||||||
|
result = await synthesize_answer(
|
||||||
|
question="What happened?",
|
||||||
|
segments=sample_segments,
|
||||||
|
llm=mock_llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert 99 not in result.cited_segment_ids
|
||||||
|
assert result.cited_segment_ids == [1, 3]
|
||||||
|
|
||||||
|
async def test_synthesize_answer_builds_prompt_with_segments(
|
||||||
|
self,
|
||||||
|
mock_llm: AsyncMock,
|
||||||
|
sample_segments: list[RetrievalResult],
|
||||||
|
) -> None:
|
||||||
|
mock_llm.complete.return_value = "Answer."
|
||||||
|
|
||||||
|
await synthesize_answer(
|
||||||
|
question="What is happening?",
|
||||||
|
segments=sample_segments,
|
||||||
|
llm=mock_llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = mock_llm.complete.call_args
|
||||||
|
prompt = call_args[0][0]
|
||||||
|
assert "What is happening?" in prompt
|
||||||
|
assert "[1]" in prompt
|
||||||
|
assert "[3]" in prompt
|
||||||
|
assert "John discussed" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractCitedIds:
|
||||||
|
def test_extracts_single_citation(self) -> None:
|
||||||
|
result = extract_cited_ids("The answer is here [5].", {1, 3, 5})
|
||||||
|
|
||||||
|
assert result == [5]
|
||||||
|
|
||||||
|
def test_extracts_multiple_citations(self) -> None:
|
||||||
|
result = extract_cited_ids("See [1] and [3] for details.", {1, 3, 5})
|
||||||
|
|
||||||
|
assert result == [1, 3]
|
||||||
|
|
||||||
|
def test_filters_invalid_ids(self) -> None:
|
||||||
|
result = extract_cited_ids("See [1] and [99].", {1, 3, 5})
|
||||||
|
|
||||||
|
assert result == [1]
|
||||||
|
|
||||||
|
def test_deduplicates_citations(self) -> None:
|
||||||
|
result = extract_cited_ids("See [1] and then [1] again.", {1, 3})
|
||||||
|
|
||||||
|
assert result == [1]
|
||||||
|
|
||||||
|
def test_preserves_order(self) -> None:
|
||||||
|
result = extract_cited_ids("[3] comes first, then [1].", {1, 3})
|
||||||
|
|
||||||
|
assert result == [3, 1]
|
||||||
|
|
||||||
|
def test_empty_for_no_citations(self) -> None:
|
||||||
|
result = extract_cited_ids("No citations here.", {1, 3})
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestSynthesisResult:
|
||||||
|
def test_is_frozen(self) -> None:
|
||||||
|
result = SynthesisResult(
|
||||||
|
answer="Test answer",
|
||||||
|
cited_segment_ids=[1, 2],
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
result.answer = "Modified" # type: ignore[misc]
|
||||||
1
typings/langgraph/__init__.pyi
Normal file
1
typings/langgraph/__init__.pyi
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Type stubs for langgraph
|
||||||
1
typings/langgraph/checkpoint/postgres/__init__.pyi
Normal file
1
typings/langgraph/checkpoint/postgres/__init__.pyi
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Type stubs for langgraph-checkpoint-postgres
|
||||||
10
typings/langgraph/checkpoint/postgres/aio.pyi
Normal file
10
typings/langgraph/checkpoint/postgres/aio.pyi
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# Type stubs for langgraph-checkpoint-postgres async module
|
||||||
|
|
||||||
|
class AsyncPostgresSaver:
|
||||||
|
"""Async PostgreSQL checkpointer for LangGraph.
|
||||||
|
|
||||||
|
This stub provides typing for the langgraph-checkpoint-postgres package.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pool: object) -> None: ...
|
||||||
|
async def setup(self) -> None: ...
|
||||||
41
typings/langgraph/graph/__init__.pyi
Normal file
41
typings/langgraph/graph/__init__.pyi
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# Type stubs for langgraph.graph
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
_StateT = TypeVar("_StateT")
|
||||||
|
|
||||||
|
class START:
|
||||||
|
"""Sentinel for graph start node."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
class END:
|
||||||
|
"""Sentinel for graph end node."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
class CompiledStateGraph(Generic[_StateT]):
|
||||||
|
"""Compiled state graph that can be invoked."""
|
||||||
|
|
||||||
|
async def ainvoke(self, input: _StateT) -> _StateT: ...
|
||||||
|
def invoke(self, input: _StateT) -> _StateT: ...
|
||||||
|
|
||||||
|
class StateGraph(Generic[_StateT]):
|
||||||
|
"""State graph builder.
|
||||||
|
|
||||||
|
This stub provides typing for langgraph StateGraph.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, state_schema: type[_StateT]) -> None: ...
|
||||||
|
def add_node(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
action: Callable[[_StateT], dict[str, object]]
|
||||||
|
| Callable[[_StateT], Coroutine[object, object, dict[str, object]]],
|
||||||
|
) -> None: ...
|
||||||
|
def add_edge(
|
||||||
|
self,
|
||||||
|
start_key: str | type[START],
|
||||||
|
end_key: str | type[END],
|
||||||
|
) -> None: ...
|
||||||
|
def compile(self) -> CompiledStateGraph[_StateT]: ...
|
||||||
Reference in New Issue
Block a user