640 lines
24 KiB
Python
640 lines
24 KiB
Python
"""Integration tests for WhisperPyTorchEngine.
|
|
|
|
These tests verify that the PyTorch-based Whisper engine can actually
|
|
load models and transcribe audio. Unlike mock-based unit tests, these
|
|
tests exercise the real transcription pipeline.
|
|
|
|
Requirements:
|
|
- openai-whisper package installed
|
|
- CPU-only (no GPU required)
|
|
- Internet connection for first model download
|
|
|
|
Test audio fixture:
|
|
Uses tests/fixtures/sample_discord.wav (16kHz mono PCM)
|
|
Fixtures defined in conftest.py: audio_fixture_path, audio_samples
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Final
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from .conftest import MAX_AUDIO_SECONDS, SAMPLE_RATE
|
|
|
|
if TYPE_CHECKING:
|
|
from numpy.typing import NDArray
|
|
|
|
# ============================================================================
|
|
# Test Constants
|
|
# ============================================================================
|
|
|
|
MODEL_SIZE_TINY: Final[str] = "tiny"
|
|
DEVICE_CPU: Final[str] = "cpu"
|
|
COMPUTE_TYPE_FLOAT32: Final[str] = "float32"
|
|
|
|
|
|
def _check_whisper_available() -> bool:
|
|
"""Check if openai-whisper is available.
|
|
|
|
Note: There's a package conflict with graphite's 'whisper' database package.
|
|
We check for 'load_model' attribute to verify it's the correct whisper.
|
|
"""
|
|
try:
|
|
import whisper
|
|
|
|
# Verify it's OpenAI's whisper, not graphite's whisper database
|
|
return hasattr(whisper, "load_model")
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
# Provide informative skip message
|
|
_WHISPER_SKIP_REASON = (
|
|
"openai-whisper not installed (note: 'whisper' package exists but is "
|
|
"graphite's database, not OpenAI's speech recognition)"
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Integration Tests - Core Functionality
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.integration
|
|
@pytest.mark.skipif(not _check_whisper_available(), reason=_WHISPER_SKIP_REASON)
|
|
class TestWhisperPyTorchEngineIntegration:
|
|
"""Integration tests for WhisperPyTorchEngine with real model loading."""
|
|
|
|
def test_engine_creation(self) -> None:
|
|
"""Test engine can be created with CPU device."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(
|
|
device=DEVICE_CPU,
|
|
compute_type=COMPUTE_TYPE_FLOAT32,
|
|
)
|
|
|
|
assert engine.device == DEVICE_CPU, "Expected CPU device"
|
|
assert engine.compute_type == COMPUTE_TYPE_FLOAT32, "Expected float32 compute type"
|
|
assert engine.model_size is None, "Expected model size to be unset before load_model"
|
|
assert engine.is_loaded is False, "Expected engine to be unloaded initially"
|
|
|
|
def test_model_loading(self) -> None:
|
|
"""Test tiny model can be loaded on CPU."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(
|
|
device=DEVICE_CPU,
|
|
compute_type=COMPUTE_TYPE_FLOAT32,
|
|
)
|
|
|
|
# Load model with size
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
assert engine.is_loaded is True, "Expected model to be loaded"
|
|
assert engine.model_size == MODEL_SIZE_TINY, "Expected model size to match"
|
|
|
|
# Unload model
|
|
engine.unload()
|
|
assert engine.is_loaded is False, "Expected engine to be unloaded"
|
|
|
|
def test_transcription_produces_text(
|
|
self,
|
|
audio_samples: NDArray[np.float32],
|
|
) -> None:
|
|
"""Test transcription produces non-empty text from real audio."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
try:
|
|
results = list(engine.transcribe(audio_samples))
|
|
assert len(results) > 0, "Expected at least one transcription segment"
|
|
|
|
first_result = results[0]
|
|
assert hasattr(first_result, "text"), "Expected text attribute on result"
|
|
assert hasattr(first_result, "start"), "Expected start attribute on result"
|
|
assert hasattr(first_result, "end"), "Expected end attribute on result"
|
|
assert hasattr(first_result, "language"), "Expected language attribute on result"
|
|
assert (
|
|
len(first_result.text.strip()) > 0
|
|
), "Expected non-empty transcription text"
|
|
assert first_result.start >= 0.0, "Expected non-negative start time"
|
|
assert first_result.end > first_result.start, "Expected end > start"
|
|
assert (
|
|
first_result.end <= MAX_AUDIO_SECONDS + 1.0
|
|
), "Expected end time within audio duration buffer"
|
|
finally:
|
|
engine.unload()
|
|
|
|
def test_transcription_with_word_timings(
|
|
self,
|
|
audio_samples: NDArray[np.float32],
|
|
) -> None:
|
|
"""Test transcription produces word-level timings."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(
|
|
device=DEVICE_CPU,
|
|
compute_type=COMPUTE_TYPE_FLOAT32,
|
|
)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
try:
|
|
results = list(engine.transcribe(audio_samples))
|
|
assert len(results) > 0
|
|
|
|
# Get first result with word timings
|
|
first_result = results[0]
|
|
assert hasattr(first_result, "words"), "Expected words attribute in result"
|
|
assert len(first_result.words) > 0, "Expected word-level timings in first result"
|
|
|
|
# Verify first word timing structure
|
|
first_word = first_result.words[0]
|
|
assert hasattr(first_word, "word"), "Expected word attribute"
|
|
assert hasattr(first_word, "start"), "Expected start attribute"
|
|
assert hasattr(first_word, "end"), "Expected end attribute"
|
|
assert first_word.end >= first_word.start, "Expected end >= start"
|
|
|
|
finally:
|
|
engine.unload()
|
|
|
|
def test_transcribe_file_helper(
|
|
self,
|
|
audio_fixture_path: Path,
|
|
) -> None:
|
|
"""Test transcribe_file helper method works."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(
|
|
device=DEVICE_CPU,
|
|
compute_type=COMPUTE_TYPE_FLOAT32,
|
|
)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
try:
|
|
# Use transcribe_file helper
|
|
results = list(engine.transcribe_file(audio_fixture_path))
|
|
|
|
# Verify we got results
|
|
assert len(results) > 0, "Expected transcription results from file"
|
|
|
|
# Verify text was produced in first result
|
|
first_result = results[0]
|
|
assert len(first_result.text.strip()) > 0, "Expected non-empty transcription text"
|
|
|
|
finally:
|
|
engine.unload()
|
|
|
|
def test_language_detection(
|
|
self,
|
|
audio_samples: NDArray[np.float32],
|
|
) -> None:
|
|
"""Test language is detected from audio."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(
|
|
device=DEVICE_CPU,
|
|
compute_type=COMPUTE_TYPE_FLOAT32,
|
|
)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
try:
|
|
results = list(engine.transcribe(audio_samples))
|
|
assert len(results) > 0, "Expected at least one transcription segment"
|
|
|
|
# Verify language was detected
|
|
first_result = results[0]
|
|
assert first_result.language is not None, "Expected detected language"
|
|
assert len(first_result.language) == 2, "Expected 2-letter language code"
|
|
|
|
finally:
|
|
engine.unload()
|
|
|
|
def test_transcribe_without_model_raises(self) -> None:
|
|
"""Test transcribing without loading model raises RuntimeError."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(
|
|
device=DEVICE_CPU,
|
|
compute_type=COMPUTE_TYPE_FLOAT32,
|
|
)
|
|
|
|
# Don't load model
|
|
assert engine.is_loaded is False
|
|
|
|
# Attempt to transcribe should raise
|
|
dummy_audio = np.zeros(SAMPLE_RATE, dtype=np.float32)
|
|
with pytest.raises(RuntimeError, match="Model not loaded"):
|
|
list(engine.transcribe(dummy_audio))
|
|
|
|
|
|
# ============================================================================
|
|
# Edge Case Tests
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.integration
|
|
@pytest.mark.skipif(not _check_whisper_available(), reason=_WHISPER_SKIP_REASON)
|
|
class TestWhisperPyTorchEngineEdgeCases:
|
|
"""Edge case tests for WhisperPyTorchEngine."""
|
|
|
|
def test_empty_audio_returns_empty_list(self) -> None:
|
|
"""Test transcribing empty audio returns empty list."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
empty_audio = np.array([], dtype=np.float32)
|
|
|
|
# Whisper handles empty audio gracefully by returning empty results
|
|
results = list(engine.transcribe(empty_audio))
|
|
assert results == [], "Expected empty list for empty audio"
|
|
|
|
engine.unload()
|
|
|
|
def test_very_short_audio_handled(self) -> None:
|
|
"""Test transcribing very short audio (< 1 second) is handled."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
try:
|
|
# 0.5 seconds of silence
|
|
short_audio = np.zeros(SAMPLE_RATE // 2, dtype=np.float32)
|
|
results = list(engine.transcribe(short_audio))
|
|
|
|
# Should handle without crashing
|
|
assert isinstance(results, list)
|
|
finally:
|
|
engine.unload()
|
|
|
|
def test_silent_audio_produces_minimal_output(self) -> None:
|
|
"""Test transcribing silent audio produces minimal/no speech output."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
try:
|
|
# 3 seconds of silence
|
|
silent_audio = np.zeros(SAMPLE_RATE * 3, dtype=np.float32)
|
|
results = list(engine.transcribe(silent_audio))
|
|
|
|
# Silent audio should produce a valid (possibly empty) result list
|
|
assert isinstance(results, list)
|
|
finally:
|
|
engine.unload()
|
|
|
|
def test_audio_with_clipping_handled(self) -> None:
|
|
"""Test audio with extreme values (clipping) is handled."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
try:
|
|
# Create clipped audio (values at ±1.0)
|
|
clipped_audio = np.ones(SAMPLE_RATE * 2, dtype=np.float32)
|
|
clipped_audio[::2] = -1.0 # Alternating +1/-1 (harsh noise)
|
|
|
|
results = list(engine.transcribe(clipped_audio))
|
|
|
|
# Should handle without crashing
|
|
assert isinstance(results, list)
|
|
finally:
|
|
engine.unload()
|
|
|
|
def test_audio_outside_normal_range_handled(self) -> None:
|
|
"""Test audio with values outside [-1, 1] range is handled."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
try:
|
|
# Audio with values outside normal range
|
|
rng = np.random.default_rng(42)
|
|
loud_audio = rng.uniform(-5.0, 5.0, SAMPLE_RATE * 2).astype(np.float32)
|
|
|
|
results = list(engine.transcribe(loud_audio))
|
|
|
|
# Should handle without crashing (whisper normalizes internally)
|
|
assert isinstance(results, list)
|
|
finally:
|
|
engine.unload()
|
|
|
|
def test_nan_values_in_audio_raises_error(self) -> None:
|
|
"""Test audio containing NaN values raises ValueError during decoding."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
# Audio with NaN values causes invalid logits in whisper decoding
|
|
audio_with_nan = np.zeros(SAMPLE_RATE, dtype=np.float32)
|
|
audio_with_nan[100:200] = np.nan
|
|
|
|
with pytest.raises(ValueError, match="invalid values"):
|
|
list(engine.transcribe(audio_with_nan))
|
|
|
|
engine.unload()
|
|
|
|
|
|
# ============================================================================
|
|
# Functional Scenario Tests
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.integration
|
|
@pytest.mark.skipif(not _check_whisper_available(), reason=_WHISPER_SKIP_REASON)
|
|
class TestWhisperPyTorchEngineFunctionalScenarios:
|
|
"""Functional scenario tests for WhisperPyTorchEngine."""
|
|
|
|
def test_multiple_sequential_transcriptions(
|
|
self,
|
|
audio_samples: NDArray[np.float32],
|
|
) -> None:
|
|
"""Test multiple transcriptions with same engine instance."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
try:
|
|
# First transcription
|
|
results1 = list(engine.transcribe(audio_samples))
|
|
assert len(results1) > 0, "Expected results from first transcription"
|
|
text1 = results1[0].text
|
|
|
|
# Second transcription (should produce consistent results)
|
|
results2 = list(engine.transcribe(audio_samples))
|
|
assert len(results2) > 0, "Expected results from second transcription"
|
|
text2 = results2[0].text
|
|
|
|
# First result text should be identical for same input
|
|
assert text1 == text2, "Same audio should produce same transcription"
|
|
|
|
# Third transcription with shorter audio (first 3 seconds)
|
|
short_audio = audio_samples[: SAMPLE_RATE * 3]
|
|
results3 = list(engine.transcribe(short_audio))
|
|
|
|
assert isinstance(results3, list), "Expected list result from short audio"
|
|
finally:
|
|
engine.unload()
|
|
|
|
def test_transcription_with_language_hint(
|
|
self,
|
|
audio_samples: NDArray[np.float32],
|
|
) -> None:
|
|
"""Test transcription with explicit language specification."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
try:
|
|
# Transcribe with English hint
|
|
results = list(engine.transcribe(audio_samples, language="en"))
|
|
|
|
assert len(results) > 0, "Expected transcription results with language hint"
|
|
# Language should match hint
|
|
assert results[0].language == "en", "Expected language to match hint"
|
|
finally:
|
|
engine.unload()
|
|
|
|
def test_model_reload_behavior(self) -> None:
|
|
"""Test loading different model sizes sequentially."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
|
|
# Load tiny
|
|
engine.load_model("tiny")
|
|
assert engine.model_size == "tiny", "Expected tiny model size"
|
|
assert engine.is_loaded is True, "Expected model to be loaded"
|
|
|
|
# Unload
|
|
engine.unload()
|
|
assert engine.is_loaded is False, "Expected model to be unloaded"
|
|
|
|
# Load base (different size)
|
|
engine.load_model("base")
|
|
assert engine.model_size == "base", "Expected base model size"
|
|
assert engine.is_loaded is True, "Expected model to be loaded"
|
|
|
|
engine.unload()
|
|
|
|
def test_multiple_load_unload_cycles(self) -> None:
|
|
"""Test multiple load/unload cycles don't cause issues."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
|
|
# Cycle 1
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
assert engine.is_loaded is True, "Expected model to be loaded (cycle 1)"
|
|
engine.unload()
|
|
assert engine.is_loaded is False, "Expected model to be unloaded (cycle 1)"
|
|
|
|
# Cycle 2
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
assert engine.is_loaded is True, "Expected model to be loaded (cycle 2)"
|
|
engine.unload()
|
|
assert engine.is_loaded is False, "Expected model to be unloaded (cycle 2)"
|
|
|
|
# Cycle 3
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
assert engine.is_loaded is True, "Expected model to be loaded (cycle 3)"
|
|
engine.unload()
|
|
assert engine.is_loaded is False, "Expected model to be unloaded (cycle 3)"
|
|
|
|
def test_unload_without_load_is_safe(self) -> None:
|
|
"""Test calling unload without loading is safe."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
|
|
# Should not raise
|
|
engine.unload()
|
|
engine.unload() # Multiple unloads should be safe
|
|
|
|
assert engine.is_loaded is False
|
|
|
|
def test_transcription_timing_accuracy(
|
|
self,
|
|
audio_samples: NDArray[np.float32],
|
|
) -> None:
|
|
"""Test that segment timings are accurate and non-overlapping."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
try:
|
|
results = list(engine.transcribe(audio_samples))
|
|
|
|
# Verify at least one result
|
|
assert len(results) >= 1, "Expected at least one transcription segment"
|
|
|
|
# Verify first segment has valid timing
|
|
first_segment = results[0]
|
|
assert first_segment.start >= 0.0, "Expected non-negative start time"
|
|
assert first_segment.end > first_segment.start, "Expected end > start"
|
|
assert first_segment.end <= MAX_AUDIO_SECONDS + 1.0, "Expected reasonable end time"
|
|
finally:
|
|
engine.unload()
|
|
|
|
|
|
# ============================================================================
|
|
# Error Handling Tests
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.integration
|
|
@pytest.mark.skipif(not _check_whisper_available(), reason=_WHISPER_SKIP_REASON)
|
|
class TestWhisperPyTorchEngineErrorHandling:
|
|
"""Error handling tests for WhisperPyTorchEngine."""
|
|
|
|
def test_invalid_model_size_raises(self) -> None:
|
|
"""Test loading invalid model size raises ValueError."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
|
|
with pytest.raises(ValueError, match="Invalid model size"):
|
|
engine.load_model("nonexistent_model")
|
|
|
|
def test_transcribe_file_nonexistent_raises(self) -> None:
|
|
"""Test transcribing nonexistent file raises appropriate error."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
try:
|
|
nonexistent_path = Path("/nonexistent/path/audio.wav")
|
|
|
|
with pytest.raises((FileNotFoundError, RuntimeError, OSError), match=".*"):
|
|
list(engine.transcribe_file(nonexistent_path))
|
|
finally:
|
|
engine.unload()
|
|
|
|
def test_double_load_overwrites_model(self) -> None:
|
|
"""Test loading model twice overwrites previous model."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
|
|
engine.load_model("tiny")
|
|
assert engine.model_size == "tiny", "Expected tiny model size"
|
|
|
|
# Load again without unload
|
|
engine.load_model("base")
|
|
assert engine.model_size == "base", "Expected base model size"
|
|
assert engine.is_loaded is True, "Expected model to be loaded"
|
|
|
|
engine.unload()
|
|
|
|
|
|
# ============================================================================
|
|
# Compute Type Tests
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.integration
|
|
@pytest.mark.skipif(not _check_whisper_available(), reason=_WHISPER_SKIP_REASON)
|
|
class TestWhisperPyTorchEngineComputeTypes:
|
|
"""Test different compute type configurations."""
|
|
|
|
def test_float32_compute_type(self) -> None:
|
|
"""Test float32 compute type works on CPU."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type="float32")
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
assert engine.compute_type == "float32", "Expected float32 compute type"
|
|
assert engine.is_loaded is True, "Expected model to be loaded"
|
|
|
|
engine.unload()
|
|
|
|
def test_int8_normalized_to_float32_on_cpu(self) -> None:
|
|
"""Test int8 is normalized to float32 on CPU."""
|
|
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
|
|
|
engine = WhisperPyTorchEngine(device=DEVICE_CPU, compute_type="int8")
|
|
|
|
# int8 not supported on CPU, should normalize to float32
|
|
assert engine.compute_type == "float32"
|
|
|
|
|
|
# ============================================================================
|
|
# Factory Integration Tests
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.integration
|
|
@pytest.mark.skipif(not _check_whisper_available(), reason=_WHISPER_SKIP_REASON)
|
|
class TestAsrFactoryIntegration:
|
|
"""Integration tests for ASR factory with real engine creation."""
|
|
|
|
def test_factory_creates_cpu_engine(self) -> None:
|
|
"""Test factory creates working CPU engine."""
|
|
from noteflow.infrastructure.asr.factory import create_asr_engine
|
|
|
|
engine = create_asr_engine(
|
|
device=DEVICE_CPU,
|
|
compute_type=COMPUTE_TYPE_FLOAT32,
|
|
)
|
|
|
|
# Factory should return a working engine
|
|
assert engine is not None, "Expected engine instance"
|
|
assert engine.device == DEVICE_CPU, "Expected CPU device"
|
|
# model_size is None until load_model is called
|
|
assert engine.model_size is None, "Expected model size to be unset before load_model"
|
|
|
|
def test_factory_auto_device_resolves_to_cpu(self) -> None:
|
|
"""Test auto device resolves to CPU when no GPU available."""
|
|
from noteflow.infrastructure.asr.factory import create_asr_engine
|
|
|
|
# In CI/test environment without GPU, should fall back to CPU
|
|
engine = create_asr_engine(
|
|
device="auto",
|
|
compute_type=COMPUTE_TYPE_FLOAT32,
|
|
)
|
|
|
|
assert engine is not None, "Expected engine instance"
|
|
# Device should be resolved (not "auto")
|
|
assert engine.device in ("cpu", "cuda", "rocm", "mps"), "Expected resolved device"
|
|
|
|
def test_factory_engine_can_transcribe(
|
|
self,
|
|
audio_samples: NDArray[np.float32],
|
|
) -> None:
|
|
"""Test factory-created engine can actually transcribe."""
|
|
from noteflow.infrastructure.asr.factory import create_asr_engine
|
|
|
|
engine = create_asr_engine(device=DEVICE_CPU, compute_type=COMPUTE_TYPE_FLOAT32)
|
|
engine.load_model(MODEL_SIZE_TINY)
|
|
|
|
try:
|
|
results = list(engine.transcribe(audio_samples))
|
|
|
|
assert len(results) > 0, "Expected transcription results"
|
|
first_result = results[0]
|
|
assert len(first_result.text.strip()) > 0, "Expected non-empty transcription"
|
|
finally:
|
|
engine.unload()
|