- Added bind mount for .venv directory in devcontainer to persist Python virtual environment across container rebuilds - Enabled updateRemoteUserUID for proper file permissions in devcontainer - Normalized peer dependency flags in package-lock.json (removed inconsistent "peer": true from core dependencies, added to test-only dependencies) - Added empty codex file placeholder - Created comprehensive
36 KiB
36 KiB
ROCm Support: Architecture & Code Samples
This document provides architecture diagrams and concrete code samples for the ROCm integration.
Alignment note: the current codebase already defines AsrEngine in
src/noteflow/infrastructure/asr/protocols.py and uses AsrResult from
src/noteflow/infrastructure/asr/dto.py. Prefer extending those rather than
adding parallel protocol/DTO types unless we plan a broader layering refactor.
System Architecture
Current State (CUDA Only)
┌─────────────────────────────────────────────────────────────────────┐
│ gRPC Server │
│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────────┐ │
│ │ StreamingMixin │ │ DiarizationMixin│ │ AsrConfigMixin │ │
│ └────────┬────────┘ └────────┬────────┘ └──────────┬──────────┘ │
└───────────┼─────────────────────┼─────────────────────┼─────────────┘
│ │ │
▼ ▼ │
┌───────────────────────┐ ┌───────────────────────┐ │
│ FasterWhisperEngine │ │ DiarizationEngine │ │
│ (CUDA/CPU only) │ │ (CUDA/CPU/MPS) │ │
│ │ │ │ │
│ - device: "cuda"|"cpu"│ │ - device: auto-detect │ │
│ - uses: CTranslate2 │ │ - uses: PyTorch │ │
└───────────────────────┘ └───────────────────────┘ │
│ │ │
▼ ▼ ▼
┌───────────────────────────────────────────────────────────────────┐
│ torch.cuda (CUDA or ROCm HIP build) │
└───────────────────────────────────────────────────────────────────┘
Target State (Multi-Backend)
┌─────────────────────────────────────────────────────────────────────┐
│ gRPC Server │
│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────────┐ │
│ │ StreamingMixin │ │ DiarizationMixin│ │ AsrConfigMixin │ │
│ └────────┬────────┘ └────────┬────────┘ └──────────┬──────────┘ │
└───────────┼─────────────────────┼─────────────────────┼─────────────┘
│ │ │
▼ │ │
┌───────────────────────────────┐ │ ┌──────────────────────────────┐
│ AsrEngineFactory │ │ │ AsrEngineManager │
│ ┌─────────────────────────┐ │ │ │ - detect_cuda_available() │
│ │ create_asr_engine() │ │ │ │ - detect_rocm_available() │ ◄─┘
│ │ - auto-detect backend │ │ │ │ - build_capabilities() │
│ │ - fallback logic │ │ │ └──────────────────────────────┘
│ └─────────────────────────┘ │ │
└───────────────────────────────┘ │
│ │
┌───────┴───────┬─────────────┤
│ │ │
▼ ▼ ▼
┌─────────┐ ┌─────────┐ ┌─────────────────────┐
│Faster │ │Faster │ │WhisperPyTorch │
│Whisper │ │Whisper │ │Engine │
│Engine │ │RocmEng │ │(universal fallback) │
│(CUDA/CPU)│ │(ROCm) │ │ │
└─────────┘ └─────────┘ └─────────────────────┘
│ │ │
│ │ │
▼ ▼ ▼
┌───────────────────────────────────────────────────┐
│ GPU Detection Layer │
│ ┌─────────────────────────────────────────────┐ │
│ │ detect_gpu_backend() -> GpuBackend │ │
│ │ - CUDA: torch.cuda + no HIP │ │
│ │ - ROCM: torch.cuda + torch.version.hip │ │
│ │ - MPS: torch.backends.mps │ │
│ │ - NONE: no GPU available │ │
│ └─────────────────────────────────────────────┘ │
└───────────────────────────────────────────────────┘
│ │ │
▼ ▼ ▼
┌───────────────┐ ┌───────────────┐ ┌───────────────┐
│ NVIDIA CUDA │ │ AMD ROCm/HIP │ │ CPU/fallback │
└───────────────┘ └───────────────┘ └───────────────┘
Module Structure
New Modules
src/noteflow/
├── domain/
│ └── ports/
│ ├── gpu.py # NEW: GpuBackend enum, GpuInfo
│ └── asr.py # OPTIONAL: only if relocating AsrEngine protocol
│
├── infrastructure/
│ ├── gpu/ # NEW: GPU detection module
│ │ ├── __init__.py
│ │ └── detection.py # detect_gpu_backend(), get_gpu_info()
│ │
│ └── asr/
│ ├── engine.py # EXISTING: FasterWhisperEngine (refactored)
│ ├── protocols.py # EXISTING: AsrEngine protocol (extend)
│ ├── pytorch_engine.py # NEW: WhisperPyTorchEngine
│ ├── rocm_engine.py # NEW: FasterWhisperRocmEngine
│ └── factory.py # NEW: create_asr_engine()
Code Samples
1. Domain Types (domain/ports/gpu.py)
"""GPU backend types and detection protocol."""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import Protocol
class GpuBackend(str, Enum):
"""Detected GPU backend type."""
NONE = "none"
CUDA = "cuda"
ROCM = "rocm"
MPS = "mps"
@dataclass(frozen=True)
class GpuInfo:
"""Information about detected GPU."""
backend: GpuBackend
device_name: str
vram_total_mb: int
driver_version: str
architecture: str | None = None # e.g., "gfx1100" for AMD
class GpuDetectionProtocol(Protocol):
"""Protocol for GPU detection implementations."""
def detect_backend(self) -> GpuBackend:
"""Detect the available GPU backend."""
...
def get_info(self) -> GpuInfo | None:
"""Get detailed GPU information."""
...
def is_supported_for_asr(self) -> bool:
"""Check if GPU is supported for ASR workloads."""
...
2. ASR Engine Protocol (infrastructure/asr/protocols.py)
"""ASR protocols defining contracts for ASR components."""
from __future__ import annotations
from collections.abc import Iterator
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from pathlib import Path
import numpy as np
from numpy.typing import NDArray
from noteflow.infrastructure.asr.dto import AsrResult
class AsrEngine(Protocol):
"""Protocol for ASR engine implementations.
All ASR engines must implement this interface to be used
by the engine manager and gRPC handlers.
Implementations:
- FasterWhisperEngine: CUDA/CPU via CTranslate2
- FasterWhisperRocmEngine: ROCm via CTranslate2-ROCm fork
- WhisperPyTorchEngine: Universal via openai-whisper
"""
@property
def device(self) -> str:
"""Return the requested device ("cpu", "cuda", "rocm")."""
...
@property
def compute_type(self) -> str:
"""Return the compute precision ("int8", "float16", "float32")."""
...
@property
def model_size(self) -> str | None:
"""Return the loaded model size, or None if not loaded."""
...
@property
def is_loaded(self) -> bool:
"""Return True if model is loaded and ready for inference."""
...
def load_model(self, model_size: str = "base") -> None:
"""Load the specified Whisper model."""
...
def unload(self) -> None:
"""Unload the model and free GPU/CPU resources."""
...
def transcribe(
self,
audio: NDArray[np.float32],
language: str | None = None,
) -> Iterator[AsrResult]:
"""Transcribe audio samples.
Args:
audio: Audio samples as float32 array, 16kHz mono, normalized to [-1, 1].
language: Optional BCP-47 language code (auto-detect if None).
Yields:
AsrResult for each detected segment.
"""
...
def transcribe_file(
self,
audio_path: Path,
*,
language: str | None = None,
) -> Iterator[AsrResult]:
"""Transcribe audio file.
Args:
audio_path: Path to audio file (WAV, MP3, FLAC, etc.)
language: Optional language code.
Yields:
AsrResult for each detected segment.
"""
...
3. GPU Detection (infrastructure/gpu/detection.py)
"""GPU backend detection utilities."""
from __future__ import annotations
import os
from functools import cache
from noteflow.domain.ports.gpu import GpuBackend, GpuInfo
from noteflow.infrastructure.logging import get_logger
logger = get_logger(__name__)
# Example AMD GPU architectures; keep in sync with AMD ROCm support docs
SUPPORTED_AMD_ARCHITECTURES: frozenset[str] = frozenset({
# CDNA (Instinct)
"gfx906", # MI50
"gfx908", # MI100
"gfx90a", # MI210, MI250
"gfx942", # MI300X
# RDNA 2
"gfx1030", # RX 6800, 6900
# RDNA 3
"gfx1100", # RX 7900 XTX
"gfx1101", # RX 7900 XT
"gfx1102", # RX 7600
})
@cache
def detect_gpu_backend() -> GpuBackend:
"""Detect the available GPU backend.
Results are cached for performance.
Returns:
GpuBackend enum indicating the detected backend.
"""
try:
import torch
except ImportError:
logger.debug("PyTorch not installed, no GPU backend available")
return GpuBackend.NONE
# Check CUDA/ROCm availability
if torch.cuda.is_available():
# Distinguish between CUDA and ROCm via HIP version
if hasattr(torch.version, "hip") and torch.version.hip:
logger.info("ROCm/HIP backend detected", version=torch.version.hip)
return GpuBackend.ROCM
logger.info("CUDA backend detected", version=torch.version.cuda)
return GpuBackend.CUDA
# Check Apple Metal Performance Shaders
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
logger.info("MPS backend detected")
return GpuBackend.MPS
logger.debug("No GPU backend available, using CPU")
return GpuBackend.NONE
def get_gpu_info() -> GpuInfo | None:
"""Get detailed GPU information.
Returns:
GpuInfo if a GPU is available, None otherwise.
"""
backend = detect_gpu_backend()
if backend == GpuBackend.NONE:
return None
import torch
if backend in (GpuBackend.CUDA, GpuBackend.ROCM):
try:
props = torch.cuda.get_device_properties(0)
vram_mb = props.total_memory // (1024 * 1024)
# Get driver version
if backend == GpuBackend.ROCM:
driver_version = str(torch.version.hip) if torch.version.hip else "unknown"
# On ROCm, props.name may include a gfx ID; parse if present.
architecture = props.name if props.name.startswith("gfx") else None
else:
driver_version = torch.version.cuda or "unknown"
architecture = f"sm_{props.major}{props.minor}"
return GpuInfo(
backend=backend,
device_name=props.name,
vram_total_mb=vram_mb,
driver_version=driver_version,
architecture=architecture,
)
except RuntimeError as e:
logger.warning("Failed to get GPU properties", error=str(e))
return None
if backend == GpuBackend.MPS:
return GpuInfo(
backend=backend,
device_name="Apple Metal",
vram_total_mb=0, # MPS doesn't expose VRAM
driver_version="mps",
architecture=None,
)
return None
def is_rocm_architecture_supported(architecture: str | None) -> bool:
"""Check if AMD GPU architecture is officially supported.
Args:
architecture: GPU architecture string (e.g., "gfx1100")
Returns:
True if supported, False otherwise.
"""
if architecture is None:
return False
# Check for override (allows unofficial GPUs)
if os.environ.get("HSA_OVERRIDE_GFX_VERSION"):
return True
return architecture in SUPPORTED_AMD_ARCHITECTURES
def is_ctranslate2_rocm_available() -> bool:
"""Check if CTranslate2-ROCm fork is installed.
Returns:
True if the ROCm fork is available.
"""
try:
import ctranslate2
# The ROCm fork should have HIP support
# Check by attempting to create a HIP allocator
return hasattr(ctranslate2, "get_supported_compute_types")
except ImportError:
return False
def get_rocm_environment_info() -> dict[str, str]:
"""Get ROCm-related environment variables for debugging.
Returns:
Dictionary of relevant environment variables.
"""
rocm_vars = [
"HSA_OVERRIDE_GFX_VERSION",
"HIP_VISIBLE_DEVICES",
"ROCM_PATH",
"MIOPEN_USER_DB_PATH",
"MIOPEN_FIND_MODE",
"AMD_LOG_LEVEL",
]
return {var: os.environ.get(var, "") for var in rocm_vars if os.environ.get(var)}
4. Engine Factory (infrastructure/asr/factory.py)
"""ASR engine factory for backend selection."""
from __future__ import annotations
from typing import TYPE_CHECKING
from noteflow.domain.ports.gpu import GpuBackend
from noteflow.infrastructure.gpu.detection import (
detect_gpu_backend,
get_gpu_info,
is_ctranslate2_rocm_available,
is_rocm_architecture_supported,
)
from noteflow.infrastructure.logging import get_logger
if TYPE_CHECKING:
from noteflow.infrastructure.asr.protocols import AsrEngine
logger = get_logger(__name__)
class EngineCreationError(Exception):
"""Raised when ASR engine creation fails."""
def create_asr_engine(
device: str = "auto",
compute_type: str = "int8",
*,
prefer_faster_whisper: bool = True,
) -> AsrEngine:
"""Create an ASR engine for the specified device.
This factory handles:
1. Auto-detection of available GPU backends
2. Selection of appropriate engine implementation
3. Fallback to PyTorch Whisper when native engines unavailable
Args:
device: Target device ("auto", "cpu", "cuda", "rocm").
compute_type: Compute precision ("int8", "float16", "float32").
prefer_faster_whisper: If True, prefer faster-whisper over PyTorch Whisper.
faster-whisper uses CTranslate2 and is significantly faster.
Returns:
An ASR engine implementing AsrEngine.
Raises:
EngineCreationError: If engine creation fails.
Example:
>>> engine = create_asr_engine(device="auto")
>>> engine.load_model("base")
>>> for segment in engine.transcribe(audio):
... print(segment.text)
"""
resolved_device = _resolve_device(device)
logger.info(
"Creating ASR engine",
requested_device=device,
resolved_device=resolved_device,
compute_type=compute_type,
prefer_faster_whisper=prefer_faster_whisper,
)
if resolved_device == "cpu":
return _create_cpu_engine(compute_type)
if resolved_device == "cuda":
return _create_cuda_engine(compute_type, prefer_faster_whisper)
if resolved_device == "rocm":
return _create_rocm_engine(compute_type, prefer_faster_whisper)
msg = f"Unsupported device: {resolved_device}"
raise EngineCreationError(msg)
def _resolve_device(device: str) -> str:
"""Resolve 'auto' device to actual backend.
Args:
device: Requested device string.
Returns:
Resolved device string ("cpu", "cuda", or "rocm").
"""
if device != "auto":
return device
backend = detect_gpu_backend()
if backend == GpuBackend.CUDA:
return "cuda"
if backend == GpuBackend.ROCM:
# Check if ROCm architecture is supported for ASR
gpu_info = get_gpu_info()
if gpu_info and is_rocm_architecture_supported(gpu_info.architecture):
return "rocm"
else:
logger.warning(
"ROCm detected but architecture may not be supported, falling back to CPU",
architecture=gpu_info.architecture if gpu_info else "unknown",
)
return "cpu"
# MPS not supported by faster-whisper; PyTorch Whisper may work but is untested
if backend == GpuBackend.MPS:
logger.info("MPS detected but not supported for ASR, using CPU")
return "cpu"
def _create_cpu_engine(compute_type: str) -> AsrEngine:
"""Create CPU engine (always uses faster-whisper).
Args:
compute_type: Requested compute type.
Returns:
ASR engine for CPU.
"""
from noteflow.infrastructure.asr.engine import FasterWhisperEngine
# CPU only supports int8 and float32
if compute_type == "float16":
logger.debug("float16 not supported on CPU, using float32")
compute_type = "float32"
return FasterWhisperEngine(device="cpu", compute_type=compute_type)
def _create_cuda_engine(
compute_type: str,
prefer_faster_whisper: bool,
) -> AsrEngine:
"""Create CUDA engine.
Args:
compute_type: Compute precision.
prefer_faster_whisper: Whether to prefer faster-whisper.
Returns:
ASR engine for CUDA.
"""
if prefer_faster_whisper:
from noteflow.infrastructure.asr.engine import FasterWhisperEngine
return FasterWhisperEngine(device="cuda", compute_type=compute_type)
return _create_pytorch_engine("cuda", compute_type)
def _create_rocm_engine(
compute_type: str,
prefer_faster_whisper: bool,
) -> AsrEngine:
"""Create ROCm engine.
Attempts to use CTranslate2-ROCm fork if available,
falls back to PyTorch Whisper otherwise.
Args:
compute_type: Compute precision.
prefer_faster_whisper: Whether to prefer faster-whisper.
Returns:
ASR engine for ROCm.
"""
if prefer_faster_whisper and is_ctranslate2_rocm_available():
try:
from noteflow.infrastructure.asr.rocm_engine import FasterWhisperRocmEngine
logger.info("Using CTranslate2-ROCm for ASR")
return FasterWhisperRocmEngine(compute_type=compute_type)
except ImportError as e:
logger.warning(
"CTranslate2-ROCm import failed, falling back to PyTorch Whisper",
error=str(e),
)
logger.info("Using PyTorch Whisper for ROCm ASR")
return _create_pytorch_engine("cuda", compute_type) # ROCm uses "cuda" device string
def _create_pytorch_engine(device: str, compute_type: str) -> AsrEngine:
"""Create PyTorch Whisper engine (universal fallback).
Args:
device: Target device.
compute_type: Compute precision.
Returns:
PyTorch-based Whisper engine.
Raises:
EngineCreationError: If openai-whisper is not installed.
"""
try:
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
return WhisperPyTorchEngine(device=device, compute_type=compute_type)
except ImportError as e:
msg = (
"Neither CTranslate2 nor openai-whisper is available. "
"Install one of: pip install faster-whisper OR pip install openai-whisper"
)
raise EngineCreationError(msg) from e
5. PyTorch Whisper Engine (infrastructure/asr/pytorch_engine.py)
"""PyTorch-based Whisper engine (universal fallback)."""
from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING
from noteflow.infrastructure.asr.dto import AsrResult, WordTiming
from noteflow.infrastructure.logging import get_logger
if TYPE_CHECKING:
import numpy as np
from numpy.typing import NDArray
logger = get_logger(__name__)
class WhisperPyTorchEngine:
"""Pure PyTorch Whisper implementation.
Uses the official openai-whisper package for transcription.
Works on any PyTorch-supported device (CPU, CUDA, ROCm via HIP).
This engine is slower than CTranslate2-based engines but provides
universal compatibility across all GPU backends.
"""
def __init__(
self,
device: str = "cpu",
compute_type: str = "float32",
) -> None:
"""Initialize PyTorch Whisper engine.
Args:
device: Target device ("cpu" or "cuda").
For ROCm, use "cuda" - HIP handles the translation.
compute_type: Compute precision. Only "float16" and "float32"
are supported. "int8" will be treated as "float32".
"""
self._device = device
self._compute_type = self._normalize_compute_type(compute_type)
self._model_size: str = ""
self._model: whisper.Whisper | None = None # type: ignore[name-defined]
@staticmethod
def _normalize_compute_type(compute_type: str) -> str:
"""Normalize compute type for PyTorch.
PyTorch Whisper doesn't support int8, map to float32.
"""
if compute_type == "int8":
logger.debug("int8 not supported in PyTorch Whisper, using float32")
return "float32"
return compute_type
@property
def device(self) -> str:
"""Return the device this engine runs on."""
return self._device
@property
def compute_type(self) -> str:
"""Return the compute precision."""
return self._compute_type
@property
def model_size(self) -> str:
"""Return the loaded model size."""
return self._model_size
@property
def is_loaded(self) -> bool:
"""Return True if model is loaded."""
return self._model is not None
def load_model(self, model_size: str) -> None:
"""Load the specified Whisper model.
Args:
model_size: Whisper model size (e.g., "base", "small", "large-v3").
"""
import torch
import whisper
logger.info(
"Loading PyTorch Whisper model",
model_size=model_size,
device=self._device,
compute_type=self._compute_type,
)
# Load model
self._model = whisper.load_model(model_size, device=self._device)
self._model_size = model_size
# Apply compute type
if self._compute_type == "float16" and self._device != "cpu":
self._model = self._model.half()
logger.info("PyTorch Whisper model loaded successfully")
def unload(self) -> None:
"""Unload the model and free resources."""
if self._model is not None:
import gc
import torch
del self._model
self._model = None
self._model_size = ""
# Force garbage collection and clear GPU cache
gc.collect()
if self._device != "cpu":
torch.cuda.empty_cache()
logger.debug("PyTorch Whisper model unloaded")
def transcribe(
self,
audio: NDArray[np.float32],
*,
language: str | None = None,
initial_prompt: str | None = None,
) -> Iterator[AsrResult]:
"""Transcribe audio samples.
Args:
audio: Audio samples as float32 array, 16kHz mono.
language: Optional language code.
initial_prompt: Optional prompt for context.
Yields:
AsrResult for each detected segment.
"""
if self._model is None:
msg = "Model not loaded. Call load_model() first."
raise RuntimeError(msg)
# Build transcription options
options: dict[str, object] = {
"word_timestamps": True,
"fp16": self._compute_type == "float16" and self._device != "cpu",
}
if language is not None:
options["language"] = language
if initial_prompt is not None:
options["initial_prompt"] = initial_prompt
# Transcribe
result = self._model.transcribe(audio, **options)
# Convert to our segment format
for segment in result["segments"]:
words = tuple(
WordTiming(
word=w["word"],
start=w["start"],
end=w["end"],
probability=w.get("probability", 0.0),
)
for w in segment.get("words", [])
)
yield AsrResult(
text=segment["text"].strip(),
start=segment["start"],
end=segment["end"],
words=words,
language=result.get("language", "en"),
avg_logprob=segment.get("avg_logprob", 0.0),
no_speech_prob=segment.get("no_speech_prob", 0.0),
)
def transcribe_file(
self,
audio_path: Path,
*,
language: str | None = None,
) -> Iterator[AsrResult]:
"""Transcribe audio file.
Args:
audio_path: Path to audio file.
language: Optional language code.
Yields:
AsrResult for each detected segment.
"""
import whisper
# Load audio using whisper's utility
audio = whisper.load_audio(str(audio_path))
yield from self.transcribe(audio, language=language)
6. Updated ASR Device Types (application/services/asr_config/types.py)
"""Types and constants for ASR configuration."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from enum import Enum
from typing import Final
from uuid import UUID
class AsrConfigPhase(str, Enum):
"""Phases of ASR reconfiguration."""
VALIDATING = "validating"
DOWNLOADING = "downloading"
LOADING = "loading"
COMPLETED = "completed"
FAILED = "failed"
class AsrDevice(str, Enum):
"""Supported ASR devices."""
CPU = "cpu"
CUDA = "cuda"
ROCM = "rocm" # NEW
class AsrComputeType(str, Enum):
"""Supported compute types."""
INT8 = "int8"
FLOAT16 = "float16"
FLOAT32 = "float32"
# Compute types available for each device
DEVICE_COMPUTE_TYPES: Final[dict[AsrDevice, tuple[AsrComputeType, ...]]] = {
AsrDevice.CPU: (AsrComputeType.INT8, AsrComputeType.FLOAT32),
AsrDevice.CUDA: (
AsrComputeType.INT8,
AsrComputeType.FLOAT16,
AsrComputeType.FLOAT32,
),
AsrDevice.ROCM: ( # NEW
AsrComputeType.INT8,
AsrComputeType.FLOAT16,
AsrComputeType.FLOAT32,
),
}
@dataclass
class AsrConfigJob:
"""Tracks ASR reconfiguration job state."""
job_id: UUID
status: str
phase: AsrConfigPhase
progress_percent: float
error_message: str
target_model_size: str
target_device: AsrDevice
target_compute_type: AsrComputeType
task: asyncio.Task[None] | None = field(default=None, repr=False)
@dataclass(frozen=True)
class AsrCapabilities:
"""Current ASR capabilities and configuration."""
model_size: str | None
device: AsrDevice
compute_type: AsrComputeType
is_ready: bool
cuda_available: bool
rocm_available: bool # NEW
gpu_backend: str # NEW: "cuda", "rocm", "mps", or "none"
available_model_sizes: tuple[str, ...]
available_compute_types: tuple[AsrComputeType, ...]
Testing Examples
Unit Test for GPU Detection
"""Tests for GPU detection utilities."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from noteflow.domain.ports.gpu import GpuBackend
from noteflow.infrastructure.gpu.detection import (
SUPPORTED_AMD_ARCHITECTURES,
detect_gpu_backend,
get_gpu_info,
is_rocm_architecture_supported,
)
class TestDetectGpuBackend:
"""Tests for detect_gpu_backend function."""
def test_no_torch_returns_none(self) -> None:
"""Return NONE when torch is not installed."""
with patch.dict("sys.modules", {"torch": None}):
# Clear cache and reimport
detect_gpu_backend.cache_clear()
result = detect_gpu_backend()
assert result == GpuBackend.NONE
def test_cuda_without_hip_returns_cuda(self) -> None:
"""Return CUDA when CUDA available and no HIP."""
mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = True
mock_torch.version.hip = None
mock_torch.version.cuda = "12.1"
with patch.dict("sys.modules", {"torch": mock_torch}):
detect_gpu_backend.cache_clear()
result = detect_gpu_backend()
assert result == GpuBackend.CUDA
def test_cuda_with_hip_returns_rocm(self) -> None:
"""Return ROCM when HIP version is present."""
mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = True
mock_torch.version.hip = "6.0.0"
with patch.dict("sys.modules", {"torch": mock_torch}):
detect_gpu_backend.cache_clear()
result = detect_gpu_backend()
assert result == GpuBackend.ROCM
def test_mps_available_returns_mps(self) -> None:
"""Return MPS on Apple Silicon."""
mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = True
with patch.dict("sys.modules", {"torch": mock_torch}):
detect_gpu_backend.cache_clear()
result = detect_gpu_backend()
assert result == GpuBackend.MPS
class TestIsRocmArchitectureSupported:
"""Tests for ROCm architecture support check."""
@pytest.mark.parametrize(
"architecture",
list(SUPPORTED_AMD_ARCHITECTURES),
)
def test_supported_architectures(self, architecture: str) -> None:
"""All listed architectures should be supported."""
assert is_rocm_architecture_supported(architecture)
@pytest.mark.parametrize(
"architecture",
["gfx803", "gfx1010", "sm_80", None],
)
def test_unsupported_architectures(self, architecture: str | None) -> None:
"""Unsupported architectures return False."""
assert not is_rocm_architecture_supported(architecture)
def test_override_env_allows_unsupported(self) -> None:
"""HSA_OVERRIDE_GFX_VERSION allows unsupported GPUs."""
with patch.dict("os.environ", {"HSA_OVERRIDE_GFX_VERSION": "11.0.0"}):
assert is_rocm_architecture_supported("gfx803")
Integration Test for Engine Factory
"""Integration tests for ASR engine factory."""
from __future__ import annotations
import pytest
from noteflow.infrastructure.asr.factory import create_asr_engine
from noteflow.infrastructure.gpu.detection import detect_gpu_backend
from noteflow.domain.ports.gpu import GpuBackend
class TestEngineFactory:
"""Integration tests for engine factory."""
def test_cpu_engine_always_works(self) -> None:
"""CPU engine should always be creatable."""
engine = create_asr_engine(device="cpu", compute_type="int8")
assert engine.device == "cpu"
assert engine.compute_type in ("int8", "float32")
def test_auto_device_selects_available(self) -> None:
"""Auto device should select an available backend."""
engine = create_asr_engine(device="auto")
assert engine.device in ("cpu", "cuda", "rocm")
@pytest.mark.skipif(
detect_gpu_backend() != GpuBackend.CUDA,
reason="CUDA not available",
)
def test_cuda_engine_with_cuda(self) -> None:
"""CUDA engine works on NVIDIA hardware."""
engine = create_asr_engine(device="cuda")
assert engine.device == "cuda"
@pytest.mark.skipif(
detect_gpu_backend() != GpuBackend.ROCM,
reason="ROCm not available",
)
def test_rocm_engine_with_rocm(self) -> None:
"""ROCm engine works on AMD hardware."""
engine = create_asr_engine(device="rocm")
assert engine.device in ("rocm", "cuda") # May use "cuda" string internally
def test_pytorch_fallback_when_forced(self) -> None:
"""PyTorch fallback engine works when explicitly requested."""
engine = create_asr_engine(
device="cpu",
prefer_faster_whisper=False,
)
assert engine.device == "cpu"
Configuration Examples
Environment Variables
# Force specific device
NOTEFLOW_ASR_DEVICE=rocm
# Enable ROCm feature (during rollout)
NOTEFLOW_FEATURE_ROCM_ENABLED=true
# ROCm tuning
HSA_OVERRIDE_GFX_VERSION=11.0.0 # Override for unsupported GPUs
HIP_VISIBLE_DEVICES=0 # Limit to first GPU
MIOPEN_FIND_MODE=3 # Fast kernel selection
# Debug
AMD_LOG_LEVEL=3 # Verbose ROCm logging
Docker Compose with ROCm
version: "3.8"
services:
noteflow-rocm:
build:
context: .
dockerfile: docker/Dockerfile.rocm
devices:
- /dev/kfd
- /dev/dri
group_add:
- video
- render
environment:
- NOTEFLOW_ASR_DEVICE=auto
- NOTEFLOW_FEATURE_ROCM_ENABLED=true
volumes:
- ./data:/app/data
ports:
- "50051:50051"
Summary
This architecture enables:
- Transparent ROCm Support: Pure PyTorch components work unchanged
- Swappable ASR Engines: Protocol pattern allows different backends
- Graceful Fallbacks: PyTorch Whisper when native engines unavailable
- Clean Detection: Clear distinction between CUDA and ROCm
- Testability: Mock-friendly design for CI without GPUs