recovery
This commit is contained in:
@@ -0,0 +1,590 @@
|
||||
# Sprint 25: LangGraph Foundation
|
||||
|
||||
> **Size**: L | **Owner**: Backend | **Phase**: 5 - Platform Evolution
|
||||
> **Effort**: ~1 sprint | **Prerequisites**: Sprint 19 (Embeddings)
|
||||
|
||||
---
|
||||
|
||||
## Objective
|
||||
|
||||
Establish LangGraph infrastructure and wrap existing summarization as proof of pattern.
|
||||
|
||||
---
|
||||
|
||||
## Current State Analysis
|
||||
|
||||
### What Exists
|
||||
|
||||
| Component | Location | Status |
|
||||
|-----------|----------|--------|
|
||||
| Summarization service | `application/services/summarization/` | ✅ Working |
|
||||
| Segment semantic search | `infrastructure/persistence/repositories/segment_repo.py` | ✅ Working |
|
||||
| Cloud consent pattern | `application/services/summarization/_consent_manager.py` | ✅ Working |
|
||||
| Usage event tracking | `application/observability/ports.py` | ✅ Working |
|
||||
| gRPC mixin pattern | `grpc/_mixins/` | ✅ Working |
|
||||
|
||||
### What's Missing
|
||||
|
||||
| Component | Target Location | Sprint Task |
|
||||
|-----------|-----------------|-------------|
|
||||
| LangGraph dependencies | `pyproject.toml` | Task 1 |
|
||||
| State schemas | `domain/ai/state.py` | Task 2 |
|
||||
| Checkpointer factory | `infrastructure/ai/checkpointer.py` | Task 3 |
|
||||
| Retrieval tools | `infrastructure/ai/tools/retrieval.py` | Task 4 |
|
||||
| Synthesis tools | `infrastructure/ai/tools/synthesis.py` | Task 5 |
|
||||
| Summarization graph | `infrastructure/ai/graphs/summarization.py` | Task 6 |
|
||||
| AssistantService | `application/services/assistant/` | Task 7 |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Tasks
|
||||
|
||||
### Task 1: Add LangGraph Dependencies
|
||||
|
||||
**File**: `pyproject.toml`
|
||||
|
||||
Add new optional dependency group:
|
||||
|
||||
```toml
|
||||
langgraph = [
|
||||
"langgraph>=0.2",
|
||||
"langgraph-checkpoint-postgres>=2.0",
|
||||
"langchain-core>=0.3",
|
||||
]
|
||||
```
|
||||
|
||||
Also add to `optional` and `all` groups.
|
||||
|
||||
**Verification**: `uv pip install -e ".[langgraph]"` succeeds
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Create State Schemas
|
||||
|
||||
**Files to create**:
|
||||
- `src/noteflow/domain/ai/__init__.py`
|
||||
- `src/noteflow/domain/ai/state.py`
|
||||
- `src/noteflow/domain/ai/citations.py`
|
||||
- `src/noteflow/domain/ai/ports.py`
|
||||
|
||||
#### `state.py` - Input/Output/Internal Separation
|
||||
|
||||
```python
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
import operator
|
||||
|
||||
|
||||
class AssistantInputState(TypedDict):
|
||||
"""Public API input - what clients send."""
|
||||
question: str
|
||||
meeting_id: UUID | None
|
||||
thread_id: str | None
|
||||
allow_web: bool
|
||||
top_k: int
|
||||
|
||||
|
||||
class AssistantOutputState(TypedDict):
|
||||
"""Public API output - what clients receive."""
|
||||
answer: str
|
||||
citations: list[dict]
|
||||
suggested_annotations: list[dict]
|
||||
thread_id: str
|
||||
|
||||
|
||||
class AssistantInternalState(AssistantInputState):
|
||||
"""Internal graph state - can evolve without breaking API."""
|
||||
# Retrieval
|
||||
retrieved_segment_ids: Annotated[list[int], operator.add]
|
||||
retrieved_segments: list[dict]
|
||||
|
||||
# Synthesis
|
||||
draft_answer: str
|
||||
verification_passed: bool
|
||||
|
||||
# Tracking
|
||||
loop_count: int
|
||||
```
|
||||
|
||||
#### `citations.py` - Value Object
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SegmentCitation:
|
||||
"""Reference to transcript segment used as evidence."""
|
||||
meeting_id: UUID
|
||||
segment_id: int
|
||||
start_time: float
|
||||
end_time: float
|
||||
text: str
|
||||
score: float = 0.0
|
||||
```
|
||||
|
||||
#### `ports.py` - Protocol Definition
|
||||
|
||||
```python
|
||||
from typing import Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.ai.state import AssistantOutputState
|
||||
|
||||
|
||||
class AssistantPort(Protocol):
|
||||
"""Protocol for AI assistant operations."""
|
||||
|
||||
async def ask(
|
||||
self,
|
||||
question: str,
|
||||
meeting_id: UUID | None = None,
|
||||
thread_id: str | None = None,
|
||||
allow_web: bool = False,
|
||||
top_k: int = 8,
|
||||
) -> AssistantOutputState:
|
||||
...
|
||||
```
|
||||
|
||||
**Verification**: `basedpyright src/noteflow/domain/ai/` passes
|
||||
|
||||
---
|
||||
|
||||
### Task 3: Create Checkpointer Factory
|
||||
|
||||
**File**: `src/noteflow/infrastructure/ai/checkpointer.py`
|
||||
|
||||
```python
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
CHECKPOINTER_POOL_SIZE: Final[int] = 5
|
||||
|
||||
|
||||
async def create_checkpointer(
|
||||
database_url: str,
|
||||
pool_size: int = CHECKPOINTER_POOL_SIZE,
|
||||
) -> AsyncPostgresSaver:
|
||||
"""Create async Postgres checkpointer for LangGraph state persistence."""
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=database_url,
|
||||
max_size=pool_size,
|
||||
kwargs={"autocommit": True},
|
||||
)
|
||||
checkpointer = AsyncPostgresSaver(pool)
|
||||
await checkpointer.setup()
|
||||
return checkpointer
|
||||
```
|
||||
|
||||
**Verification**: Unit test with mock pool
|
||||
|
||||
---
|
||||
|
||||
### Task 4: Create Retrieval Tools
|
||||
|
||||
**File**: `src/noteflow/infrastructure/ai/tools/retrieval.py`
|
||||
|
||||
```python
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.entities import Segment
|
||||
|
||||
|
||||
class EmbedderProtocol(Protocol):
|
||||
"""Protocol for text embedding."""
|
||||
async def embed(self, text: str) -> list[float]: ...
|
||||
|
||||
|
||||
class SegmentSearchProtocol(Protocol):
|
||||
"""Protocol for semantic segment search."""
|
||||
async def search_semantic(
|
||||
self,
|
||||
query_embedding: list[float],
|
||||
meeting_id: UUID | None,
|
||||
limit: int,
|
||||
) -> list[tuple[Segment, float]]: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
"""Result from segment retrieval."""
|
||||
segment_id: int
|
||||
meeting_id: UUID
|
||||
text: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
score: float
|
||||
|
||||
|
||||
async def retrieve_segments(
|
||||
query: str,
|
||||
embedder: EmbedderProtocol,
|
||||
segment_repo: SegmentSearchProtocol,
|
||||
meeting_id: UUID | None = None,
|
||||
top_k: int = 8,
|
||||
) -> list[RetrievalResult]:
|
||||
"""Retrieve relevant transcript segments via semantic search."""
|
||||
query_embedding = await embedder.embed(query)
|
||||
results = await segment_repo.search_semantic(
|
||||
query_embedding=query_embedding,
|
||||
meeting_id=meeting_id,
|
||||
limit=top_k,
|
||||
)
|
||||
return [
|
||||
RetrievalResult(
|
||||
segment_id=segment.segment_id,
|
||||
meeting_id=segment.meeting_id,
|
||||
text=segment.text,
|
||||
start_time=segment.start_time,
|
||||
end_time=segment.end_time,
|
||||
score=score,
|
||||
)
|
||||
for segment, score in results
|
||||
]
|
||||
```
|
||||
|
||||
**Verification**: Unit test with mock embedder and repo
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Create Synthesis Tools
|
||||
|
||||
**File**: `src/noteflow/infrastructure/ai/tools/synthesis.py`
|
||||
|
||||
```python
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.ai.tools.retrieval import RetrievalResult
|
||||
|
||||
|
||||
class LLMProtocol(Protocol):
|
||||
"""Protocol for LLM completion."""
|
||||
async def complete(self, prompt: str) -> str: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class SynthesisResult:
|
||||
"""Result from answer synthesis."""
|
||||
answer: str
|
||||
cited_segment_ids: list[int]
|
||||
|
||||
|
||||
SYNTHESIS_PROMPT_TEMPLATE = '''Answer the question based on the following transcript segments.
|
||||
Cite specific segments by their ID when making claims.
|
||||
|
||||
Question: {question}
|
||||
|
||||
Segments:
|
||||
{segments}
|
||||
|
||||
Answer (cite segment IDs in brackets like [1], [3]):'''
|
||||
|
||||
|
||||
async def synthesize_answer(
|
||||
question: str,
|
||||
segments: list[RetrievalResult],
|
||||
llm: LLMProtocol,
|
||||
) -> SynthesisResult:
|
||||
"""Generate answer with segment citations."""
|
||||
segment_text = "\n".join(
|
||||
f"[{s.segment_id}] ({s.start_time:.1f}s-{s.end_time:.1f}s): {s.text}"
|
||||
for s in segments
|
||||
)
|
||||
prompt = SYNTHESIS_PROMPT_TEMPLATE.format(
|
||||
question=question,
|
||||
segments=segment_text,
|
||||
)
|
||||
answer = await llm.complete(prompt)
|
||||
cited_ids = _extract_cited_ids(answer, [s.segment_id for s in segments])
|
||||
return SynthesisResult(answer=answer, cited_segment_ids=cited_ids)
|
||||
|
||||
|
||||
def _extract_cited_ids(answer: str, valid_ids: list[int]) -> list[int]:
|
||||
"""Extract segment IDs cited in the answer."""
|
||||
import re
|
||||
pattern = r'\[(\d+)\]'
|
||||
matches = re.findall(pattern, answer)
|
||||
cited = [int(m) for m in matches if int(m) in valid_ids]
|
||||
return list(dict.fromkeys(cited)) # Dedupe preserving order
|
||||
```
|
||||
|
||||
**Verification**: Unit test with mock LLM
|
||||
|
||||
---
|
||||
|
||||
### Task 6: Create Summarization Graph Wrapper
|
||||
|
||||
**File**: `src/noteflow/infrastructure/ai/graphs/summarization.py`
|
||||
|
||||
This wraps the existing SummarizationService in a LangGraph graph as proof of pattern.
|
||||
|
||||
```python
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
from collections.abc import Sequence
|
||||
from noteflow.domain.entities import Segment
|
||||
from noteflow.application.services.summarization import SummarizationService
|
||||
|
||||
|
||||
class SummarizationState(TypedDict):
|
||||
"""State for summarization graph."""
|
||||
meeting_id: UUID
|
||||
segments: Sequence[Segment]
|
||||
summary_text: str
|
||||
key_points: list[dict]
|
||||
action_items: list[dict]
|
||||
|
||||
|
||||
def build_summarization_graph(
|
||||
summarization_service: SummarizationService,
|
||||
) -> StateGraph:
|
||||
"""Build LangGraph wrapper around existing summarization service."""
|
||||
|
||||
async def summarize_node(state: SummarizationState) -> dict:
|
||||
result = await summarization_service.summarize(
|
||||
meeting_id=state["meeting_id"],
|
||||
segments=state["segments"],
|
||||
)
|
||||
summary = result.summary
|
||||
return {
|
||||
"summary_text": summary.executive_summary,
|
||||
"key_points": [
|
||||
{"text": kp.text, "segment_ids": kp.segment_ids}
|
||||
for kp in summary.key_points
|
||||
],
|
||||
"action_items": [
|
||||
{"text": ai.text, "segment_ids": ai.segment_ids, "assignee": ai.assignee}
|
||||
for ai in summary.action_items
|
||||
],
|
||||
}
|
||||
|
||||
builder = StateGraph(SummarizationState)
|
||||
builder.add_node("summarize", summarize_node)
|
||||
builder.add_edge(START, "summarize")
|
||||
builder.add_edge("summarize", END)
|
||||
|
||||
return builder.compile()
|
||||
```
|
||||
|
||||
**Verification**: Integration test with mock summarization service
|
||||
|
||||
---
|
||||
|
||||
### Task 7: Create AssistantService Shell
|
||||
|
||||
**Files**:
|
||||
- `src/noteflow/application/services/assistant/__init__.py`
|
||||
- `src/noteflow/application/services/assistant/assistant_service.py`
|
||||
|
||||
#### `__init__.py`
|
||||
|
||||
```python
|
||||
from noteflow.application.services.assistant.assistant_service import AssistantService
|
||||
|
||||
__all__ = ["AssistantService"]
|
||||
```
|
||||
|
||||
#### `assistant_service.py`
|
||||
|
||||
```python
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Final
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.application.observability.ports import NullUsageEventSink, UsageEventSink
|
||||
from noteflow.domain.ai.state import AssistantOutputState
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DEFAULT_TOP_K: Final[int] = 8
|
||||
THREAD_ID_PREFIX: Final[str] = "meeting"
|
||||
|
||||
|
||||
def build_thread_id(meeting_id: UUID | None, user_id: UUID, graph_name: str) -> str:
|
||||
"""Build deterministic thread_id for checkpointing."""
|
||||
meeting_part = str(meeting_id) if meeting_id else "workspace"
|
||||
return f"{THREAD_ID_PREFIX}:{meeting_part}:user:{user_id}:graph:{graph_name}:v1"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantService:
|
||||
"""Orchestrates AI assistant workflows via LangGraph."""
|
||||
|
||||
uow_factory: Callable[[], UnitOfWork]
|
||||
usage_events: UsageEventSink = field(default_factory=NullUsageEventSink)
|
||||
|
||||
async def ask(
|
||||
self,
|
||||
question: str,
|
||||
user_id: UUID,
|
||||
meeting_id: UUID | None = None,
|
||||
thread_id: str | None = None,
|
||||
allow_web: bool = False,
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
) -> AssistantOutputState:
|
||||
"""Ask a question about meeting transcript(s)."""
|
||||
effective_thread_id = thread_id or build_thread_id(
|
||||
meeting_id, user_id, "meeting_qa"
|
||||
)
|
||||
logger.info(
|
||||
"assistant_ask",
|
||||
question_length=len(question),
|
||||
meeting_id=str(meeting_id) if meeting_id else None,
|
||||
thread_id=effective_thread_id,
|
||||
)
|
||||
# TODO: Implement in Sprint 26
|
||||
return AssistantOutputState(
|
||||
answer="Not implemented yet",
|
||||
citations=[],
|
||||
suggested_annotations=[],
|
||||
thread_id=effective_thread_id,
|
||||
)
|
||||
```
|
||||
|
||||
**Verification**: Unit test for thread_id generation
|
||||
|
||||
---
|
||||
|
||||
## Test Plan
|
||||
|
||||
### Unit Tests (`tests/domain/ai/`)
|
||||
|
||||
```python
|
||||
# test_citations.py
|
||||
def test_segment_citation_creation():
|
||||
citation = SegmentCitation(
|
||||
meeting_id=uuid4(),
|
||||
segment_id=1,
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
text="Test segment",
|
||||
score=0.95,
|
||||
)
|
||||
assert citation.duration == 5.0
|
||||
|
||||
|
||||
def test_segment_citation_invalid_times():
|
||||
with pytest.raises(ValueError):
|
||||
SegmentCitation(
|
||||
meeting_id=uuid4(),
|
||||
segment_id=1,
|
||||
start_time=10.0,
|
||||
end_time=5.0, # Invalid: end < start
|
||||
text="Test",
|
||||
)
|
||||
```
|
||||
|
||||
### Unit Tests (`tests/infrastructure/ai/`)
|
||||
|
||||
```python
|
||||
# test_retrieval.py
|
||||
async def test_retrieve_segments_success(mock_embedder, mock_segment_repo):
|
||||
mock_embedder.embed.return_value = [0.1, 0.2, 0.3]
|
||||
mock_segment_repo.search_semantic.return_value = [
|
||||
(sample_segment, 0.95),
|
||||
]
|
||||
|
||||
results = await retrieve_segments(
|
||||
query="test query",
|
||||
embedder=mock_embedder,
|
||||
segment_repo=mock_segment_repo,
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].score == 0.95
|
||||
|
||||
|
||||
# test_synthesis.py
|
||||
async def test_synthesize_answer_extracts_citations(mock_llm):
|
||||
mock_llm.complete.return_value = "The answer is X [1] and Y [3]."
|
||||
|
||||
result = await synthesize_answer(
|
||||
question="What happened?",
|
||||
segments=[...],
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
assert result.cited_segment_ids == [1, 3]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Acceptance Criteria
|
||||
|
||||
- [ ] `uv pip install -e ".[langgraph]"` succeeds
|
||||
- [ ] `basedpyright src/noteflow/domain/ai/` passes with 0 errors
|
||||
- [ ] `basedpyright src/noteflow/infrastructure/ai/` passes with 0 errors
|
||||
- [ ] `pytest tests/domain/ai/` passes
|
||||
- [ ] `pytest tests/infrastructure/ai/` passes
|
||||
- [ ] Existing summarization behavior unchanged (`pytest tests/application/services/test_summarization*`)
|
||||
- [ ] `make quality` passes
|
||||
|
||||
---
|
||||
|
||||
## Rollback Plan
|
||||
|
||||
If issues arise:
|
||||
1. Remove langgraph from dependencies
|
||||
2. Delete `domain/ai/` and `infrastructure/ai/` directories
|
||||
3. AssistantService is not wired to gRPC yet, so no API impact
|
||||
|
||||
---
|
||||
|
||||
## Files Created/Modified
|
||||
|
||||
| Action | Path |
|
||||
|--------|------|
|
||||
| Modified | `pyproject.toml` |
|
||||
| Created | `src/noteflow/domain/ai/__init__.py` |
|
||||
| Created | `src/noteflow/domain/ai/state.py` |
|
||||
| Created | `src/noteflow/domain/ai/citations.py` |
|
||||
| Created | `src/noteflow/domain/ai/ports.py` |
|
||||
| Created | `src/noteflow/infrastructure/ai/__init__.py` |
|
||||
| Created | `src/noteflow/infrastructure/ai/checkpointer.py` |
|
||||
| Created | `src/noteflow/infrastructure/ai/tools/__init__.py` |
|
||||
| Created | `src/noteflow/infrastructure/ai/tools/retrieval.py` |
|
||||
| Created | `src/noteflow/infrastructure/ai/tools/synthesis.py` |
|
||||
| Created | `src/noteflow/infrastructure/ai/graphs/__init__.py` |
|
||||
| Created | `src/noteflow/infrastructure/ai/graphs/summarization.py` |
|
||||
| Created | `src/noteflow/application/services/assistant/__init__.py` |
|
||||
| Created | `src/noteflow/application/services/assistant/assistant_service.py` |
|
||||
| Created | `tests/domain/ai/test_citations.py` |
|
||||
| Created | `tests/domain/ai/test_state.py` |
|
||||
| Created | `tests/infrastructure/ai/test_retrieval.py` |
|
||||
| Created | `tests/infrastructure/ai/test_synthesis.py` |
|
||||
| Created | `tests/infrastructure/ai/test_checkpointer.py` |
|
||||
@@ -0,0 +1,530 @@
|
||||
# Sprint 26: Meeting Q&A MVP
|
||||
|
||||
> **Size**: L | **Owner**: Backend + Client | **Phase**: 5 - Platform Evolution
|
||||
> **Effort**: ~1 sprint | **Prerequisites**: Sprint 25 (Foundation)
|
||||
|
||||
---
|
||||
|
||||
## Objective
|
||||
|
||||
Implement single-meeting Q&A with segment citations via gRPC API and React UI.
|
||||
|
||||
---
|
||||
|
||||
## Current State (After Sprint 25)
|
||||
|
||||
| Component | Status |
|
||||
|-----------|--------|
|
||||
| LangGraph infrastructure | ✅ Ready |
|
||||
| State schemas | ✅ Ready |
|
||||
| Retrieval tools | ✅ Ready |
|
||||
| Synthesis tools | ✅ Ready |
|
||||
| AssistantService shell | ✅ Ready |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Tasks
|
||||
|
||||
### Task 1: Define MeetingQA Graph
|
||||
|
||||
**File**: `src/noteflow/infrastructure/ai/graphs/meeting_qa.py`
|
||||
|
||||
Graph flow: `retrieve → verify → synthesize`
|
||||
|
||||
```python
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
|
||||
class MeetingQAState(TypedDict):
|
||||
# Input
|
||||
question: str
|
||||
meeting_id: UUID
|
||||
top_k: int
|
||||
|
||||
# Internal
|
||||
retrieved_segments: list[RetrievalResult]
|
||||
verification_passed: bool
|
||||
|
||||
# Output
|
||||
answer: str
|
||||
citations: list[SegmentCitation]
|
||||
|
||||
|
||||
def build_meeting_qa_graph(
|
||||
embedder: EmbedderProtocol,
|
||||
segment_repo: SegmentSearchProtocol,
|
||||
llm: LLMProtocol,
|
||||
verifier: CitationVerifier,
|
||||
) -> StateGraph:
|
||||
|
||||
async def retrieve_node(state: MeetingQAState) -> dict:
|
||||
results = await retrieve_segments(
|
||||
query=state["question"],
|
||||
embedder=embedder,
|
||||
segment_repo=segment_repo,
|
||||
meeting_id=state["meeting_id"],
|
||||
top_k=state["top_k"],
|
||||
)
|
||||
return {"retrieved_segments": results}
|
||||
|
||||
async def verify_node(state: MeetingQAState) -> dict:
|
||||
# Verify segments exist and are relevant
|
||||
valid = len(state["retrieved_segments"]) > 0
|
||||
return {"verification_passed": valid}
|
||||
|
||||
async def synthesize_node(state: MeetingQAState) -> dict:
|
||||
if not state["verification_passed"]:
|
||||
return {
|
||||
"answer": "I couldn't find relevant information in this meeting.",
|
||||
"citations": [],
|
||||
}
|
||||
|
||||
result = await synthesize_answer(
|
||||
question=state["question"],
|
||||
segments=state["retrieved_segments"],
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
citations = [
|
||||
SegmentCitation(
|
||||
meeting_id=state["meeting_id"],
|
||||
segment_id=seg.segment_id,
|
||||
start_time=seg.start_time,
|
||||
end_time=seg.end_time,
|
||||
text=seg.text,
|
||||
score=seg.score,
|
||||
)
|
||||
for seg in state["retrieved_segments"]
|
||||
if seg.segment_id in result.cited_segment_ids
|
||||
]
|
||||
|
||||
return {"answer": result.answer, "citations": citations}
|
||||
|
||||
builder = StateGraph(MeetingQAState)
|
||||
builder.add_node("retrieve", retrieve_node)
|
||||
builder.add_node("verify", verify_node)
|
||||
builder.add_node("synthesize", synthesize_node)
|
||||
|
||||
builder.add_edge(START, "retrieve")
|
||||
builder.add_edge("retrieve", "verify")
|
||||
builder.add_edge("verify", "synthesize")
|
||||
builder.add_edge("synthesize", END)
|
||||
|
||||
return builder.compile()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Create Citation Verifier Node
|
||||
|
||||
**File**: `src/noteflow/infrastructure/ai/nodes/verification.py`
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationResult:
|
||||
is_valid: bool
|
||||
invalid_citation_indices: list[int]
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
def verify_citations(
|
||||
answer: str,
|
||||
cited_ids: list[int],
|
||||
available_ids: set[int],
|
||||
) -> VerificationResult:
|
||||
"""Verify all cited segment IDs exist in available segments."""
|
||||
invalid = [i for i, cid in enumerate(cited_ids) if cid not in available_ids]
|
||||
return VerificationResult(
|
||||
is_valid=len(invalid) == 0,
|
||||
invalid_citation_indices=invalid,
|
||||
reason=f"Invalid citations: {invalid}" if invalid else None,
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: Add Proto Messages
|
||||
|
||||
**File**: `src/noteflow/grpc/proto/noteflow.proto`
|
||||
|
||||
```protobuf
|
||||
// Add to existing proto
|
||||
|
||||
message SegmentCitation {
|
||||
string meeting_id = 1;
|
||||
int32 segment_id = 2;
|
||||
float start_time = 3;
|
||||
float end_time = 4;
|
||||
string text = 5;
|
||||
float score = 6;
|
||||
}
|
||||
|
||||
message AskAssistantRequest {
|
||||
string question = 1;
|
||||
optional string meeting_id = 2;
|
||||
optional string thread_id = 3;
|
||||
bool allow_web = 4;
|
||||
int32 top_k = 5;
|
||||
}
|
||||
|
||||
message AskAssistantResponse {
|
||||
string answer = 1;
|
||||
repeated SegmentCitation citations = 2;
|
||||
repeated SuggestedAnnotation suggested_annotations = 3;
|
||||
string thread_id = 4;
|
||||
}
|
||||
|
||||
message SuggestedAnnotation {
|
||||
string text = 1;
|
||||
AnnotationType type = 2;
|
||||
repeated int32 segment_ids = 3;
|
||||
}
|
||||
|
||||
// Add to NoteFlowService
|
||||
rpc AskAssistant(AskAssistantRequest) returns (AskAssistantResponse);
|
||||
```
|
||||
|
||||
After modifying proto:
|
||||
```bash
|
||||
python -m grpc_tools.protoc -I src/noteflow/grpc/proto \
|
||||
--python_out=src/noteflow/grpc/proto \
|
||||
--grpc_python_out=src/noteflow/grpc/proto \
|
||||
src/noteflow/grpc/proto/noteflow.proto
|
||||
|
||||
python scripts/patch_grpc_stubs.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: Add gRPC Mixin
|
||||
|
||||
**File**: `src/noteflow/grpc/_mixins/assistant.py`
|
||||
|
||||
```python
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.grpc.proto import noteflow_pb2 as pb
|
||||
from noteflow.grpc._mixins.protocols import ServicerHost
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from grpc.aio import ServicerContext
|
||||
|
||||
|
||||
class AssistantMixin:
|
||||
"""gRPC mixin for AI assistant operations."""
|
||||
|
||||
async def AskAssistant(
|
||||
self: ServicerHost,
|
||||
request: pb.AskAssistantRequest,
|
||||
context: ServicerContext,
|
||||
) -> pb.AskAssistantResponse:
|
||||
from uuid import UUID
|
||||
|
||||
meeting_id = UUID(request.meeting_id) if request.meeting_id else None
|
||||
op_context = await self.get_operation_context(context)
|
||||
|
||||
result = await self.assistant_service.ask(
|
||||
question=request.question,
|
||||
user_id=op_context.user.id,
|
||||
meeting_id=meeting_id,
|
||||
thread_id=request.thread_id or None,
|
||||
allow_web=request.allow_web,
|
||||
top_k=request.top_k or 8,
|
||||
)
|
||||
|
||||
return pb.AskAssistantResponse(
|
||||
answer=result["answer"],
|
||||
citations=[
|
||||
pb.SegmentCitation(
|
||||
meeting_id=str(c["meeting_id"]),
|
||||
segment_id=c["segment_id"],
|
||||
start_time=c["start_time"],
|
||||
end_time=c["end_time"],
|
||||
text=c["text"],
|
||||
score=c["score"],
|
||||
)
|
||||
for c in result["citations"]
|
||||
],
|
||||
thread_id=result["thread_id"],
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Add Rust Command
|
||||
|
||||
**File**: `client/src-tauri/src/commands/assistant.rs`
|
||||
|
||||
```rust
|
||||
use crate::grpc::client::GrpcClient;
|
||||
use crate::grpc::types::assistant::{AskAssistantRequest, AskAssistantResponse};
|
||||
use tauri::State;
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn ask_assistant(
|
||||
client: State<'_, GrpcClient>,
|
||||
question: String,
|
||||
meeting_id: Option<String>,
|
||||
thread_id: Option<String>,
|
||||
allow_web: bool,
|
||||
top_k: i32,
|
||||
) -> Result<AskAssistantResponse, String> {
|
||||
let request = AskAssistantRequest {
|
||||
question,
|
||||
meeting_id,
|
||||
thread_id,
|
||||
allow_web,
|
||||
top_k,
|
||||
};
|
||||
|
||||
client
|
||||
.ask_assistant(request)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 6: Add TypeScript Adapter
|
||||
|
||||
**File**: `client/src/api/tauri-adapter.ts` (add method)
|
||||
|
||||
```typescript
|
||||
async askAssistant(params: AskAssistantParams): Promise<AskAssistantResponse> {
|
||||
return invoke('ask_assistant', {
|
||||
question: params.question,
|
||||
meetingId: params.meetingId,
|
||||
threadId: params.threadId,
|
||||
allowWeb: params.allowWeb ?? false,
|
||||
topK: params.topK ?? 8,
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
**File**: `client/src/api/types/assistant.ts`
|
||||
|
||||
```typescript
|
||||
export interface SegmentCitation {
|
||||
meetingId: string;
|
||||
segmentId: number;
|
||||
startTime: number;
|
||||
endTime: number;
|
||||
text: string;
|
||||
score: number;
|
||||
}
|
||||
|
||||
export interface AskAssistantParams {
|
||||
question: string;
|
||||
meetingId?: string;
|
||||
threadId?: string;
|
||||
allowWeb?: boolean;
|
||||
topK?: number;
|
||||
}
|
||||
|
||||
export interface AskAssistantResponse {
|
||||
answer: string;
|
||||
citations: SegmentCitation[];
|
||||
suggestedAnnotations: SuggestedAnnotation[];
|
||||
threadId: string;
|
||||
}
|
||||
|
||||
export interface SuggestedAnnotation {
|
||||
text: string;
|
||||
type: AnnotationType;
|
||||
segmentIds: number[];
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 7: Create Ask UI Component
|
||||
|
||||
**File**: `client/src/components/meeting/AskPanel.tsx`
|
||||
|
||||
```tsx
|
||||
import { useState } from 'react';
|
||||
import { useAssistant } from '@/hooks/use-assistant';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Textarea } from '@/components/ui/textarea';
|
||||
import { Card } from '@/components/ui/card';
|
||||
|
||||
interface AskPanelProps {
|
||||
meetingId: string;
|
||||
onCitationClick?: (segmentId: number) => void;
|
||||
}
|
||||
|
||||
export function AskPanel({ meetingId, onCitationClick }: AskPanelProps) {
|
||||
const [question, setQuestion] = useState('');
|
||||
const { ask, isLoading, response, error } = useAssistant();
|
||||
|
||||
const handleAsk = async () => {
|
||||
if (!question.trim()) return;
|
||||
await ask({ question, meetingId });
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4 p-4">
|
||||
<Textarea
|
||||
placeholder="Ask a question about this meeting..."
|
||||
value={question}
|
||||
onChange={(e) => setQuestion(e.target.value)}
|
||||
disabled={isLoading}
|
||||
/>
|
||||
<Button onClick={handleAsk} disabled={isLoading || !question.trim()}>
|
||||
{isLoading ? 'Thinking...' : 'Ask'}
|
||||
</Button>
|
||||
|
||||
{response && (
|
||||
<Card className="p-4">
|
||||
<p className="whitespace-pre-wrap">{response.answer}</p>
|
||||
{response.citations.length > 0 && (
|
||||
<div className="mt-4 border-t pt-2">
|
||||
<p className="text-sm text-muted-foreground mb-2">Sources:</p>
|
||||
{response.citations.map((citation) => (
|
||||
<button
|
||||
key={citation.segmentId}
|
||||
onClick={() => onCitationClick?.(citation.segmentId)}
|
||||
className="text-sm text-blue-600 hover:underline block"
|
||||
>
|
||||
[{citation.startTime.toFixed(1)}s] {citation.text.slice(0, 50)}...
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</Card>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<p className="text-sm text-red-600">{error}</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
**File**: `client/src/hooks/use-assistant.ts`
|
||||
|
||||
```typescript
|
||||
import { useState, useCallback } from 'react';
|
||||
import { api } from '@/api';
|
||||
import type { AskAssistantParams, AskAssistantResponse } from '@/api/types/assistant';
|
||||
|
||||
export function useAssistant() {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [response, setResponse] = useState<AskAssistantResponse | null>(null);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const ask = useCallback(async (params: AskAssistantParams) => {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const result = await api.askAssistant(params);
|
||||
setResponse(result);
|
||||
return result;
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to get answer');
|
||||
throw err;
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const reset = useCallback(() => {
|
||||
setResponse(null);
|
||||
setError(null);
|
||||
}, []);
|
||||
|
||||
return { ask, isLoading, response, error, reset };
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 8: Implement AssistantService.ask()
|
||||
|
||||
Complete the implementation in `application/services/assistant/assistant_service.py`:
|
||||
|
||||
```python
|
||||
async def ask(
|
||||
self,
|
||||
question: str,
|
||||
user_id: UUID,
|
||||
meeting_id: UUID | None = None,
|
||||
thread_id: str | None = None,
|
||||
allow_web: bool = False,
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
) -> AssistantOutputState:
|
||||
"""Ask a question about meeting transcript(s)."""
|
||||
effective_thread_id = thread_id or build_thread_id(
|
||||
meeting_id, user_id, "meeting_qa"
|
||||
)
|
||||
|
||||
async with self.uow_factory() as uow:
|
||||
# Build and run graph
|
||||
graph = build_meeting_qa_graph(
|
||||
embedder=self._embedder,
|
||||
segment_repo=uow.segments,
|
||||
llm=self._llm,
|
||||
verifier=self._verifier,
|
||||
)
|
||||
|
||||
config = {"configurable": {"thread_id": effective_thread_id}}
|
||||
result = await graph.ainvoke(
|
||||
{
|
||||
"question": question,
|
||||
"meeting_id": meeting_id,
|
||||
"top_k": top_k,
|
||||
},
|
||||
config,
|
||||
)
|
||||
|
||||
# Record usage
|
||||
self.usage_events.record_simple(
|
||||
"assistant.ask",
|
||||
meeting_id=str(meeting_id) if meeting_id else None,
|
||||
question_length=len(question),
|
||||
citation_count=len(result.get("citations", [])),
|
||||
)
|
||||
|
||||
return AssistantOutputState(
|
||||
answer=result["answer"],
|
||||
citations=[asdict(c) for c in result["citations"]],
|
||||
suggested_annotations=[],
|
||||
thread_id=effective_thread_id,
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Acceptance Criteria
|
||||
|
||||
- [ ] Q&A returns answers with valid segment citations
|
||||
- [ ] Citations link to correct timestamps in transcript
|
||||
- [ ] Feature hidden when `rag_enabled=false` in project rules
|
||||
- [ ] Thread ID persists conversation context
|
||||
- [ ] `make quality` passes
|
||||
- [ ] `pytest tests/grpc/test_assistant.py` passes
|
||||
- [ ] UI component displays answer and clickable citations
|
||||
|
||||
---
|
||||
|
||||
## Files Created/Modified
|
||||
|
||||
| Action | Path |
|
||||
|--------|------|
|
||||
| Created | `src/noteflow/infrastructure/ai/graphs/meeting_qa.py` |
|
||||
| Created | `src/noteflow/infrastructure/ai/nodes/verification.py` |
|
||||
| Modified | `src/noteflow/grpc/proto/noteflow.proto` |
|
||||
| Created | `src/noteflow/grpc/_mixins/assistant.py` |
|
||||
| Modified | `src/noteflow/grpc/service.py` (add mixin) |
|
||||
| Created | `client/src-tauri/src/commands/assistant.rs` |
|
||||
| Modified | `client/src/api/tauri-adapter.ts` |
|
||||
| Created | `client/src/api/types/assistant.ts` |
|
||||
| Created | `client/src/components/meeting/AskPanel.tsx` |
|
||||
| Created | `client/src/hooks/use-assistant.ts` |
|
||||
| Modified | `src/noteflow/application/services/assistant/assistant_service.py` |
|
||||
@@ -0,0 +1,87 @@
|
||||
# Sprint 27: Cross-Meeting RAG
|
||||
|
||||
> **Size**: M | **Owner**: Backend + Client | **Phase**: 5 - Platform Evolution
|
||||
> **Effort**: ~1 sprint | **Prerequisites**: Sprint 26 (Meeting Q&A MVP)
|
||||
|
||||
---
|
||||
|
||||
## Objective
|
||||
|
||||
Enable workspace-scoped Q&A and annotation suggestions across multiple meetings.
|
||||
|
||||
---
|
||||
|
||||
## Key Tasks
|
||||
|
||||
### Task 1: Workspace-Scoped Semantic Search
|
||||
|
||||
Extend `SegmentRepository` with workspace-scoped search:
|
||||
|
||||
```python
|
||||
async def search_semantic_workspace(
|
||||
self,
|
||||
query_embedding: list[float],
|
||||
workspace_id: UUID,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[tuple[Segment, float]]:
|
||||
"""Search segments across all meetings in workspace/project."""
|
||||
```
|
||||
|
||||
### Task 2: WorkspaceQA Graph
|
||||
|
||||
Create `infrastructure/ai/graphs/workspace_qa.py`:
|
||||
|
||||
- Similar to MeetingQA but omits `meeting_id` filter
|
||||
- Groups results by meeting for citation display
|
||||
- Returns cross-meeting citations
|
||||
|
||||
### Task 3: Annotation Suggester
|
||||
|
||||
Add annotation suggestion output to graph:
|
||||
|
||||
```python
|
||||
class SuggestedAnnotation:
|
||||
text: str
|
||||
annotation_type: AnnotationType
|
||||
segment_ids: list[int]
|
||||
confidence: float
|
||||
```
|
||||
|
||||
### Task 4: Conversation History
|
||||
|
||||
Implement thread persistence with checkpointer:
|
||||
|
||||
- Store conversation turns in graph state
|
||||
- Support follow-up questions
|
||||
- Maintain context across requests
|
||||
|
||||
### Task 5: Apply Annotation Flow
|
||||
|
||||
UI flow to apply suggested annotations:
|
||||
|
||||
1. Display suggested annotations in AskPanel
|
||||
2. User clicks "Apply" on suggestion
|
||||
3. Call existing `AddAnnotation` RPC
|
||||
4. Update UI to show applied status
|
||||
|
||||
---
|
||||
|
||||
## Acceptance Criteria
|
||||
|
||||
- [ ] Cross-meeting queries return results from multiple meetings
|
||||
- [ ] Suggested annotations can be applied with one click
|
||||
- [ ] Conversation history persists across requests
|
||||
- [ ] Follow-up questions reference previous context
|
||||
- [ ] `make quality` passes
|
||||
|
||||
---
|
||||
|
||||
## Files Created/Modified
|
||||
|
||||
| Action | Path |
|
||||
|--------|------|
|
||||
| Modified | `src/noteflow/infrastructure/persistence/repositories/segment_repo.py` |
|
||||
| Created | `src/noteflow/infrastructure/ai/graphs/workspace_qa.py` |
|
||||
| Modified | `src/noteflow/application/services/assistant/assistant_service.py` |
|
||||
| Modified | `client/src/components/meeting/AskPanel.tsx` |
|
||||
@@ -0,0 +1,146 @@
|
||||
# Sprint 28: Advanced Capabilities
|
||||
|
||||
> **Size**: L | **Owner**: Backend + Client | **Phase**: 5 - Platform Evolution
|
||||
> **Effort**: ~1 sprint | **Prerequisites**: Sprint 27 (Cross-Meeting RAG)
|
||||
|
||||
---
|
||||
|
||||
## Objective
|
||||
|
||||
Production hardening with streaming responses, caching, guardrails, and optional web search.
|
||||
|
||||
---
|
||||
|
||||
## Key Tasks
|
||||
|
||||
### Task 1: Streaming Responses
|
||||
|
||||
Implement `StreamAssistant` RPC for progressive answer generation:
|
||||
|
||||
```protobuf
|
||||
rpc StreamAssistant(AskAssistantRequest) returns (stream AskAssistantResponse);
|
||||
```
|
||||
|
||||
Use LangGraph's `astream` with `stream_mode="messages"`:
|
||||
|
||||
```python
|
||||
async for chunk in graph.astream(state, config, stream_mode="messages"):
|
||||
yield pb.AskAssistantResponse(answer=chunk.content, partial=True)
|
||||
```
|
||||
|
||||
### Task 2: Embedding Cache
|
||||
|
||||
Create `infrastructure/ai/cache.py`:
|
||||
|
||||
- LRU cache for query embeddings
|
||||
- Reduce redundant embedding calls
|
||||
- Configurable TTL and max size
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class EmbeddingCache:
|
||||
max_size: int = 1000
|
||||
ttl_seconds: int = 3600
|
||||
|
||||
async def get_or_compute(
|
||||
self,
|
||||
text: str,
|
||||
embedder: EmbedderProtocol,
|
||||
) -> list[float]:
|
||||
...
|
||||
```
|
||||
|
||||
### Task 3: Content Guardrails
|
||||
|
||||
Create `infrastructure/ai/guardrails.py`:
|
||||
|
||||
- Input validation (length, content)
|
||||
- Output filtering (PII, harmful content)
|
||||
- Configurable rules per workspace
|
||||
|
||||
```python
|
||||
class GuardrailResult:
|
||||
allowed: bool
|
||||
reason: str | None
|
||||
filtered_content: str | None
|
||||
|
||||
|
||||
async def check_input(text: str, rules: GuardrailRules) -> GuardrailResult:
|
||||
...
|
||||
|
||||
async def filter_output(text: str, rules: GuardrailRules) -> GuardrailResult:
|
||||
...
|
||||
```
|
||||
|
||||
### Task 4: Web Search Node
|
||||
|
||||
Create `infrastructure/ai/nodes/web_search.py`:
|
||||
|
||||
- Optional node triggered by `allow_web=true`
|
||||
- Integrate with web search API
|
||||
- Merge web results with transcript evidence
|
||||
|
||||
### Task 5: AGENT_PROGRESS Tauri Event
|
||||
|
||||
Emit progress events for UI feedback:
|
||||
|
||||
```rust
|
||||
// Emit during graph execution
|
||||
app.emit_all("AGENT_PROGRESS", AgentProgressPayload {
|
||||
stage: "retrieving",
|
||||
progress: 0.3,
|
||||
message: "Searching transcript...",
|
||||
})?;
|
||||
```
|
||||
|
||||
### Task 6: Interrupts for Approval
|
||||
|
||||
Implement LangGraph interrupts for:
|
||||
|
||||
- Web search approval
|
||||
- Annotation creation approval
|
||||
- Sensitive action confirmation
|
||||
|
||||
### Task 7: Performance Optimization
|
||||
|
||||
- Batch embedding requests
|
||||
- Parallel segment retrieval
|
||||
- Connection pool tuning
|
||||
|
||||
---
|
||||
|
||||
## Acceptance Criteria
|
||||
|
||||
- [ ] Streaming shows progressive answer generation in UI
|
||||
- [ ] Cache reduces embedding latency by >50% for repeated queries
|
||||
- [ ] Guardrails block inappropriate content
|
||||
- [ ] Web search gated by `allow_web` flag
|
||||
- [ ] Progress events update UI during long operations
|
||||
- [ ] `make quality` passes
|
||||
- [ ] E2E tests pass (`client/e2e/assistant.spec.ts`)
|
||||
|
||||
---
|
||||
|
||||
## Success Metrics
|
||||
|
||||
| Metric | Target |
|
||||
| ------------------ | ------ |
|
||||
| Q&A latency (p95) | < 3s |
|
||||
| Citation accuracy | > 90% |
|
||||
| Cache hit rate | > 60% |
|
||||
| Hallucination rate | < 5% |
|
||||
|
||||
---
|
||||
|
||||
## Files Created/Modified
|
||||
|
||||
| Action | Path |
|
||||
| -------- | ---------------------------------------------------- |
|
||||
| Modified | `src/noteflow/grpc/proto/noteflow.proto` |
|
||||
| Created | `src/noteflow/grpc/_mixins/assistant_streaming.py` |
|
||||
| Created | `src/noteflow/infrastructure/ai/cache.py` |
|
||||
| Created | `src/noteflow/infrastructure/ai/guardrails.py` |
|
||||
| Created | `src/noteflow/infrastructure/ai/nodes/web_search.py` |
|
||||
| Modified | `client/src-tauri/src/commands/assistant.rs` |
|
||||
| Modified | `client/src/components/meeting/AskPanel.tsx` |
|
||||
| Created | `client/e2e/assistant.spec.ts` |
|
||||
38
src/noteflow/domain/ai/__init__.py
Normal file
38
src/noteflow/domain/ai/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""AI domain types for LangGraph workflows.
|
||||
|
||||
State schemas, citations, interrupts, and protocols for AI assistant functionality.
|
||||
"""
|
||||
|
||||
from noteflow.domain.ai.citations import SegmentCitation
|
||||
from noteflow.domain.ai.interrupts import (
|
||||
InterruptAction,
|
||||
InterruptConfig,
|
||||
InterruptRequest,
|
||||
InterruptResponse,
|
||||
InterruptType,
|
||||
create_annotation_interrupt,
|
||||
create_sensitive_action_interrupt,
|
||||
create_web_search_interrupt,
|
||||
)
|
||||
from noteflow.domain.ai.ports import AssistantPort
|
||||
from noteflow.domain.ai.state import (
|
||||
AssistantInputState,
|
||||
AssistantInternalState,
|
||||
AssistantOutputState,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AssistantInputState",
|
||||
"AssistantInternalState",
|
||||
"AssistantOutputState",
|
||||
"AssistantPort",
|
||||
"InterruptAction",
|
||||
"InterruptConfig",
|
||||
"InterruptRequest",
|
||||
"InterruptResponse",
|
||||
"InterruptType",
|
||||
"SegmentCitation",
|
||||
"create_annotation_interrupt",
|
||||
"create_sensitive_action_interrupt",
|
||||
"create_web_search_interrupt",
|
||||
]
|
||||
43
src/noteflow/domain/ai/citations.py
Normal file
43
src/noteflow/domain/ai/citations.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Citation value objects for AI-generated responses."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SegmentCitation:
|
||||
"""Reference to a transcript segment used as evidence.
|
||||
|
||||
Links AI-generated claims to source transcript segments for verification.
|
||||
"""
|
||||
|
||||
meeting_id: UUID
|
||||
segment_id: int
|
||||
start_time: float
|
||||
end_time: float
|
||||
text: str
|
||||
score: float = 0.0
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
"""Duration of the cited segment in seconds."""
|
||||
return self.end_time - self.start_time
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.segment_id < 0:
|
||||
msg = "segment_id must be non-negative"
|
||||
raise ValueError(msg)
|
||||
if self.start_time < 0:
|
||||
msg = "start_time must be non-negative"
|
||||
raise ValueError(msg)
|
||||
if self.end_time < self.start_time:
|
||||
msg = "end_time must be >= start_time"
|
||||
raise ValueError(msg)
|
||||
if self.score < 0 or self.score > 1:
|
||||
msg = "score must be between 0 and 1"
|
||||
raise ValueError(msg)
|
||||
202
src/noteflow/domain/ai/interrupts.py
Normal file
202
src/noteflow/domain/ai/interrupts.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Domain types for LangGraph human-in-the-loop interrupts.
|
||||
|
||||
Defines interrupt request/response types for approval workflows:
|
||||
- Web search approval
|
||||
- Annotation creation approval
|
||||
- Sensitive action confirmation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from typing import Final
|
||||
|
||||
|
||||
class InterruptType(StrEnum):
|
||||
"""Types of human-in-the-loop interrupts."""
|
||||
|
||||
WEB_SEARCH_APPROVAL = "web_search_approval"
|
||||
ANNOTATION_APPROVAL = "annotation_approval"
|
||||
SENSITIVE_ACTION = "sensitive_action"
|
||||
|
||||
|
||||
class InterruptAction(StrEnum):
|
||||
"""Possible user actions for an interrupt."""
|
||||
|
||||
APPROVE = "approve"
|
||||
REJECT = "reject"
|
||||
MODIFY = "modify"
|
||||
|
||||
|
||||
DEFAULT_WEB_SEARCH_OPTIONS: Final[tuple[str, ...]] = ("approve", "reject")
|
||||
DEFAULT_ANNOTATION_OPTIONS: Final[tuple[str, ...]] = ("approve", "reject", "modify")
|
||||
DEFAULT_SENSITIVE_OPTIONS: Final[tuple[str, ...]] = ("approve", "reject")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InterruptConfig:
|
||||
"""Configuration for interrupt behavior."""
|
||||
|
||||
allow_ignore: bool = False
|
||||
allow_modify: bool = False
|
||||
timeout_seconds: float | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InterruptRequest:
|
||||
"""A request for human approval during graph execution.
|
||||
|
||||
Sent to the client when the graph hits an interrupt point.
|
||||
|
||||
Attributes:
|
||||
interrupt_type: Category of interrupt (web_search, annotation, etc.)
|
||||
message: Human-readable description of what needs approval.
|
||||
context: Additional context data for the decision (query, entities, etc.)
|
||||
options: Available response options.
|
||||
config: Interrupt configuration (allow_ignore, timeout, etc.)
|
||||
request_id: Unique identifier for this interrupt request.
|
||||
"""
|
||||
|
||||
interrupt_type: InterruptType
|
||||
message: str
|
||||
context: dict[str, object] = field(default_factory=dict)
|
||||
options: tuple[str, ...] = field(default_factory=lambda: ("approve", "reject"))
|
||||
config: InterruptConfig = field(default_factory=InterruptConfig)
|
||||
request_id: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, object]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"interrupt_type": self.interrupt_type,
|
||||
"message": self.message,
|
||||
"context": self.context,
|
||||
"options": list(self.options),
|
||||
"config": {
|
||||
"allow_ignore": self.config.allow_ignore,
|
||||
"allow_modify": self.config.allow_modify,
|
||||
"timeout_seconds": self.config.timeout_seconds,
|
||||
},
|
||||
"request_id": self.request_id,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InterruptResponse:
|
||||
"""User's response to an interrupt request.
|
||||
|
||||
Returned from the client to resume graph execution.
|
||||
|
||||
Attributes:
|
||||
action: The action taken (approve, reject, modify).
|
||||
request_id: ID of the interrupt request being responded to.
|
||||
modified_value: If action is MODIFY, the modified value.
|
||||
user_message: Optional message from the user.
|
||||
"""
|
||||
|
||||
action: InterruptAction
|
||||
request_id: str = ""
|
||||
modified_value: dict[str, object] | None = None
|
||||
user_message: str | None = None
|
||||
|
||||
@property
|
||||
def is_approved(self) -> bool:
|
||||
"""Check if the action was approved."""
|
||||
return self.action == InterruptAction.APPROVE
|
||||
|
||||
@property
|
||||
def is_rejected(self) -> bool:
|
||||
"""Check if the action was rejected."""
|
||||
return self.action == InterruptAction.REJECT
|
||||
|
||||
@property
|
||||
def is_modified(self) -> bool:
|
||||
"""Check if the action was modified."""
|
||||
return self.action == InterruptAction.MODIFY
|
||||
|
||||
def to_dict(self) -> dict[str, object]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
result: dict[str, object] = {
|
||||
"action": self.action,
|
||||
"request_id": self.request_id,
|
||||
}
|
||||
if self.modified_value is not None:
|
||||
result["modified_value"] = self.modified_value
|
||||
if self.user_message is not None:
|
||||
result["user_message"] = self.user_message
|
||||
return result
|
||||
|
||||
|
||||
def create_web_search_interrupt(
|
||||
query: str,
|
||||
request_id: str,
|
||||
*,
|
||||
allow_modify: bool = False,
|
||||
) -> InterruptRequest:
|
||||
"""Create an interrupt request for web search approval.
|
||||
|
||||
Args:
|
||||
query: The search query to be executed.
|
||||
request_id: Unique identifier for this request.
|
||||
allow_modify: Whether the user can modify the query.
|
||||
|
||||
Returns:
|
||||
InterruptRequest configured for web search approval.
|
||||
"""
|
||||
return InterruptRequest(
|
||||
interrupt_type=InterruptType.WEB_SEARCH_APPROVAL,
|
||||
message=f"Allow web search for additional context? Query: {query[:100]}",
|
||||
context={"query": query},
|
||||
options=("approve", "reject", "modify") if allow_modify else DEFAULT_WEB_SEARCH_OPTIONS,
|
||||
config=InterruptConfig(allow_modify=allow_modify),
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
|
||||
def create_annotation_interrupt(
|
||||
annotations: list[dict[str, object]],
|
||||
request_id: str,
|
||||
) -> InterruptRequest:
|
||||
"""Create an interrupt request for annotation approval.
|
||||
|
||||
Args:
|
||||
annotations: List of suggested annotations to approve.
|
||||
request_id: Unique identifier for this request.
|
||||
|
||||
Returns:
|
||||
InterruptRequest configured for annotation approval.
|
||||
"""
|
||||
count = len(annotations)
|
||||
return InterruptRequest(
|
||||
interrupt_type=InterruptType.ANNOTATION_APPROVAL,
|
||||
message=f"Apply {count} suggested annotation(s)?",
|
||||
context={"annotations": annotations, "count": count},
|
||||
options=DEFAULT_ANNOTATION_OPTIONS,
|
||||
config=InterruptConfig(allow_modify=True, allow_ignore=True),
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
|
||||
def create_sensitive_action_interrupt(
|
||||
action_name: str,
|
||||
action_description: str,
|
||||
request_id: str,
|
||||
) -> InterruptRequest:
|
||||
"""Create an interrupt request for sensitive action confirmation.
|
||||
|
||||
Args:
|
||||
action_name: Name of the sensitive action.
|
||||
action_description: Description of what the action does.
|
||||
request_id: Unique identifier for this request.
|
||||
|
||||
Returns:
|
||||
InterruptRequest configured for sensitive action confirmation.
|
||||
"""
|
||||
return InterruptRequest(
|
||||
interrupt_type=InterruptType.SENSITIVE_ACTION,
|
||||
message=f"Confirm action: {action_name}",
|
||||
context={"action_name": action_name, "description": action_description},
|
||||
options=DEFAULT_SENSITIVE_OPTIONS,
|
||||
config=InterruptConfig(allow_ignore=False),
|
||||
request_id=request_id,
|
||||
)
|
||||
26
src/noteflow/domain/ai/ports.py
Normal file
26
src/noteflow/domain/ai/ports.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Protocol definitions for AI assistant operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.ai.state import AssistantOutputState
|
||||
|
||||
|
||||
class AssistantPort(Protocol):
|
||||
"""Protocol for AI assistant operations."""
|
||||
|
||||
async def ask(
|
||||
self,
|
||||
question: str,
|
||||
user_id: UUID,
|
||||
meeting_id: UUID | None = None,
|
||||
thread_id: str | None = None,
|
||||
allow_web: bool = False,
|
||||
top_k: int = 8,
|
||||
) -> AssistantOutputState:
|
||||
"""Ask a question about meeting transcript(s)."""
|
||||
...
|
||||
39
src/noteflow/domain/ai/state.py
Normal file
39
src/noteflow/domain/ai/state.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""State schemas for LangGraph AI workflows.
|
||||
|
||||
Separates Input/Output (public API) from Internal (can evolve freely).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import Annotated, TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class AssistantInputState(TypedDict):
|
||||
"""Public API input - what clients send."""
|
||||
|
||||
question: str
|
||||
meeting_id: UUID | None
|
||||
thread_id: str | None
|
||||
allow_web: bool
|
||||
top_k: int
|
||||
|
||||
|
||||
class AssistantOutputState(TypedDict):
|
||||
"""Public API output - what clients receive."""
|
||||
|
||||
answer: str
|
||||
citations: list[dict[str, object]]
|
||||
suggested_annotations: list[dict[str, object]]
|
||||
thread_id: str
|
||||
|
||||
|
||||
class AssistantInternalState(AssistantInputState):
|
||||
"""Internal graph state - can evolve without breaking API."""
|
||||
|
||||
retrieved_segment_ids: Annotated[list[int], operator.add]
|
||||
retrieved_segments: list[dict[str, object]]
|
||||
draft_answer: str
|
||||
verification_passed: bool
|
||||
loop_count: int
|
||||
45
src/noteflow/infrastructure/ai/__init__.py
Normal file
45
src/noteflow/infrastructure/ai/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""AI infrastructure components for LangGraph workflows."""
|
||||
|
||||
from noteflow.infrastructure.ai.cache import (
|
||||
CachedEmbedder,
|
||||
EmbeddingCache,
|
||||
EmbeddingCacheStats,
|
||||
)
|
||||
from noteflow.infrastructure.ai.checkpointer import create_checkpointer
|
||||
from noteflow.infrastructure.ai.guardrails import (
|
||||
GuardrailResult,
|
||||
GuardrailRules,
|
||||
GuardrailViolation,
|
||||
check_input,
|
||||
create_default_rules,
|
||||
create_strict_rules,
|
||||
filter_output,
|
||||
)
|
||||
from noteflow.infrastructure.ai.interrupts import (
|
||||
InterruptHandler,
|
||||
check_annotation_approval,
|
||||
check_web_search_approval,
|
||||
create_resume_command,
|
||||
request_annotation_approval,
|
||||
request_web_search_approval,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CachedEmbedder",
|
||||
"EmbeddingCache",
|
||||
"EmbeddingCacheStats",
|
||||
"GuardrailResult",
|
||||
"GuardrailRules",
|
||||
"GuardrailViolation",
|
||||
"InterruptHandler",
|
||||
"check_annotation_approval",
|
||||
"check_input",
|
||||
"check_web_search_approval",
|
||||
"create_checkpointer",
|
||||
"create_default_rules",
|
||||
"create_resume_command",
|
||||
"create_strict_rules",
|
||||
"filter_output",
|
||||
"request_annotation_approval",
|
||||
"request_web_search_approval",
|
||||
]
|
||||
212
src/noteflow/infrastructure/ai/cache.py
Normal file
212
src/noteflow/infrastructure/ai/cache.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""Embedding cache with LRU eviction and TTL expiration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.ai.tools.retrieval import EmbedderProtocol
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DEFAULT_MAX_SIZE: Final[int] = 1000
|
||||
DEFAULT_TTL_SECONDS: Final[int] = 3600
|
||||
HASH_ALGORITHM: Final[str] = "sha256"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CacheEntry:
|
||||
"""Cached embedding with creation timestamp."""
|
||||
|
||||
embedding: tuple[float, ...]
|
||||
created_at: float
|
||||
|
||||
def is_expired(self, ttl_seconds: float, current_time: float) -> bool:
|
||||
"""Check if entry has expired based on TTL."""
|
||||
return (current_time - self.created_at) > ttl_seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingCacheStats:
|
||||
"""Statistics for cache performance monitoring."""
|
||||
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
evictions: int = 0
|
||||
expirations: int = 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
total = self.hits + self.misses
|
||||
return self.hits / total if total > 0 else 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingCache:
|
||||
"""LRU cache for text embeddings with TTL expiration and deduplication."""
|
||||
|
||||
max_size: int = DEFAULT_MAX_SIZE
|
||||
ttl_seconds: int = DEFAULT_TTL_SECONDS
|
||||
_cache: OrderedDict[str, CacheEntry] = field(default_factory=OrderedDict)
|
||||
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
_stats: EmbeddingCacheStats = field(default_factory=EmbeddingCacheStats)
|
||||
_in_flight: dict[str, asyncio.Future[list[float]]] = field(default_factory=dict)
|
||||
|
||||
def _compute_key(self, text: str) -> str:
|
||||
"""Compute cache key from text using hash."""
|
||||
return hashlib.new(HASH_ALGORITHM, text.encode("utf-8")).hexdigest()
|
||||
|
||||
async def get_or_compute(
|
||||
self,
|
||||
text: str,
|
||||
embedder: EmbedderProtocol,
|
||||
) -> list[float]:
|
||||
key = self._compute_key(text)
|
||||
current_time = time.monotonic()
|
||||
|
||||
existing_future: asyncio.Future[list[float]] | None = None
|
||||
|
||||
async with self._lock:
|
||||
if key in self._cache:
|
||||
entry = self._cache[key]
|
||||
if not entry.is_expired(self.ttl_seconds, current_time):
|
||||
self._cache.move_to_end(key)
|
||||
self._stats.hits += 1
|
||||
logger.debug("cache_hit", key=key[:16])
|
||||
return list(entry.embedding)
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
logger.debug("cache_expired", key=key[:16])
|
||||
|
||||
if key in self._in_flight:
|
||||
logger.debug("cache_in_flight_join", key=key[:16])
|
||||
existing_future = self._in_flight[key]
|
||||
|
||||
if existing_future is not None:
|
||||
return list(await existing_future)
|
||||
|
||||
new_future: asyncio.Future[list[float]] = asyncio.get_running_loop().create_future()
|
||||
|
||||
async with self._lock:
|
||||
if key in self._in_flight:
|
||||
existing_future = self._in_flight[key]
|
||||
else:
|
||||
self._stats.misses += 1
|
||||
self._in_flight[key] = new_future
|
||||
|
||||
if existing_future is not None:
|
||||
return list(await existing_future)
|
||||
|
||||
try:
|
||||
embedding = await embedder.embed(text)
|
||||
except Exception:
|
||||
async with self._lock:
|
||||
_ = self._in_flight.pop(key, None)
|
||||
new_future.set_exception(asyncio.CancelledError())
|
||||
raise
|
||||
|
||||
async with self._lock:
|
||||
_ = self._in_flight.pop(key, None)
|
||||
|
||||
while len(self._cache) >= self.max_size:
|
||||
evicted_key, _ = self._cache.popitem(last=False)
|
||||
self._stats.evictions += 1
|
||||
logger.debug("cache_eviction", evicted_key=evicted_key[:16])
|
||||
|
||||
self._cache[key] = CacheEntry(
|
||||
embedding=tuple(embedding),
|
||||
created_at=current_time,
|
||||
)
|
||||
logger.debug("cache_store", key=key[:16])
|
||||
|
||||
new_future.set_result(embedding)
|
||||
return embedding
|
||||
|
||||
async def get(self, text: str) -> list[float] | None:
|
||||
"""Get embedding from cache without computing.
|
||||
|
||||
Args:
|
||||
text: Text to look up.
|
||||
|
||||
Returns:
|
||||
Embedding if cached and not expired, None otherwise.
|
||||
"""
|
||||
key = self._compute_key(text)
|
||||
current_time = time.monotonic()
|
||||
|
||||
async with self._lock:
|
||||
if key in self._cache:
|
||||
entry = self._cache[key]
|
||||
if not entry.is_expired(self.ttl_seconds, current_time):
|
||||
self._cache.move_to_end(key)
|
||||
return list(entry.embedding)
|
||||
else:
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
|
||||
return None
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all cached entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared.
|
||||
"""
|
||||
async with self._lock:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
logger.info("cache_cleared", entries_cleared=count)
|
||||
return count
|
||||
|
||||
async def size(self) -> int:
|
||||
"""Get current number of cached entries."""
|
||||
async with self._lock:
|
||||
return len(self._cache)
|
||||
|
||||
def get_stats(self) -> EmbeddingCacheStats:
|
||||
"""Get cache statistics (not async - reads are atomic)."""
|
||||
return EmbeddingCacheStats(
|
||||
hits=self._stats.hits,
|
||||
misses=self._stats.misses,
|
||||
evictions=self._stats.evictions,
|
||||
expirations=self._stats.expirations,
|
||||
)
|
||||
|
||||
|
||||
class CachedEmbedder:
|
||||
"""Wrapper that adds caching to any EmbedderProtocol implementation.
|
||||
|
||||
Example:
|
||||
base_embedder = MyEmbedder()
|
||||
cached = CachedEmbedder(base_embedder, max_size=500, ttl_seconds=1800)
|
||||
embedding = await cached.embed("hello world")
|
||||
"""
|
||||
|
||||
_embedder: EmbedderProtocol
|
||||
_cache: EmbeddingCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: EmbedderProtocol,
|
||||
max_size: int = DEFAULT_MAX_SIZE,
|
||||
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||
) -> None:
|
||||
self._embedder = embedder
|
||||
self._cache = EmbeddingCache(max_size=max_size, ttl_seconds=ttl_seconds)
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""Embed text with caching."""
|
||||
return await self._cache.get_or_compute(text, self._embedder)
|
||||
|
||||
@property
|
||||
def cache(self) -> EmbeddingCache:
|
||||
"""Access underlying cache for stats/management."""
|
||||
return self._cache
|
||||
56
src/noteflow/infrastructure/ai/checkpointer.py
Normal file
56
src/noteflow/infrastructure/ai/checkpointer.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""PostgreSQL checkpointer factory for LangGraph state persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
CHECKPOINTER_POOL_SIZE: Final[int] = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointerResult:
|
||||
"""Wraps checkpointer with pool lifecycle management to prevent connection leaks."""
|
||||
|
||||
checkpointer: AsyncPostgresSaver
|
||||
_pool: AsyncConnectionPool
|
||||
|
||||
async def close(self) -> None:
|
||||
# psycopg_pool.AsyncConnectionPool.close() exists at runtime but type stubs
|
||||
# are incomplete. Use getattr to satisfy the type checker.
|
||||
close_fn = getattr(self._pool, "close", None)
|
||||
if close_fn is not None:
|
||||
await close_fn()
|
||||
|
||||
async def __aenter__(self) -> AsyncPostgresSaver:
|
||||
return self.checkpointer
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
await self.close()
|
||||
|
||||
|
||||
async def create_checkpointer(
|
||||
database_url: str,
|
||||
pool_size: int = CHECKPOINTER_POOL_SIZE,
|
||||
) -> CheckpointerResult:
|
||||
"""Create async Postgres checkpointer for LangGraph state persistence."""
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=database_url,
|
||||
max_size=pool_size,
|
||||
kwargs={"autocommit": True},
|
||||
)
|
||||
checkpointer = AsyncPostgresSaver(pool)
|
||||
await checkpointer.setup()
|
||||
return CheckpointerResult(checkpointer=checkpointer, _pool=pool)
|
||||
51
src/noteflow/infrastructure/ai/graphs/__init__.py
Normal file
51
src/noteflow/infrastructure/ai/graphs/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""LangGraph workflow definitions."""
|
||||
|
||||
from noteflow.infrastructure.ai.graphs.meeting_qa import (
|
||||
MEETING_QA_GRAPH_NAME,
|
||||
MEETING_QA_GRAPH_VERSION,
|
||||
MeetingQAConfig,
|
||||
MeetingQAInputState,
|
||||
MeetingQAInternalState,
|
||||
MeetingQAOutputState,
|
||||
build_meeting_qa_graph,
|
||||
)
|
||||
from noteflow.infrastructure.ai.graphs.summarization import (
|
||||
SUMMARIZATION_GRAPH_NAME,
|
||||
SUMMARIZATION_GRAPH_VERSION,
|
||||
SummarizationInputState,
|
||||
SummarizationOutputState,
|
||||
SummarizationState,
|
||||
build_summarization_graph,
|
||||
)
|
||||
from noteflow.infrastructure.ai.graphs.workspace_qa import (
|
||||
WORKSPACE_QA_GRAPH_NAME,
|
||||
WORKSPACE_QA_GRAPH_VERSION,
|
||||
WorkspaceQAConfig,
|
||||
WorkspaceQAInputState,
|
||||
WorkspaceQAInternalState,
|
||||
WorkspaceQAOutputState,
|
||||
build_workspace_qa_graph,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MEETING_QA_GRAPH_NAME",
|
||||
"MEETING_QA_GRAPH_VERSION",
|
||||
"MeetingQAConfig",
|
||||
"MeetingQAInputState",
|
||||
"MeetingQAInternalState",
|
||||
"MeetingQAOutputState",
|
||||
"SUMMARIZATION_GRAPH_NAME",
|
||||
"SUMMARIZATION_GRAPH_VERSION",
|
||||
"SummarizationInputState",
|
||||
"SummarizationOutputState",
|
||||
"SummarizationState",
|
||||
"WORKSPACE_QA_GRAPH_NAME",
|
||||
"WORKSPACE_QA_GRAPH_VERSION",
|
||||
"WorkspaceQAConfig",
|
||||
"WorkspaceQAInputState",
|
||||
"WorkspaceQAInternalState",
|
||||
"WorkspaceQAOutputState",
|
||||
"build_meeting_qa_graph",
|
||||
"build_summarization_graph",
|
||||
"build_workspace_qa_graph",
|
||||
]
|
||||
193
src/noteflow/infrastructure/ai/graphs/meeting_qa.py
Normal file
193
src/noteflow/infrastructure/ai/graphs/meeting_qa.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Meeting Q&A graph for single-meeting question answering with citations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Final, TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.graph import CompiledStateGraph
|
||||
|
||||
from noteflow.domain.ai.citations import SegmentCitation
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
from noteflow.infrastructure.ai.nodes.annotation_suggester import SuggestedAnnotation
|
||||
from noteflow.infrastructure.ai.nodes.web_search import WebSearchProvider
|
||||
from noteflow.infrastructure.ai.tools.retrieval import (
|
||||
EmbedderProtocol,
|
||||
RetrievalResult,
|
||||
SegmentSearchProtocol,
|
||||
)
|
||||
from noteflow.infrastructure.ai.tools.synthesis import LLMProtocol
|
||||
|
||||
MEETING_QA_GRAPH_NAME: Final[str] = "meeting_qa"
|
||||
MEETING_QA_GRAPH_VERSION: Final[int] = 2
|
||||
NO_INFORMATION_ANSWER: Final[str] = "I couldn't find relevant information in this meeting."
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MeetingQAConfig:
|
||||
enable_web_search: bool = False
|
||||
require_web_approval: bool = True
|
||||
require_annotation_approval: bool = False
|
||||
|
||||
|
||||
class MeetingQAInputState(TypedDict):
|
||||
question: str
|
||||
meeting_id: MeetingId
|
||||
top_k: int
|
||||
|
||||
|
||||
class MeetingQAOutputState(TypedDict):
|
||||
answer: str
|
||||
citations: list[SegmentCitation]
|
||||
suggested_annotations: list[SuggestedAnnotation]
|
||||
|
||||
|
||||
class MeetingQAInternalState(MeetingQAInputState, MeetingQAOutputState):
|
||||
retrieved_segments: list[RetrievalResult]
|
||||
verification_passed: bool
|
||||
web_search_approved: bool
|
||||
web_context: str
|
||||
annotations_approved: bool
|
||||
|
||||
|
||||
def build_meeting_qa_graph(
|
||||
embedder: EmbedderProtocol,
|
||||
segment_repo: SegmentSearchProtocol,
|
||||
llm: LLMProtocol,
|
||||
*,
|
||||
web_search_provider: WebSearchProvider | None = None,
|
||||
config: MeetingQAConfig | None = None,
|
||||
checkpointer: object | None = None,
|
||||
) -> CompiledStateGraph[MeetingQAInternalState]:
|
||||
"""Build a Q&A graph for single-meeting questions with segment citations.
|
||||
|
||||
Graph flow (with web search): retrieve -> verify -> [web_search_approval] -> [web_search] -> synthesize
|
||||
Graph flow (without): retrieve -> verify -> synthesize
|
||||
|
||||
Args:
|
||||
embedder: Protocol for generating text embeddings.
|
||||
segment_repo: Protocol for semantic segment search.
|
||||
llm: Protocol for LLM text completion.
|
||||
web_search_provider: Optional web search provider for augmentation.
|
||||
config: Graph configuration for features/interrupts.
|
||||
checkpointer: Optional checkpointer for interrupt support.
|
||||
|
||||
Returns:
|
||||
Compiled graph that accepts question/meeting_id and returns answer/citations.
|
||||
"""
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
|
||||
from noteflow.domain.ai.citations import SegmentCitation
|
||||
from noteflow.infrastructure.ai.interrupts import check_web_search_approval
|
||||
from noteflow.infrastructure.ai.nodes.annotation_suggester import (
|
||||
extract_annotations_from_answer,
|
||||
)
|
||||
from noteflow.infrastructure.ai.nodes.web_search import (
|
||||
WebSearchConfig,
|
||||
derive_search_query,
|
||||
execute_web_search,
|
||||
format_results_for_context,
|
||||
)
|
||||
from noteflow.infrastructure.ai.tools.retrieval import retrieve_segments
|
||||
from noteflow.infrastructure.ai.tools.synthesis import synthesize_answer
|
||||
|
||||
effective_config = config or MeetingQAConfig()
|
||||
|
||||
async def retrieve_node(state: MeetingQAInternalState) -> dict[str, object]:
|
||||
results = await retrieve_segments(
|
||||
query=state["question"],
|
||||
embedder=embedder,
|
||||
segment_repo=segment_repo,
|
||||
meeting_id=state["meeting_id"],
|
||||
top_k=state["top_k"],
|
||||
)
|
||||
return {"retrieved_segments": results}
|
||||
|
||||
async def verify_node(state: MeetingQAInternalState) -> dict[str, object]:
|
||||
has_segments = len(state["retrieved_segments"]) > 0
|
||||
return {"verification_passed": has_segments}
|
||||
|
||||
def web_search_approval_node(state: MeetingQAInternalState) -> dict[str, object]:
|
||||
if not effective_config.enable_web_search or web_search_provider is None:
|
||||
return {"web_search_approved": False}
|
||||
|
||||
if not effective_config.require_web_approval:
|
||||
return {"web_search_approved": True}
|
||||
|
||||
query = derive_search_query(state["question"])
|
||||
approved = check_web_search_approval(query, require_approval=True)
|
||||
return {"web_search_approved": approved}
|
||||
|
||||
async def web_search_node(state: MeetingQAInternalState) -> dict[str, object]:
|
||||
if not state.get("web_search_approved", False) or web_search_provider is None:
|
||||
return {"web_context": ""}
|
||||
|
||||
query = derive_search_query(state["question"])
|
||||
search_config = WebSearchConfig(enabled=True, require_approval=False)
|
||||
response = await execute_web_search(query, web_search_provider, search_config)
|
||||
context = format_results_for_context(response.results)
|
||||
return {"web_context": context}
|
||||
|
||||
async def synthesize_node(state: MeetingQAInternalState) -> dict[str, object]:
|
||||
if not state["verification_passed"]:
|
||||
return {
|
||||
"answer": NO_INFORMATION_ANSWER,
|
||||
"citations": [],
|
||||
"suggested_annotations": [],
|
||||
}
|
||||
|
||||
result = await synthesize_answer(
|
||||
question=state["question"],
|
||||
segments=state["retrieved_segments"],
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
citations = [
|
||||
SegmentCitation(
|
||||
meeting_id=state["meeting_id"],
|
||||
segment_id=seg.segment_id,
|
||||
start_time=seg.start_time,
|
||||
end_time=seg.end_time,
|
||||
text=seg.text,
|
||||
score=seg.score,
|
||||
)
|
||||
for seg in state["retrieved_segments"]
|
||||
if seg.segment_id in result.cited_segment_ids
|
||||
]
|
||||
|
||||
suggested_annotations = extract_annotations_from_answer(
|
||||
answer=result.answer,
|
||||
cited_segment_ids=tuple(result.cited_segment_ids),
|
||||
)
|
||||
|
||||
return {
|
||||
"answer": result.answer,
|
||||
"citations": citations,
|
||||
"suggested_annotations": suggested_annotations,
|
||||
}
|
||||
|
||||
builder: StateGraph[MeetingQAInternalState] = StateGraph(MeetingQAInternalState)
|
||||
builder.add_node("retrieve", retrieve_node)
|
||||
builder.add_node("verify", verify_node)
|
||||
builder.add_node("synthesize", synthesize_node)
|
||||
|
||||
if effective_config.enable_web_search and web_search_provider is not None:
|
||||
builder.add_node("web_search_approval", web_search_approval_node)
|
||||
builder.add_node("web_search", web_search_node)
|
||||
|
||||
builder.add_edge(START, "retrieve")
|
||||
builder.add_edge("retrieve", "verify")
|
||||
builder.add_edge("verify", "web_search_approval")
|
||||
builder.add_edge("web_search_approval", "web_search")
|
||||
builder.add_edge("web_search", "synthesize")
|
||||
builder.add_edge("synthesize", END)
|
||||
else:
|
||||
builder.add_edge(START, "retrieve")
|
||||
builder.add_edge("retrieve", "verify")
|
||||
builder.add_edge("verify", "synthesize")
|
||||
builder.add_edge("synthesize", END)
|
||||
|
||||
compile_method = getattr(builder, "compile")
|
||||
compiled: CompiledStateGraph[MeetingQAInternalState] = compile_method(checkpointer=checkpointer)
|
||||
return compiled
|
||||
103
src/noteflow/infrastructure/ai/graphs/summarization.py
Normal file
103
src/noteflow/infrastructure/ai/graphs/summarization.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""LangGraph wrapper for existing SummarizationService.
|
||||
|
||||
Demonstrates the LangGraph integration pattern by wrapping the existing
|
||||
summarization infrastructure in a StateGraph.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Final, TypedDict, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from langgraph.graph import CompiledStateGraph
|
||||
|
||||
from noteflow.application.services.summarization import SummarizationService
|
||||
from noteflow.domain.entities import Segment
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
SUMMARIZATION_GRAPH_NAME: Final[str] = "summarization"
|
||||
SUMMARIZATION_GRAPH_VERSION: Final[int] = 1
|
||||
|
||||
|
||||
class SummarizationInputState(TypedDict):
|
||||
"""Input state for summarization graph."""
|
||||
|
||||
meeting_id: UUID
|
||||
segments: Sequence[Segment]
|
||||
|
||||
|
||||
class SummarizationOutputState(TypedDict):
|
||||
"""Output state from summarization graph."""
|
||||
|
||||
summary_text: str
|
||||
key_points: list[dict[str, object]]
|
||||
action_items: list[dict[str, object]]
|
||||
provider_used: str
|
||||
tokens_used: int | None
|
||||
latency_ms: float | None
|
||||
|
||||
|
||||
class SummarizationState(SummarizationInputState, SummarizationOutputState):
|
||||
"""Full internal state for summarization graph."""
|
||||
|
||||
|
||||
def build_summarization_graph(
|
||||
summarization_service: SummarizationService,
|
||||
) -> CompiledStateGraph[SummarizationState]:
|
||||
"""Build LangGraph wrapper around existing SummarizationService.
|
||||
|
||||
This demonstrates the pattern of wrapping existing services in LangGraph
|
||||
graphs, enabling future expansion (checkpointing, conditional branching,
|
||||
human-in-the-loop, etc.) while maintaining backward compatibility.
|
||||
|
||||
Args:
|
||||
summarization_service: The existing summarization service to wrap.
|
||||
|
||||
Returns:
|
||||
A compiled StateGraph that can be invoked with meeting_id and segments.
|
||||
"""
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
|
||||
async def summarize_node(state: SummarizationState) -> dict[str, object]:
|
||||
# MeetingId is NewType of UUID, cast is safe
|
||||
meeting_id = cast("MeetingId", state["meeting_id"])
|
||||
result = await summarization_service.summarize(
|
||||
meeting_id=meeting_id,
|
||||
segments=state["segments"],
|
||||
)
|
||||
summary = result.summary
|
||||
return {
|
||||
"summary_text": summary.executive_summary,
|
||||
"key_points": [
|
||||
{
|
||||
"text": kp.text,
|
||||
"segment_ids": kp.segment_ids,
|
||||
"start_time": kp.start_time,
|
||||
"end_time": kp.end_time,
|
||||
}
|
||||
for kp in summary.key_points
|
||||
],
|
||||
"action_items": [
|
||||
{
|
||||
"text": ai.text,
|
||||
"segment_ids": ai.segment_ids,
|
||||
"assignee": ai.assignee,
|
||||
"start_time": ai.start_time,
|
||||
"end_time": ai.end_time,
|
||||
}
|
||||
for ai in summary.action_items
|
||||
],
|
||||
"provider_used": result.provider_used,
|
||||
"tokens_used": summary.tokens_used,
|
||||
"latency_ms": summary.latency_ms,
|
||||
}
|
||||
|
||||
builder = StateGraph(SummarizationState)
|
||||
builder.add_node("summarize", summarize_node)
|
||||
builder.add_edge(START, "summarize")
|
||||
builder.add_edge("summarize", END)
|
||||
|
||||
return builder.compile()
|
||||
198
src/noteflow/infrastructure/ai/graphs/workspace_qa.py
Normal file
198
src/noteflow/infrastructure/ai/graphs/workspace_qa.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Workspace Q&A graph for cross-meeting question answering with citations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Final, TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.graph import CompiledStateGraph
|
||||
|
||||
from noteflow.domain.ai.citations import SegmentCitation
|
||||
from noteflow.infrastructure.ai.nodes.annotation_suggester import SuggestedAnnotation
|
||||
from noteflow.infrastructure.ai.nodes.web_search import WebSearchProvider
|
||||
from noteflow.infrastructure.ai.tools.retrieval import (
|
||||
EmbedderProtocol,
|
||||
RetrievalResult,
|
||||
WorkspaceSegmentSearchProtocol,
|
||||
)
|
||||
from noteflow.infrastructure.ai.tools.synthesis import LLMProtocol
|
||||
|
||||
WORKSPACE_QA_GRAPH_NAME: Final[str] = "workspace_qa"
|
||||
WORKSPACE_QA_GRAPH_VERSION: Final[int] = 2
|
||||
NO_INFORMATION_ANSWER: Final[str] = "I couldn't find relevant information across your meetings."
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WorkspaceQAConfig:
|
||||
enable_web_search: bool = False
|
||||
require_web_approval: bool = True
|
||||
require_annotation_approval: bool = False
|
||||
|
||||
|
||||
class WorkspaceQAInputState(TypedDict):
|
||||
question: str
|
||||
workspace_id: UUID
|
||||
project_id: UUID | None
|
||||
top_k: int
|
||||
|
||||
|
||||
class WorkspaceQAOutputState(TypedDict):
|
||||
answer: str
|
||||
citations: list[SegmentCitation]
|
||||
suggested_annotations: list[SuggestedAnnotation]
|
||||
|
||||
|
||||
class WorkspaceQAInternalState(WorkspaceQAInputState, WorkspaceQAOutputState):
|
||||
retrieved_segments: list[RetrievalResult]
|
||||
verification_passed: bool
|
||||
web_search_approved: bool
|
||||
web_context: str
|
||||
|
||||
|
||||
def build_workspace_qa_graph(
|
||||
embedder: EmbedderProtocol,
|
||||
segment_repo: WorkspaceSegmentSearchProtocol,
|
||||
llm: LLMProtocol,
|
||||
*,
|
||||
web_search_provider: WebSearchProvider | None = None,
|
||||
config: WorkspaceQAConfig | None = None,
|
||||
checkpointer: BaseCheckpointSaver[str] | None = None,
|
||||
) -> CompiledStateGraph[WorkspaceQAInternalState]:
|
||||
"""Build Q&A graph for cross-meeting questions with segment citations.
|
||||
|
||||
Graph flow (with web search): retrieve -> verify -> [web_search_approval] -> [web_search] -> synthesize
|
||||
Graph flow (without): retrieve -> verify -> synthesize
|
||||
|
||||
Args:
|
||||
embedder: Protocol for generating text embeddings.
|
||||
segment_repo: Protocol for workspace-scoped semantic segment search.
|
||||
llm: Protocol for LLM text completion.
|
||||
web_search_provider: Optional web search provider for augmentation.
|
||||
config: Graph configuration for features/interrupts.
|
||||
checkpointer: Optional checkpointer for interrupt support.
|
||||
|
||||
Returns:
|
||||
Compiled graph that accepts question/workspace_id and returns answer/citations.
|
||||
"""
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
|
||||
from noteflow.domain.ai.citations import SegmentCitation
|
||||
from noteflow.infrastructure.ai.interrupts import check_web_search_approval
|
||||
from noteflow.infrastructure.ai.nodes.annotation_suggester import (
|
||||
extract_annotations_from_answer,
|
||||
)
|
||||
from noteflow.infrastructure.ai.nodes.web_search import (
|
||||
WebSearchConfig,
|
||||
derive_search_query,
|
||||
execute_web_search,
|
||||
format_results_for_context,
|
||||
)
|
||||
from noteflow.infrastructure.ai.tools.retrieval import retrieve_segments_workspace
|
||||
from noteflow.infrastructure.ai.tools.synthesis import synthesize_answer
|
||||
|
||||
effective_config = config or WorkspaceQAConfig()
|
||||
|
||||
async def retrieve_node(state: WorkspaceQAInternalState) -> dict[str, object]:
|
||||
results = await retrieve_segments_workspace(
|
||||
query=state["question"],
|
||||
embedder=embedder,
|
||||
segment_repo=segment_repo,
|
||||
workspace_id=state["workspace_id"],
|
||||
project_id=state["project_id"],
|
||||
top_k=state["top_k"],
|
||||
)
|
||||
return {"retrieved_segments": results}
|
||||
|
||||
async def verify_node(state: WorkspaceQAInternalState) -> dict[str, object]:
|
||||
has_segments = len(state["retrieved_segments"]) > 0
|
||||
return {"verification_passed": has_segments}
|
||||
|
||||
def web_search_approval_node(state: WorkspaceQAInternalState) -> dict[str, object]:
|
||||
if not effective_config.enable_web_search or web_search_provider is None:
|
||||
return {"web_search_approved": False}
|
||||
|
||||
if not effective_config.require_web_approval:
|
||||
return {"web_search_approved": True}
|
||||
|
||||
query = derive_search_query(state["question"])
|
||||
approved = check_web_search_approval(query, require_approval=True)
|
||||
return {"web_search_approved": approved}
|
||||
|
||||
async def web_search_node(state: WorkspaceQAInternalState) -> dict[str, object]:
|
||||
if not state.get("web_search_approved", False) or web_search_provider is None:
|
||||
return {"web_context": ""}
|
||||
|
||||
query = derive_search_query(state["question"])
|
||||
search_config = WebSearchConfig(enabled=True, require_approval=False)
|
||||
response = await execute_web_search(query, web_search_provider, search_config)
|
||||
context = format_results_for_context(response.results)
|
||||
return {"web_context": context}
|
||||
|
||||
async def synthesize_node(state: WorkspaceQAInternalState) -> dict[str, object]:
|
||||
if not state["verification_passed"]:
|
||||
return {
|
||||
"answer": NO_INFORMATION_ANSWER,
|
||||
"citations": [],
|
||||
"suggested_annotations": [],
|
||||
}
|
||||
|
||||
result = await synthesize_answer(
|
||||
question=state["question"],
|
||||
segments=state["retrieved_segments"],
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
citations = [
|
||||
SegmentCitation(
|
||||
meeting_id=seg.meeting_id,
|
||||
segment_id=seg.segment_id,
|
||||
start_time=seg.start_time,
|
||||
end_time=seg.end_time,
|
||||
text=seg.text,
|
||||
score=seg.score,
|
||||
)
|
||||
for seg in state["retrieved_segments"]
|
||||
if seg.segment_id in result.cited_segment_ids
|
||||
]
|
||||
|
||||
suggested_annotations = extract_annotations_from_answer(
|
||||
answer=result.answer,
|
||||
cited_segment_ids=tuple(result.cited_segment_ids),
|
||||
)
|
||||
|
||||
return {
|
||||
"answer": result.answer,
|
||||
"citations": citations,
|
||||
"suggested_annotations": suggested_annotations,
|
||||
}
|
||||
|
||||
builder: StateGraph[WorkspaceQAInternalState] = StateGraph(WorkspaceQAInternalState)
|
||||
builder.add_node("retrieve", retrieve_node)
|
||||
builder.add_node("verify", verify_node)
|
||||
builder.add_node("synthesize", synthesize_node)
|
||||
|
||||
if effective_config.enable_web_search and web_search_provider is not None:
|
||||
builder.add_node("web_search_approval", web_search_approval_node)
|
||||
builder.add_node("web_search", web_search_node)
|
||||
|
||||
builder.add_edge(START, "retrieve")
|
||||
builder.add_edge("retrieve", "verify")
|
||||
builder.add_edge("verify", "web_search_approval")
|
||||
builder.add_edge("web_search_approval", "web_search")
|
||||
builder.add_edge("web_search", "synthesize")
|
||||
builder.add_edge("synthesize", END)
|
||||
else:
|
||||
builder.add_edge(START, "retrieve")
|
||||
builder.add_edge("retrieve", "verify")
|
||||
builder.add_edge("verify", "synthesize")
|
||||
builder.add_edge("synthesize", END)
|
||||
|
||||
compile_method = getattr(builder, "compile")
|
||||
compiled: CompiledStateGraph[WorkspaceQAInternalState] = compile_method(
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
return compiled
|
||||
312
src/noteflow/infrastructure/ai/guardrails.py
Normal file
312
src/noteflow/infrastructure/ai/guardrails.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""Content guardrails for AI input validation and output filtering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Final
|
||||
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Input validation limits
|
||||
DEFAULT_MIN_INPUT_LENGTH: Final[int] = 3
|
||||
DEFAULT_MAX_INPUT_LENGTH: Final[int] = 4000
|
||||
DEFAULT_MAX_OUTPUT_LENGTH: Final[int] = 10000
|
||||
|
||||
# PII patterns (simplified - production would use more comprehensive detection)
|
||||
EMAIL_PATTERN: Final[re.Pattern[str]] = re.compile(
|
||||
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
|
||||
)
|
||||
PHONE_PATTERN: Final[re.Pattern[str]] = re.compile(
|
||||
r"\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b"
|
||||
)
|
||||
SSN_PATTERN: Final[re.Pattern[str]] = re.compile(r"\b\d{3}-\d{2}-\d{4}\b")
|
||||
CREDIT_CARD_PATTERN: Final[re.Pattern[str]] = re.compile(r"\b(?:\d{4}[-\s]?){3}\d{4}\b")
|
||||
|
||||
PII_PATTERNS: Final[tuple[tuple[str, re.Pattern[str]], ...]] = (
|
||||
("email", EMAIL_PATTERN),
|
||||
("phone", PHONE_PATTERN),
|
||||
("ssn", SSN_PATTERN),
|
||||
("credit_card", CREDIT_CARD_PATTERN),
|
||||
)
|
||||
|
||||
# Redaction placeholder
|
||||
PII_REDACTION: Final[str] = "[REDACTED]"
|
||||
|
||||
|
||||
class GuardrailViolation(str, Enum):
|
||||
"""Types of guardrail violations."""
|
||||
|
||||
INPUT_TOO_SHORT = "input_too_short"
|
||||
INPUT_TOO_LONG = "input_too_long"
|
||||
OUTPUT_TOO_LONG = "output_too_long"
|
||||
CONTAINS_PII = "contains_pii"
|
||||
BLOCKED_CONTENT = "blocked_content"
|
||||
INJECTION_ATTEMPT = "injection_attempt"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GuardrailResult:
|
||||
"""Result of a guardrail check."""
|
||||
|
||||
allowed: bool
|
||||
violation: GuardrailViolation | None = None
|
||||
reason: str | None = None
|
||||
filtered_content: str | None = None
|
||||
|
||||
@staticmethod
|
||||
def ok(content: str | None = None) -> GuardrailResult:
|
||||
"""Create a passing result."""
|
||||
return GuardrailResult(allowed=True, filtered_content=content)
|
||||
|
||||
@staticmethod
|
||||
def blocked(
|
||||
violation: GuardrailViolation,
|
||||
reason: str,
|
||||
) -> GuardrailResult:
|
||||
"""Create a blocking result."""
|
||||
return GuardrailResult(allowed=False, violation=violation, reason=reason)
|
||||
|
||||
@staticmethod
|
||||
def filtered(
|
||||
content: str,
|
||||
violation: GuardrailViolation,
|
||||
reason: str,
|
||||
) -> GuardrailResult:
|
||||
"""Create a result with filtered content."""
|
||||
return GuardrailResult(
|
||||
allowed=True,
|
||||
violation=violation,
|
||||
reason=reason,
|
||||
filtered_content=content,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuardrailRules:
|
||||
"""Configurable guardrail rules.
|
||||
|
||||
Attributes:
|
||||
min_input_length: Minimum allowed input length.
|
||||
max_input_length: Maximum allowed input length.
|
||||
max_output_length: Maximum allowed output length.
|
||||
block_pii: Whether to block content containing PII.
|
||||
redact_pii: Whether to redact PII instead of blocking.
|
||||
blocked_phrases: Phrases that should block content entirely.
|
||||
detect_injection: Whether to detect prompt injection attempts.
|
||||
"""
|
||||
|
||||
min_input_length: int = DEFAULT_MIN_INPUT_LENGTH
|
||||
max_input_length: int = DEFAULT_MAX_INPUT_LENGTH
|
||||
max_output_length: int = DEFAULT_MAX_OUTPUT_LENGTH
|
||||
block_pii: bool = False
|
||||
redact_pii: bool = True
|
||||
blocked_phrases: frozenset[str] = field(default_factory=frozenset)
|
||||
detect_injection: bool = True
|
||||
|
||||
|
||||
# Common injection patterns
|
||||
INJECTION_PATTERNS: Final[tuple[re.Pattern[str], ...]] = (
|
||||
re.compile(r"ignore\s+(?:all\s+)?(?:previous|above)\s+instructions", re.IGNORECASE),
|
||||
re.compile(r"disregard\s+(?:all\s+)?(?:previous|prior)\s+", re.IGNORECASE),
|
||||
re.compile(r"you\s+are\s+now\s+(?:a|an)\s+", re.IGNORECASE),
|
||||
re.compile(r"forget\s+(?:everything|all)\s+(?:you|and)\s+", re.IGNORECASE),
|
||||
re.compile(r"new\s+(?:system\s+)?instructions?:", re.IGNORECASE),
|
||||
)
|
||||
|
||||
|
||||
def _check_length(
|
||||
text: str,
|
||||
min_length: int,
|
||||
max_length: int,
|
||||
is_input: bool,
|
||||
) -> GuardrailResult | None:
|
||||
"""Check text length constraints."""
|
||||
if is_input and len(text) < min_length:
|
||||
return GuardrailResult.blocked(
|
||||
GuardrailViolation.INPUT_TOO_SHORT,
|
||||
f"Input must be at least {min_length} characters",
|
||||
)
|
||||
|
||||
if is_input and len(text) > max_length:
|
||||
return GuardrailResult.blocked(
|
||||
GuardrailViolation.INPUT_TOO_LONG,
|
||||
f"Input must be at most {max_length} characters",
|
||||
)
|
||||
|
||||
if not is_input and len(text) > max_length:
|
||||
return GuardrailResult.blocked(
|
||||
GuardrailViolation.OUTPUT_TOO_LONG,
|
||||
f"Output exceeds {max_length} characters",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _check_blocked_phrases(
|
||||
text: str,
|
||||
blocked_phrases: frozenset[str],
|
||||
) -> GuardrailResult | None:
|
||||
"""Check for blocked phrases."""
|
||||
text_lower = text.lower()
|
||||
for phrase in blocked_phrases:
|
||||
if phrase.lower() in text_lower:
|
||||
logger.warning("blocked_phrase_detected", phrase=phrase[:20])
|
||||
return GuardrailResult.blocked(
|
||||
GuardrailViolation.BLOCKED_CONTENT,
|
||||
"Content contains blocked phrase",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _check_injection(text: str) -> GuardrailResult | None:
|
||||
"""Check for prompt injection attempts."""
|
||||
for pattern in INJECTION_PATTERNS:
|
||||
if pattern.search(text):
|
||||
logger.warning("injection_attempt_detected")
|
||||
return GuardrailResult.blocked(
|
||||
GuardrailViolation.INJECTION_ATTEMPT,
|
||||
"Potential prompt injection detected",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _detect_pii(text: str) -> list[tuple[str, str]]:
|
||||
"""Detect PII in text.
|
||||
|
||||
Returns:
|
||||
List of (pii_type, matched_text) tuples.
|
||||
"""
|
||||
findings: list[tuple[str, str]] = []
|
||||
for pii_type, pattern in PII_PATTERNS:
|
||||
for match in pattern.finditer(text):
|
||||
findings.append((pii_type, match.group()))
|
||||
return findings
|
||||
|
||||
|
||||
def _redact_pii(text: str) -> str:
|
||||
"""Redact all PII in text."""
|
||||
result = text
|
||||
for _, pattern in PII_PATTERNS:
|
||||
result = pattern.sub(PII_REDACTION, result)
|
||||
return result
|
||||
|
||||
|
||||
async def check_input(text: str, rules: GuardrailRules) -> GuardrailResult:
|
||||
"""Validate input text against guardrail rules.
|
||||
|
||||
Args:
|
||||
text: Input text to validate.
|
||||
rules: Guardrail rules to apply.
|
||||
|
||||
Returns:
|
||||
GuardrailResult indicating if input is allowed.
|
||||
"""
|
||||
# Length checks
|
||||
length_result = _check_length(
|
||||
text,
|
||||
rules.min_input_length,
|
||||
rules.max_input_length,
|
||||
is_input=True,
|
||||
)
|
||||
if length_result is not None:
|
||||
return length_result
|
||||
|
||||
# Blocked phrases
|
||||
phrase_result = _check_blocked_phrases(text, rules.blocked_phrases)
|
||||
if phrase_result is not None:
|
||||
return phrase_result
|
||||
|
||||
# Injection detection
|
||||
if rules.detect_injection:
|
||||
injection_result = _check_injection(text)
|
||||
if injection_result is not None:
|
||||
return injection_result
|
||||
|
||||
# PII checks
|
||||
if rules.block_pii or rules.redact_pii:
|
||||
pii_findings = _detect_pii(text)
|
||||
if pii_findings:
|
||||
pii_types = [f[0] for f in pii_findings]
|
||||
logger.info("pii_detected_in_input", pii_types=pii_types)
|
||||
|
||||
if rules.block_pii:
|
||||
return GuardrailResult.blocked(
|
||||
GuardrailViolation.CONTAINS_PII,
|
||||
f"Input contains PII: {', '.join(pii_types)}",
|
||||
)
|
||||
|
||||
# Redact instead of block
|
||||
redacted = _redact_pii(text)
|
||||
return GuardrailResult.filtered(
|
||||
redacted,
|
||||
GuardrailViolation.CONTAINS_PII,
|
||||
f"PII redacted: {', '.join(pii_types)}",
|
||||
)
|
||||
|
||||
return GuardrailResult.ok(text)
|
||||
|
||||
|
||||
async def filter_output(text: str, rules: GuardrailRules) -> GuardrailResult:
|
||||
"""Filter output text, redacting sensitive content.
|
||||
|
||||
Args:
|
||||
text: Output text to filter.
|
||||
rules: Guardrail rules to apply.
|
||||
|
||||
Returns:
|
||||
GuardrailResult with potentially filtered content.
|
||||
"""
|
||||
# Length check
|
||||
length_result = _check_length(
|
||||
text,
|
||||
min_length=0, # No minimum for output
|
||||
max_length=rules.max_output_length,
|
||||
is_input=False,
|
||||
)
|
||||
if length_result is not None:
|
||||
# Truncate instead of blocking for output
|
||||
truncated = text[: rules.max_output_length]
|
||||
return GuardrailResult.filtered(
|
||||
truncated,
|
||||
GuardrailViolation.OUTPUT_TOO_LONG,
|
||||
f"Output truncated to {rules.max_output_length} characters",
|
||||
)
|
||||
|
||||
# Blocked phrases in output
|
||||
phrase_result = _check_blocked_phrases(text, rules.blocked_phrases)
|
||||
if phrase_result is not None:
|
||||
return phrase_result
|
||||
|
||||
# PII redaction in output (always redact, never block output)
|
||||
if rules.redact_pii:
|
||||
pii_findings = _detect_pii(text)
|
||||
if pii_findings:
|
||||
pii_types = [f[0] for f in pii_findings]
|
||||
logger.info("pii_detected_in_output", pii_types=pii_types)
|
||||
redacted = _redact_pii(text)
|
||||
return GuardrailResult.filtered(
|
||||
redacted,
|
||||
GuardrailViolation.CONTAINS_PII,
|
||||
f"PII redacted: {', '.join(pii_types)}",
|
||||
)
|
||||
|
||||
return GuardrailResult.ok(text)
|
||||
|
||||
|
||||
def create_default_rules() -> GuardrailRules:
|
||||
"""Create default guardrail rules."""
|
||||
return GuardrailRules()
|
||||
|
||||
|
||||
def create_strict_rules() -> GuardrailRules:
|
||||
"""Create strict guardrail rules with PII blocking."""
|
||||
return GuardrailRules(
|
||||
block_pii=True,
|
||||
redact_pii=False,
|
||||
detect_injection=True,
|
||||
max_input_length=2000,
|
||||
)
|
||||
231
src/noteflow/infrastructure/ai/interrupts.py
Normal file
231
src/noteflow/infrastructure/ai/interrupts.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""Infrastructure utilities for LangGraph human-in-the-loop interrupts.
|
||||
|
||||
Wraps LangGraph's interrupt() and Command APIs for consistent usage across graphs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Final
|
||||
from uuid import uuid4
|
||||
|
||||
from langgraph.types import Command, interrupt
|
||||
|
||||
from noteflow.domain.ai.interrupts import (
|
||||
InterruptAction,
|
||||
InterruptResponse,
|
||||
create_annotation_interrupt,
|
||||
create_web_search_interrupt,
|
||||
)
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.ai.nodes.annotation_suggester import SuggestedAnnotation
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
INTERRUPT_RESPONSE_KEY: Final[str] = "response"
|
||||
INTERRUPT_APPROVED_VALUE: Final[str] = "approved"
|
||||
|
||||
|
||||
def request_web_search_approval(
|
||||
query: str,
|
||||
*,
|
||||
allow_modify: bool = False,
|
||||
) -> InterruptResponse:
|
||||
"""Request user approval for web search via LangGraph interrupt.
|
||||
|
||||
Args:
|
||||
query: The search query to be executed.
|
||||
allow_modify: Whether the user can modify the query.
|
||||
|
||||
Returns:
|
||||
InterruptResponse with user's decision.
|
||||
"""
|
||||
request_id = str(uuid4())
|
||||
interrupt_request = create_web_search_interrupt(
|
||||
query=query,
|
||||
request_id=request_id,
|
||||
allow_modify=allow_modify,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"interrupt_web_search_requested",
|
||||
request_id=request_id,
|
||||
query_preview=query[:50],
|
||||
)
|
||||
|
||||
response_data = interrupt(interrupt_request.to_dict())
|
||||
|
||||
return _parse_interrupt_response(response_data, request_id)
|
||||
|
||||
|
||||
def request_annotation_approval(
|
||||
annotations: list[SuggestedAnnotation],
|
||||
) -> InterruptResponse:
|
||||
"""Request user approval for suggested annotations via LangGraph interrupt.
|
||||
|
||||
Args:
|
||||
annotations: List of suggested annotations to approve.
|
||||
|
||||
Returns:
|
||||
InterruptResponse with user's decision.
|
||||
"""
|
||||
request_id = str(uuid4())
|
||||
annotation_dicts = [ann.to_dict() for ann in annotations]
|
||||
interrupt_request = create_annotation_interrupt(
|
||||
annotations=annotation_dicts,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"interrupt_annotation_requested",
|
||||
request_id=request_id,
|
||||
annotation_count=len(annotations),
|
||||
)
|
||||
|
||||
response_data = interrupt(interrupt_request.to_dict())
|
||||
|
||||
return _parse_interrupt_response(response_data, request_id)
|
||||
|
||||
|
||||
def _parse_interrupt_response(
|
||||
response_data: object,
|
||||
request_id: str,
|
||||
) -> InterruptResponse:
|
||||
"""Parse LangGraph interrupt response into domain type.
|
||||
|
||||
Args:
|
||||
response_data: Raw response from LangGraph interrupt.
|
||||
request_id: ID of the original request.
|
||||
|
||||
Returns:
|
||||
Parsed InterruptResponse.
|
||||
"""
|
||||
if isinstance(response_data, str):
|
||||
action = _string_to_action(response_data)
|
||||
return InterruptResponse(action=action, request_id=request_id)
|
||||
|
||||
if isinstance(response_data, dict):
|
||||
action_str = str(response_data.get("action", "reject"))
|
||||
action = _string_to_action(action_str)
|
||||
|
||||
modified_value = response_data.get("modified_value")
|
||||
if modified_value is not None and not isinstance(modified_value, dict):
|
||||
modified_value = None
|
||||
|
||||
user_message = response_data.get("user_message")
|
||||
if user_message is not None:
|
||||
user_message = str(user_message)
|
||||
|
||||
return InterruptResponse(
|
||||
action=action,
|
||||
request_id=request_id,
|
||||
modified_value=modified_value,
|
||||
user_message=user_message,
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"interrupt_response_unknown_format",
|
||||
request_id=request_id,
|
||||
response_type=type(response_data).__name__,
|
||||
)
|
||||
return InterruptResponse(action=InterruptAction.REJECT, request_id=request_id)
|
||||
|
||||
|
||||
def _string_to_action(value: str) -> InterruptAction:
|
||||
"""Convert string response to InterruptAction."""
|
||||
normalized = value.lower().strip()
|
||||
if normalized in ("approve", "yes", "approved", "accept"):
|
||||
return InterruptAction.APPROVE
|
||||
if normalized in ("modify", "edit", "change"):
|
||||
return InterruptAction.MODIFY
|
||||
return InterruptAction.REJECT
|
||||
|
||||
|
||||
def create_resume_command(response: InterruptResponse) -> Command[None]:
|
||||
"""Create a LangGraph Command to resume execution with user response.
|
||||
|
||||
Args:
|
||||
response: User's interrupt response.
|
||||
|
||||
Returns:
|
||||
Command to resume graph execution.
|
||||
"""
|
||||
return Command(resume=response.to_dict())
|
||||
|
||||
|
||||
class InterruptHandler:
|
||||
"""Handles interrupt requests and responses for a graph execution."""
|
||||
|
||||
_require_web_approval: bool
|
||||
|
||||
def __init__(self, require_web_approval: bool = True) -> None:
|
||||
self._require_web_approval = require_web_approval
|
||||
|
||||
def should_interrupt_for_web_search(self) -> bool:
|
||||
return self._require_web_approval
|
||||
|
||||
def request_web_search(self, query: str) -> InterruptResponse:
|
||||
return request_web_search_approval(query)
|
||||
|
||||
def request_annotation_approval(
|
||||
self,
|
||||
annotations: list[SuggestedAnnotation],
|
||||
) -> InterruptResponse:
|
||||
return request_annotation_approval(annotations)
|
||||
|
||||
|
||||
def check_web_search_approval(
|
||||
query: str,
|
||||
require_approval: bool,
|
||||
) -> bool:
|
||||
"""Check if web search should proceed (with optional interrupt).
|
||||
|
||||
Args:
|
||||
query: Search query to execute.
|
||||
require_approval: Whether to interrupt for user approval.
|
||||
|
||||
Returns:
|
||||
True if search should proceed, False if rejected.
|
||||
"""
|
||||
if not require_approval:
|
||||
return True
|
||||
|
||||
response = request_web_search_approval(query)
|
||||
return response.is_approved
|
||||
|
||||
|
||||
def check_annotation_approval(
|
||||
annotations: list[SuggestedAnnotation],
|
||||
) -> tuple[bool, list[SuggestedAnnotation]]:
|
||||
"""Check if annotations should be applied (with interrupt).
|
||||
|
||||
Args:
|
||||
annotations: Suggested annotations to approve.
|
||||
|
||||
Returns:
|
||||
Tuple of (should_apply, possibly_modified_annotations).
|
||||
"""
|
||||
if not annotations:
|
||||
return False, []
|
||||
|
||||
response = request_annotation_approval(annotations)
|
||||
|
||||
if response.is_rejected:
|
||||
return False, []
|
||||
|
||||
if response.is_modified and response.modified_value:
|
||||
modified_list_raw = response.modified_value.get("annotations", [])
|
||||
if isinstance(modified_list_raw, list):
|
||||
from noteflow.infrastructure.ai.nodes.annotation_suggester import (
|
||||
SuggestedAnnotation,
|
||||
)
|
||||
|
||||
modified_annotations: list[SuggestedAnnotation] = []
|
||||
for item in modified_list_raw:
|
||||
if isinstance(item, dict):
|
||||
item_dict: dict[str, object] = {str(k): v for k, v in item.items()}
|
||||
modified_annotations.append(SuggestedAnnotation.from_dict(item_dict))
|
||||
return True, modified_annotations
|
||||
|
||||
return response.is_approved, annotations
|
||||
39
src/noteflow/infrastructure/ai/nodes/__init__.py
Normal file
39
src/noteflow/infrastructure/ai/nodes/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""LangGraph node implementations."""
|
||||
|
||||
from noteflow.infrastructure.ai.nodes.annotation_suggester import (
|
||||
SuggestedAnnotation,
|
||||
SuggestedAnnotationType,
|
||||
extract_annotations_from_answer,
|
||||
)
|
||||
from noteflow.infrastructure.ai.nodes.verification import (
|
||||
VerificationResult,
|
||||
verify_citations,
|
||||
)
|
||||
from noteflow.infrastructure.ai.nodes.web_search import (
|
||||
DisabledWebSearchProvider,
|
||||
WebSearchConfig,
|
||||
WebSearchProvider,
|
||||
WebSearchResponse,
|
||||
WebSearchResult,
|
||||
derive_search_query,
|
||||
execute_web_search,
|
||||
format_results_for_context,
|
||||
merge_contexts,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DisabledWebSearchProvider",
|
||||
"SuggestedAnnotation",
|
||||
"SuggestedAnnotationType",
|
||||
"VerificationResult",
|
||||
"WebSearchConfig",
|
||||
"WebSearchProvider",
|
||||
"WebSearchResponse",
|
||||
"WebSearchResult",
|
||||
"derive_search_query",
|
||||
"execute_web_search",
|
||||
"extract_annotations_from_answer",
|
||||
"format_results_for_context",
|
||||
"merge_contexts",
|
||||
"verify_citations",
|
||||
]
|
||||
121
src/noteflow/infrastructure/ai/nodes/annotation_suggester.py
Normal file
121
src/noteflow/infrastructure/ai/nodes/annotation_suggester.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Annotation suggester for extracting action items and decisions from answers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Final
|
||||
|
||||
|
||||
class SuggestedAnnotationType(str, Enum):
|
||||
ACTION_ITEM = "action_item"
|
||||
DECISION = "decision"
|
||||
NOTE = "note"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SuggestedAnnotation:
|
||||
text: str
|
||||
annotation_type: SuggestedAnnotationType
|
||||
segment_ids: tuple[int, ...]
|
||||
confidence: float = 0.8
|
||||
|
||||
def to_dict(self) -> dict[str, object]:
|
||||
return {
|
||||
"text": self.text,
|
||||
"type": self.annotation_type.value,
|
||||
"segment_ids": list(self.segment_ids),
|
||||
"confidence": self.confidence,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, object]) -> SuggestedAnnotation:
|
||||
text = str(data.get("text", ""))
|
||||
type_str = str(data.get("type", "note"))
|
||||
segment_ids_raw = data.get("segment_ids", [])
|
||||
if isinstance(segment_ids_raw, list):
|
||||
segment_ids = tuple(
|
||||
int(sid) for sid in segment_ids_raw if isinstance(sid, (int, float))
|
||||
)
|
||||
else:
|
||||
segment_ids = ()
|
||||
confidence_raw = data.get("confidence", 0.8)
|
||||
confidence = float(confidence_raw) if isinstance(confidence_raw, (int, float)) else 0.8
|
||||
|
||||
try:
|
||||
annotation_type = SuggestedAnnotationType(type_str)
|
||||
except ValueError:
|
||||
annotation_type = SuggestedAnnotationType.NOTE
|
||||
|
||||
return cls(
|
||||
text=text,
|
||||
annotation_type=annotation_type,
|
||||
segment_ids=segment_ids,
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
|
||||
ACTION_ITEM_PATTERNS: Final[tuple[re.Pattern[str], ...]] = (
|
||||
re.compile(r"(?:need to|should|must|will|going to|has to)\s+(.+?)(?:\.|$)", re.IGNORECASE),
|
||||
re.compile(r"(?:action item|TODO|task):\s*(.+?)(?:\.|$)", re.IGNORECASE),
|
||||
re.compile(r"(?:follow[- ]up|next step):\s*(.+?)(?:\.|$)", re.IGNORECASE),
|
||||
)
|
||||
|
||||
DECISION_PATTERNS: Final[tuple[re.Pattern[str], ...]] = (
|
||||
re.compile(r"(?:decided to|agreed to|will go with|chose to)\s+(.+?)(?:\.|$)", re.IGNORECASE),
|
||||
re.compile(r"(?:decision|resolution):\s*(.+?)(?:\.|$)", re.IGNORECASE),
|
||||
re.compile(r"(?:the team|we|they) (?:decided|agreed|chose)\s+(.+?)(?:\.|$)", re.IGNORECASE),
|
||||
)
|
||||
|
||||
MIN_TEXT_LENGTH: Final[int] = 10
|
||||
MAX_TEXT_LENGTH: Final[int] = 200
|
||||
|
||||
|
||||
def extract_annotations_from_answer(
|
||||
answer: str,
|
||||
cited_segment_ids: tuple[int, ...],
|
||||
) -> list[SuggestedAnnotation]:
|
||||
"""Extract action items and decisions from synthesized answer."""
|
||||
suggestions: list[SuggestedAnnotation] = []
|
||||
|
||||
for pattern in ACTION_ITEM_PATTERNS:
|
||||
for match in pattern.finditer(answer):
|
||||
text = match.group(1).strip()
|
||||
if MIN_TEXT_LENGTH <= len(text) <= MAX_TEXT_LENGTH:
|
||||
suggestions.append(
|
||||
SuggestedAnnotation(
|
||||
text=text,
|
||||
annotation_type=SuggestedAnnotationType.ACTION_ITEM,
|
||||
segment_ids=cited_segment_ids,
|
||||
confidence=0.7,
|
||||
)
|
||||
)
|
||||
|
||||
for pattern in DECISION_PATTERNS:
|
||||
for match in pattern.finditer(answer):
|
||||
text = match.group(1).strip()
|
||||
if MIN_TEXT_LENGTH <= len(text) <= MAX_TEXT_LENGTH:
|
||||
suggestions.append(
|
||||
SuggestedAnnotation(
|
||||
text=text,
|
||||
annotation_type=SuggestedAnnotationType.DECISION,
|
||||
segment_ids=cited_segment_ids,
|
||||
confidence=0.75,
|
||||
)
|
||||
)
|
||||
|
||||
return _dedupe_suggestions(suggestions)
|
||||
|
||||
|
||||
def _dedupe_suggestions(suggestions: list[SuggestedAnnotation]) -> list[SuggestedAnnotation]:
|
||||
seen_texts: set[str] = set()
|
||||
deduped: list[SuggestedAnnotation] = []
|
||||
|
||||
for suggestion in suggestions:
|
||||
normalized = suggestion.text.lower().strip()
|
||||
if normalized not in seen_texts:
|
||||
seen_texts.add(normalized)
|
||||
deduped.append(suggestion)
|
||||
|
||||
return deduped
|
||||
61
src/noteflow/infrastructure/ai/nodes/verification.py
Normal file
61
src/noteflow/infrastructure/ai/nodes/verification.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Citation verification for AI-generated answers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Final
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VerificationResult:
|
||||
"""Result of citation verification.
|
||||
|
||||
Attributes:
|
||||
is_valid: True if all cited segment IDs exist in available segments.
|
||||
invalid_citation_indices: Indices of citations that failed validation.
|
||||
reason: Human-readable explanation if validation failed.
|
||||
"""
|
||||
|
||||
is_valid: bool
|
||||
invalid_citation_indices: tuple[int, ...]
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
NO_SEGMENTS_REASON: Final[str] = "No segments retrieved for question"
|
||||
INVALID_CITATIONS_PREFIX: Final[str] = "Invalid citation indices: "
|
||||
|
||||
|
||||
def verify_citations(
|
||||
cited_ids: list[int],
|
||||
available_ids: set[int],
|
||||
) -> VerificationResult:
|
||||
"""Verify all cited segment IDs exist in available segments.
|
||||
|
||||
Args:
|
||||
cited_ids: List of segment IDs cited in the answer.
|
||||
available_ids: Set of valid segment IDs from retrieval.
|
||||
|
||||
Returns:
|
||||
VerificationResult with validation status and any invalid indices.
|
||||
"""
|
||||
if not available_ids:
|
||||
return VerificationResult(
|
||||
is_valid=False,
|
||||
invalid_citation_indices=(),
|
||||
reason=NO_SEGMENTS_REASON,
|
||||
)
|
||||
|
||||
invalid_indices = tuple(i for i, cid in enumerate(cited_ids) if cid not in available_ids)
|
||||
|
||||
if invalid_indices:
|
||||
return VerificationResult(
|
||||
is_valid=False,
|
||||
invalid_citation_indices=invalid_indices,
|
||||
reason=f"{INVALID_CITATIONS_PREFIX}{invalid_indices}",
|
||||
)
|
||||
|
||||
return VerificationResult(
|
||||
is_valid=True,
|
||||
invalid_citation_indices=(),
|
||||
reason=None,
|
||||
)
|
||||
226
src/noteflow/infrastructure/ai/nodes/web_search.py
Normal file
226
src/noteflow/infrastructure/ai/nodes/web_search.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Web search node for augmenting RAG with external knowledge."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Final, Protocol
|
||||
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DEFAULT_MAX_RESULTS: Final[int] = 5
|
||||
DEFAULT_TIMEOUT_SECONDS: Final[float] = 10.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WebSearchResult:
|
||||
"""A single web search result."""
|
||||
|
||||
title: str
|
||||
url: str
|
||||
snippet: str
|
||||
score: float = 1.0
|
||||
|
||||
def to_dict(self) -> dict[str, object]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"title": self.title,
|
||||
"url": self.url,
|
||||
"snippet": self.snippet,
|
||||
"score": self.score,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WebSearchResponse:
|
||||
"""Response from a web search query."""
|
||||
|
||||
query: str
|
||||
results: tuple[WebSearchResult, ...]
|
||||
total_results: int
|
||||
search_time_ms: float
|
||||
|
||||
@property
|
||||
def has_results(self) -> bool:
|
||||
"""Check if search returned any results."""
|
||||
return len(self.results) > 0
|
||||
|
||||
|
||||
class WebSearchProvider(Protocol):
|
||||
"""Protocol for web search providers.
|
||||
|
||||
Implementations can integrate with:
|
||||
- Exa AI
|
||||
- SerpAPI
|
||||
- Brave Search API
|
||||
- Bing Web Search API
|
||||
- Google Custom Search
|
||||
"""
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = DEFAULT_MAX_RESULTS,
|
||||
) -> WebSearchResponse:
|
||||
"""Execute a web search query.
|
||||
|
||||
Args:
|
||||
query: Search query string.
|
||||
max_results: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
WebSearchResponse with search results.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DisabledWebSearchProvider:
|
||||
"""Stub provider that returns empty results.
|
||||
|
||||
Used when web search is not configured or disabled.
|
||||
"""
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
_max_results: int = DEFAULT_MAX_RESULTS,
|
||||
) -> WebSearchResponse:
|
||||
"""Return empty results - web search disabled."""
|
||||
logger.debug("web_search_disabled", query=query[:50])
|
||||
return WebSearchResponse(
|
||||
query=query,
|
||||
results=(),
|
||||
total_results=0,
|
||||
search_time_ms=0.0,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSearchConfig:
|
||||
"""Configuration for web search node."""
|
||||
|
||||
enabled: bool = False
|
||||
max_results: int = DEFAULT_MAX_RESULTS
|
||||
timeout_seconds: float = DEFAULT_TIMEOUT_SECONDS
|
||||
require_approval: bool = True
|
||||
|
||||
|
||||
def format_results_for_context(results: tuple[WebSearchResult, ...]) -> str:
|
||||
"""Format web search results for LLM context.
|
||||
|
||||
Args:
|
||||
results: Web search results to format.
|
||||
|
||||
Returns:
|
||||
Formatted string suitable for LLM context.
|
||||
"""
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
formatted_parts: list[str] = ["## Web Search Results\n"]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
formatted_parts.append(f"### [{i}] {result.title}")
|
||||
formatted_parts.append(f"Source: {result.url}")
|
||||
formatted_parts.append(f"{result.snippet}\n")
|
||||
|
||||
return "\n".join(formatted_parts)
|
||||
|
||||
|
||||
async def execute_web_search(
|
||||
query: str,
|
||||
provider: WebSearchProvider,
|
||||
config: WebSearchConfig,
|
||||
) -> WebSearchResponse:
|
||||
"""Execute web search with configuration.
|
||||
|
||||
Args:
|
||||
query: Search query derived from user question.
|
||||
provider: Web search provider implementation.
|
||||
config: Search configuration.
|
||||
|
||||
Returns:
|
||||
WebSearchResponse with results (empty if disabled).
|
||||
"""
|
||||
if not config.enabled:
|
||||
logger.debug("web_search_skipped_disabled")
|
||||
return WebSearchResponse(
|
||||
query=query,
|
||||
results=(),
|
||||
total_results=0,
|
||||
search_time_ms=0.0,
|
||||
)
|
||||
|
||||
logger.info("web_search_executing", query=query[:100], max_results=config.max_results)
|
||||
|
||||
response = await provider.search(
|
||||
query=query,
|
||||
max_results=config.max_results,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"web_search_completed",
|
||||
query=query[:50],
|
||||
result_count=len(response.results),
|
||||
search_time_ms=response.search_time_ms,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def merge_contexts(
|
||||
transcript_context: str,
|
||||
web_results: WebSearchResponse,
|
||||
) -> str:
|
||||
"""Merge transcript segments with web search results.
|
||||
|
||||
Args:
|
||||
transcript_context: Context from transcript segments.
|
||||
web_results: Web search response.
|
||||
|
||||
Returns:
|
||||
Combined context for LLM synthesis.
|
||||
"""
|
||||
if not web_results.has_results:
|
||||
return transcript_context
|
||||
|
||||
web_context = format_results_for_context(web_results.results)
|
||||
|
||||
return f"""## Meeting Transcript Context
|
||||
{transcript_context}
|
||||
|
||||
{web_context}
|
||||
|
||||
Note: Web search results are provided as supplementary context.
|
||||
Prioritize information from the meeting transcript when answering questions about the meeting.
|
||||
"""
|
||||
|
||||
|
||||
def derive_search_query(question: str, meeting_context: str | None = None) -> str:
|
||||
"""Derive a web search query from user question and context.
|
||||
|
||||
Args:
|
||||
question: User's original question.
|
||||
meeting_context: Optional context about the meeting topic.
|
||||
|
||||
Returns:
|
||||
Optimized search query.
|
||||
"""
|
||||
# Simple approach: use the question directly
|
||||
# A more sophisticated approach would use an LLM to generate the query
|
||||
query = question.strip()
|
||||
|
||||
# Add meeting context keywords if available
|
||||
if meeting_context:
|
||||
# Extract key terms from context (simplified)
|
||||
context_terms = meeting_context[:100].strip()
|
||||
if context_terms:
|
||||
query = f"{query} {context_terms}"
|
||||
|
||||
# Limit query length
|
||||
max_query_length = 256
|
||||
if len(query) > max_query_length:
|
||||
query = query[:max_query_length].rsplit(" ", 1)[0]
|
||||
|
||||
return query
|
||||
27
src/noteflow/infrastructure/ai/tools/__init__.py
Normal file
27
src/noteflow/infrastructure/ai/tools/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Tool adapters for LangGraph workflows."""
|
||||
|
||||
from noteflow.infrastructure.ai.tools.retrieval import (
|
||||
BatchEmbedderProtocol,
|
||||
EmbedderProtocol,
|
||||
RetrievalResult,
|
||||
retrieve_segments,
|
||||
retrieve_segments_batch,
|
||||
retrieve_segments_workspace,
|
||||
retrieve_segments_workspace_batch,
|
||||
)
|
||||
from noteflow.infrastructure.ai.tools.synthesis import (
|
||||
SynthesisResult,
|
||||
synthesize_answer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BatchEmbedderProtocol",
|
||||
"EmbedderProtocol",
|
||||
"RetrievalResult",
|
||||
"SynthesisResult",
|
||||
"retrieve_segments",
|
||||
"retrieve_segments_batch",
|
||||
"retrieve_segments_workspace",
|
||||
"retrieve_segments_workspace_batch",
|
||||
"synthesize_answer",
|
||||
]
|
||||
237
src/noteflow/infrastructure/ai/tools/retrieval.py
Normal file
237
src/noteflow/infrastructure/ai/tools/retrieval.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Segment retrieval tools for LangGraph workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Final, Protocol, runtime_checkable
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
|
||||
# Limit concurrent parallel operations to prevent resource exhaustion
|
||||
MAX_CONCURRENT_OPERATIONS: Final[int] = 10
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class EmbedderProtocol(Protocol):
|
||||
"""Protocol for text embedding providers."""
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""Embed a single text string."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BatchEmbedderProtocol(EmbedderProtocol, Protocol):
|
||||
"""Extended protocol for embedders supporting batch operations."""
|
||||
|
||||
async def embed_batch(self, texts: Sequence[str]) -> list[list[float]]:
|
||||
"""Embed multiple texts in a single batch operation.
|
||||
|
||||
More efficient than calling embed() multiple times for providers
|
||||
that support batching (reduces API calls, leverages GPU batching).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SegmentLike(Protocol):
|
||||
segment_id: int
|
||||
meeting_id: MeetingId | None
|
||||
text: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
|
||||
|
||||
class SegmentSearchProtocol(Protocol):
|
||||
async def search_semantic(
|
||||
self,
|
||||
query_embedding: list[float],
|
||||
limit: int,
|
||||
meeting_id: MeetingId | None,
|
||||
) -> Sequence[tuple[SegmentLike, float]]: ...
|
||||
|
||||
|
||||
class WorkspaceSegmentSearchProtocol(Protocol):
|
||||
async def search_semantic_workspace(
|
||||
self,
|
||||
query_embedding: list[float],
|
||||
workspace_id: UUID,
|
||||
project_id: UUID | None,
|
||||
limit: int,
|
||||
) -> Sequence[tuple[SegmentLike, float]]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RetrievalResult:
|
||||
segment_id: int
|
||||
meeting_id: UUID
|
||||
text: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
score: float
|
||||
|
||||
|
||||
def _meeting_id_to_uuid(mid: MeetingId | None) -> UUID:
|
||||
if mid is None:
|
||||
msg = "meeting_id is required for RetrievalResult"
|
||||
raise ValueError(msg)
|
||||
return UUID(str(mid))
|
||||
|
||||
|
||||
async def retrieve_segments(
|
||||
query: str,
|
||||
embedder: EmbedderProtocol,
|
||||
segment_repo: SegmentSearchProtocol,
|
||||
meeting_id: MeetingId | None = None,
|
||||
top_k: int = 8,
|
||||
) -> list[RetrievalResult]:
|
||||
"""Retrieve relevant transcript segments via semantic search."""
|
||||
query_embedding = await embedder.embed(query)
|
||||
results = await segment_repo.search_semantic(
|
||||
query_embedding=query_embedding,
|
||||
limit=top_k,
|
||||
meeting_id=meeting_id,
|
||||
)
|
||||
return [
|
||||
RetrievalResult(
|
||||
segment_id=segment.segment_id,
|
||||
meeting_id=_meeting_id_to_uuid(segment.meeting_id),
|
||||
text=segment.text,
|
||||
start_time=segment.start_time,
|
||||
end_time=segment.end_time,
|
||||
score=score,
|
||||
)
|
||||
for segment, score in results
|
||||
]
|
||||
|
||||
|
||||
async def retrieve_segments_workspace(
|
||||
query: str,
|
||||
embedder: EmbedderProtocol,
|
||||
segment_repo: WorkspaceSegmentSearchProtocol,
|
||||
workspace_id: UUID,
|
||||
project_id: UUID | None = None,
|
||||
top_k: int = 20,
|
||||
) -> list[RetrievalResult]:
|
||||
"""Retrieve relevant transcript segments across workspace/project via semantic search."""
|
||||
query_embedding = await embedder.embed(query)
|
||||
results = await segment_repo.search_semantic_workspace(
|
||||
query_embedding=query_embedding,
|
||||
workspace_id=workspace_id,
|
||||
project_id=project_id,
|
||||
limit=top_k,
|
||||
)
|
||||
return [
|
||||
RetrievalResult(
|
||||
segment_id=segment.segment_id,
|
||||
meeting_id=_meeting_id_to_uuid(segment.meeting_id),
|
||||
text=segment.text,
|
||||
start_time=segment.start_time,
|
||||
end_time=segment.end_time,
|
||||
score=score,
|
||||
)
|
||||
for segment, score in results
|
||||
]
|
||||
|
||||
|
||||
async def _embed_batch_fallback(
|
||||
texts: Sequence[str],
|
||||
embedder: EmbedderProtocol,
|
||||
) -> list[list[float]]:
|
||||
"""Embed multiple texts, using batch API if available or parallel fallback."""
|
||||
if isinstance(embedder, BatchEmbedderProtocol):
|
||||
return await embedder.embed_batch(texts)
|
||||
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_OPERATIONS)
|
||||
|
||||
async def _bounded_embed(text: str) -> list[float]:
|
||||
async with semaphore:
|
||||
return await embedder.embed(text)
|
||||
|
||||
return list(await asyncio.gather(*(_bounded_embed(t) for t in texts)))
|
||||
|
||||
|
||||
async def retrieve_segments_batch(
|
||||
queries: Sequence[str],
|
||||
embedder: EmbedderProtocol,
|
||||
segment_repo: SegmentSearchProtocol,
|
||||
meeting_id: MeetingId | None = None,
|
||||
top_k: int = 8,
|
||||
) -> list[list[RetrievalResult]]:
|
||||
"""Retrieve segments for multiple queries in parallel.
|
||||
|
||||
Uses batch embedding when available, then parallel search execution.
|
||||
Returns results in the same order as input queries.
|
||||
"""
|
||||
if not queries:
|
||||
return []
|
||||
embeddings = await _embed_batch_fallback(list(queries), embedder)
|
||||
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_OPERATIONS)
|
||||
|
||||
async def _search(emb: list[float]) -> list[RetrievalResult]:
|
||||
async with semaphore:
|
||||
results = await segment_repo.search_semantic(
|
||||
query_embedding=emb,
|
||||
limit=top_k,
|
||||
meeting_id=meeting_id,
|
||||
)
|
||||
return [
|
||||
RetrievalResult(
|
||||
segment_id=seg.segment_id,
|
||||
meeting_id=_meeting_id_to_uuid(seg.meeting_id),
|
||||
text=seg.text,
|
||||
start_time=seg.start_time,
|
||||
end_time=seg.end_time,
|
||||
score=score,
|
||||
)
|
||||
for seg, score in results
|
||||
]
|
||||
|
||||
search_results = await asyncio.gather(*(_search(emb) for emb in embeddings))
|
||||
return list(search_results)
|
||||
|
||||
|
||||
async def retrieve_segments_workspace_batch(
|
||||
queries: Sequence[str],
|
||||
embedder: EmbedderProtocol,
|
||||
segment_repo: WorkspaceSegmentSearchProtocol,
|
||||
workspace_id: UUID,
|
||||
project_id: UUID | None = None,
|
||||
top_k: int = 20,
|
||||
) -> list[list[RetrievalResult]]:
|
||||
"""Retrieve workspace segments for multiple queries in parallel.
|
||||
|
||||
Uses batch embedding when available, then parallel search execution.
|
||||
Returns results in the same order as input queries.
|
||||
"""
|
||||
if not queries:
|
||||
return []
|
||||
embeddings = await _embed_batch_fallback(list(queries), embedder)
|
||||
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_OPERATIONS)
|
||||
|
||||
async def _search(emb: list[float]) -> list[RetrievalResult]:
|
||||
async with semaphore:
|
||||
results = await segment_repo.search_semantic_workspace(
|
||||
query_embedding=emb,
|
||||
workspace_id=workspace_id,
|
||||
project_id=project_id,
|
||||
limit=top_k,
|
||||
)
|
||||
return [
|
||||
RetrievalResult(
|
||||
segment_id=seg.segment_id,
|
||||
meeting_id=_meeting_id_to_uuid(seg.meeting_id),
|
||||
text=seg.text,
|
||||
start_time=seg.start_time,
|
||||
end_time=seg.end_time,
|
||||
score=score,
|
||||
)
|
||||
for seg, score in results
|
||||
]
|
||||
|
||||
search_results = await asyncio.gather(*(_search(emb) for emb in embeddings))
|
||||
return list(search_results)
|
||||
60
src/noteflow/infrastructure/ai/tools/synthesis.py
Normal file
60
src/noteflow/infrastructure/ai/tools/synthesis.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Answer synthesis tools for LangGraph workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Final, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.ai.tools.retrieval import RetrievalResult
|
||||
|
||||
|
||||
class LLMProtocol(Protocol):
|
||||
async def complete(self, prompt: str) -> str: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SynthesisResult:
|
||||
answer: str
|
||||
cited_segment_ids: list[int]
|
||||
|
||||
|
||||
SYNTHESIS_PROMPT_TEMPLATE: Final[
|
||||
str
|
||||
] = """Answer the question based on the following transcript segments.
|
||||
Cite specific segments by their ID when making claims.
|
||||
|
||||
Question: {question}
|
||||
|
||||
Segments:
|
||||
{segments}
|
||||
|
||||
Answer (cite segment IDs in brackets like [1], [3]):"""
|
||||
|
||||
CITATION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\[(\d+)\]")
|
||||
|
||||
|
||||
async def synthesize_answer(
|
||||
question: str,
|
||||
segments: list[RetrievalResult],
|
||||
llm: LLMProtocol,
|
||||
) -> SynthesisResult:
|
||||
"""Generate answer with segment citations using LLM."""
|
||||
segment_text = "\n".join(
|
||||
f"[{s.segment_id}] ({s.start_time:.1f}s-{s.end_time:.1f}s): {s.text}" for s in segments
|
||||
)
|
||||
prompt = SYNTHESIS_PROMPT_TEMPLATE.format(
|
||||
question=question,
|
||||
segments=segment_text,
|
||||
)
|
||||
answer = await llm.complete(prompt)
|
||||
valid_ids = {s.segment_id for s in segments}
|
||||
cited_ids = extract_cited_ids(answer, valid_ids)
|
||||
return SynthesisResult(answer=answer, cited_segment_ids=cited_ids)
|
||||
|
||||
|
||||
def extract_cited_ids(answer: str, valid_ids: set[int]) -> list[int]:
|
||||
matches = CITATION_PATTERN.findall(answer)
|
||||
cited = [int(m) for m in matches if int(m) in valid_ids]
|
||||
return list(dict.fromkeys(cited))
|
||||
1
tests/domain/ai/__init__.py
Normal file
1
tests/domain/ai/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for domain/ai/ module."""
|
||||
118
tests/domain/ai/test_citations.py
Normal file
118
tests/domain/ai/test_citations.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.domain.ai.citations import SegmentCitation
|
||||
|
||||
|
||||
class TestSegmentCitation:
|
||||
def test_creation_with_valid_values(self) -> None:
|
||||
meeting_id = uuid4()
|
||||
citation = SegmentCitation(
|
||||
meeting_id=meeting_id,
|
||||
segment_id=1,
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
text="Test segment text",
|
||||
score=0.95,
|
||||
)
|
||||
|
||||
assert citation.meeting_id == meeting_id
|
||||
assert citation.segment_id == 1
|
||||
assert citation.start_time == 0.0
|
||||
assert citation.end_time == 5.0
|
||||
assert citation.text == "Test segment text"
|
||||
assert citation.score == 0.95
|
||||
|
||||
def test_duration_property(self) -> None:
|
||||
citation = SegmentCitation(
|
||||
meeting_id=uuid4(),
|
||||
segment_id=1,
|
||||
start_time=10.0,
|
||||
end_time=25.0,
|
||||
text="Test",
|
||||
)
|
||||
|
||||
assert citation.duration == 15.0
|
||||
|
||||
def test_default_score_is_zero(self) -> None:
|
||||
citation = SegmentCitation(
|
||||
meeting_id=uuid4(),
|
||||
segment_id=1,
|
||||
start_time=0.0,
|
||||
end_time=1.0,
|
||||
text="Test",
|
||||
)
|
||||
|
||||
assert citation.score == 0.0
|
||||
|
||||
def test_rejects_negative_segment_id(self) -> None:
|
||||
with pytest.raises(ValueError, match="segment_id must be non-negative"):
|
||||
SegmentCitation(
|
||||
meeting_id=uuid4(),
|
||||
segment_id=-1,
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
text="Test",
|
||||
)
|
||||
|
||||
def test_rejects_negative_start_time(self) -> None:
|
||||
with pytest.raises(ValueError, match="start_time must be non-negative"):
|
||||
SegmentCitation(
|
||||
meeting_id=uuid4(),
|
||||
segment_id=1,
|
||||
start_time=-1.0,
|
||||
end_time=5.0,
|
||||
text="Test",
|
||||
)
|
||||
|
||||
def test_rejects_end_time_before_start_time(self) -> None:
|
||||
with pytest.raises(ValueError, match="end_time must be >= start_time"):
|
||||
SegmentCitation(
|
||||
meeting_id=uuid4(),
|
||||
segment_id=1,
|
||||
start_time=10.0,
|
||||
end_time=5.0,
|
||||
text="Test",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_score",
|
||||
[
|
||||
pytest.param(-0.1, id="negative"),
|
||||
pytest.param(1.1, id="above_one"),
|
||||
],
|
||||
)
|
||||
def test_rejects_invalid_score(self, invalid_score: float) -> None:
|
||||
with pytest.raises(ValueError, match="score must be between 0 and 1"):
|
||||
SegmentCitation(
|
||||
meeting_id=uuid4(),
|
||||
segment_id=1,
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
text="Test",
|
||||
score=invalid_score,
|
||||
)
|
||||
|
||||
def test_accepts_zero_duration(self) -> None:
|
||||
citation = SegmentCitation(
|
||||
meeting_id=uuid4(),
|
||||
segment_id=1,
|
||||
start_time=5.0,
|
||||
end_time=5.0,
|
||||
text="Instant moment",
|
||||
)
|
||||
|
||||
assert citation.duration == 0.0
|
||||
|
||||
def test_is_frozen(self) -> None:
|
||||
citation = SegmentCitation(
|
||||
meeting_id=uuid4(),
|
||||
segment_id=1,
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
text="Test",
|
||||
)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
citation.text = "Modified" # type: ignore[misc]
|
||||
1
tests/infrastructure/ai/__init__.py
Normal file
1
tests/infrastructure/ai/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for infrastructure/ai/ module."""
|
||||
268
tests/infrastructure/ai/test_retrieval.py
Normal file
268
tests/infrastructure/ai/test_retrieval.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.infrastructure.ai.tools.retrieval import (
|
||||
BatchEmbedderProtocol,
|
||||
RetrievalResult,
|
||||
retrieve_segments,
|
||||
retrieve_segments_batch,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockSegment:
|
||||
segment_id: int
|
||||
meeting_id: object
|
||||
text: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
|
||||
|
||||
class TestRetrieveSegments:
|
||||
@pytest.fixture
|
||||
def mock_embedder(self) -> AsyncMock:
|
||||
embedder = AsyncMock()
|
||||
embedder.embed.return_value = [0.1, 0.2, 0.3]
|
||||
return embedder
|
||||
|
||||
@pytest.fixture
|
||||
def mock_segment_repo(self) -> AsyncMock:
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_meeting_id(self) -> object:
|
||||
return uuid4()
|
||||
|
||||
async def test_retrieve_segments_success(
|
||||
self,
|
||||
mock_embedder: AsyncMock,
|
||||
mock_segment_repo: AsyncMock,
|
||||
sample_meeting_id: object,
|
||||
) -> None:
|
||||
segment = MockSegment(
|
||||
segment_id=1,
|
||||
meeting_id=sample_meeting_id,
|
||||
text="Test segment",
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
)
|
||||
mock_segment_repo.search_semantic.return_value = [(segment, 0.95)]
|
||||
|
||||
results = await retrieve_segments(
|
||||
query="test query",
|
||||
embedder=mock_embedder,
|
||||
segment_repo=mock_segment_repo,
|
||||
meeting_id=sample_meeting_id, # type: ignore[arg-type]
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].segment_id == 1
|
||||
assert results[0].text == "Test segment"
|
||||
assert results[0].score == 0.95
|
||||
|
||||
async def test_retrieve_segments_calls_embedder_with_query(
|
||||
self,
|
||||
mock_embedder: AsyncMock,
|
||||
mock_segment_repo: AsyncMock,
|
||||
) -> None:
|
||||
mock_segment_repo.search_semantic.return_value = []
|
||||
|
||||
await retrieve_segments(
|
||||
query="what happened in the meeting",
|
||||
embedder=mock_embedder,
|
||||
segment_repo=mock_segment_repo,
|
||||
)
|
||||
|
||||
mock_embedder.embed.assert_called_once_with("what happened in the meeting")
|
||||
|
||||
async def test_retrieve_segments_passes_embedding_to_repo(
|
||||
self,
|
||||
mock_embedder: AsyncMock,
|
||||
mock_segment_repo: AsyncMock,
|
||||
) -> None:
|
||||
mock_embedder.embed.return_value = [1.0, 2.0, 3.0]
|
||||
mock_segment_repo.search_semantic.return_value = []
|
||||
|
||||
await retrieve_segments(
|
||||
query="test",
|
||||
embedder=mock_embedder,
|
||||
segment_repo=mock_segment_repo,
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
mock_segment_repo.search_semantic.assert_called_once_with(
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
meeting_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
async def test_retrieve_segments_empty_result(
|
||||
self,
|
||||
mock_embedder: AsyncMock,
|
||||
mock_segment_repo: AsyncMock,
|
||||
) -> None:
|
||||
mock_segment_repo.search_semantic.return_value = []
|
||||
|
||||
results = await retrieve_segments(
|
||||
query="test",
|
||||
embedder=mock_embedder,
|
||||
segment_repo=mock_segment_repo,
|
||||
)
|
||||
|
||||
assert results == []
|
||||
|
||||
async def test_retrieval_result_is_frozen(self) -> None:
|
||||
result = RetrievalResult(
|
||||
segment_id=1,
|
||||
meeting_id=uuid4(),
|
||||
text="Test",
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
score=0.9,
|
||||
)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
result.text = "Modified" # type: ignore[misc]
|
||||
|
||||
|
||||
class MockBatchEmbedder:
|
||||
def __init__(self, embedding: list[float]) -> None:
|
||||
self._embedding = embedding
|
||||
self.embed_calls: list[str] = []
|
||||
self.embed_batch_calls: list[Sequence[str]] = []
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
self.embed_calls.append(text)
|
||||
return self._embedding
|
||||
|
||||
async def embed_batch(self, texts: Sequence[str]) -> list[list[float]]:
|
||||
self.embed_batch_calls.append(texts)
|
||||
return [self._embedding for _ in texts]
|
||||
|
||||
|
||||
class TestRetrieveSegmentsBatch:
|
||||
@pytest.fixture
|
||||
def mock_embedder(self) -> AsyncMock:
|
||||
embedder = AsyncMock()
|
||||
embedder.embed.return_value = [0.1, 0.2, 0.3]
|
||||
return embedder
|
||||
|
||||
@pytest.fixture
|
||||
def batch_embedder(self) -> MockBatchEmbedder:
|
||||
return MockBatchEmbedder([0.1, 0.2, 0.3])
|
||||
|
||||
@pytest.fixture
|
||||
def mock_segment_repo(self) -> AsyncMock:
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_meeting_id(self) -> object:
|
||||
return uuid4()
|
||||
|
||||
async def test_batch_returns_empty_for_no_queries(
|
||||
self,
|
||||
mock_embedder: AsyncMock,
|
||||
mock_segment_repo: AsyncMock,
|
||||
) -> None:
|
||||
results = await retrieve_segments_batch(
|
||||
queries=[],
|
||||
embedder=mock_embedder,
|
||||
segment_repo=mock_segment_repo,
|
||||
)
|
||||
|
||||
assert results == []
|
||||
mock_embedder.embed.assert_not_called()
|
||||
|
||||
async def test_batch_uses_embed_batch_when_available(
|
||||
self,
|
||||
batch_embedder: MockBatchEmbedder,
|
||||
mock_segment_repo: AsyncMock,
|
||||
sample_meeting_id: object,
|
||||
) -> None:
|
||||
segment = MockSegment(
|
||||
segment_id=1,
|
||||
meeting_id=sample_meeting_id,
|
||||
text="Test",
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
)
|
||||
mock_segment_repo.search_semantic.return_value = [(segment, 0.9)]
|
||||
|
||||
assert isinstance(batch_embedder, BatchEmbedderProtocol)
|
||||
|
||||
results = await retrieve_segments_batch(
|
||||
queries=["query1", "query2"],
|
||||
embedder=batch_embedder,
|
||||
segment_repo=mock_segment_repo,
|
||||
meeting_id=sample_meeting_id, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert len(batch_embedder.embed_batch_calls) == 1
|
||||
assert list(batch_embedder.embed_batch_calls[0]) == ["query1", "query2"]
|
||||
assert batch_embedder.embed_calls == []
|
||||
|
||||
async def test_batch_falls_back_to_parallel_embed(
|
||||
self,
|
||||
mock_embedder: AsyncMock,
|
||||
mock_segment_repo: AsyncMock,
|
||||
sample_meeting_id: object,
|
||||
) -> None:
|
||||
segment = MockSegment(
|
||||
segment_id=1,
|
||||
meeting_id=sample_meeting_id,
|
||||
text="Test",
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
)
|
||||
mock_segment_repo.search_semantic.return_value = [(segment, 0.9)]
|
||||
|
||||
results = await retrieve_segments_batch(
|
||||
queries=["query1", "query2"],
|
||||
embedder=mock_embedder,
|
||||
segment_repo=mock_segment_repo,
|
||||
meeting_id=sample_meeting_id, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert mock_embedder.embed.call_count == 2
|
||||
|
||||
async def test_batch_preserves_query_order(
|
||||
self,
|
||||
mock_segment_repo: AsyncMock,
|
||||
sample_meeting_id: object,
|
||||
) -> None:
|
||||
segment1 = MockSegment(1, sample_meeting_id, "First", 0.0, 5.0)
|
||||
segment2 = MockSegment(2, sample_meeting_id, "Second", 5.0, 10.0)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def side_effect(
|
||||
query_embedding: list[float],
|
||||
limit: int,
|
||||
meeting_id: object,
|
||||
) -> list[tuple[MockSegment, float]]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return [(segment1, 0.9)]
|
||||
return [(segment2, 0.8)]
|
||||
|
||||
mock_segment_repo.search_semantic.side_effect = side_effect
|
||||
|
||||
embedder = MockBatchEmbedder([0.1, 0.2])
|
||||
results = await retrieve_segments_batch(
|
||||
queries=["first", "second"],
|
||||
embedder=embedder,
|
||||
segment_repo=mock_segment_repo,
|
||||
meeting_id=sample_meeting_id, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0][0].text == "First"
|
||||
assert results[1][0].text == "Second"
|
||||
149
tests/infrastructure/ai/test_synthesis.py
Normal file
149
tests/infrastructure/ai/test_synthesis.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.infrastructure.ai.tools.retrieval import RetrievalResult
|
||||
from noteflow.infrastructure.ai.tools.synthesis import (
|
||||
SynthesisResult,
|
||||
extract_cited_ids,
|
||||
synthesize_answer,
|
||||
)
|
||||
|
||||
|
||||
class TestSynthesizeAnswer:
|
||||
@pytest.fixture
|
||||
def mock_llm(self) -> AsyncMock:
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_segments(self) -> list[RetrievalResult]:
|
||||
meeting_id = uuid4()
|
||||
return [
|
||||
RetrievalResult(
|
||||
segment_id=1,
|
||||
meeting_id=meeting_id,
|
||||
text="John discussed the project timeline",
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
score=0.95,
|
||||
),
|
||||
RetrievalResult(
|
||||
segment_id=3,
|
||||
meeting_id=meeting_id,
|
||||
text="The deadline is next Friday",
|
||||
start_time=10.0,
|
||||
end_time=15.0,
|
||||
score=0.85,
|
||||
),
|
||||
]
|
||||
|
||||
async def test_synthesize_answer_returns_result(
|
||||
self,
|
||||
mock_llm: AsyncMock,
|
||||
sample_segments: list[RetrievalResult],
|
||||
) -> None:
|
||||
mock_llm.complete.return_value = "The project deadline is next Friday [3]."
|
||||
|
||||
result = await synthesize_answer(
|
||||
question="What is the deadline?",
|
||||
segments=sample_segments,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
assert isinstance(result, SynthesisResult)
|
||||
assert "deadline" in result.answer.lower()
|
||||
|
||||
async def test_synthesize_answer_extracts_citations(
|
||||
self,
|
||||
mock_llm: AsyncMock,
|
||||
sample_segments: list[RetrievalResult],
|
||||
) -> None:
|
||||
mock_llm.complete.return_value = "John discussed timelines [1] and the deadline [3]."
|
||||
|
||||
result = await synthesize_answer(
|
||||
question="What happened?",
|
||||
segments=sample_segments,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
assert result.cited_segment_ids == [1, 3]
|
||||
|
||||
async def test_synthesize_answer_filters_invalid_citations(
|
||||
self,
|
||||
mock_llm: AsyncMock,
|
||||
sample_segments: list[RetrievalResult],
|
||||
) -> None:
|
||||
mock_llm.complete.return_value = "Found [1], [99], and [3]."
|
||||
|
||||
result = await synthesize_answer(
|
||||
question="What happened?",
|
||||
segments=sample_segments,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
assert 99 not in result.cited_segment_ids
|
||||
assert result.cited_segment_ids == [1, 3]
|
||||
|
||||
async def test_synthesize_answer_builds_prompt_with_segments(
|
||||
self,
|
||||
mock_llm: AsyncMock,
|
||||
sample_segments: list[RetrievalResult],
|
||||
) -> None:
|
||||
mock_llm.complete.return_value = "Answer."
|
||||
|
||||
await synthesize_answer(
|
||||
question="What is happening?",
|
||||
segments=sample_segments,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
call_args = mock_llm.complete.call_args
|
||||
prompt = call_args[0][0]
|
||||
assert "What is happening?" in prompt
|
||||
assert "[1]" in prompt
|
||||
assert "[3]" in prompt
|
||||
assert "John discussed" in prompt
|
||||
|
||||
|
||||
class TestExtractCitedIds:
|
||||
def test_extracts_single_citation(self) -> None:
|
||||
result = extract_cited_ids("The answer is here [5].", {1, 3, 5})
|
||||
|
||||
assert result == [5]
|
||||
|
||||
def test_extracts_multiple_citations(self) -> None:
|
||||
result = extract_cited_ids("See [1] and [3] for details.", {1, 3, 5})
|
||||
|
||||
assert result == [1, 3]
|
||||
|
||||
def test_filters_invalid_ids(self) -> None:
|
||||
result = extract_cited_ids("See [1] and [99].", {1, 3, 5})
|
||||
|
||||
assert result == [1]
|
||||
|
||||
def test_deduplicates_citations(self) -> None:
|
||||
result = extract_cited_ids("See [1] and then [1] again.", {1, 3})
|
||||
|
||||
assert result == [1]
|
||||
|
||||
def test_preserves_order(self) -> None:
|
||||
result = extract_cited_ids("[3] comes first, then [1].", {1, 3})
|
||||
|
||||
assert result == [3, 1]
|
||||
|
||||
def test_empty_for_no_citations(self) -> None:
|
||||
result = extract_cited_ids("No citations here.", {1, 3})
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestSynthesisResult:
|
||||
def test_is_frozen(self) -> None:
|
||||
result = SynthesisResult(
|
||||
answer="Test answer",
|
||||
cited_segment_ids=[1, 2],
|
||||
)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
result.answer = "Modified" # type: ignore[misc]
|
||||
1
typings/langgraph/__init__.pyi
Normal file
1
typings/langgraph/__init__.pyi
Normal file
@@ -0,0 +1 @@
|
||||
# Type stubs for langgraph
|
||||
1
typings/langgraph/checkpoint/postgres/__init__.pyi
Normal file
1
typings/langgraph/checkpoint/postgres/__init__.pyi
Normal file
@@ -0,0 +1 @@
|
||||
# Type stubs for langgraph-checkpoint-postgres
|
||||
10
typings/langgraph/checkpoint/postgres/aio.pyi
Normal file
10
typings/langgraph/checkpoint/postgres/aio.pyi
Normal file
@@ -0,0 +1,10 @@
|
||||
# Type stubs for langgraph-checkpoint-postgres async module
|
||||
|
||||
class AsyncPostgresSaver:
|
||||
"""Async PostgreSQL checkpointer for LangGraph.
|
||||
|
||||
This stub provides typing for the langgraph-checkpoint-postgres package.
|
||||
"""
|
||||
|
||||
def __init__(self, pool: object) -> None: ...
|
||||
async def setup(self) -> None: ...
|
||||
41
typings/langgraph/graph/__init__.pyi
Normal file
41
typings/langgraph/graph/__init__.pyi
Normal file
@@ -0,0 +1,41 @@
|
||||
# Type stubs for langgraph.graph
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
_StateT = TypeVar("_StateT")
|
||||
|
||||
class START:
|
||||
"""Sentinel for graph start node."""
|
||||
|
||||
pass
|
||||
|
||||
class END:
|
||||
"""Sentinel for graph end node."""
|
||||
|
||||
pass
|
||||
|
||||
class CompiledStateGraph(Generic[_StateT]):
|
||||
"""Compiled state graph that can be invoked."""
|
||||
|
||||
async def ainvoke(self, input: _StateT) -> _StateT: ...
|
||||
def invoke(self, input: _StateT) -> _StateT: ...
|
||||
|
||||
class StateGraph(Generic[_StateT]):
|
||||
"""State graph builder.
|
||||
|
||||
This stub provides typing for langgraph StateGraph.
|
||||
"""
|
||||
|
||||
def __init__(self, state_schema: type[_StateT]) -> None: ...
|
||||
def add_node(
|
||||
self,
|
||||
name: str,
|
||||
action: Callable[[_StateT], dict[str, object]]
|
||||
| Callable[[_StateT], Coroutine[object, object, dict[str, object]]],
|
||||
) -> None: ...
|
||||
def add_edge(
|
||||
self,
|
||||
start_key: str | type[START],
|
||||
end_key: str | type[END],
|
||||
) -> None: ...
|
||||
def compile(self) -> CompiledStateGraph[_StateT]: ...
|
||||
Reference in New Issue
Block a user