Files
noteflow/docs/sprints/phase-5-evolution/sprint-18.5-rocm-support/ARCHITECTURE.md
Travis Vasceannie d0d4eea847 chore: configure devcontainer Python venv persistence and normalize package-lock peer dependencies
- Added bind mount for .venv directory in devcontainer to persist Python virtual environment across container rebuilds
- Enabled updateRemoteUserUID for proper file permissions in devcontainer
- Normalized peer dependency flags in package-lock.json (removed inconsistent "peer": true from core dependencies, added to test-only dependencies)
- Added empty codex file placeholder
- Created comprehensive
2026-01-17 21:36:22 -05:00

36 KiB

ROCm Support: Architecture & Code Samples

This document provides architecture diagrams and concrete code samples for the ROCm integration.

Alignment note: the current codebase already defines AsrEngine in src/noteflow/infrastructure/asr/protocols.py and uses AsrResult from src/noteflow/infrastructure/asr/dto.py. Prefer extending those rather than adding parallel protocol/DTO types unless we plan a broader layering refactor.


System Architecture

Current State (CUDA Only)

┌─────────────────────────────────────────────────────────────────────┐
│                         gRPC Server                                  │
│  ┌─────────────────┐  ┌─────────────────┐  ┌─────────────────────┐  │
│  │ StreamingMixin  │  │ DiarizationMixin│  │ AsrConfigMixin      │  │
│  └────────┬────────┘  └────────┬────────┘  └──────────┬──────────┘  │
└───────────┼─────────────────────┼─────────────────────┼─────────────┘
            │                     │                     │
            ▼                     ▼                     │
┌───────────────────────┐ ┌───────────────────────┐     │
│ FasterWhisperEngine   │ │ DiarizationEngine     │     │
│ (CUDA/CPU only)       │ │ (CUDA/CPU/MPS)        │     │
│                       │ │                       │     │
│ - device: "cuda"|"cpu"│ │ - device: auto-detect │     │
│ - uses: CTranslate2   │ │ - uses: PyTorch       │     │
└───────────────────────┘ └───────────────────────┘     │
            │                     │                     │
            ▼                     ▼                     ▼
┌───────────────────────────────────────────────────────────────────┐
│                   torch.cuda (CUDA or ROCm HIP build)              │
└───────────────────────────────────────────────────────────────────┘

Target State (Multi-Backend)

┌─────────────────────────────────────────────────────────────────────┐
│                         gRPC Server                                  │
│  ┌─────────────────┐  ┌─────────────────┐  ┌─────────────────────┐  │
│  │ StreamingMixin  │  │ DiarizationMixin│  │ AsrConfigMixin      │  │
│  └────────┬────────┘  └────────┬────────┘  └──────────┬──────────┘  │
└───────────┼─────────────────────┼─────────────────────┼─────────────┘
            │                     │                     │
            ▼                     │                     │
┌───────────────────────────────┐ │  ┌──────────────────────────────┐
│      AsrEngineFactory         │ │  │      AsrEngineManager        │
│  ┌─────────────────────────┐  │ │  │  - detect_cuda_available()  │
│  │ create_asr_engine()     │  │ │  │  - detect_rocm_available()  │ ◄─┘
│  │ - auto-detect backend   │  │ │  │  - build_capabilities()     │
│  │ - fallback logic        │  │ │  └──────────────────────────────┘
│  └─────────────────────────┘  │ │
└───────────────────────────────┘ │
            │                     │
    ┌───────┴───────┬─────────────┤
    │               │             │
    ▼               ▼             ▼
┌─────────┐   ┌─────────┐   ┌─────────────────────┐
│Faster   │   │Faster   │   │WhisperPyTorch       │
│Whisper  │   │Whisper  │   │Engine               │
│Engine   │   │RocmEng  │   │(universal fallback) │
│(CUDA/CPU)│   │(ROCm)   │   │                     │
└─────────┘   └─────────┘   └─────────────────────┘
    │               │             │
    │               │             │
    ▼               ▼             ▼
┌───────────────────────────────────────────────────┐
│              GPU Detection Layer                   │
│  ┌─────────────────────────────────────────────┐  │
│  │ detect_gpu_backend() -> GpuBackend          │  │
│  │ - CUDA: torch.cuda + no HIP                 │  │
│  │ - ROCM: torch.cuda + torch.version.hip      │  │
│  │ - MPS:  torch.backends.mps                  │  │
│  │ - NONE: no GPU available                    │  │
│  └─────────────────────────────────────────────┘  │
└───────────────────────────────────────────────────┘
            │               │             │
            ▼               ▼             ▼
┌───────────────┐   ┌───────────────┐   ┌───────────────┐
│ NVIDIA CUDA   │   │ AMD ROCm/HIP  │   │ CPU/fallback  │
└───────────────┘   └───────────────┘   └───────────────┘

Module Structure

New Modules

src/noteflow/
├── domain/
│   └── ports/
│       ├── gpu.py          # NEW: GpuBackend enum, GpuInfo
│       └── asr.py          # OPTIONAL: only if relocating AsrEngine protocol
│
├── infrastructure/
│   ├── gpu/                # NEW: GPU detection module
│   │   ├── __init__.py
│   │   └── detection.py    # detect_gpu_backend(), get_gpu_info()
│   │
│   └── asr/
│       ├── engine.py       # EXISTING: FasterWhisperEngine (refactored)
│       ├── protocols.py    # EXISTING: AsrEngine protocol (extend)
│       ├── pytorch_engine.py  # NEW: WhisperPyTorchEngine
│       ├── rocm_engine.py     # NEW: FasterWhisperRocmEngine
│       └── factory.py         # NEW: create_asr_engine()

Code Samples

1. Domain Types (domain/ports/gpu.py)

"""GPU backend types and detection protocol."""

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Protocol


class GpuBackend(str, Enum):
    """Detected GPU backend type."""

    NONE = "none"
    CUDA = "cuda"
    ROCM = "rocm"
    MPS = "mps"


@dataclass(frozen=True)
class GpuInfo:
    """Information about detected GPU."""

    backend: GpuBackend
    device_name: str
    vram_total_mb: int
    driver_version: str
    architecture: str | None = None  # e.g., "gfx1100" for AMD


class GpuDetectionProtocol(Protocol):
    """Protocol for GPU detection implementations."""

    def detect_backend(self) -> GpuBackend:
        """Detect the available GPU backend."""
        ...

    def get_info(self) -> GpuInfo | None:
        """Get detailed GPU information."""
        ...

    def is_supported_for_asr(self) -> bool:
        """Check if GPU is supported for ASR workloads."""
        ...

2. ASR Engine Protocol (infrastructure/asr/protocols.py)

"""ASR protocols defining contracts for ASR components."""

from __future__ import annotations

from collections.abc import Iterator
from typing import TYPE_CHECKING, Protocol

if TYPE_CHECKING:
    from pathlib import Path

    import numpy as np
    from numpy.typing import NDArray
    from noteflow.infrastructure.asr.dto import AsrResult


class AsrEngine(Protocol):
    """Protocol for ASR engine implementations.

    All ASR engines must implement this interface to be used
    by the engine manager and gRPC handlers.

    Implementations:
        - FasterWhisperEngine: CUDA/CPU via CTranslate2
        - FasterWhisperRocmEngine: ROCm via CTranslate2-ROCm fork
        - WhisperPyTorchEngine: Universal via openai-whisper
    """

    @property
    def device(self) -> str:
        """Return the requested device ("cpu", "cuda", "rocm")."""
        ...

    @property
    def compute_type(self) -> str:
        """Return the compute precision ("int8", "float16", "float32")."""
        ...

    @property
    def model_size(self) -> str | None:
        """Return the loaded model size, or None if not loaded."""
        ...

    @property
    def is_loaded(self) -> bool:
        """Return True if model is loaded and ready for inference."""
        ...

    def load_model(self, model_size: str = "base") -> None:
        """Load the specified Whisper model."""
        ...

    def unload(self) -> None:
        """Unload the model and free GPU/CPU resources."""
        ...

    def transcribe(
        self,
        audio: NDArray[np.float32],
        language: str | None = None,
    ) -> Iterator[AsrResult]:
        """Transcribe audio samples.

        Args:
            audio: Audio samples as float32 array, 16kHz mono, normalized to [-1, 1].
            language: Optional BCP-47 language code (auto-detect if None).

        Yields:
            AsrResult for each detected segment.
        """
        ...

    def transcribe_file(
        self,
        audio_path: Path,
        *,
        language: str | None = None,
    ) -> Iterator[AsrResult]:
        """Transcribe audio file.

        Args:
            audio_path: Path to audio file (WAV, MP3, FLAC, etc.)
            language: Optional language code.

        Yields:
            AsrResult for each detected segment.
        """
        ...

3. GPU Detection (infrastructure/gpu/detection.py)

"""GPU backend detection utilities."""

from __future__ import annotations

import os
from functools import cache

from noteflow.domain.ports.gpu import GpuBackend, GpuInfo
from noteflow.infrastructure.logging import get_logger

logger = get_logger(__name__)

# Example AMD GPU architectures; keep in sync with AMD ROCm support docs
SUPPORTED_AMD_ARCHITECTURES: frozenset[str] = frozenset({
    # CDNA (Instinct)
    "gfx906",  # MI50
    "gfx908",  # MI100
    "gfx90a",  # MI210, MI250
    "gfx942",  # MI300X
    # RDNA 2
    "gfx1030",  # RX 6800, 6900
    # RDNA 3
    "gfx1100",  # RX 7900 XTX
    "gfx1101",  # RX 7900 XT
    "gfx1102",  # RX 7600
})


@cache
def detect_gpu_backend() -> GpuBackend:
    """Detect the available GPU backend.

    Results are cached for performance.

    Returns:
        GpuBackend enum indicating the detected backend.
    """
    try:
        import torch
    except ImportError:
        logger.debug("PyTorch not installed, no GPU backend available")
        return GpuBackend.NONE

    # Check CUDA/ROCm availability
    if torch.cuda.is_available():
        # Distinguish between CUDA and ROCm via HIP version
        if hasattr(torch.version, "hip") and torch.version.hip:
            logger.info("ROCm/HIP backend detected", version=torch.version.hip)
            return GpuBackend.ROCM

        logger.info("CUDA backend detected", version=torch.version.cuda)
        return GpuBackend.CUDA

    # Check Apple Metal Performance Shaders
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        logger.info("MPS backend detected")
        return GpuBackend.MPS

    logger.debug("No GPU backend available, using CPU")
    return GpuBackend.NONE


def get_gpu_info() -> GpuInfo | None:
    """Get detailed GPU information.

    Returns:
        GpuInfo if a GPU is available, None otherwise.
    """
    backend = detect_gpu_backend()

    if backend == GpuBackend.NONE:
        return None

    import torch

    if backend in (GpuBackend.CUDA, GpuBackend.ROCM):
        try:
            props = torch.cuda.get_device_properties(0)
            vram_mb = props.total_memory // (1024 * 1024)

            # Get driver version
            if backend == GpuBackend.ROCM:
                driver_version = str(torch.version.hip) if torch.version.hip else "unknown"
                # On ROCm, props.name may include a gfx ID; parse if present.
                architecture = props.name if props.name.startswith("gfx") else None
            else:
                driver_version = torch.version.cuda or "unknown"
                architecture = f"sm_{props.major}{props.minor}"

            return GpuInfo(
                backend=backend,
                device_name=props.name,
                vram_total_mb=vram_mb,
                driver_version=driver_version,
                architecture=architecture,
            )
        except RuntimeError as e:
            logger.warning("Failed to get GPU properties", error=str(e))
            return None

    if backend == GpuBackend.MPS:
        return GpuInfo(
            backend=backend,
            device_name="Apple Metal",
            vram_total_mb=0,  # MPS doesn't expose VRAM
            driver_version="mps",
            architecture=None,
        )

    return None


def is_rocm_architecture_supported(architecture: str | None) -> bool:
    """Check if AMD GPU architecture is officially supported.

    Args:
        architecture: GPU architecture string (e.g., "gfx1100")

    Returns:
        True if supported, False otherwise.
    """
    if architecture is None:
        return False

    # Check for override (allows unofficial GPUs)
    if os.environ.get("HSA_OVERRIDE_GFX_VERSION"):
        return True

    return architecture in SUPPORTED_AMD_ARCHITECTURES


def is_ctranslate2_rocm_available() -> bool:
    """Check if CTranslate2-ROCm fork is installed.

    Returns:
        True if the ROCm fork is available.
    """
    try:
        import ctranslate2

        # The ROCm fork should have HIP support
        # Check by attempting to create a HIP allocator
        return hasattr(ctranslate2, "get_supported_compute_types")
    except ImportError:
        return False


def get_rocm_environment_info() -> dict[str, str]:
    """Get ROCm-related environment variables for debugging.

    Returns:
        Dictionary of relevant environment variables.
    """
    rocm_vars = [
        "HSA_OVERRIDE_GFX_VERSION",
        "HIP_VISIBLE_DEVICES",
        "ROCM_PATH",
        "MIOPEN_USER_DB_PATH",
        "MIOPEN_FIND_MODE",
        "AMD_LOG_LEVEL",
    ]

    return {var: os.environ.get(var, "") for var in rocm_vars if os.environ.get(var)}

4. Engine Factory (infrastructure/asr/factory.py)

"""ASR engine factory for backend selection."""

from __future__ import annotations

from typing import TYPE_CHECKING

from noteflow.domain.ports.gpu import GpuBackend
from noteflow.infrastructure.gpu.detection import (
    detect_gpu_backend,
    get_gpu_info,
    is_ctranslate2_rocm_available,
    is_rocm_architecture_supported,
)
from noteflow.infrastructure.logging import get_logger

if TYPE_CHECKING:
    from noteflow.infrastructure.asr.protocols import AsrEngine

logger = get_logger(__name__)


class EngineCreationError(Exception):
    """Raised when ASR engine creation fails."""


def create_asr_engine(
    device: str = "auto",
    compute_type: str = "int8",
    *,
    prefer_faster_whisper: bool = True,
) -> AsrEngine:
    """Create an ASR engine for the specified device.

    This factory handles:
    1. Auto-detection of available GPU backends
    2. Selection of appropriate engine implementation
    3. Fallback to PyTorch Whisper when native engines unavailable

    Args:
        device: Target device ("auto", "cpu", "cuda", "rocm").
        compute_type: Compute precision ("int8", "float16", "float32").
        prefer_faster_whisper: If True, prefer faster-whisper over PyTorch Whisper.
            faster-whisper uses CTranslate2 and is significantly faster.

    Returns:
        An ASR engine implementing AsrEngine.

    Raises:
        EngineCreationError: If engine creation fails.

    Example:
        >>> engine = create_asr_engine(device="auto")
        >>> engine.load_model("base")
        >>> for segment in engine.transcribe(audio):
        ...     print(segment.text)
    """
    resolved_device = _resolve_device(device)

    logger.info(
        "Creating ASR engine",
        requested_device=device,
        resolved_device=resolved_device,
        compute_type=compute_type,
        prefer_faster_whisper=prefer_faster_whisper,
    )

    if resolved_device == "cpu":
        return _create_cpu_engine(compute_type)

    if resolved_device == "cuda":
        return _create_cuda_engine(compute_type, prefer_faster_whisper)

    if resolved_device == "rocm":
        return _create_rocm_engine(compute_type, prefer_faster_whisper)

    msg = f"Unsupported device: {resolved_device}"
    raise EngineCreationError(msg)


def _resolve_device(device: str) -> str:
    """Resolve 'auto' device to actual backend.

    Args:
        device: Requested device string.

    Returns:
        Resolved device string ("cpu", "cuda", or "rocm").
    """
    if device != "auto":
        return device

    backend = detect_gpu_backend()

    if backend == GpuBackend.CUDA:
        return "cuda"

    if backend == GpuBackend.ROCM:
        # Check if ROCm architecture is supported for ASR
        gpu_info = get_gpu_info()
        if gpu_info and is_rocm_architecture_supported(gpu_info.architecture):
            return "rocm"
        else:
            logger.warning(
                "ROCm detected but architecture may not be supported, falling back to CPU",
                architecture=gpu_info.architecture if gpu_info else "unknown",
            )
            return "cpu"

    # MPS not supported by faster-whisper; PyTorch Whisper may work but is untested
    if backend == GpuBackend.MPS:
        logger.info("MPS detected but not supported for ASR, using CPU")

    return "cpu"


def _create_cpu_engine(compute_type: str) -> AsrEngine:
    """Create CPU engine (always uses faster-whisper).

    Args:
        compute_type: Requested compute type.

    Returns:
        ASR engine for CPU.
    """
    from noteflow.infrastructure.asr.engine import FasterWhisperEngine

    # CPU only supports int8 and float32
    if compute_type == "float16":
        logger.debug("float16 not supported on CPU, using float32")
        compute_type = "float32"

    return FasterWhisperEngine(device="cpu", compute_type=compute_type)


def _create_cuda_engine(
    compute_type: str,
    prefer_faster_whisper: bool,
) -> AsrEngine:
    """Create CUDA engine.

    Args:
        compute_type: Compute precision.
        prefer_faster_whisper: Whether to prefer faster-whisper.

    Returns:
        ASR engine for CUDA.
    """
    if prefer_faster_whisper:
        from noteflow.infrastructure.asr.engine import FasterWhisperEngine

        return FasterWhisperEngine(device="cuda", compute_type=compute_type)

    return _create_pytorch_engine("cuda", compute_type)


def _create_rocm_engine(
    compute_type: str,
    prefer_faster_whisper: bool,
) -> AsrEngine:
    """Create ROCm engine.

    Attempts to use CTranslate2-ROCm fork if available,
    falls back to PyTorch Whisper otherwise.

    Args:
        compute_type: Compute precision.
        prefer_faster_whisper: Whether to prefer faster-whisper.

    Returns:
        ASR engine for ROCm.
    """
    if prefer_faster_whisper and is_ctranslate2_rocm_available():
        try:
            from noteflow.infrastructure.asr.rocm_engine import FasterWhisperRocmEngine

            logger.info("Using CTranslate2-ROCm for ASR")
            return FasterWhisperRocmEngine(compute_type=compute_type)
        except ImportError as e:
            logger.warning(
                "CTranslate2-ROCm import failed, falling back to PyTorch Whisper",
                error=str(e),
            )

    logger.info("Using PyTorch Whisper for ROCm ASR")
    return _create_pytorch_engine("cuda", compute_type)  # ROCm uses "cuda" device string


def _create_pytorch_engine(device: str, compute_type: str) -> AsrEngine:
    """Create PyTorch Whisper engine (universal fallback).

    Args:
        device: Target device.
        compute_type: Compute precision.

    Returns:
        PyTorch-based Whisper engine.

    Raises:
        EngineCreationError: If openai-whisper is not installed.
    """
    try:
        from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine

        return WhisperPyTorchEngine(device=device, compute_type=compute_type)
    except ImportError as e:
        msg = (
            "Neither CTranslate2 nor openai-whisper is available. "
            "Install one of: pip install faster-whisper OR pip install openai-whisper"
        )
        raise EngineCreationError(msg) from e

5. PyTorch Whisper Engine (infrastructure/asr/pytorch_engine.py)

"""PyTorch-based Whisper engine (universal fallback)."""

from __future__ import annotations

from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING

from noteflow.infrastructure.asr.dto import AsrResult, WordTiming
from noteflow.infrastructure.logging import get_logger

if TYPE_CHECKING:
    import numpy as np
    from numpy.typing import NDArray

logger = get_logger(__name__)


class WhisperPyTorchEngine:
    """Pure PyTorch Whisper implementation.

    Uses the official openai-whisper package for transcription.
    Works on any PyTorch-supported device (CPU, CUDA, ROCm via HIP).

    This engine is slower than CTranslate2-based engines but provides
    universal compatibility across all GPU backends.
    """

    def __init__(
        self,
        device: str = "cpu",
        compute_type: str = "float32",
    ) -> None:
        """Initialize PyTorch Whisper engine.

        Args:
            device: Target device ("cpu" or "cuda").
                For ROCm, use "cuda" - HIP handles the translation.
            compute_type: Compute precision. Only "float16" and "float32"
                are supported. "int8" will be treated as "float32".
        """
        self._device = device
        self._compute_type = self._normalize_compute_type(compute_type)
        self._model_size: str = ""
        self._model: whisper.Whisper | None = None  # type: ignore[name-defined]

    @staticmethod
    def _normalize_compute_type(compute_type: str) -> str:
        """Normalize compute type for PyTorch.

        PyTorch Whisper doesn't support int8, map to float32.
        """
        if compute_type == "int8":
            logger.debug("int8 not supported in PyTorch Whisper, using float32")
            return "float32"
        return compute_type

    @property
    def device(self) -> str:
        """Return the device this engine runs on."""
        return self._device

    @property
    def compute_type(self) -> str:
        """Return the compute precision."""
        return self._compute_type

    @property
    def model_size(self) -> str:
        """Return the loaded model size."""
        return self._model_size

    @property
    def is_loaded(self) -> bool:
        """Return True if model is loaded."""
        return self._model is not None

    def load_model(self, model_size: str) -> None:
        """Load the specified Whisper model.

        Args:
            model_size: Whisper model size (e.g., "base", "small", "large-v3").
        """
        import torch
        import whisper

        logger.info(
            "Loading PyTorch Whisper model",
            model_size=model_size,
            device=self._device,
            compute_type=self._compute_type,
        )

        # Load model
        self._model = whisper.load_model(model_size, device=self._device)
        self._model_size = model_size

        # Apply compute type
        if self._compute_type == "float16" and self._device != "cpu":
            self._model = self._model.half()

        logger.info("PyTorch Whisper model loaded successfully")

    def unload(self) -> None:
        """Unload the model and free resources."""
        if self._model is not None:
            import gc

            import torch

            del self._model
            self._model = None
            self._model_size = ""

            # Force garbage collection and clear GPU cache
            gc.collect()
            if self._device != "cpu":
                torch.cuda.empty_cache()

            logger.debug("PyTorch Whisper model unloaded")

    def transcribe(
        self,
        audio: NDArray[np.float32],
        *,
        language: str | None = None,
        initial_prompt: str | None = None,
    ) -> Iterator[AsrResult]:
        """Transcribe audio samples.

        Args:
            audio: Audio samples as float32 array, 16kHz mono.
            language: Optional language code.
            initial_prompt: Optional prompt for context.

        Yields:
            AsrResult for each detected segment.
        """
        if self._model is None:
            msg = "Model not loaded. Call load_model() first."
            raise RuntimeError(msg)

        # Build transcription options
        options: dict[str, object] = {
            "word_timestamps": True,
            "fp16": self._compute_type == "float16" and self._device != "cpu",
        }

        if language is not None:
            options["language"] = language

        if initial_prompt is not None:
            options["initial_prompt"] = initial_prompt

        # Transcribe
        result = self._model.transcribe(audio, **options)

        # Convert to our segment format
        for segment in result["segments"]:
            words = tuple(
                WordTiming(
                    word=w["word"],
                    start=w["start"],
                    end=w["end"],
                    probability=w.get("probability", 0.0),
                )
                for w in segment.get("words", [])
            )

            yield AsrResult(
                text=segment["text"].strip(),
                start=segment["start"],
                end=segment["end"],
                words=words,
                language=result.get("language", "en"),
                avg_logprob=segment.get("avg_logprob", 0.0),
                no_speech_prob=segment.get("no_speech_prob", 0.0),
            )

    def transcribe_file(
        self,
        audio_path: Path,
        *,
        language: str | None = None,
    ) -> Iterator[AsrResult]:
        """Transcribe audio file.

        Args:
            audio_path: Path to audio file.
            language: Optional language code.

        Yields:
            AsrResult for each detected segment.
        """
        import whisper

        # Load audio using whisper's utility
        audio = whisper.load_audio(str(audio_path))

        yield from self.transcribe(audio, language=language)

6. Updated ASR Device Types (application/services/asr_config/types.py)

"""Types and constants for ASR configuration."""

from __future__ import annotations

import asyncio
from dataclasses import dataclass, field
from enum import Enum
from typing import Final
from uuid import UUID


class AsrConfigPhase(str, Enum):
    """Phases of ASR reconfiguration."""

    VALIDATING = "validating"
    DOWNLOADING = "downloading"
    LOADING = "loading"
    COMPLETED = "completed"
    FAILED = "failed"


class AsrDevice(str, Enum):
    """Supported ASR devices."""

    CPU = "cpu"
    CUDA = "cuda"
    ROCM = "rocm"  # NEW


class AsrComputeType(str, Enum):
    """Supported compute types."""

    INT8 = "int8"
    FLOAT16 = "float16"
    FLOAT32 = "float32"


# Compute types available for each device
DEVICE_COMPUTE_TYPES: Final[dict[AsrDevice, tuple[AsrComputeType, ...]]] = {
    AsrDevice.CPU: (AsrComputeType.INT8, AsrComputeType.FLOAT32),
    AsrDevice.CUDA: (
        AsrComputeType.INT8,
        AsrComputeType.FLOAT16,
        AsrComputeType.FLOAT32,
    ),
    AsrDevice.ROCM: (  # NEW
        AsrComputeType.INT8,
        AsrComputeType.FLOAT16,
        AsrComputeType.FLOAT32,
    ),
}


@dataclass
class AsrConfigJob:
    """Tracks ASR reconfiguration job state."""

    job_id: UUID
    status: str
    phase: AsrConfigPhase
    progress_percent: float
    error_message: str
    target_model_size: str
    target_device: AsrDevice
    target_compute_type: AsrComputeType
    task: asyncio.Task[None] | None = field(default=None, repr=False)


@dataclass(frozen=True)
class AsrCapabilities:
    """Current ASR capabilities and configuration."""

    model_size: str | None
    device: AsrDevice
    compute_type: AsrComputeType
    is_ready: bool
    cuda_available: bool
    rocm_available: bool  # NEW
    gpu_backend: str  # NEW: "cuda", "rocm", "mps", or "none"
    available_model_sizes: tuple[str, ...]
    available_compute_types: tuple[AsrComputeType, ...]

Testing Examples

Unit Test for GPU Detection

"""Tests for GPU detection utilities."""

from __future__ import annotations

from unittest.mock import MagicMock, patch

import pytest

from noteflow.domain.ports.gpu import GpuBackend
from noteflow.infrastructure.gpu.detection import (
    SUPPORTED_AMD_ARCHITECTURES,
    detect_gpu_backend,
    get_gpu_info,
    is_rocm_architecture_supported,
)


class TestDetectGpuBackend:
    """Tests for detect_gpu_backend function."""

    def test_no_torch_returns_none(self) -> None:
        """Return NONE when torch is not installed."""
        with patch.dict("sys.modules", {"torch": None}):
            # Clear cache and reimport
            detect_gpu_backend.cache_clear()
            result = detect_gpu_backend()
            assert result == GpuBackend.NONE

    def test_cuda_without_hip_returns_cuda(self) -> None:
        """Return CUDA when CUDA available and no HIP."""
        mock_torch = MagicMock()
        mock_torch.cuda.is_available.return_value = True
        mock_torch.version.hip = None
        mock_torch.version.cuda = "12.1"

        with patch.dict("sys.modules", {"torch": mock_torch}):
            detect_gpu_backend.cache_clear()
            result = detect_gpu_backend()
            assert result == GpuBackend.CUDA

    def test_cuda_with_hip_returns_rocm(self) -> None:
        """Return ROCM when HIP version is present."""
        mock_torch = MagicMock()
        mock_torch.cuda.is_available.return_value = True
        mock_torch.version.hip = "6.0.0"

        with patch.dict("sys.modules", {"torch": mock_torch}):
            detect_gpu_backend.cache_clear()
            result = detect_gpu_backend()
            assert result == GpuBackend.ROCM

    def test_mps_available_returns_mps(self) -> None:
        """Return MPS on Apple Silicon."""
        mock_torch = MagicMock()
        mock_torch.cuda.is_available.return_value = False
        mock_torch.backends.mps.is_available.return_value = True

        with patch.dict("sys.modules", {"torch": mock_torch}):
            detect_gpu_backend.cache_clear()
            result = detect_gpu_backend()
            assert result == GpuBackend.MPS


class TestIsRocmArchitectureSupported:
    """Tests for ROCm architecture support check."""

    @pytest.mark.parametrize(
        "architecture",
        list(SUPPORTED_AMD_ARCHITECTURES),
    )
    def test_supported_architectures(self, architecture: str) -> None:
        """All listed architectures should be supported."""
        assert is_rocm_architecture_supported(architecture)

    @pytest.mark.parametrize(
        "architecture",
        ["gfx803", "gfx1010", "sm_80", None],
    )
    def test_unsupported_architectures(self, architecture: str | None) -> None:
        """Unsupported architectures return False."""
        assert not is_rocm_architecture_supported(architecture)

    def test_override_env_allows_unsupported(self) -> None:
        """HSA_OVERRIDE_GFX_VERSION allows unsupported GPUs."""
        with patch.dict("os.environ", {"HSA_OVERRIDE_GFX_VERSION": "11.0.0"}):
            assert is_rocm_architecture_supported("gfx803")

Integration Test for Engine Factory

"""Integration tests for ASR engine factory."""

from __future__ import annotations

import pytest

from noteflow.infrastructure.asr.factory import create_asr_engine
from noteflow.infrastructure.gpu.detection import detect_gpu_backend
from noteflow.domain.ports.gpu import GpuBackend


class TestEngineFactory:
    """Integration tests for engine factory."""

    def test_cpu_engine_always_works(self) -> None:
        """CPU engine should always be creatable."""
        engine = create_asr_engine(device="cpu", compute_type="int8")
        assert engine.device == "cpu"
        assert engine.compute_type in ("int8", "float32")

    def test_auto_device_selects_available(self) -> None:
        """Auto device should select an available backend."""
        engine = create_asr_engine(device="auto")
        assert engine.device in ("cpu", "cuda", "rocm")

    @pytest.mark.skipif(
        detect_gpu_backend() != GpuBackend.CUDA,
        reason="CUDA not available",
    )
    def test_cuda_engine_with_cuda(self) -> None:
        """CUDA engine works on NVIDIA hardware."""
        engine = create_asr_engine(device="cuda")
        assert engine.device == "cuda"

    @pytest.mark.skipif(
        detect_gpu_backend() != GpuBackend.ROCM,
        reason="ROCm not available",
    )
    def test_rocm_engine_with_rocm(self) -> None:
        """ROCm engine works on AMD hardware."""
        engine = create_asr_engine(device="rocm")
        assert engine.device in ("rocm", "cuda")  # May use "cuda" string internally

    def test_pytorch_fallback_when_forced(self) -> None:
        """PyTorch fallback engine works when explicitly requested."""
        engine = create_asr_engine(
            device="cpu",
            prefer_faster_whisper=False,
        )
        assert engine.device == "cpu"

Configuration Examples

Environment Variables

# Force specific device
NOTEFLOW_ASR_DEVICE=rocm

# Enable ROCm feature (during rollout)
NOTEFLOW_FEATURE_ROCM_ENABLED=true

# ROCm tuning
HSA_OVERRIDE_GFX_VERSION=11.0.0  # Override for unsupported GPUs
HIP_VISIBLE_DEVICES=0             # Limit to first GPU
MIOPEN_FIND_MODE=3                # Fast kernel selection

# Debug
AMD_LOG_LEVEL=3                   # Verbose ROCm logging

Docker Compose with ROCm

version: "3.8"

services:
  noteflow-rocm:
    build:
      context: .
      dockerfile: docker/Dockerfile.rocm
    devices:
      - /dev/kfd
      - /dev/dri
    group_add:
      - video
      - render
    environment:
      - NOTEFLOW_ASR_DEVICE=auto
      - NOTEFLOW_FEATURE_ROCM_ENABLED=true
    volumes:
      - ./data:/app/data
    ports:
      - "50051:50051"

Summary

This architecture enables:

  1. Transparent ROCm Support: Pure PyTorch components work unchanged
  2. Swappable ASR Engines: Protocol pattern allows different backends
  3. Graceful Fallbacks: PyTorch Whisper when native engines unavailable
  4. Clean Detection: Clear distinction between CUDA and ROCm
  5. Testability: Mock-friendly design for CI without GPUs