feat: enhance summarization and diarization features with memory logging and action item extraction

- Integrated memory snapshot logging throughout the summarization and diarization processes to improve diagnostics and performance tracking.
- Added action item extraction capabilities from meeting segments, utilizing predefined keywords to identify actionable tasks.
- Refactored summarization generation to ensure cloud provider registration is handled dynamically based on application configuration.
- Introduced a new audio buffer class for managing streaming audio data in the diarization process, enhancing memory management and performance.
- Updated various mixins to incorporate model status logging for better observability during processing stages.
This commit is contained in:
2026-01-15 02:45:11 -05:00
parent ec07cb6dd4
commit a95a92ca25
18 changed files with 797 additions and 163 deletions

View File

@@ -11,10 +11,13 @@ use std::env;
use std::fs::File;
use std::sync::Mutex;
use serde_json::{json, Value};
static STREAMING_TEST_LOCK: Mutex<()> = Mutex::new(());
const TARGET_SAMPLE_RATE_HZ: u32 = 16000;
const CHUNK_SAMPLES: usize = 1600;
const CHUNK_BYTES: usize = CHUNK_SAMPLES * 4;
const SUMMARY_TIMEOUT_SECS: u64 = 90;
/// Check if integration tests should run
fn should_run_integration_tests() -> bool {
@@ -1209,6 +1212,12 @@ mod integration {
.await
.expect("Failed to connect");
let cloud_model = env::var("NOTEFLOW_CLOUD_LLM_MODEL").ok();
let cloud_base_url = env::var("NOTEFLOW_CLOUD_LLM_BASE_URL").ok();
let cloud_api_key = env::var("NOTEFLOW_CLOUD_LLM_API_KEY").ok();
let mut original_ai_config: Option<String> = None;
let mut original_cloud_consent: Option<bool> = None;
println!("\n=== Real Audio Streaming E2E Test ===\n");
println!(
"Server info: version={}, asr_model={}, asr_ready={}",
@@ -1428,12 +1437,79 @@ mod integration {
.await
.expect("Failed to reconnect after streaming");
if let (Some(model), Some(base_url), Some(api_key)) =
(cloud_model.clone(), cloud_base_url, cloud_api_key)
{
println!(
"Configuring cloud summary provider (model={}, base_url={})",
model, base_url
);
let existing = post_stream_client
.get_preferences(Some(vec!["ai_config".to_string()]))
.await
.expect("Failed to fetch ai_config preference");
original_ai_config = existing.preferences.get("ai_config").cloned();
let consent_status = post_stream_client
.get_cloud_consent_status()
.await
.expect("Failed to fetch cloud consent status");
original_cloud_consent = Some(consent_status);
let mut ai_config_value = original_ai_config
.as_deref()
.and_then(|raw| serde_json::from_str::<Value>(raw).ok())
.unwrap_or_else(|| json!({}));
if !ai_config_value.is_object() {
ai_config_value = json!({});
}
let summary_entry = ai_config_value
.as_object_mut()
.expect("ai_config should be an object")
.entry("summary".to_string())
.or_insert_with(|| json!({}));
if !summary_entry.is_object() {
*summary_entry = json!({});
}
let summary_obj = summary_entry
.as_object_mut()
.expect("summary config should be an object");
summary_obj.insert("provider".to_string(), json!("openai"));
summary_obj.insert("base_url".to_string(), json!(base_url));
summary_obj.insert("api_key".to_string(), json!(api_key));
summary_obj.insert("selected_model".to_string(), json!(model));
summary_obj.insert("model".to_string(), json!(model));
summary_obj.insert("test_status".to_string(), json!("success"));
let mut prefs_update = HashMap::new();
prefs_update.insert("ai_config".to_string(), ai_config_value.to_string());
post_stream_client
.set_preferences(prefs_update, None, None, true)
.await
.expect("Failed to update cloud preferences");
if !consent_status {
post_stream_client
.grant_cloud_consent()
.await
.expect("Failed to grant cloud consent");
}
}
// 7. Generate summary
println!("\n=== STEP 7: Generate Summary ===");
let summary = post_stream_client
.generate_summary(&meeting.id, false, None)
.await
.expect("Failed to generate summary");
let summary = match tokio::time::timeout(
std::time::Duration::from_secs(SUMMARY_TIMEOUT_SECS),
post_stream_client.generate_summary(&meeting.id, false, None),
)
.await
{
Ok(result) => result.expect("Failed to generate summary"),
Err(_) => {
panic!(
"Summary generation timed out after {} seconds",
SUMMARY_TIMEOUT_SECS
);
}
};
println!(
"✓ Summary generated ({} key points, {} action items)",
summary.key_points.len(),
@@ -1447,10 +1523,43 @@ mod integration {
!summary.key_points.is_empty(),
"Summary should include key points"
);
assert!(
!summary.action_items.is_empty(),
"Summary should include action items for task extraction"
);
const ACTION_KEYWORDS: [&str; 15] = [
"todo",
"action",
"will",
"should",
"must",
"need to",
"let's",
"lets",
"follow up",
"next step",
"next steps",
"schedule",
"send",
"share",
"review",
];
let transcript_has_tasks = final_meeting.segments.iter().any(|segment| {
let text = segment.text.to_lowercase();
ACTION_KEYWORDS.iter().any(|keyword| text.contains(keyword))
});
if transcript_has_tasks {
assert!(
!summary.action_items.is_empty(),
"Summary should include action items for task extraction"
);
} else {
println!("No action keywords detected; skipping action item expectation.");
}
if let Some(expected_model) = cloud_model.clone() {
let expected = format!("openai/{expected_model}");
assert_eq!(
summary.model_version, expected,
"Expected cloud summary model version {}, got {}",
expected, summary.model_version
);
}
let segment_id_set: HashSet<i32> = final_meeting
.segments
@@ -1558,6 +1667,26 @@ mod integration {
.expect("Failed to delete meeting");
println!("✓ Deleted meeting: {}", deleted);
let mut prefs_restore = HashMap::new();
if let Some(previous) = original_ai_config {
prefs_restore.insert("ai_config".to_string(), previous);
}
if !prefs_restore.is_empty() {
if let Err(error) = post_stream_client
.set_preferences(prefs_restore, None, None, true)
.await
{
println!("⚠ Failed to restore ai_config preference: {}", error);
}
}
if let Some(previous_consent) = original_cloud_consent {
if !previous_consent {
if let Err(error) = post_stream_client.revoke_cloud_consent().await {
println!("⚠ Failed to restore cloud consent: {}", error);
}
}
}
// Final summary
println!("\n=== AUDIO STREAMING E2E TEST SUMMARY ===");
println!(

View File

@@ -14,6 +14,7 @@ from noteflow.config.constants import ERROR_MSG_MEETING_PREFIX
from noteflow.config.settings import get_feature_flags
from noteflow.domain.entities.named_entity import NamedEntity
from noteflow.infrastructure.logging import get_logger, log_timing
from noteflow.infrastructure.metrics.memory_logger import log_memory_snapshot
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
@@ -156,6 +157,10 @@ class NerService:
self._model_helper = _ModelLifecycleHelper(ner_engine)
self._extraction_helper = _ExtractionHelper(ner_engine, self._model_helper)
def is_ready(self) -> bool:
"""Return True if the NER engine is loaded and ready."""
return self._ner_engine.is_ready()
async def extract_entities(
self,
meeting_id: MeetingId,
@@ -185,13 +190,35 @@ class NerService:
return cached_or_segments
segments = cached_or_segments
return await self._extract_and_persist_entities(
meeting_id=meeting_id,
segments=segments,
force_refresh=force_refresh,
)
# Extract and persist
async def _extract_and_persist_entities(
self,
meeting_id: MeetingId,
segments: list[tuple[int, str]],
force_refresh: bool,
) -> ExtractionResult:
"""Extract entities and persist results."""
log_memory_snapshot(
"ner_extraction_start",
meeting_id=str(meeting_id),
segment_count=len(segments),
)
entities = await self._extraction_helper.extract(segments)
for entity in entities:
entity.meeting_id = meeting_id
await self._persist_entities(meeting_id, entities, force_refresh)
log_memory_snapshot(
"ner_extraction_end",
meeting_id=str(meeting_id),
segment_count=len(segments),
entity_count=len(entities),
)
logger.info(
"Extracted %d entities from meeting %s (%d segments)",
@@ -234,14 +261,6 @@ class NerService:
await uow.commit()
async def get_entities(self, meeting_id: MeetingId) -> Sequence[NamedEntity]:
"""Get cached entities for a meeting (no extraction).
Args:
meeting_id: Meeting UUID.
Returns:
List of entities (empty if not extracted yet).
"""
async with self._uow_factory() as uow:
return await uow.entities.get_by_meeting(meeting_id)
@@ -277,14 +296,6 @@ class NerService:
return count
async def has_entities(self, meeting_id: MeetingId) -> bool:
"""Check if a meeting has extracted entities.
Args:
meeting_id: Meeting UUID.
Returns:
True if at least one entity exists.
"""
async with self._uow_factory() as uow:
return await uow.entities.exists_for_meeting(meeting_id)
@@ -300,9 +311,6 @@ class NerService:
return get_feature_flags().ner_enabled and self._ner_engine.is_ready()
# --- Module-level helper functions ---
def _check_feature_enabled() -> None:
"""Raise if NER feature is disabled."""
if not get_feature_flags().ner_enabled:

View File

@@ -0,0 +1,125 @@
"""Model status logging helpers for diagnostics."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from noteflow.infrastructure.logging import get_logger
if TYPE_CHECKING:
from noteflow.application.services.ner_service import NerService
from noteflow.application.services.summarization import SummarizationService
from noteflow.infrastructure.asr import FasterWhisperEngine
from noteflow.infrastructure.diarization.engine import DiarizationEngine
from noteflow.domain.summarization import SummarizerProvider
logger = get_logger(__name__)
class _ModelStatusHost(Protocol):
asr_engine: FasterWhisperEngine | None
diarization_engine: DiarizationEngine | None
ner_service: NerService | None
summarization_service: SummarizationService | None
def log_model_status(
host: _ModelStatusHost,
stage: str,
*,
meeting_id: str | None = None,
include_provider_details: bool = False,
) -> None:
"""Log model readiness and provider state for diagnostics."""
payload: dict[str, object] = {"stage": stage}
if meeting_id is not None:
payload["meeting_id"] = meeting_id
_append_asr_status(payload, host)
_append_diarization_status(payload, host)
_append_ner_status(payload, host)
_append_summarization_status(payload, host, include_provider_details)
logger.info("model_status_snapshot", **payload)
def _append_asr_status(payload: dict[str, object], host: _ModelStatusHost) -> None:
asr = host.asr_engine
if asr is None:
payload["asr_loaded"] = False
return
payload["asr_loaded"] = asr.is_loaded
payload["asr_model_size"] = asr.model_size or ""
payload["asr_device"] = asr.device
payload["asr_compute_type"] = asr.compute_type
def _append_diarization_status(payload: dict[str, object], host: _ModelStatusHost) -> None:
diarization = host.diarization_engine
if diarization is None:
payload["diarization_streaming_loaded"] = False
payload["diarization_offline_loaded"] = False
return
payload["diarization_streaming_loaded"] = diarization.is_streaming_loaded
payload["diarization_offline_loaded"] = diarization.is_offline_loaded
payload["diarization_device"] = diarization.device or ""
def _append_ner_status(payload: dict[str, object], host: _ModelStatusHost) -> None:
ner = host.ner_service
if ner is None:
payload["ner_ready"] = False
return
payload["ner_ready"] = ner.is_ready()
def _append_summarization_status(
payload: dict[str, object],
host: _ModelStatusHost,
include_provider_details: bool,
) -> None:
service = host.summarization_service
if service is None:
payload["summarization_ready"] = False
return
payload["summarization_ready"] = True
payload["summarization_default_mode"] = service.settings.default_mode.value
payload["summarization_cloud_consent"] = service.settings.cloud_consent_granted
payload["summarization_fallback"] = service.settings.fallback_to_local
if not include_provider_details:
payload["summarization_provider_count"] = len(service.providers)
return
summaries: list[str] = []
for mode, provider in service.providers.items():
summaries.append(_format_provider_summary(mode, provider))
payload["summarization_providers"] = summaries
def _format_provider_summary(mode: object, provider: SummarizerProvider) -> str:
mode_value = getattr(mode, "value", str(mode))
model_name = _read_provider_model(provider)
client_ready = _read_provider_client_ready(provider)
available = provider.is_available
consent = provider.requires_cloud_consent
return (
f"{mode_value}:{provider.provider_name}"
f"|available={available}"
f"|consent={consent}"
f"|client_ready={client_ready}"
f"|provider_model={model_name}"
)
def _read_provider_model(provider: SummarizerProvider) -> str:
raw_model = getattr(provider, "_model", None)
if isinstance(raw_model, str):
return raw_model
return ""
def _read_provider_client_ready(provider: SummarizerProvider) -> bool:
client = getattr(provider, "_client", None)
return client is not None

View File

@@ -9,6 +9,7 @@ from noteflow.infrastructure.logging import get_logger
from ...proto import noteflow_pb2
from .._types import GrpcStatusContext
from .._model_status import log_model_status
from ._jobs import JobsMixin
from ._refinement import RefinementMixin
from ._speaker import SpeakerMixin
@@ -47,6 +48,7 @@ class DiarizationMixin(
Load the full meeting audio, run offline diarization, and update
segment speaker assignments. Job state is persisted when DB available.
"""
log_model_status(self, "diarization_refine_start", meeting_id=request.meeting_id)
await self.prune_diarization_jobs()
return await self.start_diarization_job(request, context)

View File

@@ -5,23 +5,55 @@ from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from numpy.typing import NDArray
from noteflow.infrastructure.audio.reader import MeetingAudioReader
from noteflow.infrastructure.diarization import SpeakerTurn
from noteflow.infrastructure.logging import get_logger
from noteflow.infrastructure.metrics.memory_logger import log_memory_snapshot
from ..converters import parse_meeting_id_or_none
from ._speaker import apply_speaker_to_segment
if TYPE_CHECKING:
from pathlib import Path
from noteflow.domain.entities import Segment
from noteflow.domain.ports.unit_of_work import UnitOfWork
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
from ..protocols import ServicerHost
logger = get_logger(__name__)
def _load_diarization_audio(
meeting_id: str,
meetings_dir: Path,
crypto: AesGcmCryptoBox,
) -> tuple[NDArray[np.float32], int, float]:
audio_reader = MeetingAudioReader(crypto, meetings_dir)
if not audio_reader.audio_exists(meeting_id):
raise RuntimeError("No audio file found for meeting")
logger.info("Loading audio for meeting %s", meeting_id)
try:
audio_chunks = audio_reader.load_meeting_audio(meeting_id)
except (FileNotFoundError, ValueError) as exc:
raise RuntimeError(f"Failed to load audio: {exc}") from exc
if not audio_chunks:
raise RuntimeError("No audio chunks loaded for meeting")
sample_rate = audio_reader.sample_rate
all_audio = np.concatenate([chunk.frames for chunk in audio_chunks]).astype(
np.float32,
copy=False,
)
audio_seconds = len(all_audio) / sample_rate
return all_audio, sample_rate, audio_seconds
class RefinementMixin:
"""Mixin providing offline diarization refinement functionality."""
@@ -38,25 +70,22 @@ class RefinementMixin:
logger.info("Loading offline diarization model for refinement...")
self.diarization_engine.load_offline_model()
audio_reader = MeetingAudioReader(self.crypto, self.meetings_dir)
if not audio_reader.audio_exists(meeting_id):
raise RuntimeError("No audio file found for meeting")
all_audio, sample_rate, audio_seconds = _load_diarization_audio(
meeting_id=meeting_id,
meetings_dir=self.meetings_dir,
crypto=self.crypto,
)
logger.info("Loading audio for meeting %s", meeting_id)
try:
audio_chunks = audio_reader.load_meeting_audio(meeting_id)
except (FileNotFoundError, ValueError) as exc:
raise RuntimeError(f"Failed to load audio: {exc}") from exc
if not audio_chunks:
raise RuntimeError("No audio chunks loaded for meeting")
sample_rate = audio_reader.sample_rate
all_audio = np.concatenate([chunk.frames for chunk in audio_chunks])
log_memory_snapshot(
"diarization_refinement_start",
meeting_id=meeting_id,
audio_seconds=audio_seconds,
num_speakers=num_speakers or 0,
)
logger.info(
"Running offline diarization on %.2f seconds of audio",
len(all_audio) / sample_rate,
audio_seconds,
)
turns = self.diarization_engine.diarize_full(
@@ -66,6 +95,11 @@ class RefinementMixin:
)
logger.info("Diarization found %d speaker turns", len(turns))
log_memory_snapshot(
"diarization_refinement_end",
meeting_id=meeting_id,
turn_count=len(turns),
)
return list(turns)
async def apply_diarization_turns(

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import asyncio
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Final
from typing import Protocol as TypingProtocol
import numpy as np
@@ -23,6 +23,11 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
_SECONDS_PER_MINUTE: Final[int] = 60
_MAX_TURN_AGE_MINUTES: Final[int] = 15
_MAX_TURN_AGE_SECONDS: Final[int] = _MAX_TURN_AGE_MINUTES * _SECONDS_PER_MINUTE
_MAX_TURN_COUNT: Final[int] = 5_000
class _DiarizationEngine(TypingProtocol):
"""Protocol for diarization engine interface."""
@@ -74,6 +79,7 @@ class StreamingDiarizationMixin:
# Populate diarization_turns for compatibility with maybe_assign_speaker
state.diarization_turns.extend(new_turns)
state.diarization_stream_time = session.stream_time
_prune_diarization_turns(state, session.stream_time)
# Persist turns immediately for crash resilience (DB only)
await self.persist_streaming_turns(meeting_id, list(new_turns))
@@ -169,6 +175,19 @@ def _restore_session_state(session: DiarizationSession, state: MeetingStreamStat
session.restore(state.diarization_turns, stream_time=state.diarization_stream_time)
def _prune_diarization_turns(state: MeetingStreamState, stream_time: float) -> None:
"""Prune old diarization turns to bound in-memory growth."""
if not state.diarization_turns:
return
cutoff = stream_time - _MAX_TURN_AGE_SECONDS
if cutoff > 0:
state.diarization_turns = [turn for turn in state.diarization_turns if turn.end >= cutoff]
if len(state.diarization_turns) > _MAX_TURN_COUNT:
state.diarization_turns = state.diarization_turns[-_MAX_TURN_COUNT:]
def _convert_turns_to_streaming(turns: list[SpeakerTurn]) -> list[StreamingTurn]:
"""Convert domain SpeakerTurns to StreamingTurn for persistence."""
return [

View File

@@ -8,6 +8,7 @@ from noteflow.infrastructure.logging import get_logger
from ..proto import noteflow_pb2
from ._types import GrpcContext
from ._model_status import log_model_status
from .converters import entity_to_proto, parse_meeting_id_or_abort
from .errors import (
ENTITY_ENTITY,
@@ -50,6 +51,7 @@ class EntitiesMixin:
Delegates to NerService for extraction, caching, and persistence.
Returns cached results if available, unless force_refresh is True.
"""
log_model_status(self, "ner_extract_start", meeting_id=request.meeting_id)
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
ner_service = await require_ner_service(self.ner_service, context)

View File

@@ -10,9 +10,11 @@ import numpy as np
from numpy.typing import NDArray
from noteflow.infrastructure.logging import get_logger
from noteflow.infrastructure.metrics.memory_logger import log_memory_snapshot
from ...proto import noteflow_pb2
from .._types import GrpcContext
from .._model_status import log_model_status
from ..errors import abort_failed_precondition, abort_invalid_argument
from ._asr import process_audio_segment
from ._cleanup import cleanup_stream_resources
@@ -107,6 +109,7 @@ class StreamingMixin:
cleanup_meeting = stream_state.current or stream_state.initialized
if cleanup_meeting:
cleanup_stream_resources(self, cleanup_meeting)
log_memory_snapshot("stream_cleanup_end", meeting_id=cleanup_meeting)
async def process_stream_chunks(
self: ServicerHost,
@@ -193,14 +196,26 @@ class StreamingMixin:
if current_meeting_id is None:
# Track meeting_id BEFORE init to guarantee cleanup on any exception
initialized_meeting_id = meeting_id
sample_rate = chunk.sample_rate
channels = chunk.channels
logger.info(
"StreamTranscription initializing meeting %s (sample_rate=%s channels=%s)",
meeting_id,
chunk.sample_rate,
chunk.channels,
sample_rate,
channels,
)
log_memory_snapshot(
"stream_init_start",
meeting_id=meeting_id,
sample_rate=sample_rate,
channels=channels,
)
log_model_status(self, "stream_init_start", meeting_id=meeting_id)
init_result = await self.init_stream_for_meeting(meeting_id, context)
return None if init_result is None else (meeting_id, initialized_meeting_id)
if init_result is None:
return None
log_memory_snapshot("stream_init_end", meeting_id=meeting_id)
return meeting_id, initialized_meeting_id
if meeting_id != current_meeting_id:
await abort_invalid_argument(context, "Stream may only contain a single meeting_id")
raise AssertionError("unreachable") # abort is NoReturn

View File

@@ -12,7 +12,9 @@ from noteflow.domain.value_objects import MeetingId
from noteflow.infrastructure.logging import get_logger
from ...proto import noteflow_pb2
from ..._startup import auto_enable_cloud_llm
from .._types import GrpcContext
from .._model_status import log_model_status
from ..converters import parse_meeting_id_or_abort, summary_to_proto
from ..errors import ENTITY_MEETING, abort_not_found
from ._summary_generation import generate_placeholder_summary, summarize_or_placeholder
@@ -47,14 +49,20 @@ class SummarizationGenerationMixin:
context: GrpcContext,
) -> noteflow_pb2.Summary:
"""Generate meeting summary using SummarizationService with fallback."""
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
op_context = self.get_operation_context(context)
style_instructions = build_style_prompt_from_request(request)
meeting, existing, segments = await self._load_meeting_context(meeting_id, request, context)
log_model_status(
self,
"summary_generate_start",
meeting_id=request.meeting_id,
include_provider_details=True,
)
meeting_id, op_context, style_instructions, meeting, existing, segments = (
await self._prepare_summary_request(request, context)
)
if existing and not request.force_regenerate:
return summary_to_proto(existing)
await self._ensure_cloud_provider()
async with cast(UnitOfWork, self.create_repository_provider()) as repo:
style_prompt = await resolve_template_prompt(
TemplateResolutionInputs(
@@ -85,6 +93,41 @@ class SummarizationGenerationMixin:
await self._trigger_summary_webhook(meeting, saved)
return summary_to_proto(saved)
async def _prepare_summary_request(
self,
request: noteflow_pb2.GenerateSummaryRequest,
context: GrpcContext,
) -> tuple[
MeetingId,
OperationContext,
str | None,
Meeting,
Summary | None,
list[Segment],
]:
"""Prepare summary inputs from the request and meeting data."""
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
op_context = self.get_operation_context(context)
style_instructions = build_style_prompt_from_request(request)
meeting, existing, segments = await self._load_meeting_context(meeting_id, request, context)
return meeting_id, op_context, style_instructions, meeting, existing, segments
async def _ensure_cloud_provider(self) -> None:
"""Register cloud provider if app config is available at runtime."""
if self.summarization_service is None:
return
from noteflow.application.services.summarization import SummarizationMode
if self.summarization_service.settings.default_mode == SummarizationMode.CLOUD:
return
async with cast(UnitOfWork, self.create_repository_provider()) as repo:
if not repo.supports_preferences:
return
provider = await auto_enable_cloud_llm(repo, self.summarization_service)
if provider is not None:
logger.info("Cloud summarization enabled for request", provider=provider)
async def _save_summary(
self,
meeting: Meeting,

View File

@@ -7,10 +7,60 @@ from noteflow.domain.entities import Segment, Summary
from noteflow.domain.summarization import ProviderUnavailableError
from noteflow.domain.value_objects import MeetingId
from noteflow.infrastructure.logging import get_logger
from noteflow.infrastructure.metrics.memory_logger import log_memory_snapshot
logger = get_logger(__name__)
def _log_summary_snapshot(
stage: str,
meeting_id: MeetingId,
segment_count: int,
**fields: object,
) -> None:
log_memory_snapshot(
stage,
meeting_id=str(meeting_id),
segment_count=segment_count,
**fields,
)
def _log_summary_error(
meeting_id: MeetingId,
segment_count: int,
exc: Exception,
) -> None:
_log_summary_snapshot(
"summary_error",
meeting_id,
segment_count,
error_type=type(exc).__name__,
)
async def _summarize_with_service(
summarization_service: SummarizationService,
meeting_id: MeetingId,
segments: list[Segment],
style_prompt: str | None,
) -> Summary:
result = await summarization_service.summarize(
meeting_id=meeting_id,
segments=segments,
style_prompt=style_prompt,
)
summary = result.summary
provider_name = summary.provider_name
model_name = summary.model_name
logger.info(
"Generated summary using %s/%s",
provider_name,
model_name,
)
return summary
async def summarize_or_placeholder(
summarization_service: SummarizationService | None,
meeting_id: MeetingId,
@@ -19,30 +69,39 @@ async def summarize_or_placeholder(
) -> Summary:
"""Try to summarize via service, fallback to placeholder on failure."""
if summarization_service is None:
_log_summary_snapshot(
"summary_skip",
meeting_id,
len(segments),
reason="service_unavailable",
)
logger.warning("SummarizationService not configured; using placeholder summary")
return generate_placeholder_summary(meeting_id, segments)
try:
result = await summarization_service.summarize(
meeting_id=meeting_id,
segments=segments,
style_prompt=style_prompt,
_log_summary_snapshot("summary_start", meeting_id, len(segments))
summary = await _summarize_with_service(
summarization_service,
meeting_id,
segments,
style_prompt,
)
summary = result.summary
provider_name = summary.provider_name
model_name = summary.model_name
logger.info(
"Generated summary using %s/%s",
provider_name,
model_name,
_log_summary_snapshot(
"summary_end",
meeting_id,
len(segments),
provider=summary.provider_name,
model=summary.model_name,
)
return summary
except ProviderUnavailableError as exc:
logger.warning("Summarization provider unavailable; using placeholder: %s", exc)
_log_summary_error(meeting_id, len(segments), exc)
except (TimeoutError, RuntimeError, ValueError) as exc:
logger.exception(
"Summarization failed (%s); using placeholder summary", type(exc).__name__
)
_log_summary_error(meeting_id, len(segments), exc)
return generate_placeholder_summary(meeting_id, segments)

View File

@@ -27,6 +27,7 @@ from noteflow.infrastructure.persistence.database import (
create_engine_and_session_factory,
ensure_schema_ready,
)
from noteflow.domain.ports.unit_of_work import UnitOfWork
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
from noteflow.infrastructure.summarization import CloudBackend, CloudSummarizer
@@ -62,6 +63,8 @@ class _SummaryConfig(TypedDict, total=False):
api_key: str
test_status: str
model: str
selected_model: str
base_url: str
class AsrConfigLike(Protocol):
@@ -114,7 +117,7 @@ logger = get_logger(__name__)
async def auto_enable_cloud_llm(
uow: SqlAlchemyUnitOfWork,
uow: UnitOfWork,
summarization_service: SummarizationService,
) -> str | None:
"""Check for app-configured LLM and register cloud provider if valid.
@@ -126,33 +129,36 @@ async def auto_enable_cloud_llm(
Returns:
Provider name if cloud provider was auto-enabled, None otherwise.
"""
if not uow.supports_preferences:
return None
ai_config_value = await uow.preferences.get("ai_config")
if not isinstance(ai_config_value, dict):
return None
ai_config = cast(dict[str, object], ai_config_value)
summary_config_value = ai_config.get("summary", {})
if not isinstance(summary_config_value, dict):
return None
summary_config = cast(_SummaryConfig, summary_config_value)
provider = summary_config.get(PROVIDER, "")
api_key = summary_config.get("api_key", "")
test_status = summary_config.get("test_status", "")
model = summary_config.get("model") or None # Convert empty string to None
# Only register if configured and tested successfully
if provider not in (PROVIDER_NAME_OPENAI, "anthropic") or not api_key or test_status != "success":
model = summary_config.get("model") or summary_config.get("selected_model") or None
base_url = summary_config.get("base_url") or None
if (
provider not in (PROVIDER_NAME_OPENAI, "anthropic")
or not api_key
or test_status != "success"
):
return None
backend = CloudBackend.OPENAI if provider == PROVIDER_NAME_OPENAI else CloudBackend.ANTHROPIC
cloud_summarizer = CloudSummarizer(
backend=backend,
api_key=api_key,
model=model,
base_url=base_url if backend == CloudBackend.OPENAI else None,
)
summarization_service.register_provider(SummarizationMode.CLOUD, cloud_summarizer)
# Auto-grant consent since user explicitly configured in app
summarization_service.set_default_mode(SummarizationMode.CLOUD)
summarization_service.settings.cloud_consent_granted = True
logger.info("Auto-enabled CLOUD summarization from app config: provider=%s", provider)
return provider

View File

@@ -0,0 +1,107 @@
"""Audio buffer for streaming diarization."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Final
import numpy as np
from numpy.typing import NDArray
from noteflow.infrastructure.logging import get_logger
logger = get_logger(__name__)
_MAX_BUFFER_CHUNKS: Final[int] = 4
@dataclass
class DiarizationAudioBuffer:
"""Bounded audio buffer for diarization streaming chunks."""
sample_rate: int
chunk_duration: float
max_buffer_chunks: int = _MAX_BUFFER_CHUNKS
_buffer: list[NDArray[np.float32]] = field(default_factory=list, init=False)
_buffer_samples: int = field(default=0, init=False)
def append(self, audio: NDArray[np.float32], *, meeting_id: str) -> None:
"""Append audio to the buffer, enforcing size limits."""
if audio.size == 0:
return
self._buffer.append(audio)
self._buffer_samples += len(audio)
self._enforce_limit(meeting_id=meeting_id)
def extract_full_chunk(self, required_samples: int) -> NDArray[np.float32] | None:
"""Extract the next full chunk if enough samples are buffered."""
if self._buffer_samples < required_samples:
return None
pieces: list[NDArray[np.float32]] = []
remaining_samples = required_samples
new_buffer: list[NDArray[np.float32]] = []
for idx, segment in enumerate(self._buffer):
if remaining_samples <= 0:
new_buffer.extend(self._buffer[idx:])
break
segment_len = len(segment)
if segment_len <= remaining_samples:
pieces.append(segment)
remaining_samples -= segment_len
continue
pieces.append(segment[:remaining_samples])
remainder = segment[remaining_samples:]
if remainder.size:
new_buffer.append(remainder.copy())
new_buffer.extend(self._buffer[idx + 1 :])
remaining_samples = 0
break
if remaining_samples > 0 or not pieces:
return None
chunk_audio = np.concatenate(pieces) if len(pieces) > 1 else pieces[0].copy()
self._buffer = new_buffer
self._buffer_samples = sum(len(segment) for segment in new_buffer)
return chunk_audio
def clear(self) -> None:
"""Clear buffered audio."""
self._buffer.clear()
self._buffer_samples = 0
def _enforce_limit(self, *, meeting_id: str) -> None:
max_samples = int(self.sample_rate * self.chunk_duration * self.max_buffer_chunks)
if self._buffer_samples <= max_samples:
return
drop_samples = self._buffer_samples - max_samples
if drop_samples <= 0:
return
new_buffer: list[NDArray[np.float32]] = []
remaining_drop = drop_samples
for segment in self._buffer:
if remaining_drop <= 0:
new_buffer.append(segment)
continue
segment_len = len(segment)
if segment_len <= remaining_drop:
remaining_drop -= segment_len
continue
new_buffer.append(segment[remaining_drop:].copy())
remaining_drop = 0
dropped = drop_samples - remaining_drop
if dropped > 0:
logger.warning(
"diarization_buffer_overrun",
meeting_id=meeting_id,
dropped_samples=dropped,
)
self._buffer = new_buffer
self._buffer_samples = sum(len(segment) for segment in new_buffer)

View File

@@ -8,17 +8,19 @@ from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Protocol
from typing import TYPE_CHECKING, Final, Protocol, cast
import numpy as np
from numpy.typing import NDArray
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
from noteflow.infrastructure.diarization.audio_buffer import DiarizationAudioBuffer
from noteflow.infrastructure.diarization.dto import SpeakerTurn
from noteflow.infrastructure.logging import get_logger
if TYPE_CHECKING:
from diart import SpeakerDiarization
from pyannote.core import SlidingWindowFeature
class _TrackSegment(Protocol):
@@ -38,6 +40,10 @@ logger = get_logger(__name__)
# Default chunk duration in seconds (matches pyannote segmentation model)
DEFAULT_CHUNK_DURATION: float = 5.0
_SECONDS_PER_MINUTE: Final[int] = 60
_MAX_TURN_AGE_MINUTES: Final[int] = 15
_MAX_TURN_AGE_SECONDS: Final[int] = _MAX_TURN_AGE_MINUTES * _SECONDS_PER_MINUTE
_MAX_TURN_COUNT: Final[int] = 5_000
def _collect_turns(
@@ -93,8 +99,13 @@ class DiarizationSession:
_stream_time: float = field(default=0.0, init=False)
_turns: list[SpeakerTurn] = field(default_factory=list, init=False)
_closed: bool = field(default=False, init=False)
_audio_buffer: list[NDArray[np.float32]] = field(default_factory=list, init=False)
_buffer_samples: int = field(default=0, init=False)
_audio_buffer: DiarizationAudioBuffer = field(init=False)
def __post_init__(self) -> None:
self._audio_buffer = DiarizationAudioBuffer(
sample_rate=self._sample_rate,
chunk_duration=self._chunk_duration,
)
def process_chunk(
self,
@@ -141,8 +152,7 @@ class DiarizationSession:
"""Add audio to buffer, ensuring 1D format."""
if audio.ndim > 1:
audio = audio.flatten()
self._audio_buffer.append(audio)
self._buffer_samples += len(audio)
self._audio_buffer.append(audio, meeting_id=self.meeting_id)
def _extract_full_chunk_if_ready(
self,
@@ -150,18 +160,7 @@ class DiarizationSession:
) -> NDArray[np.float32] | None:
"""Extract a full chunk from buffer if enough samples available."""
required_samples = int(self._chunk_duration * sample_rate)
if self._buffer_samples < required_samples:
return None
full_audio = np.concatenate(self._audio_buffer)
chunk_audio = full_audio[:required_samples]
remaining = full_audio[required_samples:]
self._audio_buffer = [remaining] if len(remaining) > 0 else []
self._buffer_samples = max(len(remaining), 0)
return chunk_audio
return self._audio_buffer.extract_full_chunk(required_samples)
def _process_full_chunk(
self,
@@ -180,39 +179,49 @@ class DiarizationSession:
if self._pipeline is None:
return []
duration = len(audio) / sample_rate
current_time = self._stream_time + duration
waveform = self._build_waveform(audio, duration)
new_turns = self._run_pipeline(waveform)
if new_turns:
self._turns.extend(new_turns)
self._prune_turns(current_time)
self._stream_time = current_time
return new_turns
def _build_waveform(
self,
audio: NDArray[np.float32],
duration: float,
) -> SlidingWindowFeature:
"""Build a SlidingWindowFeature for diarization."""
from pyannote.core import SlidingWindow, SlidingWindowFeature
duration = len(audio) / sample_rate
# Reshape to (samples, channels) for diart - mono audio has 1 channel
audio_2d = audio.reshape(-1, 1)
# Create SlidingWindowFeature for diart
window = SlidingWindow(start=0.0, duration=duration, step=duration)
waveform = SlidingWindowFeature(audio_2d, window)
return SlidingWindowFeature(audio_2d, window)
def _run_pipeline(self, waveform: SlidingWindowFeature) -> list[SpeakerTurn]:
"""Run diarization pipeline and return speaker turns."""
try:
# Process through pipeline
# Note: Frame rate mismatch between segmentation-3.0 and embedding models
# may cause warnings and occasional errors, which we handle gracefully
results = self._pipeline([waveform])
# Convert results to turns with absolute time offsets
new_turns = _collect_turns(results, self._stream_time)
self._turns.extend(new_turns)
except (RuntimeError, ZeroDivisionError, ValueError) as e:
# Handle frame/weights mismatch and related errors gracefully
# Streaming diarization continues even if individual chunks fail
results: Sequence[tuple[_Annotation, object]]
if self._pipeline is None:
results = ()
else:
# pyright cannot infer diart pipeline return types; cast to expected shape.
results = cast(
Sequence[tuple[_Annotation, object]],
self._pipeline([waveform]),
)
except (RuntimeError, ZeroDivisionError, ValueError) as exc:
logger.warning(
"Diarization chunk processing failed (non-fatal): %s",
str(e),
str(exc),
exc_info=False,
)
new_turns = []
return []
self._stream_time += duration
return new_turns
return _collect_turns(results, self._stream_time)
def reset(self) -> None:
"""Reset session state for restarting diarization.
@@ -227,7 +236,6 @@ class DiarizationSession:
self._stream_time = 0.0
self._turns.clear()
self._audio_buffer.clear()
self._buffer_samples = 0
logger.debug("Session %s reset", self.meeting_id)
def restore(
@@ -252,6 +260,7 @@ class DiarizationSession:
if stream_time is None:
stream_time = max((t.end for t in turns), default=0.0)
self._stream_time = max(self._stream_time, stream_time)
self._prune_turns(self._stream_time)
logger.debug(
"Session %s restored stream_time=%.3f with %d turns",
self.meeting_id,
@@ -272,11 +281,20 @@ class DiarizationSession:
self._closed = True
self._turns.clear()
self._audio_buffer.clear()
self._buffer_samples = 0
# Explicitly release pipeline reference to allow GC and GPU memory release
self._pipeline = None
logger.info("diarization_session_closed", meeting_id=self.meeting_id)
def _prune_turns(self, current_time: float) -> None:
"""Prune old speaker turns to keep memory bounded."""
if not self._turns:
return
cutoff = current_time - _MAX_TURN_AGE_SECONDS
if cutoff > 0:
self._turns = [turn for turn in self._turns if turn.end >= cutoff]
if len(self._turns) > _MAX_TURN_COUNT:
self._turns = self._turns[-_MAX_TURN_COUNT:]
@property
def stream_time(self) -> float:
"""Current stream time position in seconds."""

View File

@@ -0,0 +1,25 @@
"""Memory snapshot logging utilities."""
from __future__ import annotations
from typing import Final
from noteflow.infrastructure.logging import get_logger
from noteflow.infrastructure.metrics.collector import get_metrics_collector
logger = get_logger(__name__)
_EVENT_NAME: Final[str] = "memory_snapshot"
def log_memory_snapshot(stage: str, **fields: object) -> None:
"""Log a structured memory snapshot for diagnostics."""
metrics = get_metrics_collector().collect_now()
logger.info(
_EVENT_NAME,
stage=stage,
process_memory_mb=metrics.process_memory_mb,
memory_percent=metrics.memory_percent,
memory_mb=metrics.memory_mb,
**fields,
)

View File

@@ -0,0 +1,51 @@
"""Action item extraction helpers for summarization."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Final
from noteflow.domain.entities import ActionItem, Segment
_ACTION_KEYWORDS: Final[tuple[str, ...]] = (
"todo",
"action",
"will",
"should",
"must",
"need to",
"let's",
"lets",
"follow up",
"next step",
"next steps",
"schedule",
"send",
"share",
"review",
"update",
"prepare",
"assign",
)
_MAX_ACTION_TEXT_LEN: Final[int] = 80
def extract_action_items_from_segments(
segments: Sequence[Segment],
max_action_items: int,
) -> list[ActionItem]:
"""Extract action items by scanning segments for action keywords."""
action_items: list[ActionItem] = []
for segment in segments:
if len(action_items) >= max_action_items:
break
text_lower = segment.text.lower()
if any(keyword in text_lower for keyword in _ACTION_KEYWORDS):
action_items.append(
ActionItem(
text=f"Action: {segment.text[:_MAX_ACTION_TEXT_LEN]}",
assignee="",
segment_ids=[segment.segment_id],
)
)
return action_items

View File

@@ -10,6 +10,9 @@ from noteflow.domain.constants.fields import ACTION_ITEMS, KEY_POINTS, SEGMENT_I
from noteflow.domain.entities import ActionItem, KeyPoint, Summary
from noteflow.domain.summarization import InvalidResponseError
from noteflow.infrastructure.logging import get_logger
from noteflow.infrastructure.summarization._action_items import (
extract_action_items_from_segments,
)
if TYPE_CHECKING:
from collections.abc import Sequence
@@ -270,10 +273,16 @@ def parse_llm_response(response_text: str, request: SummarizationRequest) -> Sum
_parse_key_point(kp_data, valid_ids, request.segments)
for kp_data in data.get(KEY_POINTS, [])[: request.max_key_points]
]
action_items = [
parsed_action_items = [
_parse_action_item(ai_data, valid_ids)
for ai_data in data.get(ACTION_ITEMS, [])[: request.max_action_items]
]
action_items = [item for item in parsed_action_items if item.segment_ids]
if not action_items and request.max_action_items > 0:
action_items = extract_action_items_from_segments(
request.segments,
request.max_action_items,
)
return Summary(
meeting_id=request.meeting_id,

View File

@@ -21,33 +21,34 @@ class CloudSummarizerClients:
def _get_openai_client(self) -> openai.OpenAI:
"""Get or create OpenAI client."""
try:
import openai as openai_module
except ImportError as e:
raise ProviderUnavailableError(
"openai package not installed. Install with: pip install openai"
) from e
if self._client is None:
try:
import openai
self._client = openai.OpenAI(
api_key=self._api_key,
timeout=self._timeout,
base_url=self._openai_base_url,
)
except ImportError as e:
raise ProviderUnavailableError(
"openai package not installed. Install with: pip install openai"
) from e
return cast(openai.OpenAI, self._client)
self._client = openai_module.OpenAI(
api_key=self._api_key,
timeout=self._timeout,
base_url=self._openai_base_url,
)
return cast(openai_module.OpenAI, self._client)
def _get_anthropic_client(self) -> anthropic.Anthropic:
"""Get or create Anthropic client."""
try:
import anthropic as anthropic_module
except ImportError as e:
raise ProviderUnavailableError(
"anthropic package not installed. Install with: pip install anthropic"
) from e
if self._client is None:
try:
import anthropic
self._client = anthropic.Anthropic(api_key=self._api_key, timeout=self._timeout)
except ImportError as e:
raise ProviderUnavailableError(
"anthropic package not installed. Install with: pip install anthropic"
) from e
return cast(anthropic.Anthropic, self._client)
self._client = anthropic_module.Anthropic(
api_key=self._api_key,
timeout=self._timeout,
)
return cast(anthropic_module.Anthropic, self._client)
def get_openai_client(self) -> openai.OpenAI:
"""Expose OpenAI client for integrations and testing."""

View File

@@ -2,19 +2,17 @@
import time
from datetime import UTC, datetime
from typing import Final
from noteflow.domain.entities import ActionItem, KeyPoint, Summary
from noteflow.domain.summarization import (
SummarizationRequest,
SummarizationResult,
)
from noteflow.infrastructure.summarization._action_items import (
extract_action_items_from_segments,
)
from noteflow.infrastructure.summarization._availability import AvailabilityProviderBase
_ACTION_KEYWORDS: Final[frozenset[str]] = frozenset({
"todo", "action", "will", "should", "must", "need to"
})
def _truncate_text(text: str, max_length: int) -> str:
"""Truncate text with ellipsis if needed."""
@@ -45,28 +43,11 @@ def _build_key_points(request: SummarizationRequest) -> list[KeyPoint]:
def _build_action_items(request: SummarizationRequest) -> list[ActionItem]:
"""Build action items from segments containing action keywords.
Args:
request: Summarization request with segments.
Returns:
List of action items (up to max_action_items).
"""
action_items: list[ActionItem] = []
for segment in request.segments:
if len(action_items) >= request.max_action_items:
break
text_lower = segment.text.lower()
if any(kw in text_lower for kw in _ACTION_KEYWORDS):
action_items.append(
ActionItem(
text=f"Action: {segment.text[:80]}",
assignee="",
segment_ids=[segment.segment_id],
)
)
return action_items
"""Build action items from segments containing action keywords."""
return extract_action_items_from_segments(
request.segments,
request.max_action_items,
)
class MockSummarizer(AvailabilityProviderBase):