370 lines
12 KiB
Python
370 lines
12 KiB
Python
"""Pytest fixtures for integration tests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncGenerator
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Final
|
|
from uuid import UUID, uuid4
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
|
|
from noteflow.domain.entities import Meeting, Segment
|
|
from noteflow.domain.entities.task import Task, TaskStatus
|
|
from noteflow.domain.value_objects import MeetingId
|
|
from noteflow.infrastructure.persistence.repositories.meeting_repo import (
|
|
SqlAlchemyMeetingRepository,
|
|
)
|
|
from noteflow.infrastructure.persistence.repositories.segment_repo import (
|
|
SqlAlchemySegmentRepository,
|
|
)
|
|
from noteflow.infrastructure.persistence.repositories.task_repo import SqlAlchemyTaskRepository
|
|
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
|
from support.db_utils import (
|
|
cleanup_test_schema,
|
|
create_test_engine,
|
|
create_test_session_factory,
|
|
get_or_create_container,
|
|
initialize_test_schema,
|
|
stop_container,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from numpy.typing import NDArray
|
|
|
|
# ============================================================================
|
|
# Repository Fixtures (Shared across integration tests)
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def meeting_repo(session: AsyncSession) -> SqlAlchemyMeetingRepository:
|
|
"""Create meeting repository instance."""
|
|
return SqlAlchemyMeetingRepository(session)
|
|
|
|
|
|
@pytest.fixture
|
|
def segment_repo(session: AsyncSession) -> SqlAlchemySegmentRepository:
|
|
"""Create segment repository instance."""
|
|
return SqlAlchemySegmentRepository(session)
|
|
|
|
|
|
# ============================================================================
|
|
# Embedding Constants
|
|
# ============================================================================
|
|
|
|
EMBEDDING_DIM: Final[int] = 1536
|
|
"""Standard embedding dimension for pgvector tests."""
|
|
|
|
|
|
# ============================================================================
|
|
# Embedding Helper Functions (Module-Level)
|
|
# ============================================================================
|
|
|
|
|
|
def create_unit_embedding(dimension_index: int) -> list[float]:
|
|
"""Create a unit embedding with 1.0 at the specified dimension index.
|
|
|
|
Used to create orthogonal embeddings for precise similarity testing.
|
|
Cosine similarity between two unit embeddings at different dimensions is 0.
|
|
|
|
Args:
|
|
dimension_index: Index (0-based) where 1.0 should be placed.
|
|
|
|
Returns:
|
|
Embedding vector of length EMBEDDING_DIM with 1.0 at dimension_index.
|
|
"""
|
|
embedding = [0.0] * EMBEDDING_DIM
|
|
embedding[dimension_index] = 1.0
|
|
return embedding
|
|
|
|
|
|
def create_weighted_embedding(weights: list[tuple[int, float]]) -> list[float]:
|
|
"""Create a normalized embedding with specified weights at given dimensions.
|
|
|
|
Useful for creating embeddings with known similarity to unit embeddings.
|
|
|
|
Args:
|
|
weights: List of (dimension_index, weight) tuples.
|
|
|
|
Returns:
|
|
Normalized embedding vector.
|
|
"""
|
|
embedding = np.zeros(EMBEDDING_DIM, dtype=np.float64)
|
|
for dim_index, weight in weights:
|
|
embedding[dim_index] = weight
|
|
# Normalize to unit length for proper cosine similarity
|
|
norm = np.linalg.norm(embedding)
|
|
if norm > 0:
|
|
embedding = embedding / norm
|
|
return embedding.tolist()
|
|
|
|
|
|
def compute_cosine_similarity(embedding_a: list[float], embedding_b: list[float]) -> float:
|
|
"""Compute cosine similarity between two embeddings.
|
|
|
|
Used to verify database similarity scores match expected values.
|
|
|
|
Args:
|
|
embedding_a: First embedding vector.
|
|
embedding_b: Second embedding vector.
|
|
|
|
Returns:
|
|
Cosine similarity score in range [-1, 1].
|
|
"""
|
|
vec_a = np.array(embedding_a, dtype=np.float64)
|
|
vec_b = np.array(embedding_b, dtype=np.float64)
|
|
dot_product = np.dot(vec_a, vec_b)
|
|
norm_a = np.linalg.norm(vec_a)
|
|
norm_b = np.linalg.norm(vec_b)
|
|
if norm_a == 0 or norm_b == 0:
|
|
return 0.0
|
|
return float(dot_product / (norm_a * norm_b))
|
|
|
|
|
|
# ============================================================================
|
|
# Audio Fixture Constants
|
|
# ============================================================================
|
|
|
|
SAMPLE_RATE: Final[int] = 16000
|
|
MAX_AUDIO_SECONDS: Final[float] = 10.0
|
|
MAX_AUDIO_SAMPLES: Final[int] = int(MAX_AUDIO_SECONDS * SAMPLE_RATE)
|
|
|
|
|
|
@pytest.fixture
|
|
async def session_factory() -> AsyncGenerator[async_sessionmaker[AsyncSession], None]:
|
|
"""Create a session factory and initialize the database schema."""
|
|
_, database_url = get_or_create_container()
|
|
|
|
engine = create_test_engine(database_url)
|
|
|
|
async with engine.begin() as conn:
|
|
await initialize_test_schema(conn)
|
|
|
|
yield create_test_session_factory(engine)
|
|
|
|
# Cleanup - drop schema to reset for next test
|
|
async with engine.begin() as conn:
|
|
await cleanup_test_schema(conn)
|
|
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
async def session(
|
|
session_factory: async_sessionmaker[AsyncSession],
|
|
) -> AsyncGenerator[AsyncSession, None]:
|
|
"""Provide a database session for each test."""
|
|
async with session_factory() as session:
|
|
yield session
|
|
# Rollback any uncommitted changes
|
|
await session.rollback()
|
|
|
|
|
|
def pytest_sessionfinish(session: pytest.Session, exitstatus: int) -> None:
|
|
"""Cleanup container after all tests complete."""
|
|
stop_container()
|
|
|
|
|
|
# ============================================================================
|
|
# Meeting Factory Fixtures
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
async def persisted_meeting(
|
|
session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path
|
|
) -> MeetingId:
|
|
"""Create and persist a simple meeting for tests."""
|
|
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
|
|
meeting = Meeting.create(title="Test Meeting")
|
|
await uow.meetings.create(meeting)
|
|
await uow.commit()
|
|
return meeting.id
|
|
|
|
|
|
@pytest.fixture
|
|
async def persisted_meeting_with_segment(
|
|
session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path
|
|
) -> MeetingId:
|
|
"""Create and persist a meeting with one segment for tests."""
|
|
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
|
|
meeting = Meeting.create(title="Test Meeting with Segment")
|
|
await uow.meetings.create(meeting)
|
|
segment = Segment(segment_id=0, text="Test segment content.", start_time=0.0, end_time=5.0)
|
|
await uow.segments.add(meeting.id, segment)
|
|
await uow.commit()
|
|
return meeting.id
|
|
|
|
|
|
@pytest.fixture
|
|
async def stopped_meeting_with_segments(
|
|
session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path
|
|
) -> MeetingId:
|
|
"""Create a stopped meeting with two speaker segments for tests."""
|
|
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
|
|
meeting = Meeting.create(title="Test Meeting with Speakers")
|
|
meeting.start_recording()
|
|
meeting.begin_stopping()
|
|
meeting.stop_recording()
|
|
await uow.meetings.create(meeting)
|
|
segment_0 = Segment(
|
|
segment_id=0, text="First speaker.", start_time=0.0, end_time=3.0, speaker_id="Alice"
|
|
)
|
|
segment_1 = Segment(
|
|
segment_id=1, text="Second speaker.", start_time=3.0, end_time=6.0, speaker_id="Bob"
|
|
)
|
|
await uow.segments.add(meeting.id, segment_0)
|
|
await uow.segments.add(meeting.id, segment_1)
|
|
await uow.commit()
|
|
return meeting.id
|
|
|
|
|
|
# ============================================================================
|
|
# Task Fixtures (for Task repository tests)
|
|
# ============================================================================
|
|
|
|
TASK_WORKSPACE_ID: Final[UUID] = UUID("00000000-0000-0000-0000-000000000099")
|
|
"""Fixed workspace ID for task integration tests."""
|
|
|
|
|
|
@pytest.fixture
|
|
async def task_workspace(session: AsyncSession) -> UUID:
|
|
"""Create and persist a workspace for task tests."""
|
|
from noteflow.infrastructure.persistence.models.identity.identity import WorkspaceModel
|
|
|
|
workspace = WorkspaceModel()
|
|
workspace.id = TASK_WORKSPACE_ID
|
|
workspace.name = "Task Test Workspace"
|
|
workspace.slug = "task-test-workspace"
|
|
workspace.created_at = datetime.now(UTC)
|
|
workspace.updated_at = datetime.now(UTC)
|
|
session.add(workspace)
|
|
await session.flush()
|
|
return TASK_WORKSPACE_ID
|
|
|
|
|
|
@pytest.fixture
|
|
async def task_repo(session: AsyncSession) -> SqlAlchemyTaskRepository:
|
|
"""Create a task repository instance."""
|
|
return SqlAlchemyTaskRepository(session)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_task(task_workspace: UUID) -> Task:
|
|
"""Create a sample task entity for testing."""
|
|
return Task(
|
|
id=uuid4(),
|
|
workspace_id=task_workspace,
|
|
meeting_id=None,
|
|
action_item_id=None,
|
|
text="Sample task for testing",
|
|
status=TaskStatus.OPEN,
|
|
priority=1,
|
|
created_at=datetime.now(UTC),
|
|
updated_at=datetime.now(UTC),
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
async def persisted_task(
|
|
session: AsyncSession,
|
|
task_repo: SqlAlchemyTaskRepository,
|
|
sample_task: Task,
|
|
) -> Task:
|
|
"""Create and persist a task for tests."""
|
|
await task_repo.create(sample_task)
|
|
await session.commit()
|
|
return sample_task
|
|
|
|
|
|
@pytest.fixture
|
|
async def tasks_with_statuses(
|
|
session: AsyncSession,
|
|
task_repo: SqlAlchemyTaskRepository,
|
|
task_workspace: UUID,
|
|
) -> tuple[Task, Task, Task]:
|
|
"""Create tasks with each status type for filtering tests."""
|
|
open_task = Task(
|
|
id=uuid4(),
|
|
workspace_id=task_workspace,
|
|
meeting_id=None,
|
|
action_item_id=None,
|
|
text="Open task",
|
|
status=TaskStatus.OPEN,
|
|
priority=1,
|
|
created_at=datetime.now(UTC),
|
|
updated_at=datetime.now(UTC),
|
|
)
|
|
done_task = Task(
|
|
id=uuid4(),
|
|
workspace_id=task_workspace,
|
|
meeting_id=None,
|
|
action_item_id=None,
|
|
text="Done task",
|
|
status=TaskStatus.DONE,
|
|
priority=1,
|
|
created_at=datetime.now(UTC),
|
|
updated_at=datetime.now(UTC),
|
|
)
|
|
dismissed_task = Task(
|
|
id=uuid4(),
|
|
workspace_id=task_workspace,
|
|
meeting_id=None,
|
|
action_item_id=None,
|
|
text="Dismissed task",
|
|
status=TaskStatus.DISMISSED,
|
|
priority=1,
|
|
created_at=datetime.now(UTC),
|
|
updated_at=datetime.now(UTC),
|
|
)
|
|
await task_repo.create(open_task)
|
|
await task_repo.create(done_task)
|
|
await task_repo.create(dismissed_task)
|
|
await session.commit()
|
|
return open_task, done_task, dismissed_task
|
|
|
|
|
|
# ============================================================================
|
|
# Audio Fixtures (for ASR integration tests)
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def audio_fixture_path() -> Path:
|
|
"""Path to the test audio fixture.
|
|
|
|
Returns path to tests/fixtures/sample_discord.wav (16kHz mono PCM).
|
|
Skips test if fixture file is not found.
|
|
"""
|
|
path = Path(__file__).parent.parent / "fixtures" / "sample_discord.wav"
|
|
if not path.exists():
|
|
pytest.skip(f"Test audio fixture not found: {path}")
|
|
return path
|
|
|
|
|
|
@pytest.fixture
|
|
def audio_samples(audio_fixture_path: Path) -> NDArray[np.float32]:
|
|
"""Load audio samples from fixture file.
|
|
|
|
Returns first 10 seconds as float32 array normalized to [-1, 1].
|
|
"""
|
|
import wave
|
|
|
|
with wave.open(str(audio_fixture_path), "rb") as wav:
|
|
assert wav.getsampwidth() == 2, "Expected 16-bit audio"
|
|
assert wav.getnchannels() == 1, "Expected mono audio"
|
|
assert wav.getframerate() == SAMPLE_RATE, f"Expected {SAMPLE_RATE}Hz"
|
|
|
|
# Read limited samples for faster testing
|
|
n_frames = min(wav.getnframes(), MAX_AUDIO_SAMPLES)
|
|
raw_data = wav.readframes(n_frames)
|
|
|
|
# Convert to float32 normalized
|
|
samples = np.frombuffer(raw_data, dtype=np.int16).astype(np.float32)
|
|
samples /= 32768.0 # Normalize int16 to [-1, 1]
|
|
return samples
|