feat: enhance summarization and diarization features with memory logging and action item extraction
- Integrated memory snapshot logging throughout the summarization and diarization processes to improve diagnostics and performance tracking. - Added action item extraction capabilities from meeting segments, utilizing predefined keywords to identify actionable tasks. - Refactored summarization generation to ensure cloud provider registration is handled dynamically based on application configuration. - Introduced a new audio buffer class for managing streaming audio data in the diarization process, enhancing memory management and performance. - Updated various mixins to incorporate model status logging for better observability during processing stages.
This commit is contained in:
@@ -11,10 +11,13 @@ use std::env;
|
||||
use std::fs::File;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use serde_json::{json, Value};
|
||||
|
||||
static STREAMING_TEST_LOCK: Mutex<()> = Mutex::new(());
|
||||
const TARGET_SAMPLE_RATE_HZ: u32 = 16000;
|
||||
const CHUNK_SAMPLES: usize = 1600;
|
||||
const CHUNK_BYTES: usize = CHUNK_SAMPLES * 4;
|
||||
const SUMMARY_TIMEOUT_SECS: u64 = 90;
|
||||
|
||||
/// Check if integration tests should run
|
||||
fn should_run_integration_tests() -> bool {
|
||||
@@ -1209,6 +1212,12 @@ mod integration {
|
||||
.await
|
||||
.expect("Failed to connect");
|
||||
|
||||
let cloud_model = env::var("NOTEFLOW_CLOUD_LLM_MODEL").ok();
|
||||
let cloud_base_url = env::var("NOTEFLOW_CLOUD_LLM_BASE_URL").ok();
|
||||
let cloud_api_key = env::var("NOTEFLOW_CLOUD_LLM_API_KEY").ok();
|
||||
let mut original_ai_config: Option<String> = None;
|
||||
let mut original_cloud_consent: Option<bool> = None;
|
||||
|
||||
println!("\n=== Real Audio Streaming E2E Test ===\n");
|
||||
println!(
|
||||
"Server info: version={}, asr_model={}, asr_ready={}",
|
||||
@@ -1428,12 +1437,79 @@ mod integration {
|
||||
.await
|
||||
.expect("Failed to reconnect after streaming");
|
||||
|
||||
if let (Some(model), Some(base_url), Some(api_key)) =
|
||||
(cloud_model.clone(), cloud_base_url, cloud_api_key)
|
||||
{
|
||||
println!(
|
||||
"Configuring cloud summary provider (model={}, base_url={})",
|
||||
model, base_url
|
||||
);
|
||||
let existing = post_stream_client
|
||||
.get_preferences(Some(vec!["ai_config".to_string()]))
|
||||
.await
|
||||
.expect("Failed to fetch ai_config preference");
|
||||
original_ai_config = existing.preferences.get("ai_config").cloned();
|
||||
let consent_status = post_stream_client
|
||||
.get_cloud_consent_status()
|
||||
.await
|
||||
.expect("Failed to fetch cloud consent status");
|
||||
original_cloud_consent = Some(consent_status);
|
||||
|
||||
let mut ai_config_value = original_ai_config
|
||||
.as_deref()
|
||||
.and_then(|raw| serde_json::from_str::<Value>(raw).ok())
|
||||
.unwrap_or_else(|| json!({}));
|
||||
if !ai_config_value.is_object() {
|
||||
ai_config_value = json!({});
|
||||
}
|
||||
let summary_entry = ai_config_value
|
||||
.as_object_mut()
|
||||
.expect("ai_config should be an object")
|
||||
.entry("summary".to_string())
|
||||
.or_insert_with(|| json!({}));
|
||||
if !summary_entry.is_object() {
|
||||
*summary_entry = json!({});
|
||||
}
|
||||
let summary_obj = summary_entry
|
||||
.as_object_mut()
|
||||
.expect("summary config should be an object");
|
||||
summary_obj.insert("provider".to_string(), json!("openai"));
|
||||
summary_obj.insert("base_url".to_string(), json!(base_url));
|
||||
summary_obj.insert("api_key".to_string(), json!(api_key));
|
||||
summary_obj.insert("selected_model".to_string(), json!(model));
|
||||
summary_obj.insert("model".to_string(), json!(model));
|
||||
summary_obj.insert("test_status".to_string(), json!("success"));
|
||||
|
||||
let mut prefs_update = HashMap::new();
|
||||
prefs_update.insert("ai_config".to_string(), ai_config_value.to_string());
|
||||
post_stream_client
|
||||
.set_preferences(prefs_update, None, None, true)
|
||||
.await
|
||||
.expect("Failed to update cloud preferences");
|
||||
if !consent_status {
|
||||
post_stream_client
|
||||
.grant_cloud_consent()
|
||||
.await
|
||||
.expect("Failed to grant cloud consent");
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Generate summary
|
||||
println!("\n=== STEP 7: Generate Summary ===");
|
||||
let summary = post_stream_client
|
||||
.generate_summary(&meeting.id, false, None)
|
||||
.await
|
||||
.expect("Failed to generate summary");
|
||||
let summary = match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(SUMMARY_TIMEOUT_SECS),
|
||||
post_stream_client.generate_summary(&meeting.id, false, None),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result.expect("Failed to generate summary"),
|
||||
Err(_) => {
|
||||
panic!(
|
||||
"Summary generation timed out after {} seconds",
|
||||
SUMMARY_TIMEOUT_SECS
|
||||
);
|
||||
}
|
||||
};
|
||||
println!(
|
||||
"✓ Summary generated ({} key points, {} action items)",
|
||||
summary.key_points.len(),
|
||||
@@ -1447,10 +1523,43 @@ mod integration {
|
||||
!summary.key_points.is_empty(),
|
||||
"Summary should include key points"
|
||||
);
|
||||
assert!(
|
||||
!summary.action_items.is_empty(),
|
||||
"Summary should include action items for task extraction"
|
||||
);
|
||||
const ACTION_KEYWORDS: [&str; 15] = [
|
||||
"todo",
|
||||
"action",
|
||||
"will",
|
||||
"should",
|
||||
"must",
|
||||
"need to",
|
||||
"let's",
|
||||
"lets",
|
||||
"follow up",
|
||||
"next step",
|
||||
"next steps",
|
||||
"schedule",
|
||||
"send",
|
||||
"share",
|
||||
"review",
|
||||
];
|
||||
let transcript_has_tasks = final_meeting.segments.iter().any(|segment| {
|
||||
let text = segment.text.to_lowercase();
|
||||
ACTION_KEYWORDS.iter().any(|keyword| text.contains(keyword))
|
||||
});
|
||||
if transcript_has_tasks {
|
||||
assert!(
|
||||
!summary.action_items.is_empty(),
|
||||
"Summary should include action items for task extraction"
|
||||
);
|
||||
} else {
|
||||
println!("No action keywords detected; skipping action item expectation.");
|
||||
}
|
||||
if let Some(expected_model) = cloud_model.clone() {
|
||||
let expected = format!("openai/{expected_model}");
|
||||
assert_eq!(
|
||||
summary.model_version, expected,
|
||||
"Expected cloud summary model version {}, got {}",
|
||||
expected, summary.model_version
|
||||
);
|
||||
}
|
||||
|
||||
let segment_id_set: HashSet<i32> = final_meeting
|
||||
.segments
|
||||
@@ -1558,6 +1667,26 @@ mod integration {
|
||||
.expect("Failed to delete meeting");
|
||||
println!("✓ Deleted meeting: {}", deleted);
|
||||
|
||||
let mut prefs_restore = HashMap::new();
|
||||
if let Some(previous) = original_ai_config {
|
||||
prefs_restore.insert("ai_config".to_string(), previous);
|
||||
}
|
||||
if !prefs_restore.is_empty() {
|
||||
if let Err(error) = post_stream_client
|
||||
.set_preferences(prefs_restore, None, None, true)
|
||||
.await
|
||||
{
|
||||
println!("⚠ Failed to restore ai_config preference: {}", error);
|
||||
}
|
||||
}
|
||||
if let Some(previous_consent) = original_cloud_consent {
|
||||
if !previous_consent {
|
||||
if let Err(error) = post_stream_client.revoke_cloud_consent().await {
|
||||
println!("⚠ Failed to restore cloud consent: {}", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final summary
|
||||
println!("\n=== AUDIO STREAMING E2E TEST SUMMARY ===");
|
||||
println!(
|
||||
|
||||
@@ -14,6 +14,7 @@ from noteflow.config.constants import ERROR_MSG_MEETING_PREFIX
|
||||
from noteflow.config.settings import get_feature_flags
|
||||
from noteflow.domain.entities.named_entity import NamedEntity
|
||||
from noteflow.infrastructure.logging import get_logger, log_timing
|
||||
from noteflow.infrastructure.metrics.memory_logger import log_memory_snapshot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
@@ -156,6 +157,10 @@ class NerService:
|
||||
self._model_helper = _ModelLifecycleHelper(ner_engine)
|
||||
self._extraction_helper = _ExtractionHelper(ner_engine, self._model_helper)
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Return True if the NER engine is loaded and ready."""
|
||||
return self._ner_engine.is_ready()
|
||||
|
||||
async def extract_entities(
|
||||
self,
|
||||
meeting_id: MeetingId,
|
||||
@@ -185,13 +190,35 @@ class NerService:
|
||||
return cached_or_segments
|
||||
|
||||
segments = cached_or_segments
|
||||
return await self._extract_and_persist_entities(
|
||||
meeting_id=meeting_id,
|
||||
segments=segments,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
|
||||
# Extract and persist
|
||||
async def _extract_and_persist_entities(
|
||||
self,
|
||||
meeting_id: MeetingId,
|
||||
segments: list[tuple[int, str]],
|
||||
force_refresh: bool,
|
||||
) -> ExtractionResult:
|
||||
"""Extract entities and persist results."""
|
||||
log_memory_snapshot(
|
||||
"ner_extraction_start",
|
||||
meeting_id=str(meeting_id),
|
||||
segment_count=len(segments),
|
||||
)
|
||||
entities = await self._extraction_helper.extract(segments)
|
||||
for entity in entities:
|
||||
entity.meeting_id = meeting_id
|
||||
|
||||
await self._persist_entities(meeting_id, entities, force_refresh)
|
||||
log_memory_snapshot(
|
||||
"ner_extraction_end",
|
||||
meeting_id=str(meeting_id),
|
||||
segment_count=len(segments),
|
||||
entity_count=len(entities),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Extracted %d entities from meeting %s (%d segments)",
|
||||
@@ -234,14 +261,6 @@ class NerService:
|
||||
await uow.commit()
|
||||
|
||||
async def get_entities(self, meeting_id: MeetingId) -> Sequence[NamedEntity]:
|
||||
"""Get cached entities for a meeting (no extraction).
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting UUID.
|
||||
|
||||
Returns:
|
||||
List of entities (empty if not extracted yet).
|
||||
"""
|
||||
async with self._uow_factory() as uow:
|
||||
return await uow.entities.get_by_meeting(meeting_id)
|
||||
|
||||
@@ -277,14 +296,6 @@ class NerService:
|
||||
return count
|
||||
|
||||
async def has_entities(self, meeting_id: MeetingId) -> bool:
|
||||
"""Check if a meeting has extracted entities.
|
||||
|
||||
Args:
|
||||
meeting_id: Meeting UUID.
|
||||
|
||||
Returns:
|
||||
True if at least one entity exists.
|
||||
"""
|
||||
async with self._uow_factory() as uow:
|
||||
return await uow.entities.exists_for_meeting(meeting_id)
|
||||
|
||||
@@ -300,9 +311,6 @@ class NerService:
|
||||
return get_feature_flags().ner_enabled and self._ner_engine.is_ready()
|
||||
|
||||
|
||||
# --- Module-level helper functions ---
|
||||
|
||||
|
||||
def _check_feature_enabled() -> None:
|
||||
"""Raise if NER feature is disabled."""
|
||||
if not get_feature_flags().ner_enabled:
|
||||
|
||||
125
src/noteflow/grpc/_mixins/_model_status.py
Normal file
125
src/noteflow/grpc/_mixins/_model_status.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Model status logging helpers for diagnostics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.application.services.ner_service import NerService
|
||||
from noteflow.application.services.summarization import SummarizationService
|
||||
from noteflow.infrastructure.asr import FasterWhisperEngine
|
||||
from noteflow.infrastructure.diarization.engine import DiarizationEngine
|
||||
from noteflow.domain.summarization import SummarizerProvider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class _ModelStatusHost(Protocol):
|
||||
asr_engine: FasterWhisperEngine | None
|
||||
diarization_engine: DiarizationEngine | None
|
||||
ner_service: NerService | None
|
||||
summarization_service: SummarizationService | None
|
||||
|
||||
|
||||
def log_model_status(
|
||||
host: _ModelStatusHost,
|
||||
stage: str,
|
||||
*,
|
||||
meeting_id: str | None = None,
|
||||
include_provider_details: bool = False,
|
||||
) -> None:
|
||||
"""Log model readiness and provider state for diagnostics."""
|
||||
payload: dict[str, object] = {"stage": stage}
|
||||
if meeting_id is not None:
|
||||
payload["meeting_id"] = meeting_id
|
||||
|
||||
_append_asr_status(payload, host)
|
||||
_append_diarization_status(payload, host)
|
||||
_append_ner_status(payload, host)
|
||||
_append_summarization_status(payload, host, include_provider_details)
|
||||
|
||||
logger.info("model_status_snapshot", **payload)
|
||||
|
||||
|
||||
def _append_asr_status(payload: dict[str, object], host: _ModelStatusHost) -> None:
|
||||
asr = host.asr_engine
|
||||
if asr is None:
|
||||
payload["asr_loaded"] = False
|
||||
return
|
||||
payload["asr_loaded"] = asr.is_loaded
|
||||
payload["asr_model_size"] = asr.model_size or ""
|
||||
payload["asr_device"] = asr.device
|
||||
payload["asr_compute_type"] = asr.compute_type
|
||||
|
||||
|
||||
def _append_diarization_status(payload: dict[str, object], host: _ModelStatusHost) -> None:
|
||||
diarization = host.diarization_engine
|
||||
if diarization is None:
|
||||
payload["diarization_streaming_loaded"] = False
|
||||
payload["diarization_offline_loaded"] = False
|
||||
return
|
||||
payload["diarization_streaming_loaded"] = diarization.is_streaming_loaded
|
||||
payload["diarization_offline_loaded"] = diarization.is_offline_loaded
|
||||
payload["diarization_device"] = diarization.device or ""
|
||||
|
||||
|
||||
def _append_ner_status(payload: dict[str, object], host: _ModelStatusHost) -> None:
|
||||
ner = host.ner_service
|
||||
if ner is None:
|
||||
payload["ner_ready"] = False
|
||||
return
|
||||
payload["ner_ready"] = ner.is_ready()
|
||||
|
||||
|
||||
def _append_summarization_status(
|
||||
payload: dict[str, object],
|
||||
host: _ModelStatusHost,
|
||||
include_provider_details: bool,
|
||||
) -> None:
|
||||
service = host.summarization_service
|
||||
if service is None:
|
||||
payload["summarization_ready"] = False
|
||||
return
|
||||
|
||||
payload["summarization_ready"] = True
|
||||
payload["summarization_default_mode"] = service.settings.default_mode.value
|
||||
payload["summarization_cloud_consent"] = service.settings.cloud_consent_granted
|
||||
payload["summarization_fallback"] = service.settings.fallback_to_local
|
||||
|
||||
if not include_provider_details:
|
||||
payload["summarization_provider_count"] = len(service.providers)
|
||||
return
|
||||
|
||||
summaries: list[str] = []
|
||||
for mode, provider in service.providers.items():
|
||||
summaries.append(_format_provider_summary(mode, provider))
|
||||
payload["summarization_providers"] = summaries
|
||||
|
||||
|
||||
def _format_provider_summary(mode: object, provider: SummarizerProvider) -> str:
|
||||
mode_value = getattr(mode, "value", str(mode))
|
||||
model_name = _read_provider_model(provider)
|
||||
client_ready = _read_provider_client_ready(provider)
|
||||
available = provider.is_available
|
||||
consent = provider.requires_cloud_consent
|
||||
return (
|
||||
f"{mode_value}:{provider.provider_name}"
|
||||
f"|available={available}"
|
||||
f"|consent={consent}"
|
||||
f"|client_ready={client_ready}"
|
||||
f"|provider_model={model_name}"
|
||||
)
|
||||
|
||||
|
||||
def _read_provider_model(provider: SummarizerProvider) -> str:
|
||||
raw_model = getattr(provider, "_model", None)
|
||||
if isinstance(raw_model, str):
|
||||
return raw_model
|
||||
return ""
|
||||
|
||||
|
||||
def _read_provider_client_ready(provider: SummarizerProvider) -> bool:
|
||||
client = getattr(provider, "_client", None)
|
||||
return client is not None
|
||||
@@ -9,6 +9,7 @@ from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
from ...proto import noteflow_pb2
|
||||
from .._types import GrpcStatusContext
|
||||
from .._model_status import log_model_status
|
||||
from ._jobs import JobsMixin
|
||||
from ._refinement import RefinementMixin
|
||||
from ._speaker import SpeakerMixin
|
||||
@@ -47,6 +48,7 @@ class DiarizationMixin(
|
||||
Load the full meeting audio, run offline diarization, and update
|
||||
segment speaker assignments. Job state is persisted when DB available.
|
||||
"""
|
||||
log_model_status(self, "diarization_refine_start", meeting_id=request.meeting_id)
|
||||
await self.prune_diarization_jobs()
|
||||
return await self.start_diarization_job(request, context)
|
||||
|
||||
|
||||
@@ -5,23 +5,55 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.infrastructure.audio.reader import MeetingAudioReader
|
||||
from noteflow.infrastructure.diarization import SpeakerTurn
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
from noteflow.infrastructure.metrics.memory_logger import log_memory_snapshot
|
||||
|
||||
from ..converters import parse_meeting_id_or_none
|
||||
from ._speaker import apply_speaker_to_segment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from noteflow.domain.entities import Segment
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
|
||||
|
||||
from ..protocols import ServicerHost
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _load_diarization_audio(
|
||||
meeting_id: str,
|
||||
meetings_dir: Path,
|
||||
crypto: AesGcmCryptoBox,
|
||||
) -> tuple[NDArray[np.float32], int, float]:
|
||||
audio_reader = MeetingAudioReader(crypto, meetings_dir)
|
||||
if not audio_reader.audio_exists(meeting_id):
|
||||
raise RuntimeError("No audio file found for meeting")
|
||||
|
||||
logger.info("Loading audio for meeting %s", meeting_id)
|
||||
try:
|
||||
audio_chunks = audio_reader.load_meeting_audio(meeting_id)
|
||||
except (FileNotFoundError, ValueError) as exc:
|
||||
raise RuntimeError(f"Failed to load audio: {exc}") from exc
|
||||
|
||||
if not audio_chunks:
|
||||
raise RuntimeError("No audio chunks loaded for meeting")
|
||||
|
||||
sample_rate = audio_reader.sample_rate
|
||||
all_audio = np.concatenate([chunk.frames for chunk in audio_chunks]).astype(
|
||||
np.float32,
|
||||
copy=False,
|
||||
)
|
||||
audio_seconds = len(all_audio) / sample_rate
|
||||
return all_audio, sample_rate, audio_seconds
|
||||
|
||||
|
||||
class RefinementMixin:
|
||||
"""Mixin providing offline diarization refinement functionality."""
|
||||
|
||||
@@ -38,25 +70,22 @@ class RefinementMixin:
|
||||
logger.info("Loading offline diarization model for refinement...")
|
||||
self.diarization_engine.load_offline_model()
|
||||
|
||||
audio_reader = MeetingAudioReader(self.crypto, self.meetings_dir)
|
||||
if not audio_reader.audio_exists(meeting_id):
|
||||
raise RuntimeError("No audio file found for meeting")
|
||||
all_audio, sample_rate, audio_seconds = _load_diarization_audio(
|
||||
meeting_id=meeting_id,
|
||||
meetings_dir=self.meetings_dir,
|
||||
crypto=self.crypto,
|
||||
)
|
||||
|
||||
logger.info("Loading audio for meeting %s", meeting_id)
|
||||
try:
|
||||
audio_chunks = audio_reader.load_meeting_audio(meeting_id)
|
||||
except (FileNotFoundError, ValueError) as exc:
|
||||
raise RuntimeError(f"Failed to load audio: {exc}") from exc
|
||||
|
||||
if not audio_chunks:
|
||||
raise RuntimeError("No audio chunks loaded for meeting")
|
||||
|
||||
sample_rate = audio_reader.sample_rate
|
||||
all_audio = np.concatenate([chunk.frames for chunk in audio_chunks])
|
||||
log_memory_snapshot(
|
||||
"diarization_refinement_start",
|
||||
meeting_id=meeting_id,
|
||||
audio_seconds=audio_seconds,
|
||||
num_speakers=num_speakers or 0,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Running offline diarization on %.2f seconds of audio",
|
||||
len(all_audio) / sample_rate,
|
||||
audio_seconds,
|
||||
)
|
||||
|
||||
turns = self.diarization_engine.diarize_full(
|
||||
@@ -66,6 +95,11 @@ class RefinementMixin:
|
||||
)
|
||||
|
||||
logger.info("Diarization found %d speaker turns", len(turns))
|
||||
log_memory_snapshot(
|
||||
"diarization_refinement_end",
|
||||
meeting_id=meeting_id,
|
||||
turn_count=len(turns),
|
||||
)
|
||||
return list(turns)
|
||||
|
||||
async def apply_diarization_turns(
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Final
|
||||
from typing import Protocol as TypingProtocol
|
||||
|
||||
import numpy as np
|
||||
@@ -23,6 +23,11 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_SECONDS_PER_MINUTE: Final[int] = 60
|
||||
_MAX_TURN_AGE_MINUTES: Final[int] = 15
|
||||
_MAX_TURN_AGE_SECONDS: Final[int] = _MAX_TURN_AGE_MINUTES * _SECONDS_PER_MINUTE
|
||||
_MAX_TURN_COUNT: Final[int] = 5_000
|
||||
|
||||
|
||||
class _DiarizationEngine(TypingProtocol):
|
||||
"""Protocol for diarization engine interface."""
|
||||
@@ -74,6 +79,7 @@ class StreamingDiarizationMixin:
|
||||
# Populate diarization_turns for compatibility with maybe_assign_speaker
|
||||
state.diarization_turns.extend(new_turns)
|
||||
state.diarization_stream_time = session.stream_time
|
||||
_prune_diarization_turns(state, session.stream_time)
|
||||
|
||||
# Persist turns immediately for crash resilience (DB only)
|
||||
await self.persist_streaming_turns(meeting_id, list(new_turns))
|
||||
@@ -169,6 +175,19 @@ def _restore_session_state(session: DiarizationSession, state: MeetingStreamStat
|
||||
session.restore(state.diarization_turns, stream_time=state.diarization_stream_time)
|
||||
|
||||
|
||||
def _prune_diarization_turns(state: MeetingStreamState, stream_time: float) -> None:
|
||||
"""Prune old diarization turns to bound in-memory growth."""
|
||||
if not state.diarization_turns:
|
||||
return
|
||||
|
||||
cutoff = stream_time - _MAX_TURN_AGE_SECONDS
|
||||
if cutoff > 0:
|
||||
state.diarization_turns = [turn for turn in state.diarization_turns if turn.end >= cutoff]
|
||||
|
||||
if len(state.diarization_turns) > _MAX_TURN_COUNT:
|
||||
state.diarization_turns = state.diarization_turns[-_MAX_TURN_COUNT:]
|
||||
|
||||
|
||||
def _convert_turns_to_streaming(turns: list[SpeakerTurn]) -> list[StreamingTurn]:
|
||||
"""Convert domain SpeakerTurns to StreamingTurn for persistence."""
|
||||
return [
|
||||
|
||||
@@ -8,6 +8,7 @@ from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from ._types import GrpcContext
|
||||
from ._model_status import log_model_status
|
||||
from .converters import entity_to_proto, parse_meeting_id_or_abort
|
||||
from .errors import (
|
||||
ENTITY_ENTITY,
|
||||
@@ -50,6 +51,7 @@ class EntitiesMixin:
|
||||
Delegates to NerService for extraction, caching, and persistence.
|
||||
Returns cached results if available, unless force_refresh is True.
|
||||
"""
|
||||
log_model_status(self, "ner_extract_start", meeting_id=request.meeting_id)
|
||||
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
|
||||
ner_service = await require_ner_service(self.ner_service, context)
|
||||
|
||||
|
||||
@@ -10,9 +10,11 @@ import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
from noteflow.infrastructure.metrics.memory_logger import log_memory_snapshot
|
||||
|
||||
from ...proto import noteflow_pb2
|
||||
from .._types import GrpcContext
|
||||
from .._model_status import log_model_status
|
||||
from ..errors import abort_failed_precondition, abort_invalid_argument
|
||||
from ._asr import process_audio_segment
|
||||
from ._cleanup import cleanup_stream_resources
|
||||
@@ -107,6 +109,7 @@ class StreamingMixin:
|
||||
cleanup_meeting = stream_state.current or stream_state.initialized
|
||||
if cleanup_meeting:
|
||||
cleanup_stream_resources(self, cleanup_meeting)
|
||||
log_memory_snapshot("stream_cleanup_end", meeting_id=cleanup_meeting)
|
||||
|
||||
async def process_stream_chunks(
|
||||
self: ServicerHost,
|
||||
@@ -193,14 +196,26 @@ class StreamingMixin:
|
||||
if current_meeting_id is None:
|
||||
# Track meeting_id BEFORE init to guarantee cleanup on any exception
|
||||
initialized_meeting_id = meeting_id
|
||||
sample_rate = chunk.sample_rate
|
||||
channels = chunk.channels
|
||||
logger.info(
|
||||
"StreamTranscription initializing meeting %s (sample_rate=%s channels=%s)",
|
||||
meeting_id,
|
||||
chunk.sample_rate,
|
||||
chunk.channels,
|
||||
sample_rate,
|
||||
channels,
|
||||
)
|
||||
log_memory_snapshot(
|
||||
"stream_init_start",
|
||||
meeting_id=meeting_id,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
)
|
||||
log_model_status(self, "stream_init_start", meeting_id=meeting_id)
|
||||
init_result = await self.init_stream_for_meeting(meeting_id, context)
|
||||
return None if init_result is None else (meeting_id, initialized_meeting_id)
|
||||
if init_result is None:
|
||||
return None
|
||||
log_memory_snapshot("stream_init_end", meeting_id=meeting_id)
|
||||
return meeting_id, initialized_meeting_id
|
||||
if meeting_id != current_meeting_id:
|
||||
await abort_invalid_argument(context, "Stream may only contain a single meeting_id")
|
||||
raise AssertionError("unreachable") # abort is NoReturn
|
||||
|
||||
@@ -12,7 +12,9 @@ from noteflow.domain.value_objects import MeetingId
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
from ...proto import noteflow_pb2
|
||||
from ..._startup import auto_enable_cloud_llm
|
||||
from .._types import GrpcContext
|
||||
from .._model_status import log_model_status
|
||||
from ..converters import parse_meeting_id_or_abort, summary_to_proto
|
||||
from ..errors import ENTITY_MEETING, abort_not_found
|
||||
from ._summary_generation import generate_placeholder_summary, summarize_or_placeholder
|
||||
@@ -47,14 +49,20 @@ class SummarizationGenerationMixin:
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.Summary:
|
||||
"""Generate meeting summary using SummarizationService with fallback."""
|
||||
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
|
||||
op_context = self.get_operation_context(context)
|
||||
style_instructions = build_style_prompt_from_request(request)
|
||||
|
||||
meeting, existing, segments = await self._load_meeting_context(meeting_id, request, context)
|
||||
log_model_status(
|
||||
self,
|
||||
"summary_generate_start",
|
||||
meeting_id=request.meeting_id,
|
||||
include_provider_details=True,
|
||||
)
|
||||
meeting_id, op_context, style_instructions, meeting, existing, segments = (
|
||||
await self._prepare_summary_request(request, context)
|
||||
)
|
||||
if existing and not request.force_regenerate:
|
||||
return summary_to_proto(existing)
|
||||
|
||||
await self._ensure_cloud_provider()
|
||||
|
||||
async with cast(UnitOfWork, self.create_repository_provider()) as repo:
|
||||
style_prompt = await resolve_template_prompt(
|
||||
TemplateResolutionInputs(
|
||||
@@ -85,6 +93,41 @@ class SummarizationGenerationMixin:
|
||||
await self._trigger_summary_webhook(meeting, saved)
|
||||
return summary_to_proto(saved)
|
||||
|
||||
async def _prepare_summary_request(
|
||||
self,
|
||||
request: noteflow_pb2.GenerateSummaryRequest,
|
||||
context: GrpcContext,
|
||||
) -> tuple[
|
||||
MeetingId,
|
||||
OperationContext,
|
||||
str | None,
|
||||
Meeting,
|
||||
Summary | None,
|
||||
list[Segment],
|
||||
]:
|
||||
"""Prepare summary inputs from the request and meeting data."""
|
||||
meeting_id = await parse_meeting_id_or_abort(request.meeting_id, context)
|
||||
op_context = self.get_operation_context(context)
|
||||
style_instructions = build_style_prompt_from_request(request)
|
||||
meeting, existing, segments = await self._load_meeting_context(meeting_id, request, context)
|
||||
return meeting_id, op_context, style_instructions, meeting, existing, segments
|
||||
|
||||
async def _ensure_cloud_provider(self) -> None:
|
||||
"""Register cloud provider if app config is available at runtime."""
|
||||
if self.summarization_service is None:
|
||||
return
|
||||
from noteflow.application.services.summarization import SummarizationMode
|
||||
|
||||
if self.summarization_service.settings.default_mode == SummarizationMode.CLOUD:
|
||||
return
|
||||
|
||||
async with cast(UnitOfWork, self.create_repository_provider()) as repo:
|
||||
if not repo.supports_preferences:
|
||||
return
|
||||
provider = await auto_enable_cloud_llm(repo, self.summarization_service)
|
||||
if provider is not None:
|
||||
logger.info("Cloud summarization enabled for request", provider=provider)
|
||||
|
||||
async def _save_summary(
|
||||
self,
|
||||
meeting: Meeting,
|
||||
|
||||
@@ -7,10 +7,60 @@ from noteflow.domain.entities import Segment, Summary
|
||||
from noteflow.domain.summarization import ProviderUnavailableError
|
||||
from noteflow.domain.value_objects import MeetingId
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
from noteflow.infrastructure.metrics.memory_logger import log_memory_snapshot
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _log_summary_snapshot(
|
||||
stage: str,
|
||||
meeting_id: MeetingId,
|
||||
segment_count: int,
|
||||
**fields: object,
|
||||
) -> None:
|
||||
log_memory_snapshot(
|
||||
stage,
|
||||
meeting_id=str(meeting_id),
|
||||
segment_count=segment_count,
|
||||
**fields,
|
||||
)
|
||||
|
||||
|
||||
def _log_summary_error(
|
||||
meeting_id: MeetingId,
|
||||
segment_count: int,
|
||||
exc: Exception,
|
||||
) -> None:
|
||||
_log_summary_snapshot(
|
||||
"summary_error",
|
||||
meeting_id,
|
||||
segment_count,
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
|
||||
|
||||
async def _summarize_with_service(
|
||||
summarization_service: SummarizationService,
|
||||
meeting_id: MeetingId,
|
||||
segments: list[Segment],
|
||||
style_prompt: str | None,
|
||||
) -> Summary:
|
||||
result = await summarization_service.summarize(
|
||||
meeting_id=meeting_id,
|
||||
segments=segments,
|
||||
style_prompt=style_prompt,
|
||||
)
|
||||
summary = result.summary
|
||||
provider_name = summary.provider_name
|
||||
model_name = summary.model_name
|
||||
logger.info(
|
||||
"Generated summary using %s/%s",
|
||||
provider_name,
|
||||
model_name,
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
async def summarize_or_placeholder(
|
||||
summarization_service: SummarizationService | None,
|
||||
meeting_id: MeetingId,
|
||||
@@ -19,30 +69,39 @@ async def summarize_or_placeholder(
|
||||
) -> Summary:
|
||||
"""Try to summarize via service, fallback to placeholder on failure."""
|
||||
if summarization_service is None:
|
||||
_log_summary_snapshot(
|
||||
"summary_skip",
|
||||
meeting_id,
|
||||
len(segments),
|
||||
reason="service_unavailable",
|
||||
)
|
||||
logger.warning("SummarizationService not configured; using placeholder summary")
|
||||
return generate_placeholder_summary(meeting_id, segments)
|
||||
|
||||
try:
|
||||
result = await summarization_service.summarize(
|
||||
meeting_id=meeting_id,
|
||||
segments=segments,
|
||||
style_prompt=style_prompt,
|
||||
_log_summary_snapshot("summary_start", meeting_id, len(segments))
|
||||
summary = await _summarize_with_service(
|
||||
summarization_service,
|
||||
meeting_id,
|
||||
segments,
|
||||
style_prompt,
|
||||
)
|
||||
summary = result.summary
|
||||
provider_name = summary.provider_name
|
||||
model_name = summary.model_name
|
||||
logger.info(
|
||||
"Generated summary using %s/%s",
|
||||
provider_name,
|
||||
model_name,
|
||||
_log_summary_snapshot(
|
||||
"summary_end",
|
||||
meeting_id,
|
||||
len(segments),
|
||||
provider=summary.provider_name,
|
||||
model=summary.model_name,
|
||||
)
|
||||
return summary
|
||||
except ProviderUnavailableError as exc:
|
||||
logger.warning("Summarization provider unavailable; using placeholder: %s", exc)
|
||||
_log_summary_error(meeting_id, len(segments), exc)
|
||||
except (TimeoutError, RuntimeError, ValueError) as exc:
|
||||
logger.exception(
|
||||
"Summarization failed (%s); using placeholder summary", type(exc).__name__
|
||||
)
|
||||
_log_summary_error(meeting_id, len(segments), exc)
|
||||
return generate_placeholder_summary(meeting_id, segments)
|
||||
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from noteflow.infrastructure.persistence.database import (
|
||||
create_engine_and_session_factory,
|
||||
ensure_schema_ready,
|
||||
)
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
||||
from noteflow.infrastructure.summarization import CloudBackend, CloudSummarizer
|
||||
|
||||
@@ -62,6 +63,8 @@ class _SummaryConfig(TypedDict, total=False):
|
||||
api_key: str
|
||||
test_status: str
|
||||
model: str
|
||||
selected_model: str
|
||||
base_url: str
|
||||
|
||||
|
||||
class AsrConfigLike(Protocol):
|
||||
@@ -114,7 +117,7 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def auto_enable_cloud_llm(
|
||||
uow: SqlAlchemyUnitOfWork,
|
||||
uow: UnitOfWork,
|
||||
summarization_service: SummarizationService,
|
||||
) -> str | None:
|
||||
"""Check for app-configured LLM and register cloud provider if valid.
|
||||
@@ -126,33 +129,36 @@ async def auto_enable_cloud_llm(
|
||||
Returns:
|
||||
Provider name if cloud provider was auto-enabled, None otherwise.
|
||||
"""
|
||||
if not uow.supports_preferences:
|
||||
return None
|
||||
ai_config_value = await uow.preferences.get("ai_config")
|
||||
if not isinstance(ai_config_value, dict):
|
||||
return None
|
||||
|
||||
ai_config = cast(dict[str, object], ai_config_value)
|
||||
summary_config_value = ai_config.get("summary", {})
|
||||
if not isinstance(summary_config_value, dict):
|
||||
return None
|
||||
|
||||
summary_config = cast(_SummaryConfig, summary_config_value)
|
||||
provider = summary_config.get(PROVIDER, "")
|
||||
api_key = summary_config.get("api_key", "")
|
||||
test_status = summary_config.get("test_status", "")
|
||||
model = summary_config.get("model") or None # Convert empty string to None
|
||||
|
||||
# Only register if configured and tested successfully
|
||||
if provider not in (PROVIDER_NAME_OPENAI, "anthropic") or not api_key or test_status != "success":
|
||||
model = summary_config.get("model") or summary_config.get("selected_model") or None
|
||||
base_url = summary_config.get("base_url") or None
|
||||
if (
|
||||
provider not in (PROVIDER_NAME_OPENAI, "anthropic")
|
||||
or not api_key
|
||||
or test_status != "success"
|
||||
):
|
||||
return None
|
||||
|
||||
backend = CloudBackend.OPENAI if provider == PROVIDER_NAME_OPENAI else CloudBackend.ANTHROPIC
|
||||
cloud_summarizer = CloudSummarizer(
|
||||
backend=backend,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url if backend == CloudBackend.OPENAI else None,
|
||||
)
|
||||
summarization_service.register_provider(SummarizationMode.CLOUD, cloud_summarizer)
|
||||
# Auto-grant consent since user explicitly configured in app
|
||||
summarization_service.set_default_mode(SummarizationMode.CLOUD)
|
||||
summarization_service.settings.cloud_consent_granted = True
|
||||
logger.info("Auto-enabled CLOUD summarization from app config: provider=%s", provider)
|
||||
return provider
|
||||
|
||||
107
src/noteflow/infrastructure/diarization/audio_buffer.py
Normal file
107
src/noteflow/infrastructure/diarization/audio_buffer.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Audio buffer for streaming diarization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Final
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_MAX_BUFFER_CHUNKS: Final[int] = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiarizationAudioBuffer:
|
||||
"""Bounded audio buffer for diarization streaming chunks."""
|
||||
|
||||
sample_rate: int
|
||||
chunk_duration: float
|
||||
max_buffer_chunks: int = _MAX_BUFFER_CHUNKS
|
||||
_buffer: list[NDArray[np.float32]] = field(default_factory=list, init=False)
|
||||
_buffer_samples: int = field(default=0, init=False)
|
||||
|
||||
def append(self, audio: NDArray[np.float32], *, meeting_id: str) -> None:
|
||||
"""Append audio to the buffer, enforcing size limits."""
|
||||
if audio.size == 0:
|
||||
return
|
||||
self._buffer.append(audio)
|
||||
self._buffer_samples += len(audio)
|
||||
self._enforce_limit(meeting_id=meeting_id)
|
||||
|
||||
def extract_full_chunk(self, required_samples: int) -> NDArray[np.float32] | None:
|
||||
"""Extract the next full chunk if enough samples are buffered."""
|
||||
if self._buffer_samples < required_samples:
|
||||
return None
|
||||
|
||||
pieces: list[NDArray[np.float32]] = []
|
||||
remaining_samples = required_samples
|
||||
new_buffer: list[NDArray[np.float32]] = []
|
||||
|
||||
for idx, segment in enumerate(self._buffer):
|
||||
if remaining_samples <= 0:
|
||||
new_buffer.extend(self._buffer[idx:])
|
||||
break
|
||||
segment_len = len(segment)
|
||||
if segment_len <= remaining_samples:
|
||||
pieces.append(segment)
|
||||
remaining_samples -= segment_len
|
||||
continue
|
||||
|
||||
pieces.append(segment[:remaining_samples])
|
||||
remainder = segment[remaining_samples:]
|
||||
if remainder.size:
|
||||
new_buffer.append(remainder.copy())
|
||||
new_buffer.extend(self._buffer[idx + 1 :])
|
||||
remaining_samples = 0
|
||||
break
|
||||
|
||||
if remaining_samples > 0 or not pieces:
|
||||
return None
|
||||
|
||||
chunk_audio = np.concatenate(pieces) if len(pieces) > 1 else pieces[0].copy()
|
||||
self._buffer = new_buffer
|
||||
self._buffer_samples = sum(len(segment) for segment in new_buffer)
|
||||
return chunk_audio
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear buffered audio."""
|
||||
self._buffer.clear()
|
||||
self._buffer_samples = 0
|
||||
|
||||
def _enforce_limit(self, *, meeting_id: str) -> None:
|
||||
max_samples = int(self.sample_rate * self.chunk_duration * self.max_buffer_chunks)
|
||||
if self._buffer_samples <= max_samples:
|
||||
return
|
||||
|
||||
drop_samples = self._buffer_samples - max_samples
|
||||
if drop_samples <= 0:
|
||||
return
|
||||
|
||||
new_buffer: list[NDArray[np.float32]] = []
|
||||
remaining_drop = drop_samples
|
||||
for segment in self._buffer:
|
||||
if remaining_drop <= 0:
|
||||
new_buffer.append(segment)
|
||||
continue
|
||||
segment_len = len(segment)
|
||||
if segment_len <= remaining_drop:
|
||||
remaining_drop -= segment_len
|
||||
continue
|
||||
new_buffer.append(segment[remaining_drop:].copy())
|
||||
remaining_drop = 0
|
||||
|
||||
dropped = drop_samples - remaining_drop
|
||||
if dropped > 0:
|
||||
logger.warning(
|
||||
"diarization_buffer_overrun",
|
||||
meeting_id=meeting_id,
|
||||
dropped_samples=dropped,
|
||||
)
|
||||
|
||||
self._buffer = new_buffer
|
||||
self._buffer_samples = sum(len(segment) for segment in new_buffer)
|
||||
@@ -8,17 +8,19 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from typing import TYPE_CHECKING, Final, Protocol, cast
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
from noteflow.infrastructure.diarization.audio_buffer import DiarizationAudioBuffer
|
||||
from noteflow.infrastructure.diarization.dto import SpeakerTurn
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from diart import SpeakerDiarization
|
||||
from pyannote.core import SlidingWindowFeature
|
||||
|
||||
|
||||
class _TrackSegment(Protocol):
|
||||
@@ -38,6 +40,10 @@ logger = get_logger(__name__)
|
||||
|
||||
# Default chunk duration in seconds (matches pyannote segmentation model)
|
||||
DEFAULT_CHUNK_DURATION: float = 5.0
|
||||
_SECONDS_PER_MINUTE: Final[int] = 60
|
||||
_MAX_TURN_AGE_MINUTES: Final[int] = 15
|
||||
_MAX_TURN_AGE_SECONDS: Final[int] = _MAX_TURN_AGE_MINUTES * _SECONDS_PER_MINUTE
|
||||
_MAX_TURN_COUNT: Final[int] = 5_000
|
||||
|
||||
|
||||
def _collect_turns(
|
||||
@@ -93,8 +99,13 @@ class DiarizationSession:
|
||||
_stream_time: float = field(default=0.0, init=False)
|
||||
_turns: list[SpeakerTurn] = field(default_factory=list, init=False)
|
||||
_closed: bool = field(default=False, init=False)
|
||||
_audio_buffer: list[NDArray[np.float32]] = field(default_factory=list, init=False)
|
||||
_buffer_samples: int = field(default=0, init=False)
|
||||
_audio_buffer: DiarizationAudioBuffer = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._audio_buffer = DiarizationAudioBuffer(
|
||||
sample_rate=self._sample_rate,
|
||||
chunk_duration=self._chunk_duration,
|
||||
)
|
||||
|
||||
def process_chunk(
|
||||
self,
|
||||
@@ -141,8 +152,7 @@ class DiarizationSession:
|
||||
"""Add audio to buffer, ensuring 1D format."""
|
||||
if audio.ndim > 1:
|
||||
audio = audio.flatten()
|
||||
self._audio_buffer.append(audio)
|
||||
self._buffer_samples += len(audio)
|
||||
self._audio_buffer.append(audio, meeting_id=self.meeting_id)
|
||||
|
||||
def _extract_full_chunk_if_ready(
|
||||
self,
|
||||
@@ -150,18 +160,7 @@ class DiarizationSession:
|
||||
) -> NDArray[np.float32] | None:
|
||||
"""Extract a full chunk from buffer if enough samples available."""
|
||||
required_samples = int(self._chunk_duration * sample_rate)
|
||||
|
||||
if self._buffer_samples < required_samples:
|
||||
return None
|
||||
|
||||
full_audio = np.concatenate(self._audio_buffer)
|
||||
chunk_audio = full_audio[:required_samples]
|
||||
|
||||
remaining = full_audio[required_samples:]
|
||||
self._audio_buffer = [remaining] if len(remaining) > 0 else []
|
||||
self._buffer_samples = max(len(remaining), 0)
|
||||
|
||||
return chunk_audio
|
||||
return self._audio_buffer.extract_full_chunk(required_samples)
|
||||
|
||||
def _process_full_chunk(
|
||||
self,
|
||||
@@ -180,39 +179,49 @@ class DiarizationSession:
|
||||
if self._pipeline is None:
|
||||
return []
|
||||
|
||||
duration = len(audio) / sample_rate
|
||||
current_time = self._stream_time + duration
|
||||
waveform = self._build_waveform(audio, duration)
|
||||
new_turns = self._run_pipeline(waveform)
|
||||
if new_turns:
|
||||
self._turns.extend(new_turns)
|
||||
self._prune_turns(current_time)
|
||||
self._stream_time = current_time
|
||||
return new_turns
|
||||
|
||||
def _build_waveform(
|
||||
self,
|
||||
audio: NDArray[np.float32],
|
||||
duration: float,
|
||||
) -> SlidingWindowFeature:
|
||||
"""Build a SlidingWindowFeature for diarization."""
|
||||
from pyannote.core import SlidingWindow, SlidingWindowFeature
|
||||
|
||||
duration = len(audio) / sample_rate
|
||||
|
||||
# Reshape to (samples, channels) for diart - mono audio has 1 channel
|
||||
audio_2d = audio.reshape(-1, 1)
|
||||
|
||||
# Create SlidingWindowFeature for diart
|
||||
window = SlidingWindow(start=0.0, duration=duration, step=duration)
|
||||
waveform = SlidingWindowFeature(audio_2d, window)
|
||||
return SlidingWindowFeature(audio_2d, window)
|
||||
|
||||
def _run_pipeline(self, waveform: SlidingWindowFeature) -> list[SpeakerTurn]:
|
||||
"""Run diarization pipeline and return speaker turns."""
|
||||
try:
|
||||
# Process through pipeline
|
||||
# Note: Frame rate mismatch between segmentation-3.0 and embedding models
|
||||
# may cause warnings and occasional errors, which we handle gracefully
|
||||
results = self._pipeline([waveform])
|
||||
|
||||
# Convert results to turns with absolute time offsets
|
||||
new_turns = _collect_turns(results, self._stream_time)
|
||||
self._turns.extend(new_turns)
|
||||
|
||||
except (RuntimeError, ZeroDivisionError, ValueError) as e:
|
||||
# Handle frame/weights mismatch and related errors gracefully
|
||||
# Streaming diarization continues even if individual chunks fail
|
||||
results: Sequence[tuple[_Annotation, object]]
|
||||
if self._pipeline is None:
|
||||
results = ()
|
||||
else:
|
||||
# pyright cannot infer diart pipeline return types; cast to expected shape.
|
||||
results = cast(
|
||||
Sequence[tuple[_Annotation, object]],
|
||||
self._pipeline([waveform]),
|
||||
)
|
||||
except (RuntimeError, ZeroDivisionError, ValueError) as exc:
|
||||
logger.warning(
|
||||
"Diarization chunk processing failed (non-fatal): %s",
|
||||
str(e),
|
||||
str(exc),
|
||||
exc_info=False,
|
||||
)
|
||||
new_turns = []
|
||||
return []
|
||||
|
||||
self._stream_time += duration
|
||||
return new_turns
|
||||
return _collect_turns(results, self._stream_time)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset session state for restarting diarization.
|
||||
@@ -227,7 +236,6 @@ class DiarizationSession:
|
||||
self._stream_time = 0.0
|
||||
self._turns.clear()
|
||||
self._audio_buffer.clear()
|
||||
self._buffer_samples = 0
|
||||
logger.debug("Session %s reset", self.meeting_id)
|
||||
|
||||
def restore(
|
||||
@@ -252,6 +260,7 @@ class DiarizationSession:
|
||||
if stream_time is None:
|
||||
stream_time = max((t.end for t in turns), default=0.0)
|
||||
self._stream_time = max(self._stream_time, stream_time)
|
||||
self._prune_turns(self._stream_time)
|
||||
logger.debug(
|
||||
"Session %s restored stream_time=%.3f with %d turns",
|
||||
self.meeting_id,
|
||||
@@ -272,11 +281,20 @@ class DiarizationSession:
|
||||
self._closed = True
|
||||
self._turns.clear()
|
||||
self._audio_buffer.clear()
|
||||
self._buffer_samples = 0
|
||||
# Explicitly release pipeline reference to allow GC and GPU memory release
|
||||
self._pipeline = None
|
||||
logger.info("diarization_session_closed", meeting_id=self.meeting_id)
|
||||
|
||||
def _prune_turns(self, current_time: float) -> None:
|
||||
"""Prune old speaker turns to keep memory bounded."""
|
||||
if not self._turns:
|
||||
return
|
||||
cutoff = current_time - _MAX_TURN_AGE_SECONDS
|
||||
if cutoff > 0:
|
||||
self._turns = [turn for turn in self._turns if turn.end >= cutoff]
|
||||
if len(self._turns) > _MAX_TURN_COUNT:
|
||||
self._turns = self._turns[-_MAX_TURN_COUNT:]
|
||||
|
||||
@property
|
||||
def stream_time(self) -> float:
|
||||
"""Current stream time position in seconds."""
|
||||
|
||||
25
src/noteflow/infrastructure/metrics/memory_logger.py
Normal file
25
src/noteflow/infrastructure/metrics/memory_logger.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Memory snapshot logging utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Final
|
||||
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
from noteflow.infrastructure.metrics.collector import get_metrics_collector
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_EVENT_NAME: Final[str] = "memory_snapshot"
|
||||
|
||||
|
||||
def log_memory_snapshot(stage: str, **fields: object) -> None:
|
||||
"""Log a structured memory snapshot for diagnostics."""
|
||||
metrics = get_metrics_collector().collect_now()
|
||||
logger.info(
|
||||
_EVENT_NAME,
|
||||
stage=stage,
|
||||
process_memory_mb=metrics.process_memory_mb,
|
||||
memory_percent=metrics.memory_percent,
|
||||
memory_mb=metrics.memory_mb,
|
||||
**fields,
|
||||
)
|
||||
51
src/noteflow/infrastructure/summarization/_action_items.py
Normal file
51
src/noteflow/infrastructure/summarization/_action_items.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Action item extraction helpers for summarization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Final
|
||||
|
||||
from noteflow.domain.entities import ActionItem, Segment
|
||||
|
||||
_ACTION_KEYWORDS: Final[tuple[str, ...]] = (
|
||||
"todo",
|
||||
"action",
|
||||
"will",
|
||||
"should",
|
||||
"must",
|
||||
"need to",
|
||||
"let's",
|
||||
"lets",
|
||||
"follow up",
|
||||
"next step",
|
||||
"next steps",
|
||||
"schedule",
|
||||
"send",
|
||||
"share",
|
||||
"review",
|
||||
"update",
|
||||
"prepare",
|
||||
"assign",
|
||||
)
|
||||
_MAX_ACTION_TEXT_LEN: Final[int] = 80
|
||||
|
||||
|
||||
def extract_action_items_from_segments(
|
||||
segments: Sequence[Segment],
|
||||
max_action_items: int,
|
||||
) -> list[ActionItem]:
|
||||
"""Extract action items by scanning segments for action keywords."""
|
||||
action_items: list[ActionItem] = []
|
||||
for segment in segments:
|
||||
if len(action_items) >= max_action_items:
|
||||
break
|
||||
text_lower = segment.text.lower()
|
||||
if any(keyword in text_lower for keyword in _ACTION_KEYWORDS):
|
||||
action_items.append(
|
||||
ActionItem(
|
||||
text=f"Action: {segment.text[:_MAX_ACTION_TEXT_LEN]}",
|
||||
assignee="",
|
||||
segment_ids=[segment.segment_id],
|
||||
)
|
||||
)
|
||||
return action_items
|
||||
@@ -10,6 +10,9 @@ from noteflow.domain.constants.fields import ACTION_ITEMS, KEY_POINTS, SEGMENT_I
|
||||
from noteflow.domain.entities import ActionItem, KeyPoint, Summary
|
||||
from noteflow.domain.summarization import InvalidResponseError
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
from noteflow.infrastructure.summarization._action_items import (
|
||||
extract_action_items_from_segments,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
@@ -270,10 +273,16 @@ def parse_llm_response(response_text: str, request: SummarizationRequest) -> Sum
|
||||
_parse_key_point(kp_data, valid_ids, request.segments)
|
||||
for kp_data in data.get(KEY_POINTS, [])[: request.max_key_points]
|
||||
]
|
||||
action_items = [
|
||||
parsed_action_items = [
|
||||
_parse_action_item(ai_data, valid_ids)
|
||||
for ai_data in data.get(ACTION_ITEMS, [])[: request.max_action_items]
|
||||
]
|
||||
action_items = [item for item in parsed_action_items if item.segment_ids]
|
||||
if not action_items and request.max_action_items > 0:
|
||||
action_items = extract_action_items_from_segments(
|
||||
request.segments,
|
||||
request.max_action_items,
|
||||
)
|
||||
|
||||
return Summary(
|
||||
meeting_id=request.meeting_id,
|
||||
|
||||
@@ -21,33 +21,34 @@ class CloudSummarizerClients:
|
||||
|
||||
def _get_openai_client(self) -> openai.OpenAI:
|
||||
"""Get or create OpenAI client."""
|
||||
try:
|
||||
import openai as openai_module
|
||||
except ImportError as e:
|
||||
raise ProviderUnavailableError(
|
||||
"openai package not installed. Install with: pip install openai"
|
||||
) from e
|
||||
if self._client is None:
|
||||
try:
|
||||
import openai
|
||||
|
||||
self._client = openai.OpenAI(
|
||||
api_key=self._api_key,
|
||||
timeout=self._timeout,
|
||||
base_url=self._openai_base_url,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ProviderUnavailableError(
|
||||
"openai package not installed. Install with: pip install openai"
|
||||
) from e
|
||||
return cast(openai.OpenAI, self._client)
|
||||
self._client = openai_module.OpenAI(
|
||||
api_key=self._api_key,
|
||||
timeout=self._timeout,
|
||||
base_url=self._openai_base_url,
|
||||
)
|
||||
return cast(openai_module.OpenAI, self._client)
|
||||
|
||||
def _get_anthropic_client(self) -> anthropic.Anthropic:
|
||||
"""Get or create Anthropic client."""
|
||||
try:
|
||||
import anthropic as anthropic_module
|
||||
except ImportError as e:
|
||||
raise ProviderUnavailableError(
|
||||
"anthropic package not installed. Install with: pip install anthropic"
|
||||
) from e
|
||||
if self._client is None:
|
||||
try:
|
||||
import anthropic
|
||||
|
||||
self._client = anthropic.Anthropic(api_key=self._api_key, timeout=self._timeout)
|
||||
except ImportError as e:
|
||||
raise ProviderUnavailableError(
|
||||
"anthropic package not installed. Install with: pip install anthropic"
|
||||
) from e
|
||||
return cast(anthropic.Anthropic, self._client)
|
||||
self._client = anthropic_module.Anthropic(
|
||||
api_key=self._api_key,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
return cast(anthropic_module.Anthropic, self._client)
|
||||
|
||||
def get_openai_client(self) -> openai.OpenAI:
|
||||
"""Expose OpenAI client for integrations and testing."""
|
||||
|
||||
@@ -2,19 +2,17 @@
|
||||
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Final
|
||||
|
||||
from noteflow.domain.entities import ActionItem, KeyPoint, Summary
|
||||
from noteflow.domain.summarization import (
|
||||
SummarizationRequest,
|
||||
SummarizationResult,
|
||||
)
|
||||
from noteflow.infrastructure.summarization._action_items import (
|
||||
extract_action_items_from_segments,
|
||||
)
|
||||
from noteflow.infrastructure.summarization._availability import AvailabilityProviderBase
|
||||
|
||||
_ACTION_KEYWORDS: Final[frozenset[str]] = frozenset({
|
||||
"todo", "action", "will", "should", "must", "need to"
|
||||
})
|
||||
|
||||
|
||||
def _truncate_text(text: str, max_length: int) -> str:
|
||||
"""Truncate text with ellipsis if needed."""
|
||||
@@ -45,28 +43,11 @@ def _build_key_points(request: SummarizationRequest) -> list[KeyPoint]:
|
||||
|
||||
|
||||
def _build_action_items(request: SummarizationRequest) -> list[ActionItem]:
|
||||
"""Build action items from segments containing action keywords.
|
||||
|
||||
Args:
|
||||
request: Summarization request with segments.
|
||||
|
||||
Returns:
|
||||
List of action items (up to max_action_items).
|
||||
"""
|
||||
action_items: list[ActionItem] = []
|
||||
for segment in request.segments:
|
||||
if len(action_items) >= request.max_action_items:
|
||||
break
|
||||
text_lower = segment.text.lower()
|
||||
if any(kw in text_lower for kw in _ACTION_KEYWORDS):
|
||||
action_items.append(
|
||||
ActionItem(
|
||||
text=f"Action: {segment.text[:80]}",
|
||||
assignee="",
|
||||
segment_ids=[segment.segment_id],
|
||||
)
|
||||
)
|
||||
return action_items
|
||||
"""Build action items from segments containing action keywords."""
|
||||
return extract_action_items_from_segments(
|
||||
request.segments,
|
||||
request.max_action_items,
|
||||
)
|
||||
|
||||
|
||||
class MockSummarizer(AvailabilityProviderBase):
|
||||
|
||||
Reference in New Issue
Block a user