This commit is contained in:
2026-01-22 16:15:56 +00:00
parent 19e39bed5a
commit ea0e8ee1e4
34 changed files with 4463 additions and 0 deletions

View File

@@ -0,0 +1,590 @@
# Sprint 25: LangGraph Foundation
> **Size**: L | **Owner**: Backend | **Phase**: 5 - Platform Evolution
> **Effort**: ~1 sprint | **Prerequisites**: Sprint 19 (Embeddings)
---
## Objective
Establish LangGraph infrastructure and wrap existing summarization as proof of pattern.
---
## Current State Analysis
### What Exists
| Component | Location | Status |
|-----------|----------|--------|
| Summarization service | `application/services/summarization/` | ✅ Working |
| Segment semantic search | `infrastructure/persistence/repositories/segment_repo.py` | ✅ Working |
| Cloud consent pattern | `application/services/summarization/_consent_manager.py` | ✅ Working |
| Usage event tracking | `application/observability/ports.py` | ✅ Working |
| gRPC mixin pattern | `grpc/_mixins/` | ✅ Working |
### What's Missing
| Component | Target Location | Sprint Task |
|-----------|-----------------|-------------|
| LangGraph dependencies | `pyproject.toml` | Task 1 |
| State schemas | `domain/ai/state.py` | Task 2 |
| Checkpointer factory | `infrastructure/ai/checkpointer.py` | Task 3 |
| Retrieval tools | `infrastructure/ai/tools/retrieval.py` | Task 4 |
| Synthesis tools | `infrastructure/ai/tools/synthesis.py` | Task 5 |
| Summarization graph | `infrastructure/ai/graphs/summarization.py` | Task 6 |
| AssistantService | `application/services/assistant/` | Task 7 |
---
## Implementation Tasks
### Task 1: Add LangGraph Dependencies
**File**: `pyproject.toml`
Add new optional dependency group:
```toml
langgraph = [
"langgraph>=0.2",
"langgraph-checkpoint-postgres>=2.0",
"langchain-core>=0.3",
]
```
Also add to `optional` and `all` groups.
**Verification**: `uv pip install -e ".[langgraph]"` succeeds
---
### Task 2: Create State Schemas
**Files to create**:
- `src/noteflow/domain/ai/__init__.py`
- `src/noteflow/domain/ai/state.py`
- `src/noteflow/domain/ai/citations.py`
- `src/noteflow/domain/ai/ports.py`
#### `state.py` - Input/Output/Internal Separation
```python
from __future__ import annotations
from typing import Annotated, TypedDict
from uuid import UUID
import operator
class AssistantInputState(TypedDict):
"""Public API input - what clients send."""
question: str
meeting_id: UUID | None
thread_id: str | None
allow_web: bool
top_k: int
class AssistantOutputState(TypedDict):
"""Public API output - what clients receive."""
answer: str
citations: list[dict]
suggested_annotations: list[dict]
thread_id: str
class AssistantInternalState(AssistantInputState):
"""Internal graph state - can evolve without breaking API."""
# Retrieval
retrieved_segment_ids: Annotated[list[int], operator.add]
retrieved_segments: list[dict]
# Synthesis
draft_answer: str
verification_passed: bool
# Tracking
loop_count: int
```
#### `citations.py` - Value Object
```python
from dataclasses import dataclass
from uuid import UUID
@dataclass(frozen=True)
class SegmentCitation:
"""Reference to transcript segment used as evidence."""
meeting_id: UUID
segment_id: int
start_time: float
end_time: float
text: str
score: float = 0.0
```
#### `ports.py` - Protocol Definition
```python
from typing import Protocol
from uuid import UUID
from noteflow.domain.ai.state import AssistantOutputState
class AssistantPort(Protocol):
"""Protocol for AI assistant operations."""
async def ask(
self,
question: str,
meeting_id: UUID | None = None,
thread_id: str | None = None,
allow_web: bool = False,
top_k: int = 8,
) -> AssistantOutputState:
...
```
**Verification**: `basedpyright src/noteflow/domain/ai/` passes
---
### Task 3: Create Checkpointer Factory
**File**: `src/noteflow/infrastructure/ai/checkpointer.py`
```python
from __future__ import annotations
from typing import TYPE_CHECKING, Final
if TYPE_CHECKING:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg_pool import AsyncConnectionPool
CHECKPOINTER_POOL_SIZE: Final[int] = 5
async def create_checkpointer(
database_url: str,
pool_size: int = CHECKPOINTER_POOL_SIZE,
) -> AsyncPostgresSaver:
"""Create async Postgres checkpointer for LangGraph state persistence."""
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg_pool import AsyncConnectionPool
pool = AsyncConnectionPool(
conninfo=database_url,
max_size=pool_size,
kwargs={"autocommit": True},
)
checkpointer = AsyncPostgresSaver(pool)
await checkpointer.setup()
return checkpointer
```
**Verification**: Unit test with mock pool
---
### Task 4: Create Retrieval Tools
**File**: `src/noteflow/infrastructure/ai/tools/retrieval.py`
```python
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
if TYPE_CHECKING:
from noteflow.domain.entities import Segment
class EmbedderProtocol(Protocol):
"""Protocol for text embedding."""
async def embed(self, text: str) -> list[float]: ...
class SegmentSearchProtocol(Protocol):
"""Protocol for semantic segment search."""
async def search_semantic(
self,
query_embedding: list[float],
meeting_id: UUID | None,
limit: int,
) -> list[tuple[Segment, float]]: ...
@dataclass
class RetrievalResult:
"""Result from segment retrieval."""
segment_id: int
meeting_id: UUID
text: str
start_time: float
end_time: float
score: float
async def retrieve_segments(
query: str,
embedder: EmbedderProtocol,
segment_repo: SegmentSearchProtocol,
meeting_id: UUID | None = None,
top_k: int = 8,
) -> list[RetrievalResult]:
"""Retrieve relevant transcript segments via semantic search."""
query_embedding = await embedder.embed(query)
results = await segment_repo.search_semantic(
query_embedding=query_embedding,
meeting_id=meeting_id,
limit=top_k,
)
return [
RetrievalResult(
segment_id=segment.segment_id,
meeting_id=segment.meeting_id,
text=segment.text,
start_time=segment.start_time,
end_time=segment.end_time,
score=score,
)
for segment, score in results
]
```
**Verification**: Unit test with mock embedder and repo
---
### Task 5: Create Synthesis Tools
**File**: `src/noteflow/infrastructure/ai/tools/synthesis.py`
```python
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from noteflow.infrastructure.ai.tools.retrieval import RetrievalResult
class LLMProtocol(Protocol):
"""Protocol for LLM completion."""
async def complete(self, prompt: str) -> str: ...
@dataclass
class SynthesisResult:
"""Result from answer synthesis."""
answer: str
cited_segment_ids: list[int]
SYNTHESIS_PROMPT_TEMPLATE = '''Answer the question based on the following transcript segments.
Cite specific segments by their ID when making claims.
Question: {question}
Segments:
{segments}
Answer (cite segment IDs in brackets like [1], [3]):'''
async def synthesize_answer(
question: str,
segments: list[RetrievalResult],
llm: LLMProtocol,
) -> SynthesisResult:
"""Generate answer with segment citations."""
segment_text = "\n".join(
f"[{s.segment_id}] ({s.start_time:.1f}s-{s.end_time:.1f}s): {s.text}"
for s in segments
)
prompt = SYNTHESIS_PROMPT_TEMPLATE.format(
question=question,
segments=segment_text,
)
answer = await llm.complete(prompt)
cited_ids = _extract_cited_ids(answer, [s.segment_id for s in segments])
return SynthesisResult(answer=answer, cited_segment_ids=cited_ids)
def _extract_cited_ids(answer: str, valid_ids: list[int]) -> list[int]:
"""Extract segment IDs cited in the answer."""
import re
pattern = r'\[(\d+)\]'
matches = re.findall(pattern, answer)
cited = [int(m) for m in matches if int(m) in valid_ids]
return list(dict.fromkeys(cited)) # Dedupe preserving order
```
**Verification**: Unit test with mock LLM
---
### Task 6: Create Summarization Graph Wrapper
**File**: `src/noteflow/infrastructure/ai/graphs/summarization.py`
This wraps the existing SummarizationService in a LangGraph graph as proof of pattern.
```python
from __future__ import annotations
from typing import TYPE_CHECKING, TypedDict
from langgraph.graph import END, START, StateGraph
if TYPE_CHECKING:
from uuid import UUID
from collections.abc import Sequence
from noteflow.domain.entities import Segment
from noteflow.application.services.summarization import SummarizationService
class SummarizationState(TypedDict):
"""State for summarization graph."""
meeting_id: UUID
segments: Sequence[Segment]
summary_text: str
key_points: list[dict]
action_items: list[dict]
def build_summarization_graph(
summarization_service: SummarizationService,
) -> StateGraph:
"""Build LangGraph wrapper around existing summarization service."""
async def summarize_node(state: SummarizationState) -> dict:
result = await summarization_service.summarize(
meeting_id=state["meeting_id"],
segments=state["segments"],
)
summary = result.summary
return {
"summary_text": summary.executive_summary,
"key_points": [
{"text": kp.text, "segment_ids": kp.segment_ids}
for kp in summary.key_points
],
"action_items": [
{"text": ai.text, "segment_ids": ai.segment_ids, "assignee": ai.assignee}
for ai in summary.action_items
],
}
builder = StateGraph(SummarizationState)
builder.add_node("summarize", summarize_node)
builder.add_edge(START, "summarize")
builder.add_edge("summarize", END)
return builder.compile()
```
**Verification**: Integration test with mock summarization service
---
### Task 7: Create AssistantService Shell
**Files**:
- `src/noteflow/application/services/assistant/__init__.py`
- `src/noteflow/application/services/assistant/assistant_service.py`
#### `__init__.py`
```python
from noteflow.application.services.assistant.assistant_service import AssistantService
__all__ = ["AssistantService"]
```
#### `assistant_service.py`
```python
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Final
from uuid import UUID
from noteflow.application.observability.ports import NullUsageEventSink, UsageEventSink
from noteflow.domain.ai.state import AssistantOutputState
from noteflow.infrastructure.logging import get_logger
if TYPE_CHECKING:
from collections.abc import Callable
from noteflow.domain.ports.unit_of_work import UnitOfWork
logger = get_logger(__name__)
DEFAULT_TOP_K: Final[int] = 8
THREAD_ID_PREFIX: Final[str] = "meeting"
def build_thread_id(meeting_id: UUID | None, user_id: UUID, graph_name: str) -> str:
"""Build deterministic thread_id for checkpointing."""
meeting_part = str(meeting_id) if meeting_id else "workspace"
return f"{THREAD_ID_PREFIX}:{meeting_part}:user:{user_id}:graph:{graph_name}:v1"
@dataclass
class AssistantService:
"""Orchestrates AI assistant workflows via LangGraph."""
uow_factory: Callable[[], UnitOfWork]
usage_events: UsageEventSink = field(default_factory=NullUsageEventSink)
async def ask(
self,
question: str,
user_id: UUID,
meeting_id: UUID | None = None,
thread_id: str | None = None,
allow_web: bool = False,
top_k: int = DEFAULT_TOP_K,
) -> AssistantOutputState:
"""Ask a question about meeting transcript(s)."""
effective_thread_id = thread_id or build_thread_id(
meeting_id, user_id, "meeting_qa"
)
logger.info(
"assistant_ask",
question_length=len(question),
meeting_id=str(meeting_id) if meeting_id else None,
thread_id=effective_thread_id,
)
# TODO: Implement in Sprint 26
return AssistantOutputState(
answer="Not implemented yet",
citations=[],
suggested_annotations=[],
thread_id=effective_thread_id,
)
```
**Verification**: Unit test for thread_id generation
---
## Test Plan
### Unit Tests (`tests/domain/ai/`)
```python
# test_citations.py
def test_segment_citation_creation():
citation = SegmentCitation(
meeting_id=uuid4(),
segment_id=1,
start_time=0.0,
end_time=5.0,
text="Test segment",
score=0.95,
)
assert citation.duration == 5.0
def test_segment_citation_invalid_times():
with pytest.raises(ValueError):
SegmentCitation(
meeting_id=uuid4(),
segment_id=1,
start_time=10.0,
end_time=5.0, # Invalid: end < start
text="Test",
)
```
### Unit Tests (`tests/infrastructure/ai/`)
```python
# test_retrieval.py
async def test_retrieve_segments_success(mock_embedder, mock_segment_repo):
mock_embedder.embed.return_value = [0.1, 0.2, 0.3]
mock_segment_repo.search_semantic.return_value = [
(sample_segment, 0.95),
]
results = await retrieve_segments(
query="test query",
embedder=mock_embedder,
segment_repo=mock_segment_repo,
top_k=5,
)
assert len(results) == 1
assert results[0].score == 0.95
# test_synthesis.py
async def test_synthesize_answer_extracts_citations(mock_llm):
mock_llm.complete.return_value = "The answer is X [1] and Y [3]."
result = await synthesize_answer(
question="What happened?",
segments=[...],
llm=mock_llm,
)
assert result.cited_segment_ids == [1, 3]
```
---
## Acceptance Criteria
- [ ] `uv pip install -e ".[langgraph]"` succeeds
- [ ] `basedpyright src/noteflow/domain/ai/` passes with 0 errors
- [ ] `basedpyright src/noteflow/infrastructure/ai/` passes with 0 errors
- [ ] `pytest tests/domain/ai/` passes
- [ ] `pytest tests/infrastructure/ai/` passes
- [ ] Existing summarization behavior unchanged (`pytest tests/application/services/test_summarization*`)
- [ ] `make quality` passes
---
## Rollback Plan
If issues arise:
1. Remove langgraph from dependencies
2. Delete `domain/ai/` and `infrastructure/ai/` directories
3. AssistantService is not wired to gRPC yet, so no API impact
---
## Files Created/Modified
| Action | Path |
|--------|------|
| Modified | `pyproject.toml` |
| Created | `src/noteflow/domain/ai/__init__.py` |
| Created | `src/noteflow/domain/ai/state.py` |
| Created | `src/noteflow/domain/ai/citations.py` |
| Created | `src/noteflow/domain/ai/ports.py` |
| Created | `src/noteflow/infrastructure/ai/__init__.py` |
| Created | `src/noteflow/infrastructure/ai/checkpointer.py` |
| Created | `src/noteflow/infrastructure/ai/tools/__init__.py` |
| Created | `src/noteflow/infrastructure/ai/tools/retrieval.py` |
| Created | `src/noteflow/infrastructure/ai/tools/synthesis.py` |
| Created | `src/noteflow/infrastructure/ai/graphs/__init__.py` |
| Created | `src/noteflow/infrastructure/ai/graphs/summarization.py` |
| Created | `src/noteflow/application/services/assistant/__init__.py` |
| Created | `src/noteflow/application/services/assistant/assistant_service.py` |
| Created | `tests/domain/ai/test_citations.py` |
| Created | `tests/domain/ai/test_state.py` |
| Created | `tests/infrastructure/ai/test_retrieval.py` |
| Created | `tests/infrastructure/ai/test_synthesis.py` |
| Created | `tests/infrastructure/ai/test_checkpointer.py` |

View File

@@ -0,0 +1,530 @@
# Sprint 26: Meeting Q&A MVP
> **Size**: L | **Owner**: Backend + Client | **Phase**: 5 - Platform Evolution
> **Effort**: ~1 sprint | **Prerequisites**: Sprint 25 (Foundation)
---
## Objective
Implement single-meeting Q&A with segment citations via gRPC API and React UI.
---
## Current State (After Sprint 25)
| Component | Status |
|-----------|--------|
| LangGraph infrastructure | ✅ Ready |
| State schemas | ✅ Ready |
| Retrieval tools | ✅ Ready |
| Synthesis tools | ✅ Ready |
| AssistantService shell | ✅ Ready |
---
## Implementation Tasks
### Task 1: Define MeetingQA Graph
**File**: `src/noteflow/infrastructure/ai/graphs/meeting_qa.py`
Graph flow: `retrieve → verify → synthesize`
```python
from langgraph.graph import StateGraph, START, END
class MeetingQAState(TypedDict):
# Input
question: str
meeting_id: UUID
top_k: int
# Internal
retrieved_segments: list[RetrievalResult]
verification_passed: bool
# Output
answer: str
citations: list[SegmentCitation]
def build_meeting_qa_graph(
embedder: EmbedderProtocol,
segment_repo: SegmentSearchProtocol,
llm: LLMProtocol,
verifier: CitationVerifier,
) -> StateGraph:
async def retrieve_node(state: MeetingQAState) -> dict:
results = await retrieve_segments(
query=state["question"],
embedder=embedder,
segment_repo=segment_repo,
meeting_id=state["meeting_id"],
top_k=state["top_k"],
)
return {"retrieved_segments": results}
async def verify_node(state: MeetingQAState) -> dict:
# Verify segments exist and are relevant
valid = len(state["retrieved_segments"]) > 0
return {"verification_passed": valid}
async def synthesize_node(state: MeetingQAState) -> dict:
if not state["verification_passed"]:
return {
"answer": "I couldn't find relevant information in this meeting.",
"citations": [],
}
result = await synthesize_answer(
question=state["question"],
segments=state["retrieved_segments"],
llm=llm,
)
citations = [
SegmentCitation(
meeting_id=state["meeting_id"],
segment_id=seg.segment_id,
start_time=seg.start_time,
end_time=seg.end_time,
text=seg.text,
score=seg.score,
)
for seg in state["retrieved_segments"]
if seg.segment_id in result.cited_segment_ids
]
return {"answer": result.answer, "citations": citations}
builder = StateGraph(MeetingQAState)
builder.add_node("retrieve", retrieve_node)
builder.add_node("verify", verify_node)
builder.add_node("synthesize", synthesize_node)
builder.add_edge(START, "retrieve")
builder.add_edge("retrieve", "verify")
builder.add_edge("verify", "synthesize")
builder.add_edge("synthesize", END)
return builder.compile()
```
---
### Task 2: Create Citation Verifier Node
**File**: `src/noteflow/infrastructure/ai/nodes/verification.py`
```python
from dataclasses import dataclass
@dataclass
class VerificationResult:
is_valid: bool
invalid_citation_indices: list[int]
reason: str | None = None
def verify_citations(
answer: str,
cited_ids: list[int],
available_ids: set[int],
) -> VerificationResult:
"""Verify all cited segment IDs exist in available segments."""
invalid = [i for i, cid in enumerate(cited_ids) if cid not in available_ids]
return VerificationResult(
is_valid=len(invalid) == 0,
invalid_citation_indices=invalid,
reason=f"Invalid citations: {invalid}" if invalid else None,
)
```
---
### Task 3: Add Proto Messages
**File**: `src/noteflow/grpc/proto/noteflow.proto`
```protobuf
// Add to existing proto
message SegmentCitation {
string meeting_id = 1;
int32 segment_id = 2;
float start_time = 3;
float end_time = 4;
string text = 5;
float score = 6;
}
message AskAssistantRequest {
string question = 1;
optional string meeting_id = 2;
optional string thread_id = 3;
bool allow_web = 4;
int32 top_k = 5;
}
message AskAssistantResponse {
string answer = 1;
repeated SegmentCitation citations = 2;
repeated SuggestedAnnotation suggested_annotations = 3;
string thread_id = 4;
}
message SuggestedAnnotation {
string text = 1;
AnnotationType type = 2;
repeated int32 segment_ids = 3;
}
// Add to NoteFlowService
rpc AskAssistant(AskAssistantRequest) returns (AskAssistantResponse);
```
After modifying proto:
```bash
python -m grpc_tools.protoc -I src/noteflow/grpc/proto \
--python_out=src/noteflow/grpc/proto \
--grpc_python_out=src/noteflow/grpc/proto \
src/noteflow/grpc/proto/noteflow.proto
python scripts/patch_grpc_stubs.py
```
---
### Task 4: Add gRPC Mixin
**File**: `src/noteflow/grpc/_mixins/assistant.py`
```python
from __future__ import annotations
from typing import TYPE_CHECKING
from noteflow.grpc.proto import noteflow_pb2 as pb
from noteflow.grpc._mixins.protocols import ServicerHost
if TYPE_CHECKING:
from grpc.aio import ServicerContext
class AssistantMixin:
"""gRPC mixin for AI assistant operations."""
async def AskAssistant(
self: ServicerHost,
request: pb.AskAssistantRequest,
context: ServicerContext,
) -> pb.AskAssistantResponse:
from uuid import UUID
meeting_id = UUID(request.meeting_id) if request.meeting_id else None
op_context = await self.get_operation_context(context)
result = await self.assistant_service.ask(
question=request.question,
user_id=op_context.user.id,
meeting_id=meeting_id,
thread_id=request.thread_id or None,
allow_web=request.allow_web,
top_k=request.top_k or 8,
)
return pb.AskAssistantResponse(
answer=result["answer"],
citations=[
pb.SegmentCitation(
meeting_id=str(c["meeting_id"]),
segment_id=c["segment_id"],
start_time=c["start_time"],
end_time=c["end_time"],
text=c["text"],
score=c["score"],
)
for c in result["citations"]
],
thread_id=result["thread_id"],
)
```
---
### Task 5: Add Rust Command
**File**: `client/src-tauri/src/commands/assistant.rs`
```rust
use crate::grpc::client::GrpcClient;
use crate::grpc::types::assistant::{AskAssistantRequest, AskAssistantResponse};
use tauri::State;
#[tauri::command]
pub async fn ask_assistant(
client: State<'_, GrpcClient>,
question: String,
meeting_id: Option<String>,
thread_id: Option<String>,
allow_web: bool,
top_k: i32,
) -> Result<AskAssistantResponse, String> {
let request = AskAssistantRequest {
question,
meeting_id,
thread_id,
allow_web,
top_k,
};
client
.ask_assistant(request)
.await
.map_err(|e| e.to_string())
}
```
---
### Task 6: Add TypeScript Adapter
**File**: `client/src/api/tauri-adapter.ts` (add method)
```typescript
async askAssistant(params: AskAssistantParams): Promise<AskAssistantResponse> {
return invoke('ask_assistant', {
question: params.question,
meetingId: params.meetingId,
threadId: params.threadId,
allowWeb: params.allowWeb ?? false,
topK: params.topK ?? 8,
});
}
```
**File**: `client/src/api/types/assistant.ts`
```typescript
export interface SegmentCitation {
meetingId: string;
segmentId: number;
startTime: number;
endTime: number;
text: string;
score: number;
}
export interface AskAssistantParams {
question: string;
meetingId?: string;
threadId?: string;
allowWeb?: boolean;
topK?: number;
}
export interface AskAssistantResponse {
answer: string;
citations: SegmentCitation[];
suggestedAnnotations: SuggestedAnnotation[];
threadId: string;
}
export interface SuggestedAnnotation {
text: string;
type: AnnotationType;
segmentIds: number[];
}
```
---
### Task 7: Create Ask UI Component
**File**: `client/src/components/meeting/AskPanel.tsx`
```tsx
import { useState } from 'react';
import { useAssistant } from '@/hooks/use-assistant';
import { Button } from '@/components/ui/button';
import { Textarea } from '@/components/ui/textarea';
import { Card } from '@/components/ui/card';
interface AskPanelProps {
meetingId: string;
onCitationClick?: (segmentId: number) => void;
}
export function AskPanel({ meetingId, onCitationClick }: AskPanelProps) {
const [question, setQuestion] = useState('');
const { ask, isLoading, response, error } = useAssistant();
const handleAsk = async () => {
if (!question.trim()) return;
await ask({ question, meetingId });
};
return (
<div className="flex flex-col gap-4 p-4">
<Textarea
placeholder="Ask a question about this meeting..."
value={question}
onChange={(e) => setQuestion(e.target.value)}
disabled={isLoading}
/>
<Button onClick={handleAsk} disabled={isLoading || !question.trim()}>
{isLoading ? 'Thinking...' : 'Ask'}
</Button>
{response && (
<Card className="p-4">
<p className="whitespace-pre-wrap">{response.answer}</p>
{response.citations.length > 0 && (
<div className="mt-4 border-t pt-2">
<p className="text-sm text-muted-foreground mb-2">Sources:</p>
{response.citations.map((citation) => (
<button
key={citation.segmentId}
onClick={() => onCitationClick?.(citation.segmentId)}
className="text-sm text-blue-600 hover:underline block"
>
[{citation.startTime.toFixed(1)}s] {citation.text.slice(0, 50)}...
</button>
))}
</div>
)}
</Card>
)}
{error && (
<p className="text-sm text-red-600">{error}</p>
)}
</div>
);
}
```
**File**: `client/src/hooks/use-assistant.ts`
```typescript
import { useState, useCallback } from 'react';
import { api } from '@/api';
import type { AskAssistantParams, AskAssistantResponse } from '@/api/types/assistant';
export function useAssistant() {
const [isLoading, setIsLoading] = useState(false);
const [response, setResponse] = useState<AskAssistantResponse | null>(null);
const [error, setError] = useState<string | null>(null);
const ask = useCallback(async (params: AskAssistantParams) => {
setIsLoading(true);
setError(null);
try {
const result = await api.askAssistant(params);
setResponse(result);
return result;
} catch (err) {
setError(err instanceof Error ? err.message : 'Failed to get answer');
throw err;
} finally {
setIsLoading(false);
}
}, []);
const reset = useCallback(() => {
setResponse(null);
setError(null);
}, []);
return { ask, isLoading, response, error, reset };
}
```
---
### Task 8: Implement AssistantService.ask()
Complete the implementation in `application/services/assistant/assistant_service.py`:
```python
async def ask(
self,
question: str,
user_id: UUID,
meeting_id: UUID | None = None,
thread_id: str | None = None,
allow_web: bool = False,
top_k: int = DEFAULT_TOP_K,
) -> AssistantOutputState:
"""Ask a question about meeting transcript(s)."""
effective_thread_id = thread_id or build_thread_id(
meeting_id, user_id, "meeting_qa"
)
async with self.uow_factory() as uow:
# Build and run graph
graph = build_meeting_qa_graph(
embedder=self._embedder,
segment_repo=uow.segments,
llm=self._llm,
verifier=self._verifier,
)
config = {"configurable": {"thread_id": effective_thread_id}}
result = await graph.ainvoke(
{
"question": question,
"meeting_id": meeting_id,
"top_k": top_k,
},
config,
)
# Record usage
self.usage_events.record_simple(
"assistant.ask",
meeting_id=str(meeting_id) if meeting_id else None,
question_length=len(question),
citation_count=len(result.get("citations", [])),
)
return AssistantOutputState(
answer=result["answer"],
citations=[asdict(c) for c in result["citations"]],
suggested_annotations=[],
thread_id=effective_thread_id,
)
```
---
## Acceptance Criteria
- [ ] Q&A returns answers with valid segment citations
- [ ] Citations link to correct timestamps in transcript
- [ ] Feature hidden when `rag_enabled=false` in project rules
- [ ] Thread ID persists conversation context
- [ ] `make quality` passes
- [ ] `pytest tests/grpc/test_assistant.py` passes
- [ ] UI component displays answer and clickable citations
---
## Files Created/Modified
| Action | Path |
|--------|------|
| Created | `src/noteflow/infrastructure/ai/graphs/meeting_qa.py` |
| Created | `src/noteflow/infrastructure/ai/nodes/verification.py` |
| Modified | `src/noteflow/grpc/proto/noteflow.proto` |
| Created | `src/noteflow/grpc/_mixins/assistant.py` |
| Modified | `src/noteflow/grpc/service.py` (add mixin) |
| Created | `client/src-tauri/src/commands/assistant.rs` |
| Modified | `client/src/api/tauri-adapter.ts` |
| Created | `client/src/api/types/assistant.ts` |
| Created | `client/src/components/meeting/AskPanel.tsx` |
| Created | `client/src/hooks/use-assistant.ts` |
| Modified | `src/noteflow/application/services/assistant/assistant_service.py` |

View File

@@ -0,0 +1,87 @@
# Sprint 27: Cross-Meeting RAG
> **Size**: M | **Owner**: Backend + Client | **Phase**: 5 - Platform Evolution
> **Effort**: ~1 sprint | **Prerequisites**: Sprint 26 (Meeting Q&A MVP)
---
## Objective
Enable workspace-scoped Q&A and annotation suggestions across multiple meetings.
---
## Key Tasks
### Task 1: Workspace-Scoped Semantic Search
Extend `SegmentRepository` with workspace-scoped search:
```python
async def search_semantic_workspace(
self,
query_embedding: list[float],
workspace_id: UUID,
project_id: UUID | None = None,
limit: int = 20,
) -> list[tuple[Segment, float]]:
"""Search segments across all meetings in workspace/project."""
```
### Task 2: WorkspaceQA Graph
Create `infrastructure/ai/graphs/workspace_qa.py`:
- Similar to MeetingQA but omits `meeting_id` filter
- Groups results by meeting for citation display
- Returns cross-meeting citations
### Task 3: Annotation Suggester
Add annotation suggestion output to graph:
```python
class SuggestedAnnotation:
text: str
annotation_type: AnnotationType
segment_ids: list[int]
confidence: float
```
### Task 4: Conversation History
Implement thread persistence with checkpointer:
- Store conversation turns in graph state
- Support follow-up questions
- Maintain context across requests
### Task 5: Apply Annotation Flow
UI flow to apply suggested annotations:
1. Display suggested annotations in AskPanel
2. User clicks "Apply" on suggestion
3. Call existing `AddAnnotation` RPC
4. Update UI to show applied status
---
## Acceptance Criteria
- [ ] Cross-meeting queries return results from multiple meetings
- [ ] Suggested annotations can be applied with one click
- [ ] Conversation history persists across requests
- [ ] Follow-up questions reference previous context
- [ ] `make quality` passes
---
## Files Created/Modified
| Action | Path |
|--------|------|
| Modified | `src/noteflow/infrastructure/persistence/repositories/segment_repo.py` |
| Created | `src/noteflow/infrastructure/ai/graphs/workspace_qa.py` |
| Modified | `src/noteflow/application/services/assistant/assistant_service.py` |
| Modified | `client/src/components/meeting/AskPanel.tsx` |

View File

@@ -0,0 +1,146 @@
# Sprint 28: Advanced Capabilities
> **Size**: L | **Owner**: Backend + Client | **Phase**: 5 - Platform Evolution
> **Effort**: ~1 sprint | **Prerequisites**: Sprint 27 (Cross-Meeting RAG)
---
## Objective
Production hardening with streaming responses, caching, guardrails, and optional web search.
---
## Key Tasks
### Task 1: Streaming Responses
Implement `StreamAssistant` RPC for progressive answer generation:
```protobuf
rpc StreamAssistant(AskAssistantRequest) returns (stream AskAssistantResponse);
```
Use LangGraph's `astream` with `stream_mode="messages"`:
```python
async for chunk in graph.astream(state, config, stream_mode="messages"):
yield pb.AskAssistantResponse(answer=chunk.content, partial=True)
```
### Task 2: Embedding Cache
Create `infrastructure/ai/cache.py`:
- LRU cache for query embeddings
- Reduce redundant embedding calls
- Configurable TTL and max size
```python
@dataclass
class EmbeddingCache:
max_size: int = 1000
ttl_seconds: int = 3600
async def get_or_compute(
self,
text: str,
embedder: EmbedderProtocol,
) -> list[float]:
...
```
### Task 3: Content Guardrails
Create `infrastructure/ai/guardrails.py`:
- Input validation (length, content)
- Output filtering (PII, harmful content)
- Configurable rules per workspace
```python
class GuardrailResult:
allowed: bool
reason: str | None
filtered_content: str | None
async def check_input(text: str, rules: GuardrailRules) -> GuardrailResult:
...
async def filter_output(text: str, rules: GuardrailRules) -> GuardrailResult:
...
```
### Task 4: Web Search Node
Create `infrastructure/ai/nodes/web_search.py`:
- Optional node triggered by `allow_web=true`
- Integrate with web search API
- Merge web results with transcript evidence
### Task 5: AGENT_PROGRESS Tauri Event
Emit progress events for UI feedback:
```rust
// Emit during graph execution
app.emit_all("AGENT_PROGRESS", AgentProgressPayload {
stage: "retrieving",
progress: 0.3,
message: "Searching transcript...",
})?;
```
### Task 6: Interrupts for Approval
Implement LangGraph interrupts for:
- Web search approval
- Annotation creation approval
- Sensitive action confirmation
### Task 7: Performance Optimization
- Batch embedding requests
- Parallel segment retrieval
- Connection pool tuning
---
## Acceptance Criteria
- [ ] Streaming shows progressive answer generation in UI
- [ ] Cache reduces embedding latency by >50% for repeated queries
- [ ] Guardrails block inappropriate content
- [ ] Web search gated by `allow_web` flag
- [ ] Progress events update UI during long operations
- [ ] `make quality` passes
- [ ] E2E tests pass (`client/e2e/assistant.spec.ts`)
---
## Success Metrics
| Metric | Target |
| ------------------ | ------ |
| Q&A latency (p95) | < 3s |
| Citation accuracy | > 90% |
| Cache hit rate | > 60% |
| Hallucination rate | < 5% |
---
## Files Created/Modified
| Action | Path |
| -------- | ---------------------------------------------------- |
| Modified | `src/noteflow/grpc/proto/noteflow.proto` |
| Created | `src/noteflow/grpc/_mixins/assistant_streaming.py` |
| Created | `src/noteflow/infrastructure/ai/cache.py` |
| Created | `src/noteflow/infrastructure/ai/guardrails.py` |
| Created | `src/noteflow/infrastructure/ai/nodes/web_search.py` |
| Modified | `client/src-tauri/src/commands/assistant.rs` |
| Modified | `client/src/components/meeting/AskPanel.tsx` |
| Created | `client/e2e/assistant.spec.ts` |

View File

@@ -0,0 +1,38 @@
"""AI domain types for LangGraph workflows.
State schemas, citations, interrupts, and protocols for AI assistant functionality.
"""
from noteflow.domain.ai.citations import SegmentCitation
from noteflow.domain.ai.interrupts import (
InterruptAction,
InterruptConfig,
InterruptRequest,
InterruptResponse,
InterruptType,
create_annotation_interrupt,
create_sensitive_action_interrupt,
create_web_search_interrupt,
)
from noteflow.domain.ai.ports import AssistantPort
from noteflow.domain.ai.state import (
AssistantInputState,
AssistantInternalState,
AssistantOutputState,
)
__all__ = [
"AssistantInputState",
"AssistantInternalState",
"AssistantOutputState",
"AssistantPort",
"InterruptAction",
"InterruptConfig",
"InterruptRequest",
"InterruptResponse",
"InterruptType",
"SegmentCitation",
"create_annotation_interrupt",
"create_sensitive_action_interrupt",
"create_web_search_interrupt",
]

View File

@@ -0,0 +1,43 @@
"""Citation value objects for AI-generated responses."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from uuid import UUID
@dataclass(frozen=True)
class SegmentCitation:
"""Reference to a transcript segment used as evidence.
Links AI-generated claims to source transcript segments for verification.
"""
meeting_id: UUID
segment_id: int
start_time: float
end_time: float
text: str
score: float = 0.0
@property
def duration(self) -> float:
"""Duration of the cited segment in seconds."""
return self.end_time - self.start_time
def __post_init__(self) -> None:
if self.segment_id < 0:
msg = "segment_id must be non-negative"
raise ValueError(msg)
if self.start_time < 0:
msg = "start_time must be non-negative"
raise ValueError(msg)
if self.end_time < self.start_time:
msg = "end_time must be >= start_time"
raise ValueError(msg)
if self.score < 0 or self.score > 1:
msg = "score must be between 0 and 1"
raise ValueError(msg)

View File

@@ -0,0 +1,202 @@
"""Domain types for LangGraph human-in-the-loop interrupts.
Defines interrupt request/response types for approval workflows:
- Web search approval
- Annotation creation approval
- Sensitive action confirmation
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Final
class InterruptType(StrEnum):
"""Types of human-in-the-loop interrupts."""
WEB_SEARCH_APPROVAL = "web_search_approval"
ANNOTATION_APPROVAL = "annotation_approval"
SENSITIVE_ACTION = "sensitive_action"
class InterruptAction(StrEnum):
"""Possible user actions for an interrupt."""
APPROVE = "approve"
REJECT = "reject"
MODIFY = "modify"
DEFAULT_WEB_SEARCH_OPTIONS: Final[tuple[str, ...]] = ("approve", "reject")
DEFAULT_ANNOTATION_OPTIONS: Final[tuple[str, ...]] = ("approve", "reject", "modify")
DEFAULT_SENSITIVE_OPTIONS: Final[tuple[str, ...]] = ("approve", "reject")
@dataclass(frozen=True)
class InterruptConfig:
"""Configuration for interrupt behavior."""
allow_ignore: bool = False
allow_modify: bool = False
timeout_seconds: float | None = None
@dataclass(frozen=True)
class InterruptRequest:
"""A request for human approval during graph execution.
Sent to the client when the graph hits an interrupt point.
Attributes:
interrupt_type: Category of interrupt (web_search, annotation, etc.)
message: Human-readable description of what needs approval.
context: Additional context data for the decision (query, entities, etc.)
options: Available response options.
config: Interrupt configuration (allow_ignore, timeout, etc.)
request_id: Unique identifier for this interrupt request.
"""
interrupt_type: InterruptType
message: str
context: dict[str, object] = field(default_factory=dict)
options: tuple[str, ...] = field(default_factory=lambda: ("approve", "reject"))
config: InterruptConfig = field(default_factory=InterruptConfig)
request_id: str = ""
def to_dict(self) -> dict[str, object]:
"""Convert to dictionary for serialization."""
return {
"interrupt_type": self.interrupt_type,
"message": self.message,
"context": self.context,
"options": list(self.options),
"config": {
"allow_ignore": self.config.allow_ignore,
"allow_modify": self.config.allow_modify,
"timeout_seconds": self.config.timeout_seconds,
},
"request_id": self.request_id,
}
@dataclass(frozen=True)
class InterruptResponse:
"""User's response to an interrupt request.
Returned from the client to resume graph execution.
Attributes:
action: The action taken (approve, reject, modify).
request_id: ID of the interrupt request being responded to.
modified_value: If action is MODIFY, the modified value.
user_message: Optional message from the user.
"""
action: InterruptAction
request_id: str = ""
modified_value: dict[str, object] | None = None
user_message: str | None = None
@property
def is_approved(self) -> bool:
"""Check if the action was approved."""
return self.action == InterruptAction.APPROVE
@property
def is_rejected(self) -> bool:
"""Check if the action was rejected."""
return self.action == InterruptAction.REJECT
@property
def is_modified(self) -> bool:
"""Check if the action was modified."""
return self.action == InterruptAction.MODIFY
def to_dict(self) -> dict[str, object]:
"""Convert to dictionary for serialization."""
result: dict[str, object] = {
"action": self.action,
"request_id": self.request_id,
}
if self.modified_value is not None:
result["modified_value"] = self.modified_value
if self.user_message is not None:
result["user_message"] = self.user_message
return result
def create_web_search_interrupt(
query: str,
request_id: str,
*,
allow_modify: bool = False,
) -> InterruptRequest:
"""Create an interrupt request for web search approval.
Args:
query: The search query to be executed.
request_id: Unique identifier for this request.
allow_modify: Whether the user can modify the query.
Returns:
InterruptRequest configured for web search approval.
"""
return InterruptRequest(
interrupt_type=InterruptType.WEB_SEARCH_APPROVAL,
message=f"Allow web search for additional context? Query: {query[:100]}",
context={"query": query},
options=("approve", "reject", "modify") if allow_modify else DEFAULT_WEB_SEARCH_OPTIONS,
config=InterruptConfig(allow_modify=allow_modify),
request_id=request_id,
)
def create_annotation_interrupt(
annotations: list[dict[str, object]],
request_id: str,
) -> InterruptRequest:
"""Create an interrupt request for annotation approval.
Args:
annotations: List of suggested annotations to approve.
request_id: Unique identifier for this request.
Returns:
InterruptRequest configured for annotation approval.
"""
count = len(annotations)
return InterruptRequest(
interrupt_type=InterruptType.ANNOTATION_APPROVAL,
message=f"Apply {count} suggested annotation(s)?",
context={"annotations": annotations, "count": count},
options=DEFAULT_ANNOTATION_OPTIONS,
config=InterruptConfig(allow_modify=True, allow_ignore=True),
request_id=request_id,
)
def create_sensitive_action_interrupt(
action_name: str,
action_description: str,
request_id: str,
) -> InterruptRequest:
"""Create an interrupt request for sensitive action confirmation.
Args:
action_name: Name of the sensitive action.
action_description: Description of what the action does.
request_id: Unique identifier for this request.
Returns:
InterruptRequest configured for sensitive action confirmation.
"""
return InterruptRequest(
interrupt_type=InterruptType.SENSITIVE_ACTION,
message=f"Confirm action: {action_name}",
context={"action_name": action_name, "description": action_description},
options=DEFAULT_SENSITIVE_OPTIONS,
config=InterruptConfig(allow_ignore=False),
request_id=request_id,
)

View File

@@ -0,0 +1,26 @@
"""Protocol definitions for AI assistant operations."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from uuid import UUID
from noteflow.domain.ai.state import AssistantOutputState
class AssistantPort(Protocol):
"""Protocol for AI assistant operations."""
async def ask(
self,
question: str,
user_id: UUID,
meeting_id: UUID | None = None,
thread_id: str | None = None,
allow_web: bool = False,
top_k: int = 8,
) -> AssistantOutputState:
"""Ask a question about meeting transcript(s)."""
...

View File

@@ -0,0 +1,39 @@
"""State schemas for LangGraph AI workflows.
Separates Input/Output (public API) from Internal (can evolve freely).
"""
from __future__ import annotations
import operator
from typing import Annotated, TypedDict
from uuid import UUID
class AssistantInputState(TypedDict):
"""Public API input - what clients send."""
question: str
meeting_id: UUID | None
thread_id: str | None
allow_web: bool
top_k: int
class AssistantOutputState(TypedDict):
"""Public API output - what clients receive."""
answer: str
citations: list[dict[str, object]]
suggested_annotations: list[dict[str, object]]
thread_id: str
class AssistantInternalState(AssistantInputState):
"""Internal graph state - can evolve without breaking API."""
retrieved_segment_ids: Annotated[list[int], operator.add]
retrieved_segments: list[dict[str, object]]
draft_answer: str
verification_passed: bool
loop_count: int

View File

@@ -0,0 +1,45 @@
"""AI infrastructure components for LangGraph workflows."""
from noteflow.infrastructure.ai.cache import (
CachedEmbedder,
EmbeddingCache,
EmbeddingCacheStats,
)
from noteflow.infrastructure.ai.checkpointer import create_checkpointer
from noteflow.infrastructure.ai.guardrails import (
GuardrailResult,
GuardrailRules,
GuardrailViolation,
check_input,
create_default_rules,
create_strict_rules,
filter_output,
)
from noteflow.infrastructure.ai.interrupts import (
InterruptHandler,
check_annotation_approval,
check_web_search_approval,
create_resume_command,
request_annotation_approval,
request_web_search_approval,
)
__all__ = [
"CachedEmbedder",
"EmbeddingCache",
"EmbeddingCacheStats",
"GuardrailResult",
"GuardrailRules",
"GuardrailViolation",
"InterruptHandler",
"check_annotation_approval",
"check_input",
"check_web_search_approval",
"create_checkpointer",
"create_default_rules",
"create_resume_command",
"create_strict_rules",
"filter_output",
"request_annotation_approval",
"request_web_search_approval",
]

View File

@@ -0,0 +1,212 @@
"""Embedding cache with LRU eviction and TTL expiration."""
from __future__ import annotations
import asyncio
import hashlib
import time
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Final
from noteflow.infrastructure.logging import get_logger
if TYPE_CHECKING:
from noteflow.infrastructure.ai.tools.retrieval import EmbedderProtocol
logger = get_logger(__name__)
DEFAULT_MAX_SIZE: Final[int] = 1000
DEFAULT_TTL_SECONDS: Final[int] = 3600
HASH_ALGORITHM: Final[str] = "sha256"
@dataclass(frozen=True)
class CacheEntry:
"""Cached embedding with creation timestamp."""
embedding: tuple[float, ...]
created_at: float
def is_expired(self, ttl_seconds: float, current_time: float) -> bool:
"""Check if entry has expired based on TTL."""
return (current_time - self.created_at) > ttl_seconds
@dataclass
class EmbeddingCacheStats:
"""Statistics for cache performance monitoring."""
hits: int = 0
misses: int = 0
evictions: int = 0
expirations: int = 0
@property
def hit_rate(self) -> float:
"""Calculate cache hit rate."""
total = self.hits + self.misses
return self.hits / total if total > 0 else 0.0
@dataclass
class EmbeddingCache:
"""LRU cache for text embeddings with TTL expiration and deduplication."""
max_size: int = DEFAULT_MAX_SIZE
ttl_seconds: int = DEFAULT_TTL_SECONDS
_cache: OrderedDict[str, CacheEntry] = field(default_factory=OrderedDict)
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_stats: EmbeddingCacheStats = field(default_factory=EmbeddingCacheStats)
_in_flight: dict[str, asyncio.Future[list[float]]] = field(default_factory=dict)
def _compute_key(self, text: str) -> str:
"""Compute cache key from text using hash."""
return hashlib.new(HASH_ALGORITHM, text.encode("utf-8")).hexdigest()
async def get_or_compute(
self,
text: str,
embedder: EmbedderProtocol,
) -> list[float]:
key = self._compute_key(text)
current_time = time.monotonic()
existing_future: asyncio.Future[list[float]] | None = None
async with self._lock:
if key in self._cache:
entry = self._cache[key]
if not entry.is_expired(self.ttl_seconds, current_time):
self._cache.move_to_end(key)
self._stats.hits += 1
logger.debug("cache_hit", key=key[:16])
return list(entry.embedding)
del self._cache[key]
self._stats.expirations += 1
logger.debug("cache_expired", key=key[:16])
if key in self._in_flight:
logger.debug("cache_in_flight_join", key=key[:16])
existing_future = self._in_flight[key]
if existing_future is not None:
return list(await existing_future)
new_future: asyncio.Future[list[float]] = asyncio.get_running_loop().create_future()
async with self._lock:
if key in self._in_flight:
existing_future = self._in_flight[key]
else:
self._stats.misses += 1
self._in_flight[key] = new_future
if existing_future is not None:
return list(await existing_future)
try:
embedding = await embedder.embed(text)
except Exception:
async with self._lock:
_ = self._in_flight.pop(key, None)
new_future.set_exception(asyncio.CancelledError())
raise
async with self._lock:
_ = self._in_flight.pop(key, None)
while len(self._cache) >= self.max_size:
evicted_key, _ = self._cache.popitem(last=False)
self._stats.evictions += 1
logger.debug("cache_eviction", evicted_key=evicted_key[:16])
self._cache[key] = CacheEntry(
embedding=tuple(embedding),
created_at=current_time,
)
logger.debug("cache_store", key=key[:16])
new_future.set_result(embedding)
return embedding
async def get(self, text: str) -> list[float] | None:
"""Get embedding from cache without computing.
Args:
text: Text to look up.
Returns:
Embedding if cached and not expired, None otherwise.
"""
key = self._compute_key(text)
current_time = time.monotonic()
async with self._lock:
if key in self._cache:
entry = self._cache[key]
if not entry.is_expired(self.ttl_seconds, current_time):
self._cache.move_to_end(key)
return list(entry.embedding)
else:
del self._cache[key]
self._stats.expirations += 1
return None
async def clear(self) -> int:
"""Clear all cached entries.
Returns:
Number of entries cleared.
"""
async with self._lock:
count = len(self._cache)
self._cache.clear()
logger.info("cache_cleared", entries_cleared=count)
return count
async def size(self) -> int:
"""Get current number of cached entries."""
async with self._lock:
return len(self._cache)
def get_stats(self) -> EmbeddingCacheStats:
"""Get cache statistics (not async - reads are atomic)."""
return EmbeddingCacheStats(
hits=self._stats.hits,
misses=self._stats.misses,
evictions=self._stats.evictions,
expirations=self._stats.expirations,
)
class CachedEmbedder:
"""Wrapper that adds caching to any EmbedderProtocol implementation.
Example:
base_embedder = MyEmbedder()
cached = CachedEmbedder(base_embedder, max_size=500, ttl_seconds=1800)
embedding = await cached.embed("hello world")
"""
_embedder: EmbedderProtocol
_cache: EmbeddingCache
def __init__(
self,
embedder: EmbedderProtocol,
max_size: int = DEFAULT_MAX_SIZE,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
) -> None:
self._embedder = embedder
self._cache = EmbeddingCache(max_size=max_size, ttl_seconds=ttl_seconds)
async def embed(self, text: str) -> list[float]:
"""Embed text with caching."""
return await self._cache.get_or_compute(text, self._embedder)
@property
def cache(self) -> EmbeddingCache:
"""Access underlying cache for stats/management."""
return self._cache

View File

@@ -0,0 +1,56 @@
"""PostgreSQL checkpointer factory for LangGraph state persistence."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Final
if TYPE_CHECKING:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg_pool import AsyncConnectionPool
CHECKPOINTER_POOL_SIZE: Final[int] = 5
@dataclass
class CheckpointerResult:
"""Wraps checkpointer with pool lifecycle management to prevent connection leaks."""
checkpointer: AsyncPostgresSaver
_pool: AsyncConnectionPool
async def close(self) -> None:
# psycopg_pool.AsyncConnectionPool.close() exists at runtime but type stubs
# are incomplete. Use getattr to satisfy the type checker.
close_fn = getattr(self._pool, "close", None)
if close_fn is not None:
await close_fn()
async def __aenter__(self) -> AsyncPostgresSaver:
return self.checkpointer
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
await self.close()
async def create_checkpointer(
database_url: str,
pool_size: int = CHECKPOINTER_POOL_SIZE,
) -> CheckpointerResult:
"""Create async Postgres checkpointer for LangGraph state persistence."""
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg_pool import AsyncConnectionPool
pool = AsyncConnectionPool(
conninfo=database_url,
max_size=pool_size,
kwargs={"autocommit": True},
)
checkpointer = AsyncPostgresSaver(pool)
await checkpointer.setup()
return CheckpointerResult(checkpointer=checkpointer, _pool=pool)

View File

@@ -0,0 +1,51 @@
"""LangGraph workflow definitions."""
from noteflow.infrastructure.ai.graphs.meeting_qa import (
MEETING_QA_GRAPH_NAME,
MEETING_QA_GRAPH_VERSION,
MeetingQAConfig,
MeetingQAInputState,
MeetingQAInternalState,
MeetingQAOutputState,
build_meeting_qa_graph,
)
from noteflow.infrastructure.ai.graphs.summarization import (
SUMMARIZATION_GRAPH_NAME,
SUMMARIZATION_GRAPH_VERSION,
SummarizationInputState,
SummarizationOutputState,
SummarizationState,
build_summarization_graph,
)
from noteflow.infrastructure.ai.graphs.workspace_qa import (
WORKSPACE_QA_GRAPH_NAME,
WORKSPACE_QA_GRAPH_VERSION,
WorkspaceQAConfig,
WorkspaceQAInputState,
WorkspaceQAInternalState,
WorkspaceQAOutputState,
build_workspace_qa_graph,
)
__all__ = [
"MEETING_QA_GRAPH_NAME",
"MEETING_QA_GRAPH_VERSION",
"MeetingQAConfig",
"MeetingQAInputState",
"MeetingQAInternalState",
"MeetingQAOutputState",
"SUMMARIZATION_GRAPH_NAME",
"SUMMARIZATION_GRAPH_VERSION",
"SummarizationInputState",
"SummarizationOutputState",
"SummarizationState",
"WORKSPACE_QA_GRAPH_NAME",
"WORKSPACE_QA_GRAPH_VERSION",
"WorkspaceQAConfig",
"WorkspaceQAInputState",
"WorkspaceQAInternalState",
"WorkspaceQAOutputState",
"build_meeting_qa_graph",
"build_summarization_graph",
"build_workspace_qa_graph",
]

View File

@@ -0,0 +1,193 @@
"""Meeting Q&A graph for single-meeting question answering with citations."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Final, TypedDict
if TYPE_CHECKING:
from langgraph.graph import CompiledStateGraph
from noteflow.domain.ai.citations import SegmentCitation
from noteflow.domain.value_objects import MeetingId
from noteflow.infrastructure.ai.nodes.annotation_suggester import SuggestedAnnotation
from noteflow.infrastructure.ai.nodes.web_search import WebSearchProvider
from noteflow.infrastructure.ai.tools.retrieval import (
EmbedderProtocol,
RetrievalResult,
SegmentSearchProtocol,
)
from noteflow.infrastructure.ai.tools.synthesis import LLMProtocol
MEETING_QA_GRAPH_NAME: Final[str] = "meeting_qa"
MEETING_QA_GRAPH_VERSION: Final[int] = 2
NO_INFORMATION_ANSWER: Final[str] = "I couldn't find relevant information in this meeting."
@dataclass(frozen=True)
class MeetingQAConfig:
enable_web_search: bool = False
require_web_approval: bool = True
require_annotation_approval: bool = False
class MeetingQAInputState(TypedDict):
question: str
meeting_id: MeetingId
top_k: int
class MeetingQAOutputState(TypedDict):
answer: str
citations: list[SegmentCitation]
suggested_annotations: list[SuggestedAnnotation]
class MeetingQAInternalState(MeetingQAInputState, MeetingQAOutputState):
retrieved_segments: list[RetrievalResult]
verification_passed: bool
web_search_approved: bool
web_context: str
annotations_approved: bool
def build_meeting_qa_graph(
embedder: EmbedderProtocol,
segment_repo: SegmentSearchProtocol,
llm: LLMProtocol,
*,
web_search_provider: WebSearchProvider | None = None,
config: MeetingQAConfig | None = None,
checkpointer: object | None = None,
) -> CompiledStateGraph[MeetingQAInternalState]:
"""Build a Q&A graph for single-meeting questions with segment citations.
Graph flow (with web search): retrieve -> verify -> [web_search_approval] -> [web_search] -> synthesize
Graph flow (without): retrieve -> verify -> synthesize
Args:
embedder: Protocol for generating text embeddings.
segment_repo: Protocol for semantic segment search.
llm: Protocol for LLM text completion.
web_search_provider: Optional web search provider for augmentation.
config: Graph configuration for features/interrupts.
checkpointer: Optional checkpointer for interrupt support.
Returns:
Compiled graph that accepts question/meeting_id and returns answer/citations.
"""
from langgraph.graph import END, START, StateGraph
from noteflow.domain.ai.citations import SegmentCitation
from noteflow.infrastructure.ai.interrupts import check_web_search_approval
from noteflow.infrastructure.ai.nodes.annotation_suggester import (
extract_annotations_from_answer,
)
from noteflow.infrastructure.ai.nodes.web_search import (
WebSearchConfig,
derive_search_query,
execute_web_search,
format_results_for_context,
)
from noteflow.infrastructure.ai.tools.retrieval import retrieve_segments
from noteflow.infrastructure.ai.tools.synthesis import synthesize_answer
effective_config = config or MeetingQAConfig()
async def retrieve_node(state: MeetingQAInternalState) -> dict[str, object]:
results = await retrieve_segments(
query=state["question"],
embedder=embedder,
segment_repo=segment_repo,
meeting_id=state["meeting_id"],
top_k=state["top_k"],
)
return {"retrieved_segments": results}
async def verify_node(state: MeetingQAInternalState) -> dict[str, object]:
has_segments = len(state["retrieved_segments"]) > 0
return {"verification_passed": has_segments}
def web_search_approval_node(state: MeetingQAInternalState) -> dict[str, object]:
if not effective_config.enable_web_search or web_search_provider is None:
return {"web_search_approved": False}
if not effective_config.require_web_approval:
return {"web_search_approved": True}
query = derive_search_query(state["question"])
approved = check_web_search_approval(query, require_approval=True)
return {"web_search_approved": approved}
async def web_search_node(state: MeetingQAInternalState) -> dict[str, object]:
if not state.get("web_search_approved", False) or web_search_provider is None:
return {"web_context": ""}
query = derive_search_query(state["question"])
search_config = WebSearchConfig(enabled=True, require_approval=False)
response = await execute_web_search(query, web_search_provider, search_config)
context = format_results_for_context(response.results)
return {"web_context": context}
async def synthesize_node(state: MeetingQAInternalState) -> dict[str, object]:
if not state["verification_passed"]:
return {
"answer": NO_INFORMATION_ANSWER,
"citations": [],
"suggested_annotations": [],
}
result = await synthesize_answer(
question=state["question"],
segments=state["retrieved_segments"],
llm=llm,
)
citations = [
SegmentCitation(
meeting_id=state["meeting_id"],
segment_id=seg.segment_id,
start_time=seg.start_time,
end_time=seg.end_time,
text=seg.text,
score=seg.score,
)
for seg in state["retrieved_segments"]
if seg.segment_id in result.cited_segment_ids
]
suggested_annotations = extract_annotations_from_answer(
answer=result.answer,
cited_segment_ids=tuple(result.cited_segment_ids),
)
return {
"answer": result.answer,
"citations": citations,
"suggested_annotations": suggested_annotations,
}
builder: StateGraph[MeetingQAInternalState] = StateGraph(MeetingQAInternalState)
builder.add_node("retrieve", retrieve_node)
builder.add_node("verify", verify_node)
builder.add_node("synthesize", synthesize_node)
if effective_config.enable_web_search and web_search_provider is not None:
builder.add_node("web_search_approval", web_search_approval_node)
builder.add_node("web_search", web_search_node)
builder.add_edge(START, "retrieve")
builder.add_edge("retrieve", "verify")
builder.add_edge("verify", "web_search_approval")
builder.add_edge("web_search_approval", "web_search")
builder.add_edge("web_search", "synthesize")
builder.add_edge("synthesize", END)
else:
builder.add_edge(START, "retrieve")
builder.add_edge("retrieve", "verify")
builder.add_edge("verify", "synthesize")
builder.add_edge("synthesize", END)
compile_method = getattr(builder, "compile")
compiled: CompiledStateGraph[MeetingQAInternalState] = compile_method(checkpointer=checkpointer)
return compiled

View File

@@ -0,0 +1,103 @@
"""LangGraph wrapper for existing SummarizationService.
Demonstrates the LangGraph integration pattern by wrapping the existing
summarization infrastructure in a StateGraph.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Final, TypedDict, cast
if TYPE_CHECKING:
from collections.abc import Sequence
from uuid import UUID
from langgraph.graph import CompiledStateGraph
from noteflow.application.services.summarization import SummarizationService
from noteflow.domain.entities import Segment
from noteflow.domain.value_objects import MeetingId
SUMMARIZATION_GRAPH_NAME: Final[str] = "summarization"
SUMMARIZATION_GRAPH_VERSION: Final[int] = 1
class SummarizationInputState(TypedDict):
"""Input state for summarization graph."""
meeting_id: UUID
segments: Sequence[Segment]
class SummarizationOutputState(TypedDict):
"""Output state from summarization graph."""
summary_text: str
key_points: list[dict[str, object]]
action_items: list[dict[str, object]]
provider_used: str
tokens_used: int | None
latency_ms: float | None
class SummarizationState(SummarizationInputState, SummarizationOutputState):
"""Full internal state for summarization graph."""
def build_summarization_graph(
summarization_service: SummarizationService,
) -> CompiledStateGraph[SummarizationState]:
"""Build LangGraph wrapper around existing SummarizationService.
This demonstrates the pattern of wrapping existing services in LangGraph
graphs, enabling future expansion (checkpointing, conditional branching,
human-in-the-loop, etc.) while maintaining backward compatibility.
Args:
summarization_service: The existing summarization service to wrap.
Returns:
A compiled StateGraph that can be invoked with meeting_id and segments.
"""
from langgraph.graph import END, START, StateGraph
async def summarize_node(state: SummarizationState) -> dict[str, object]:
# MeetingId is NewType of UUID, cast is safe
meeting_id = cast("MeetingId", state["meeting_id"])
result = await summarization_service.summarize(
meeting_id=meeting_id,
segments=state["segments"],
)
summary = result.summary
return {
"summary_text": summary.executive_summary,
"key_points": [
{
"text": kp.text,
"segment_ids": kp.segment_ids,
"start_time": kp.start_time,
"end_time": kp.end_time,
}
for kp in summary.key_points
],
"action_items": [
{
"text": ai.text,
"segment_ids": ai.segment_ids,
"assignee": ai.assignee,
"start_time": ai.start_time,
"end_time": ai.end_time,
}
for ai in summary.action_items
],
"provider_used": result.provider_used,
"tokens_used": summary.tokens_used,
"latency_ms": summary.latency_ms,
}
builder = StateGraph(SummarizationState)
builder.add_node("summarize", summarize_node)
builder.add_edge(START, "summarize")
builder.add_edge("summarize", END)
return builder.compile()

View File

@@ -0,0 +1,198 @@
"""Workspace Q&A graph for cross-meeting question answering with citations."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Final, TypedDict
if TYPE_CHECKING:
from uuid import UUID
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import CompiledStateGraph
from noteflow.domain.ai.citations import SegmentCitation
from noteflow.infrastructure.ai.nodes.annotation_suggester import SuggestedAnnotation
from noteflow.infrastructure.ai.nodes.web_search import WebSearchProvider
from noteflow.infrastructure.ai.tools.retrieval import (
EmbedderProtocol,
RetrievalResult,
WorkspaceSegmentSearchProtocol,
)
from noteflow.infrastructure.ai.tools.synthesis import LLMProtocol
WORKSPACE_QA_GRAPH_NAME: Final[str] = "workspace_qa"
WORKSPACE_QA_GRAPH_VERSION: Final[int] = 2
NO_INFORMATION_ANSWER: Final[str] = "I couldn't find relevant information across your meetings."
@dataclass(frozen=True)
class WorkspaceQAConfig:
enable_web_search: bool = False
require_web_approval: bool = True
require_annotation_approval: bool = False
class WorkspaceQAInputState(TypedDict):
question: str
workspace_id: UUID
project_id: UUID | None
top_k: int
class WorkspaceQAOutputState(TypedDict):
answer: str
citations: list[SegmentCitation]
suggested_annotations: list[SuggestedAnnotation]
class WorkspaceQAInternalState(WorkspaceQAInputState, WorkspaceQAOutputState):
retrieved_segments: list[RetrievalResult]
verification_passed: bool
web_search_approved: bool
web_context: str
def build_workspace_qa_graph(
embedder: EmbedderProtocol,
segment_repo: WorkspaceSegmentSearchProtocol,
llm: LLMProtocol,
*,
web_search_provider: WebSearchProvider | None = None,
config: WorkspaceQAConfig | None = None,
checkpointer: BaseCheckpointSaver[str] | None = None,
) -> CompiledStateGraph[WorkspaceQAInternalState]:
"""Build Q&A graph for cross-meeting questions with segment citations.
Graph flow (with web search): retrieve -> verify -> [web_search_approval] -> [web_search] -> synthesize
Graph flow (without): retrieve -> verify -> synthesize
Args:
embedder: Protocol for generating text embeddings.
segment_repo: Protocol for workspace-scoped semantic segment search.
llm: Protocol for LLM text completion.
web_search_provider: Optional web search provider for augmentation.
config: Graph configuration for features/interrupts.
checkpointer: Optional checkpointer for interrupt support.
Returns:
Compiled graph that accepts question/workspace_id and returns answer/citations.
"""
from langgraph.graph import END, START, StateGraph
from noteflow.domain.ai.citations import SegmentCitation
from noteflow.infrastructure.ai.interrupts import check_web_search_approval
from noteflow.infrastructure.ai.nodes.annotation_suggester import (
extract_annotations_from_answer,
)
from noteflow.infrastructure.ai.nodes.web_search import (
WebSearchConfig,
derive_search_query,
execute_web_search,
format_results_for_context,
)
from noteflow.infrastructure.ai.tools.retrieval import retrieve_segments_workspace
from noteflow.infrastructure.ai.tools.synthesis import synthesize_answer
effective_config = config or WorkspaceQAConfig()
async def retrieve_node(state: WorkspaceQAInternalState) -> dict[str, object]:
results = await retrieve_segments_workspace(
query=state["question"],
embedder=embedder,
segment_repo=segment_repo,
workspace_id=state["workspace_id"],
project_id=state["project_id"],
top_k=state["top_k"],
)
return {"retrieved_segments": results}
async def verify_node(state: WorkspaceQAInternalState) -> dict[str, object]:
has_segments = len(state["retrieved_segments"]) > 0
return {"verification_passed": has_segments}
def web_search_approval_node(state: WorkspaceQAInternalState) -> dict[str, object]:
if not effective_config.enable_web_search or web_search_provider is None:
return {"web_search_approved": False}
if not effective_config.require_web_approval:
return {"web_search_approved": True}
query = derive_search_query(state["question"])
approved = check_web_search_approval(query, require_approval=True)
return {"web_search_approved": approved}
async def web_search_node(state: WorkspaceQAInternalState) -> dict[str, object]:
if not state.get("web_search_approved", False) or web_search_provider is None:
return {"web_context": ""}
query = derive_search_query(state["question"])
search_config = WebSearchConfig(enabled=True, require_approval=False)
response = await execute_web_search(query, web_search_provider, search_config)
context = format_results_for_context(response.results)
return {"web_context": context}
async def synthesize_node(state: WorkspaceQAInternalState) -> dict[str, object]:
if not state["verification_passed"]:
return {
"answer": NO_INFORMATION_ANSWER,
"citations": [],
"suggested_annotations": [],
}
result = await synthesize_answer(
question=state["question"],
segments=state["retrieved_segments"],
llm=llm,
)
citations = [
SegmentCitation(
meeting_id=seg.meeting_id,
segment_id=seg.segment_id,
start_time=seg.start_time,
end_time=seg.end_time,
text=seg.text,
score=seg.score,
)
for seg in state["retrieved_segments"]
if seg.segment_id in result.cited_segment_ids
]
suggested_annotations = extract_annotations_from_answer(
answer=result.answer,
cited_segment_ids=tuple(result.cited_segment_ids),
)
return {
"answer": result.answer,
"citations": citations,
"suggested_annotations": suggested_annotations,
}
builder: StateGraph[WorkspaceQAInternalState] = StateGraph(WorkspaceQAInternalState)
builder.add_node("retrieve", retrieve_node)
builder.add_node("verify", verify_node)
builder.add_node("synthesize", synthesize_node)
if effective_config.enable_web_search and web_search_provider is not None:
builder.add_node("web_search_approval", web_search_approval_node)
builder.add_node("web_search", web_search_node)
builder.add_edge(START, "retrieve")
builder.add_edge("retrieve", "verify")
builder.add_edge("verify", "web_search_approval")
builder.add_edge("web_search_approval", "web_search")
builder.add_edge("web_search", "synthesize")
builder.add_edge("synthesize", END)
else:
builder.add_edge(START, "retrieve")
builder.add_edge("retrieve", "verify")
builder.add_edge("verify", "synthesize")
builder.add_edge("synthesize", END)
compile_method = getattr(builder, "compile")
compiled: CompiledStateGraph[WorkspaceQAInternalState] = compile_method(
checkpointer=checkpointer
)
return compiled

View File

@@ -0,0 +1,312 @@
"""Content guardrails for AI input validation and output filtering."""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import Final
from noteflow.infrastructure.logging import get_logger
logger = get_logger(__name__)
# Input validation limits
DEFAULT_MIN_INPUT_LENGTH: Final[int] = 3
DEFAULT_MAX_INPUT_LENGTH: Final[int] = 4000
DEFAULT_MAX_OUTPUT_LENGTH: Final[int] = 10000
# PII patterns (simplified - production would use more comprehensive detection)
EMAIL_PATTERN: Final[re.Pattern[str]] = re.compile(
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
)
PHONE_PATTERN: Final[re.Pattern[str]] = re.compile(
r"\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b"
)
SSN_PATTERN: Final[re.Pattern[str]] = re.compile(r"\b\d{3}-\d{2}-\d{4}\b")
CREDIT_CARD_PATTERN: Final[re.Pattern[str]] = re.compile(r"\b(?:\d{4}[-\s]?){3}\d{4}\b")
PII_PATTERNS: Final[tuple[tuple[str, re.Pattern[str]], ...]] = (
("email", EMAIL_PATTERN),
("phone", PHONE_PATTERN),
("ssn", SSN_PATTERN),
("credit_card", CREDIT_CARD_PATTERN),
)
# Redaction placeholder
PII_REDACTION: Final[str] = "[REDACTED]"
class GuardrailViolation(str, Enum):
"""Types of guardrail violations."""
INPUT_TOO_SHORT = "input_too_short"
INPUT_TOO_LONG = "input_too_long"
OUTPUT_TOO_LONG = "output_too_long"
CONTAINS_PII = "contains_pii"
BLOCKED_CONTENT = "blocked_content"
INJECTION_ATTEMPT = "injection_attempt"
@dataclass(frozen=True)
class GuardrailResult:
"""Result of a guardrail check."""
allowed: bool
violation: GuardrailViolation | None = None
reason: str | None = None
filtered_content: str | None = None
@staticmethod
def ok(content: str | None = None) -> GuardrailResult:
"""Create a passing result."""
return GuardrailResult(allowed=True, filtered_content=content)
@staticmethod
def blocked(
violation: GuardrailViolation,
reason: str,
) -> GuardrailResult:
"""Create a blocking result."""
return GuardrailResult(allowed=False, violation=violation, reason=reason)
@staticmethod
def filtered(
content: str,
violation: GuardrailViolation,
reason: str,
) -> GuardrailResult:
"""Create a result with filtered content."""
return GuardrailResult(
allowed=True,
violation=violation,
reason=reason,
filtered_content=content,
)
@dataclass
class GuardrailRules:
"""Configurable guardrail rules.
Attributes:
min_input_length: Minimum allowed input length.
max_input_length: Maximum allowed input length.
max_output_length: Maximum allowed output length.
block_pii: Whether to block content containing PII.
redact_pii: Whether to redact PII instead of blocking.
blocked_phrases: Phrases that should block content entirely.
detect_injection: Whether to detect prompt injection attempts.
"""
min_input_length: int = DEFAULT_MIN_INPUT_LENGTH
max_input_length: int = DEFAULT_MAX_INPUT_LENGTH
max_output_length: int = DEFAULT_MAX_OUTPUT_LENGTH
block_pii: bool = False
redact_pii: bool = True
blocked_phrases: frozenset[str] = field(default_factory=frozenset)
detect_injection: bool = True
# Common injection patterns
INJECTION_PATTERNS: Final[tuple[re.Pattern[str], ...]] = (
re.compile(r"ignore\s+(?:all\s+)?(?:previous|above)\s+instructions", re.IGNORECASE),
re.compile(r"disregard\s+(?:all\s+)?(?:previous|prior)\s+", re.IGNORECASE),
re.compile(r"you\s+are\s+now\s+(?:a|an)\s+", re.IGNORECASE),
re.compile(r"forget\s+(?:everything|all)\s+(?:you|and)\s+", re.IGNORECASE),
re.compile(r"new\s+(?:system\s+)?instructions?:", re.IGNORECASE),
)
def _check_length(
text: str,
min_length: int,
max_length: int,
is_input: bool,
) -> GuardrailResult | None:
"""Check text length constraints."""
if is_input and len(text) < min_length:
return GuardrailResult.blocked(
GuardrailViolation.INPUT_TOO_SHORT,
f"Input must be at least {min_length} characters",
)
if is_input and len(text) > max_length:
return GuardrailResult.blocked(
GuardrailViolation.INPUT_TOO_LONG,
f"Input must be at most {max_length} characters",
)
if not is_input and len(text) > max_length:
return GuardrailResult.blocked(
GuardrailViolation.OUTPUT_TOO_LONG,
f"Output exceeds {max_length} characters",
)
return None
def _check_blocked_phrases(
text: str,
blocked_phrases: frozenset[str],
) -> GuardrailResult | None:
"""Check for blocked phrases."""
text_lower = text.lower()
for phrase in blocked_phrases:
if phrase.lower() in text_lower:
logger.warning("blocked_phrase_detected", phrase=phrase[:20])
return GuardrailResult.blocked(
GuardrailViolation.BLOCKED_CONTENT,
"Content contains blocked phrase",
)
return None
def _check_injection(text: str) -> GuardrailResult | None:
"""Check for prompt injection attempts."""
for pattern in INJECTION_PATTERNS:
if pattern.search(text):
logger.warning("injection_attempt_detected")
return GuardrailResult.blocked(
GuardrailViolation.INJECTION_ATTEMPT,
"Potential prompt injection detected",
)
return None
def _detect_pii(text: str) -> list[tuple[str, str]]:
"""Detect PII in text.
Returns:
List of (pii_type, matched_text) tuples.
"""
findings: list[tuple[str, str]] = []
for pii_type, pattern in PII_PATTERNS:
for match in pattern.finditer(text):
findings.append((pii_type, match.group()))
return findings
def _redact_pii(text: str) -> str:
"""Redact all PII in text."""
result = text
for _, pattern in PII_PATTERNS:
result = pattern.sub(PII_REDACTION, result)
return result
async def check_input(text: str, rules: GuardrailRules) -> GuardrailResult:
"""Validate input text against guardrail rules.
Args:
text: Input text to validate.
rules: Guardrail rules to apply.
Returns:
GuardrailResult indicating if input is allowed.
"""
# Length checks
length_result = _check_length(
text,
rules.min_input_length,
rules.max_input_length,
is_input=True,
)
if length_result is not None:
return length_result
# Blocked phrases
phrase_result = _check_blocked_phrases(text, rules.blocked_phrases)
if phrase_result is not None:
return phrase_result
# Injection detection
if rules.detect_injection:
injection_result = _check_injection(text)
if injection_result is not None:
return injection_result
# PII checks
if rules.block_pii or rules.redact_pii:
pii_findings = _detect_pii(text)
if pii_findings:
pii_types = [f[0] for f in pii_findings]
logger.info("pii_detected_in_input", pii_types=pii_types)
if rules.block_pii:
return GuardrailResult.blocked(
GuardrailViolation.CONTAINS_PII,
f"Input contains PII: {', '.join(pii_types)}",
)
# Redact instead of block
redacted = _redact_pii(text)
return GuardrailResult.filtered(
redacted,
GuardrailViolation.CONTAINS_PII,
f"PII redacted: {', '.join(pii_types)}",
)
return GuardrailResult.ok(text)
async def filter_output(text: str, rules: GuardrailRules) -> GuardrailResult:
"""Filter output text, redacting sensitive content.
Args:
text: Output text to filter.
rules: Guardrail rules to apply.
Returns:
GuardrailResult with potentially filtered content.
"""
# Length check
length_result = _check_length(
text,
min_length=0, # No minimum for output
max_length=rules.max_output_length,
is_input=False,
)
if length_result is not None:
# Truncate instead of blocking for output
truncated = text[: rules.max_output_length]
return GuardrailResult.filtered(
truncated,
GuardrailViolation.OUTPUT_TOO_LONG,
f"Output truncated to {rules.max_output_length} characters",
)
# Blocked phrases in output
phrase_result = _check_blocked_phrases(text, rules.blocked_phrases)
if phrase_result is not None:
return phrase_result
# PII redaction in output (always redact, never block output)
if rules.redact_pii:
pii_findings = _detect_pii(text)
if pii_findings:
pii_types = [f[0] for f in pii_findings]
logger.info("pii_detected_in_output", pii_types=pii_types)
redacted = _redact_pii(text)
return GuardrailResult.filtered(
redacted,
GuardrailViolation.CONTAINS_PII,
f"PII redacted: {', '.join(pii_types)}",
)
return GuardrailResult.ok(text)
def create_default_rules() -> GuardrailRules:
"""Create default guardrail rules."""
return GuardrailRules()
def create_strict_rules() -> GuardrailRules:
"""Create strict guardrail rules with PII blocking."""
return GuardrailRules(
block_pii=True,
redact_pii=False,
detect_injection=True,
max_input_length=2000,
)

View File

@@ -0,0 +1,231 @@
"""Infrastructure utilities for LangGraph human-in-the-loop interrupts.
Wraps LangGraph's interrupt() and Command APIs for consistent usage across graphs.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Final
from uuid import uuid4
from langgraph.types import Command, interrupt
from noteflow.domain.ai.interrupts import (
InterruptAction,
InterruptResponse,
create_annotation_interrupt,
create_web_search_interrupt,
)
from noteflow.infrastructure.logging import get_logger
if TYPE_CHECKING:
from noteflow.infrastructure.ai.nodes.annotation_suggester import SuggestedAnnotation
logger = get_logger(__name__)
INTERRUPT_RESPONSE_KEY: Final[str] = "response"
INTERRUPT_APPROVED_VALUE: Final[str] = "approved"
def request_web_search_approval(
query: str,
*,
allow_modify: bool = False,
) -> InterruptResponse:
"""Request user approval for web search via LangGraph interrupt.
Args:
query: The search query to be executed.
allow_modify: Whether the user can modify the query.
Returns:
InterruptResponse with user's decision.
"""
request_id = str(uuid4())
interrupt_request = create_web_search_interrupt(
query=query,
request_id=request_id,
allow_modify=allow_modify,
)
logger.info(
"interrupt_web_search_requested",
request_id=request_id,
query_preview=query[:50],
)
response_data = interrupt(interrupt_request.to_dict())
return _parse_interrupt_response(response_data, request_id)
def request_annotation_approval(
annotations: list[SuggestedAnnotation],
) -> InterruptResponse:
"""Request user approval for suggested annotations via LangGraph interrupt.
Args:
annotations: List of suggested annotations to approve.
Returns:
InterruptResponse with user's decision.
"""
request_id = str(uuid4())
annotation_dicts = [ann.to_dict() for ann in annotations]
interrupt_request = create_annotation_interrupt(
annotations=annotation_dicts,
request_id=request_id,
)
logger.info(
"interrupt_annotation_requested",
request_id=request_id,
annotation_count=len(annotations),
)
response_data = interrupt(interrupt_request.to_dict())
return _parse_interrupt_response(response_data, request_id)
def _parse_interrupt_response(
response_data: object,
request_id: str,
) -> InterruptResponse:
"""Parse LangGraph interrupt response into domain type.
Args:
response_data: Raw response from LangGraph interrupt.
request_id: ID of the original request.
Returns:
Parsed InterruptResponse.
"""
if isinstance(response_data, str):
action = _string_to_action(response_data)
return InterruptResponse(action=action, request_id=request_id)
if isinstance(response_data, dict):
action_str = str(response_data.get("action", "reject"))
action = _string_to_action(action_str)
modified_value = response_data.get("modified_value")
if modified_value is not None and not isinstance(modified_value, dict):
modified_value = None
user_message = response_data.get("user_message")
if user_message is not None:
user_message = str(user_message)
return InterruptResponse(
action=action,
request_id=request_id,
modified_value=modified_value,
user_message=user_message,
)
logger.warning(
"interrupt_response_unknown_format",
request_id=request_id,
response_type=type(response_data).__name__,
)
return InterruptResponse(action=InterruptAction.REJECT, request_id=request_id)
def _string_to_action(value: str) -> InterruptAction:
"""Convert string response to InterruptAction."""
normalized = value.lower().strip()
if normalized in ("approve", "yes", "approved", "accept"):
return InterruptAction.APPROVE
if normalized in ("modify", "edit", "change"):
return InterruptAction.MODIFY
return InterruptAction.REJECT
def create_resume_command(response: InterruptResponse) -> Command[None]:
"""Create a LangGraph Command to resume execution with user response.
Args:
response: User's interrupt response.
Returns:
Command to resume graph execution.
"""
return Command(resume=response.to_dict())
class InterruptHandler:
"""Handles interrupt requests and responses for a graph execution."""
_require_web_approval: bool
def __init__(self, require_web_approval: bool = True) -> None:
self._require_web_approval = require_web_approval
def should_interrupt_for_web_search(self) -> bool:
return self._require_web_approval
def request_web_search(self, query: str) -> InterruptResponse:
return request_web_search_approval(query)
def request_annotation_approval(
self,
annotations: list[SuggestedAnnotation],
) -> InterruptResponse:
return request_annotation_approval(annotations)
def check_web_search_approval(
query: str,
require_approval: bool,
) -> bool:
"""Check if web search should proceed (with optional interrupt).
Args:
query: Search query to execute.
require_approval: Whether to interrupt for user approval.
Returns:
True if search should proceed, False if rejected.
"""
if not require_approval:
return True
response = request_web_search_approval(query)
return response.is_approved
def check_annotation_approval(
annotations: list[SuggestedAnnotation],
) -> tuple[bool, list[SuggestedAnnotation]]:
"""Check if annotations should be applied (with interrupt).
Args:
annotations: Suggested annotations to approve.
Returns:
Tuple of (should_apply, possibly_modified_annotations).
"""
if not annotations:
return False, []
response = request_annotation_approval(annotations)
if response.is_rejected:
return False, []
if response.is_modified and response.modified_value:
modified_list_raw = response.modified_value.get("annotations", [])
if isinstance(modified_list_raw, list):
from noteflow.infrastructure.ai.nodes.annotation_suggester import (
SuggestedAnnotation,
)
modified_annotations: list[SuggestedAnnotation] = []
for item in modified_list_raw:
if isinstance(item, dict):
item_dict: dict[str, object] = {str(k): v for k, v in item.items()}
modified_annotations.append(SuggestedAnnotation.from_dict(item_dict))
return True, modified_annotations
return response.is_approved, annotations

View File

@@ -0,0 +1,39 @@
"""LangGraph node implementations."""
from noteflow.infrastructure.ai.nodes.annotation_suggester import (
SuggestedAnnotation,
SuggestedAnnotationType,
extract_annotations_from_answer,
)
from noteflow.infrastructure.ai.nodes.verification import (
VerificationResult,
verify_citations,
)
from noteflow.infrastructure.ai.nodes.web_search import (
DisabledWebSearchProvider,
WebSearchConfig,
WebSearchProvider,
WebSearchResponse,
WebSearchResult,
derive_search_query,
execute_web_search,
format_results_for_context,
merge_contexts,
)
__all__ = [
"DisabledWebSearchProvider",
"SuggestedAnnotation",
"SuggestedAnnotationType",
"VerificationResult",
"WebSearchConfig",
"WebSearchProvider",
"WebSearchResponse",
"WebSearchResult",
"derive_search_query",
"execute_web_search",
"extract_annotations_from_answer",
"format_results_for_context",
"merge_contexts",
"verify_citations",
]

View File

@@ -0,0 +1,121 @@
"""Annotation suggester for extracting action items and decisions from answers."""
from __future__ import annotations
import re
from dataclasses import dataclass
from enum import Enum
from typing import Final
class SuggestedAnnotationType(str, Enum):
ACTION_ITEM = "action_item"
DECISION = "decision"
NOTE = "note"
@dataclass(frozen=True)
class SuggestedAnnotation:
text: str
annotation_type: SuggestedAnnotationType
segment_ids: tuple[int, ...]
confidence: float = 0.8
def to_dict(self) -> dict[str, object]:
return {
"text": self.text,
"type": self.annotation_type.value,
"segment_ids": list(self.segment_ids),
"confidence": self.confidence,
}
@classmethod
def from_dict(cls, data: dict[str, object]) -> SuggestedAnnotation:
text = str(data.get("text", ""))
type_str = str(data.get("type", "note"))
segment_ids_raw = data.get("segment_ids", [])
if isinstance(segment_ids_raw, list):
segment_ids = tuple(
int(sid) for sid in segment_ids_raw if isinstance(sid, (int, float))
)
else:
segment_ids = ()
confidence_raw = data.get("confidence", 0.8)
confidence = float(confidence_raw) if isinstance(confidence_raw, (int, float)) else 0.8
try:
annotation_type = SuggestedAnnotationType(type_str)
except ValueError:
annotation_type = SuggestedAnnotationType.NOTE
return cls(
text=text,
annotation_type=annotation_type,
segment_ids=segment_ids,
confidence=confidence,
)
ACTION_ITEM_PATTERNS: Final[tuple[re.Pattern[str], ...]] = (
re.compile(r"(?:need to|should|must|will|going to|has to)\s+(.+?)(?:\.|$)", re.IGNORECASE),
re.compile(r"(?:action item|TODO|task):\s*(.+?)(?:\.|$)", re.IGNORECASE),
re.compile(r"(?:follow[- ]up|next step):\s*(.+?)(?:\.|$)", re.IGNORECASE),
)
DECISION_PATTERNS: Final[tuple[re.Pattern[str], ...]] = (
re.compile(r"(?:decided to|agreed to|will go with|chose to)\s+(.+?)(?:\.|$)", re.IGNORECASE),
re.compile(r"(?:decision|resolution):\s*(.+?)(?:\.|$)", re.IGNORECASE),
re.compile(r"(?:the team|we|they) (?:decided|agreed|chose)\s+(.+?)(?:\.|$)", re.IGNORECASE),
)
MIN_TEXT_LENGTH: Final[int] = 10
MAX_TEXT_LENGTH: Final[int] = 200
def extract_annotations_from_answer(
answer: str,
cited_segment_ids: tuple[int, ...],
) -> list[SuggestedAnnotation]:
"""Extract action items and decisions from synthesized answer."""
suggestions: list[SuggestedAnnotation] = []
for pattern in ACTION_ITEM_PATTERNS:
for match in pattern.finditer(answer):
text = match.group(1).strip()
if MIN_TEXT_LENGTH <= len(text) <= MAX_TEXT_LENGTH:
suggestions.append(
SuggestedAnnotation(
text=text,
annotation_type=SuggestedAnnotationType.ACTION_ITEM,
segment_ids=cited_segment_ids,
confidence=0.7,
)
)
for pattern in DECISION_PATTERNS:
for match in pattern.finditer(answer):
text = match.group(1).strip()
if MIN_TEXT_LENGTH <= len(text) <= MAX_TEXT_LENGTH:
suggestions.append(
SuggestedAnnotation(
text=text,
annotation_type=SuggestedAnnotationType.DECISION,
segment_ids=cited_segment_ids,
confidence=0.75,
)
)
return _dedupe_suggestions(suggestions)
def _dedupe_suggestions(suggestions: list[SuggestedAnnotation]) -> list[SuggestedAnnotation]:
seen_texts: set[str] = set()
deduped: list[SuggestedAnnotation] = []
for suggestion in suggestions:
normalized = suggestion.text.lower().strip()
if normalized not in seen_texts:
seen_texts.add(normalized)
deduped.append(suggestion)
return deduped

View File

@@ -0,0 +1,61 @@
"""Citation verification for AI-generated answers."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Final
@dataclass(frozen=True)
class VerificationResult:
"""Result of citation verification.
Attributes:
is_valid: True if all cited segment IDs exist in available segments.
invalid_citation_indices: Indices of citations that failed validation.
reason: Human-readable explanation if validation failed.
"""
is_valid: bool
invalid_citation_indices: tuple[int, ...]
reason: str | None = None
NO_SEGMENTS_REASON: Final[str] = "No segments retrieved for question"
INVALID_CITATIONS_PREFIX: Final[str] = "Invalid citation indices: "
def verify_citations(
cited_ids: list[int],
available_ids: set[int],
) -> VerificationResult:
"""Verify all cited segment IDs exist in available segments.
Args:
cited_ids: List of segment IDs cited in the answer.
available_ids: Set of valid segment IDs from retrieval.
Returns:
VerificationResult with validation status and any invalid indices.
"""
if not available_ids:
return VerificationResult(
is_valid=False,
invalid_citation_indices=(),
reason=NO_SEGMENTS_REASON,
)
invalid_indices = tuple(i for i, cid in enumerate(cited_ids) if cid not in available_ids)
if invalid_indices:
return VerificationResult(
is_valid=False,
invalid_citation_indices=invalid_indices,
reason=f"{INVALID_CITATIONS_PREFIX}{invalid_indices}",
)
return VerificationResult(
is_valid=True,
invalid_citation_indices=(),
reason=None,
)

View File

@@ -0,0 +1,226 @@
"""Web search node for augmenting RAG with external knowledge."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Final, Protocol
from noteflow.infrastructure.logging import get_logger
logger = get_logger(__name__)
DEFAULT_MAX_RESULTS: Final[int] = 5
DEFAULT_TIMEOUT_SECONDS: Final[float] = 10.0
@dataclass(frozen=True)
class WebSearchResult:
"""A single web search result."""
title: str
url: str
snippet: str
score: float = 1.0
def to_dict(self) -> dict[str, object]:
"""Convert to dictionary for serialization."""
return {
"title": self.title,
"url": self.url,
"snippet": self.snippet,
"score": self.score,
}
@dataclass(frozen=True)
class WebSearchResponse:
"""Response from a web search query."""
query: str
results: tuple[WebSearchResult, ...]
total_results: int
search_time_ms: float
@property
def has_results(self) -> bool:
"""Check if search returned any results."""
return len(self.results) > 0
class WebSearchProvider(Protocol):
"""Protocol for web search providers.
Implementations can integrate with:
- Exa AI
- SerpAPI
- Brave Search API
- Bing Web Search API
- Google Custom Search
"""
async def search(
self,
query: str,
max_results: int = DEFAULT_MAX_RESULTS,
) -> WebSearchResponse:
"""Execute a web search query.
Args:
query: Search query string.
max_results: Maximum number of results to return.
Returns:
WebSearchResponse with search results.
"""
...
class DisabledWebSearchProvider:
"""Stub provider that returns empty results.
Used when web search is not configured or disabled.
"""
async def search(
self,
query: str,
_max_results: int = DEFAULT_MAX_RESULTS,
) -> WebSearchResponse:
"""Return empty results - web search disabled."""
logger.debug("web_search_disabled", query=query[:50])
return WebSearchResponse(
query=query,
results=(),
total_results=0,
search_time_ms=0.0,
)
@dataclass
class WebSearchConfig:
"""Configuration for web search node."""
enabled: bool = False
max_results: int = DEFAULT_MAX_RESULTS
timeout_seconds: float = DEFAULT_TIMEOUT_SECONDS
require_approval: bool = True
def format_results_for_context(results: tuple[WebSearchResult, ...]) -> str:
"""Format web search results for LLM context.
Args:
results: Web search results to format.
Returns:
Formatted string suitable for LLM context.
"""
if not results:
return ""
formatted_parts: list[str] = ["## Web Search Results\n"]
for i, result in enumerate(results, 1):
formatted_parts.append(f"### [{i}] {result.title}")
formatted_parts.append(f"Source: {result.url}")
formatted_parts.append(f"{result.snippet}\n")
return "\n".join(formatted_parts)
async def execute_web_search(
query: str,
provider: WebSearchProvider,
config: WebSearchConfig,
) -> WebSearchResponse:
"""Execute web search with configuration.
Args:
query: Search query derived from user question.
provider: Web search provider implementation.
config: Search configuration.
Returns:
WebSearchResponse with results (empty if disabled).
"""
if not config.enabled:
logger.debug("web_search_skipped_disabled")
return WebSearchResponse(
query=query,
results=(),
total_results=0,
search_time_ms=0.0,
)
logger.info("web_search_executing", query=query[:100], max_results=config.max_results)
response = await provider.search(
query=query,
max_results=config.max_results,
)
logger.info(
"web_search_completed",
query=query[:50],
result_count=len(response.results),
search_time_ms=response.search_time_ms,
)
return response
def merge_contexts(
transcript_context: str,
web_results: WebSearchResponse,
) -> str:
"""Merge transcript segments with web search results.
Args:
transcript_context: Context from transcript segments.
web_results: Web search response.
Returns:
Combined context for LLM synthesis.
"""
if not web_results.has_results:
return transcript_context
web_context = format_results_for_context(web_results.results)
return f"""## Meeting Transcript Context
{transcript_context}
{web_context}
Note: Web search results are provided as supplementary context.
Prioritize information from the meeting transcript when answering questions about the meeting.
"""
def derive_search_query(question: str, meeting_context: str | None = None) -> str:
"""Derive a web search query from user question and context.
Args:
question: User's original question.
meeting_context: Optional context about the meeting topic.
Returns:
Optimized search query.
"""
# Simple approach: use the question directly
# A more sophisticated approach would use an LLM to generate the query
query = question.strip()
# Add meeting context keywords if available
if meeting_context:
# Extract key terms from context (simplified)
context_terms = meeting_context[:100].strip()
if context_terms:
query = f"{query} {context_terms}"
# Limit query length
max_query_length = 256
if len(query) > max_query_length:
query = query[:max_query_length].rsplit(" ", 1)[0]
return query

View File

@@ -0,0 +1,27 @@
"""Tool adapters for LangGraph workflows."""
from noteflow.infrastructure.ai.tools.retrieval import (
BatchEmbedderProtocol,
EmbedderProtocol,
RetrievalResult,
retrieve_segments,
retrieve_segments_batch,
retrieve_segments_workspace,
retrieve_segments_workspace_batch,
)
from noteflow.infrastructure.ai.tools.synthesis import (
SynthesisResult,
synthesize_answer,
)
__all__ = [
"BatchEmbedderProtocol",
"EmbedderProtocol",
"RetrievalResult",
"SynthesisResult",
"retrieve_segments",
"retrieve_segments_batch",
"retrieve_segments_workspace",
"retrieve_segments_workspace_batch",
"synthesize_answer",
]

View File

@@ -0,0 +1,237 @@
"""Segment retrieval tools for LangGraph workflows."""
from __future__ import annotations
import asyncio
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Final, Protocol, runtime_checkable
from uuid import UUID
from noteflow.domain.value_objects import MeetingId
# Limit concurrent parallel operations to prevent resource exhaustion
MAX_CONCURRENT_OPERATIONS: Final[int] = 10
@runtime_checkable
class EmbedderProtocol(Protocol):
"""Protocol for text embedding providers."""
async def embed(self, text: str) -> list[float]:
"""Embed a single text string."""
...
@runtime_checkable
class BatchEmbedderProtocol(EmbedderProtocol, Protocol):
"""Extended protocol for embedders supporting batch operations."""
async def embed_batch(self, texts: Sequence[str]) -> list[list[float]]:
"""Embed multiple texts in a single batch operation.
More efficient than calling embed() multiple times for providers
that support batching (reduces API calls, leverages GPU batching).
"""
...
class SegmentLike(Protocol):
segment_id: int
meeting_id: MeetingId | None
text: str
start_time: float
end_time: float
class SegmentSearchProtocol(Protocol):
async def search_semantic(
self,
query_embedding: list[float],
limit: int,
meeting_id: MeetingId | None,
) -> Sequence[tuple[SegmentLike, float]]: ...
class WorkspaceSegmentSearchProtocol(Protocol):
async def search_semantic_workspace(
self,
query_embedding: list[float],
workspace_id: UUID,
project_id: UUID | None,
limit: int,
) -> Sequence[tuple[SegmentLike, float]]: ...
@dataclass(frozen=True)
class RetrievalResult:
segment_id: int
meeting_id: UUID
text: str
start_time: float
end_time: float
score: float
def _meeting_id_to_uuid(mid: MeetingId | None) -> UUID:
if mid is None:
msg = "meeting_id is required for RetrievalResult"
raise ValueError(msg)
return UUID(str(mid))
async def retrieve_segments(
query: str,
embedder: EmbedderProtocol,
segment_repo: SegmentSearchProtocol,
meeting_id: MeetingId | None = None,
top_k: int = 8,
) -> list[RetrievalResult]:
"""Retrieve relevant transcript segments via semantic search."""
query_embedding = await embedder.embed(query)
results = await segment_repo.search_semantic(
query_embedding=query_embedding,
limit=top_k,
meeting_id=meeting_id,
)
return [
RetrievalResult(
segment_id=segment.segment_id,
meeting_id=_meeting_id_to_uuid(segment.meeting_id),
text=segment.text,
start_time=segment.start_time,
end_time=segment.end_time,
score=score,
)
for segment, score in results
]
async def retrieve_segments_workspace(
query: str,
embedder: EmbedderProtocol,
segment_repo: WorkspaceSegmentSearchProtocol,
workspace_id: UUID,
project_id: UUID | None = None,
top_k: int = 20,
) -> list[RetrievalResult]:
"""Retrieve relevant transcript segments across workspace/project via semantic search."""
query_embedding = await embedder.embed(query)
results = await segment_repo.search_semantic_workspace(
query_embedding=query_embedding,
workspace_id=workspace_id,
project_id=project_id,
limit=top_k,
)
return [
RetrievalResult(
segment_id=segment.segment_id,
meeting_id=_meeting_id_to_uuid(segment.meeting_id),
text=segment.text,
start_time=segment.start_time,
end_time=segment.end_time,
score=score,
)
for segment, score in results
]
async def _embed_batch_fallback(
texts: Sequence[str],
embedder: EmbedderProtocol,
) -> list[list[float]]:
"""Embed multiple texts, using batch API if available or parallel fallback."""
if isinstance(embedder, BatchEmbedderProtocol):
return await embedder.embed_batch(texts)
semaphore = asyncio.Semaphore(MAX_CONCURRENT_OPERATIONS)
async def _bounded_embed(text: str) -> list[float]:
async with semaphore:
return await embedder.embed(text)
return list(await asyncio.gather(*(_bounded_embed(t) for t in texts)))
async def retrieve_segments_batch(
queries: Sequence[str],
embedder: EmbedderProtocol,
segment_repo: SegmentSearchProtocol,
meeting_id: MeetingId | None = None,
top_k: int = 8,
) -> list[list[RetrievalResult]]:
"""Retrieve segments for multiple queries in parallel.
Uses batch embedding when available, then parallel search execution.
Returns results in the same order as input queries.
"""
if not queries:
return []
embeddings = await _embed_batch_fallback(list(queries), embedder)
semaphore = asyncio.Semaphore(MAX_CONCURRENT_OPERATIONS)
async def _search(emb: list[float]) -> list[RetrievalResult]:
async with semaphore:
results = await segment_repo.search_semantic(
query_embedding=emb,
limit=top_k,
meeting_id=meeting_id,
)
return [
RetrievalResult(
segment_id=seg.segment_id,
meeting_id=_meeting_id_to_uuid(seg.meeting_id),
text=seg.text,
start_time=seg.start_time,
end_time=seg.end_time,
score=score,
)
for seg, score in results
]
search_results = await asyncio.gather(*(_search(emb) for emb in embeddings))
return list(search_results)
async def retrieve_segments_workspace_batch(
queries: Sequence[str],
embedder: EmbedderProtocol,
segment_repo: WorkspaceSegmentSearchProtocol,
workspace_id: UUID,
project_id: UUID | None = None,
top_k: int = 20,
) -> list[list[RetrievalResult]]:
"""Retrieve workspace segments for multiple queries in parallel.
Uses batch embedding when available, then parallel search execution.
Returns results in the same order as input queries.
"""
if not queries:
return []
embeddings = await _embed_batch_fallback(list(queries), embedder)
semaphore = asyncio.Semaphore(MAX_CONCURRENT_OPERATIONS)
async def _search(emb: list[float]) -> list[RetrievalResult]:
async with semaphore:
results = await segment_repo.search_semantic_workspace(
query_embedding=emb,
workspace_id=workspace_id,
project_id=project_id,
limit=top_k,
)
return [
RetrievalResult(
segment_id=seg.segment_id,
meeting_id=_meeting_id_to_uuid(seg.meeting_id),
text=seg.text,
start_time=seg.start_time,
end_time=seg.end_time,
score=score,
)
for seg, score in results
]
search_results = await asyncio.gather(*(_search(emb) for emb in embeddings))
return list(search_results)

View File

@@ -0,0 +1,60 @@
"""Answer synthesis tools for LangGraph workflows."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING, Final, Protocol
if TYPE_CHECKING:
from noteflow.infrastructure.ai.tools.retrieval import RetrievalResult
class LLMProtocol(Protocol):
async def complete(self, prompt: str) -> str: ...
@dataclass(frozen=True)
class SynthesisResult:
answer: str
cited_segment_ids: list[int]
SYNTHESIS_PROMPT_TEMPLATE: Final[
str
] = """Answer the question based on the following transcript segments.
Cite specific segments by their ID when making claims.
Question: {question}
Segments:
{segments}
Answer (cite segment IDs in brackets like [1], [3]):"""
CITATION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\[(\d+)\]")
async def synthesize_answer(
question: str,
segments: list[RetrievalResult],
llm: LLMProtocol,
) -> SynthesisResult:
"""Generate answer with segment citations using LLM."""
segment_text = "\n".join(
f"[{s.segment_id}] ({s.start_time:.1f}s-{s.end_time:.1f}s): {s.text}" for s in segments
)
prompt = SYNTHESIS_PROMPT_TEMPLATE.format(
question=question,
segments=segment_text,
)
answer = await llm.complete(prompt)
valid_ids = {s.segment_id for s in segments}
cited_ids = extract_cited_ids(answer, valid_ids)
return SynthesisResult(answer=answer, cited_segment_ids=cited_ids)
def extract_cited_ids(answer: str, valid_ids: set[int]) -> list[int]:
matches = CITATION_PATTERN.findall(answer)
cited = [int(m) for m in matches if int(m) in valid_ids]
return list(dict.fromkeys(cited))

View File

@@ -0,0 +1 @@
"""Tests for domain/ai/ module."""

View File

@@ -0,0 +1,118 @@
from uuid import uuid4
import pytest
from noteflow.domain.ai.citations import SegmentCitation
class TestSegmentCitation:
def test_creation_with_valid_values(self) -> None:
meeting_id = uuid4()
citation = SegmentCitation(
meeting_id=meeting_id,
segment_id=1,
start_time=0.0,
end_time=5.0,
text="Test segment text",
score=0.95,
)
assert citation.meeting_id == meeting_id
assert citation.segment_id == 1
assert citation.start_time == 0.0
assert citation.end_time == 5.0
assert citation.text == "Test segment text"
assert citation.score == 0.95
def test_duration_property(self) -> None:
citation = SegmentCitation(
meeting_id=uuid4(),
segment_id=1,
start_time=10.0,
end_time=25.0,
text="Test",
)
assert citation.duration == 15.0
def test_default_score_is_zero(self) -> None:
citation = SegmentCitation(
meeting_id=uuid4(),
segment_id=1,
start_time=0.0,
end_time=1.0,
text="Test",
)
assert citation.score == 0.0
def test_rejects_negative_segment_id(self) -> None:
with pytest.raises(ValueError, match="segment_id must be non-negative"):
SegmentCitation(
meeting_id=uuid4(),
segment_id=-1,
start_time=0.0,
end_time=5.0,
text="Test",
)
def test_rejects_negative_start_time(self) -> None:
with pytest.raises(ValueError, match="start_time must be non-negative"):
SegmentCitation(
meeting_id=uuid4(),
segment_id=1,
start_time=-1.0,
end_time=5.0,
text="Test",
)
def test_rejects_end_time_before_start_time(self) -> None:
with pytest.raises(ValueError, match="end_time must be >= start_time"):
SegmentCitation(
meeting_id=uuid4(),
segment_id=1,
start_time=10.0,
end_time=5.0,
text="Test",
)
@pytest.mark.parametrize(
"invalid_score",
[
pytest.param(-0.1, id="negative"),
pytest.param(1.1, id="above_one"),
],
)
def test_rejects_invalid_score(self, invalid_score: float) -> None:
with pytest.raises(ValueError, match="score must be between 0 and 1"):
SegmentCitation(
meeting_id=uuid4(),
segment_id=1,
start_time=0.0,
end_time=5.0,
text="Test",
score=invalid_score,
)
def test_accepts_zero_duration(self) -> None:
citation = SegmentCitation(
meeting_id=uuid4(),
segment_id=1,
start_time=5.0,
end_time=5.0,
text="Instant moment",
)
assert citation.duration == 0.0
def test_is_frozen(self) -> None:
citation = SegmentCitation(
meeting_id=uuid4(),
segment_id=1,
start_time=0.0,
end_time=5.0,
text="Test",
)
with pytest.raises(AttributeError):
citation.text = "Modified" # type: ignore[misc]

View File

@@ -0,0 +1 @@
"""Tests for infrastructure/ai/ module."""

View File

@@ -0,0 +1,268 @@
from collections.abc import Sequence
from dataclasses import dataclass
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from noteflow.infrastructure.ai.tools.retrieval import (
BatchEmbedderProtocol,
RetrievalResult,
retrieve_segments,
retrieve_segments_batch,
)
@dataclass
class MockSegment:
segment_id: int
meeting_id: object
text: str
start_time: float
end_time: float
class TestRetrieveSegments:
@pytest.fixture
def mock_embedder(self) -> AsyncMock:
embedder = AsyncMock()
embedder.embed.return_value = [0.1, 0.2, 0.3]
return embedder
@pytest.fixture
def mock_segment_repo(self) -> AsyncMock:
return AsyncMock()
@pytest.fixture
def sample_meeting_id(self) -> object:
return uuid4()
async def test_retrieve_segments_success(
self,
mock_embedder: AsyncMock,
mock_segment_repo: AsyncMock,
sample_meeting_id: object,
) -> None:
segment = MockSegment(
segment_id=1,
meeting_id=sample_meeting_id,
text="Test segment",
start_time=0.0,
end_time=5.0,
)
mock_segment_repo.search_semantic.return_value = [(segment, 0.95)]
results = await retrieve_segments(
query="test query",
embedder=mock_embedder,
segment_repo=mock_segment_repo,
meeting_id=sample_meeting_id, # type: ignore[arg-type]
top_k=5,
)
assert len(results) == 1
assert results[0].segment_id == 1
assert results[0].text == "Test segment"
assert results[0].score == 0.95
async def test_retrieve_segments_calls_embedder_with_query(
self,
mock_embedder: AsyncMock,
mock_segment_repo: AsyncMock,
) -> None:
mock_segment_repo.search_semantic.return_value = []
await retrieve_segments(
query="what happened in the meeting",
embedder=mock_embedder,
segment_repo=mock_segment_repo,
)
mock_embedder.embed.assert_called_once_with("what happened in the meeting")
async def test_retrieve_segments_passes_embedding_to_repo(
self,
mock_embedder: AsyncMock,
mock_segment_repo: AsyncMock,
) -> None:
mock_embedder.embed.return_value = [1.0, 2.0, 3.0]
mock_segment_repo.search_semantic.return_value = []
await retrieve_segments(
query="test",
embedder=mock_embedder,
segment_repo=mock_segment_repo,
top_k=10,
)
mock_segment_repo.search_semantic.assert_called_once_with(
query_embedding=[1.0, 2.0, 3.0],
meeting_id=None,
limit=10,
)
async def test_retrieve_segments_empty_result(
self,
mock_embedder: AsyncMock,
mock_segment_repo: AsyncMock,
) -> None:
mock_segment_repo.search_semantic.return_value = []
results = await retrieve_segments(
query="test",
embedder=mock_embedder,
segment_repo=mock_segment_repo,
)
assert results == []
async def test_retrieval_result_is_frozen(self) -> None:
result = RetrievalResult(
segment_id=1,
meeting_id=uuid4(),
text="Test",
start_time=0.0,
end_time=5.0,
score=0.9,
)
with pytest.raises(AttributeError):
result.text = "Modified" # type: ignore[misc]
class MockBatchEmbedder:
def __init__(self, embedding: list[float]) -> None:
self._embedding = embedding
self.embed_calls: list[str] = []
self.embed_batch_calls: list[Sequence[str]] = []
async def embed(self, text: str) -> list[float]:
self.embed_calls.append(text)
return self._embedding
async def embed_batch(self, texts: Sequence[str]) -> list[list[float]]:
self.embed_batch_calls.append(texts)
return [self._embedding for _ in texts]
class TestRetrieveSegmentsBatch:
@pytest.fixture
def mock_embedder(self) -> AsyncMock:
embedder = AsyncMock()
embedder.embed.return_value = [0.1, 0.2, 0.3]
return embedder
@pytest.fixture
def batch_embedder(self) -> MockBatchEmbedder:
return MockBatchEmbedder([0.1, 0.2, 0.3])
@pytest.fixture
def mock_segment_repo(self) -> AsyncMock:
return AsyncMock()
@pytest.fixture
def sample_meeting_id(self) -> object:
return uuid4()
async def test_batch_returns_empty_for_no_queries(
self,
mock_embedder: AsyncMock,
mock_segment_repo: AsyncMock,
) -> None:
results = await retrieve_segments_batch(
queries=[],
embedder=mock_embedder,
segment_repo=mock_segment_repo,
)
assert results == []
mock_embedder.embed.assert_not_called()
async def test_batch_uses_embed_batch_when_available(
self,
batch_embedder: MockBatchEmbedder,
mock_segment_repo: AsyncMock,
sample_meeting_id: object,
) -> None:
segment = MockSegment(
segment_id=1,
meeting_id=sample_meeting_id,
text="Test",
start_time=0.0,
end_time=5.0,
)
mock_segment_repo.search_semantic.return_value = [(segment, 0.9)]
assert isinstance(batch_embedder, BatchEmbedderProtocol)
results = await retrieve_segments_batch(
queries=["query1", "query2"],
embedder=batch_embedder,
segment_repo=mock_segment_repo,
meeting_id=sample_meeting_id, # type: ignore[arg-type]
)
assert len(results) == 2
assert len(batch_embedder.embed_batch_calls) == 1
assert list(batch_embedder.embed_batch_calls[0]) == ["query1", "query2"]
assert batch_embedder.embed_calls == []
async def test_batch_falls_back_to_parallel_embed(
self,
mock_embedder: AsyncMock,
mock_segment_repo: AsyncMock,
sample_meeting_id: object,
) -> None:
segment = MockSegment(
segment_id=1,
meeting_id=sample_meeting_id,
text="Test",
start_time=0.0,
end_time=5.0,
)
mock_segment_repo.search_semantic.return_value = [(segment, 0.9)]
results = await retrieve_segments_batch(
queries=["query1", "query2"],
embedder=mock_embedder,
segment_repo=mock_segment_repo,
meeting_id=sample_meeting_id, # type: ignore[arg-type]
)
assert len(results) == 2
assert mock_embedder.embed.call_count == 2
async def test_batch_preserves_query_order(
self,
mock_segment_repo: AsyncMock,
sample_meeting_id: object,
) -> None:
segment1 = MockSegment(1, sample_meeting_id, "First", 0.0, 5.0)
segment2 = MockSegment(2, sample_meeting_id, "Second", 5.0, 10.0)
call_count = 0
async def side_effect(
query_embedding: list[float],
limit: int,
meeting_id: object,
) -> list[tuple[MockSegment, float]]:
nonlocal call_count
call_count += 1
if call_count == 1:
return [(segment1, 0.9)]
return [(segment2, 0.8)]
mock_segment_repo.search_semantic.side_effect = side_effect
embedder = MockBatchEmbedder([0.1, 0.2])
results = await retrieve_segments_batch(
queries=["first", "second"],
embedder=embedder,
segment_repo=mock_segment_repo,
meeting_id=sample_meeting_id, # type: ignore[arg-type]
)
assert len(results) == 2
assert results[0][0].text == "First"
assert results[1][0].text == "Second"

View File

@@ -0,0 +1,149 @@
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from noteflow.infrastructure.ai.tools.retrieval import RetrievalResult
from noteflow.infrastructure.ai.tools.synthesis import (
SynthesisResult,
extract_cited_ids,
synthesize_answer,
)
class TestSynthesizeAnswer:
@pytest.fixture
def mock_llm(self) -> AsyncMock:
return AsyncMock()
@pytest.fixture
def sample_segments(self) -> list[RetrievalResult]:
meeting_id = uuid4()
return [
RetrievalResult(
segment_id=1,
meeting_id=meeting_id,
text="John discussed the project timeline",
start_time=0.0,
end_time=5.0,
score=0.95,
),
RetrievalResult(
segment_id=3,
meeting_id=meeting_id,
text="The deadline is next Friday",
start_time=10.0,
end_time=15.0,
score=0.85,
),
]
async def test_synthesize_answer_returns_result(
self,
mock_llm: AsyncMock,
sample_segments: list[RetrievalResult],
) -> None:
mock_llm.complete.return_value = "The project deadline is next Friday [3]."
result = await synthesize_answer(
question="What is the deadline?",
segments=sample_segments,
llm=mock_llm,
)
assert isinstance(result, SynthesisResult)
assert "deadline" in result.answer.lower()
async def test_synthesize_answer_extracts_citations(
self,
mock_llm: AsyncMock,
sample_segments: list[RetrievalResult],
) -> None:
mock_llm.complete.return_value = "John discussed timelines [1] and the deadline [3]."
result = await synthesize_answer(
question="What happened?",
segments=sample_segments,
llm=mock_llm,
)
assert result.cited_segment_ids == [1, 3]
async def test_synthesize_answer_filters_invalid_citations(
self,
mock_llm: AsyncMock,
sample_segments: list[RetrievalResult],
) -> None:
mock_llm.complete.return_value = "Found [1], [99], and [3]."
result = await synthesize_answer(
question="What happened?",
segments=sample_segments,
llm=mock_llm,
)
assert 99 not in result.cited_segment_ids
assert result.cited_segment_ids == [1, 3]
async def test_synthesize_answer_builds_prompt_with_segments(
self,
mock_llm: AsyncMock,
sample_segments: list[RetrievalResult],
) -> None:
mock_llm.complete.return_value = "Answer."
await synthesize_answer(
question="What is happening?",
segments=sample_segments,
llm=mock_llm,
)
call_args = mock_llm.complete.call_args
prompt = call_args[0][0]
assert "What is happening?" in prompt
assert "[1]" in prompt
assert "[3]" in prompt
assert "John discussed" in prompt
class TestExtractCitedIds:
def test_extracts_single_citation(self) -> None:
result = extract_cited_ids("The answer is here [5].", {1, 3, 5})
assert result == [5]
def test_extracts_multiple_citations(self) -> None:
result = extract_cited_ids("See [1] and [3] for details.", {1, 3, 5})
assert result == [1, 3]
def test_filters_invalid_ids(self) -> None:
result = extract_cited_ids("See [1] and [99].", {1, 3, 5})
assert result == [1]
def test_deduplicates_citations(self) -> None:
result = extract_cited_ids("See [1] and then [1] again.", {1, 3})
assert result == [1]
def test_preserves_order(self) -> None:
result = extract_cited_ids("[3] comes first, then [1].", {1, 3})
assert result == [3, 1]
def test_empty_for_no_citations(self) -> None:
result = extract_cited_ids("No citations here.", {1, 3})
assert result == []
class TestSynthesisResult:
def test_is_frozen(self) -> None:
result = SynthesisResult(
answer="Test answer",
cited_segment_ids=[1, 2],
)
with pytest.raises(AttributeError):
result.answer = "Modified" # type: ignore[misc]

View File

@@ -0,0 +1 @@
# Type stubs for langgraph

View File

@@ -0,0 +1 @@
# Type stubs for langgraph-checkpoint-postgres

View File

@@ -0,0 +1,10 @@
# Type stubs for langgraph-checkpoint-postgres async module
class AsyncPostgresSaver:
"""Async PostgreSQL checkpointer for LangGraph.
This stub provides typing for the langgraph-checkpoint-postgres package.
"""
def __init__(self, pool: object) -> None: ...
async def setup(self) -> None: ...

View File

@@ -0,0 +1,41 @@
# Type stubs for langgraph.graph
from collections.abc import Callable, Coroutine
from typing import Generic, TypeVar
_StateT = TypeVar("_StateT")
class START:
"""Sentinel for graph start node."""
pass
class END:
"""Sentinel for graph end node."""
pass
class CompiledStateGraph(Generic[_StateT]):
"""Compiled state graph that can be invoked."""
async def ainvoke(self, input: _StateT) -> _StateT: ...
def invoke(self, input: _StateT) -> _StateT: ...
class StateGraph(Generic[_StateT]):
"""State graph builder.
This stub provides typing for langgraph StateGraph.
"""
def __init__(self, state_schema: type[_StateT]) -> None: ...
def add_node(
self,
name: str,
action: Callable[[_StateT], dict[str, object]]
| Callable[[_StateT], Coroutine[object, object, dict[str, object]]],
) -> None: ...
def add_edge(
self,
start_key: str | type[START],
end_key: str | type[END],
) -> None: ...
def compile(self) -> CompiledStateGraph[_StateT]: ...