diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 56cbd6e..407530b 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -20,6 +20,10 @@ "DISPLAY": ":1", "XDG_RUNTIME_DIR": "/tmp/runtime-vscode" }, + "mounts": [ + "source=${localWorkspaceFolder}/.venv,target=${containerWorkspaceFolder}/.venv,type=bind,consistency=cached" + ], + "updateRemoteUserUID": true, "postCreateCommand": ".devcontainer/postCreate.sh", "remoteUser": "vscode", "customizations": { diff --git a/client/package-lock.json b/client/package-lock.json index b2fe0a1..8a78d1c 100644 --- a/client/package-lock.json +++ b/client/package-lock.json @@ -482,7 +482,6 @@ } ], "license": "MIT", - "peer": true, "engines": { "node": ">=18" }, @@ -524,7 +523,6 @@ } ], "license": "MIT", - "peer": true, "engines": { "node": ">=18" } @@ -4531,7 +4529,8 @@ "version": "5.0.4", "resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.4.tgz", "integrity": "sha512-rfT93uj5s0PRL7EzccGMs3brplhcrghnDoV26NqKhCAS1hVo+WdNsPvE/yb6ilfr5hi2MEk6d5EWJTKdxg8jVw==", - "license": "MIT" + "license": "MIT", + "peer": true }, "node_modules/@types/chai": { "version": "5.2.3", @@ -4665,7 +4664,6 @@ "integrity": "sha512-qm+G8HuG6hOHQigsi7VGuLjUVu6TtBo/F05zvX04Mw2uCg9Dv0Qxy3Qw7j41SidlTcl5D/5yg0SEZqOB+EqZnQ==", "devOptional": true, "license": "MIT", - "peer": true, "dependencies": { "undici-types": "~6.21.0" } @@ -4690,7 +4688,6 @@ "integrity": "sha512-cisd7gxkzjBKU2GgdYrTdtQx1SORymWyaAFhaxQPK9bYO9ot3Y5OikQRvY0VYQtvwjeQnizCINJAenh/V7MK2w==", "devOptional": true, "license": "MIT", - "peer": true, "dependencies": { "@types/prop-types": "*", "csstype": "^3.2.2" @@ -4702,7 +4699,6 @@ "integrity": "sha512-MEe3UeoENYVFXzoXEWsvcpg6ZvlrFNlOQ7EOsvhI3CfAXwzPfO8Qwuxd40nepsYKqyyVQnTdEfv68q91yLcKrQ==", "devOptional": true, "license": "MIT", - "peer": true, "peerDependencies": { "@types/react": "^18.0.0" } @@ -4811,7 +4807,6 @@ "integrity": "sha512-npiaib8XzbjtzS2N4HlqPvlpxpmZ14FjSJrteZpPxGUaYPlvhzlzUZ4mZyABo0EFrOWnvyd0Xxroq//hKhtAWg==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.53.0", "@typescript-eslint/types": "8.53.0", @@ -5276,7 +5271,6 @@ "integrity": "sha512-OmwPKV8c5ecLqo+EkytN7oUeYfNmRI4uOXGIR1ybP7AK5Zz+l9R0dGfoadEuwi1aZXAL0vwuhtq3p0OL3dfqHQ==", "dev": true, "license": "MIT", - "peer": true, "engines": { "node": ">=18.20.0" }, @@ -5331,7 +5325,6 @@ "integrity": "sha512-HdzDrRs+ywAqbXGKqe1i/bLtCv47plz4TvsHFH3j729OooT5VH38ctFn5aLXgECmiAKDkmH/A6kOq2Zh5DIxww==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "chalk": "^5.1.2", "loglevel": "^1.6.0", @@ -5582,7 +5575,6 @@ "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, "license": "MIT", - "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -6091,7 +6083,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "baseline-browser-mapping": "^2.9.0", "caniuse-lite": "^1.0.30001759", @@ -6852,7 +6843,6 @@ "resolved": "https://registry.npmjs.org/date-fns/-/date-fns-3.6.0.tgz", "integrity": "sha512-fRHTG8g/Gif+kSh50gaGEdToemgfj74aRX3swtiouboip5JDLAyDE9F11nHMIcvOaXeOC6D7SpNhi7uFyB7Uww==", "license": "MIT", - "peer": true, "funding": { "type": "github", "url": "https://github.com/sponsors/kossnocorp" @@ -6997,7 +6987,8 @@ "version": "0.5.16", "resolved": "https://registry.npmjs.org/dom-accessibility-api/-/dom-accessibility-api-0.5.16.tgz", "integrity": "sha512-X7BJ2yElsnOJ30pZF4uIIDfBEVgF4XEBxL9Bxhy6dnrm5hkzqmsWHGTiHqRiITNhMyFLyAiWndIJP7Z1NTteDg==", - "license": "MIT" + "license": "MIT", + "peer": true }, "node_modules/dom-helpers": { "version": "5.2.1", @@ -7208,8 +7199,7 @@ "version": "8.6.0", "resolved": "https://registry.npmjs.org/embla-carousel/-/embla-carousel-8.6.0.tgz", "integrity": "sha512-SjWyZBHJPbqxHOzckOfo8lHisEaJWmwd23XppYFYVh10bU66/Pn5tkVkbkCMZVdbUE5eTCI2nD8OyIP4Z+uwkA==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/embla-carousel-react": { "version": "8.6.0", @@ -7398,7 +7388,6 @@ "integrity": "sha512-LEyamqS7W5HB3ujJyvi0HQK/dtVINZvd5mAAp9eT5S/ujByGjiZLCzPcHVzuXbpJDJF/cxwHlfceVUDZ2lnSTw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.8.0", "@eslint-community/regexpp": "^4.12.1", @@ -7735,7 +7724,6 @@ "integrity": "sha512-gQHqfI6SmtYBIkTeMizpHThdpXh6ej2Hk68oKZneFM6iu99ZGXvOPnmhd8VDus3xOWhVDDdf4sLsMV2/o+X6Yg==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@vitest/snapshot": "^4.0.16", "deep-eql": "^5.0.2", @@ -9133,7 +9121,6 @@ "resolved": "https://registry.npmjs.org/jiti/-/jiti-1.21.7.tgz", "integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==", "license": "MIT", - "peer": true, "bin": { "jiti": "bin/jiti.js" } @@ -9162,7 +9149,6 @@ "resolved": "https://registry.npmjs.org/jsdom/-/jsdom-27.4.0.tgz", "integrity": "sha512-mjzqwWRD9Y1J1KUi7W97Gja1bwOOM5Ug0EZ6UDK3xS7j7mndrkwozHtSblfomlzyB4NepioNt+B2sOSzczVgtQ==", "license": "MIT", - "peer": true, "dependencies": { "@acemir/cssom": "^0.9.28", "@asamuzakjp/dom-selector": "^6.7.6", @@ -9608,6 +9594,7 @@ "resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.5.0.tgz", "integrity": "sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ==", "license": "MIT", + "peer": true, "bin": { "lz-string": "bin/bin.js" } @@ -10572,7 +10559,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -10739,6 +10725,7 @@ "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-27.5.1.tgz", "integrity": "sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==", "license": "MIT", + "peer": true, "dependencies": { "ansi-regex": "^5.0.1", "ansi-styles": "^5.0.0", @@ -10907,7 +10894,6 @@ "resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz", "integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==", "license": "MIT", - "peer": true, "dependencies": { "loose-envify": "^1.1.0" }, @@ -10934,7 +10920,6 @@ "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz", "integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==", "license": "MIT", - "peer": true, "dependencies": { "loose-envify": "^1.1.0", "scheduler": "^0.23.2" @@ -10948,7 +10933,6 @@ "resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.71.1.tgz", "integrity": "sha512-9SUJKCGKo8HUSsCO+y0CtqkqI5nNuaDqTxyqPsZPqIwudpj4rCrAz/jZV+jn57bx5gtZKOh3neQu94DXMc+w5w==", "license": "MIT", - "peer": true, "engines": { "node": ">=18.0.0" }, @@ -10964,7 +10948,8 @@ "version": "17.0.2", "resolved": "https://registry.npmjs.org/react-is/-/react-is-17.0.2.tgz", "integrity": "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==", - "license": "MIT" + "license": "MIT", + "peer": true }, "node_modules/react-remove-scroll": { "version": "2.7.2", @@ -12204,7 +12189,6 @@ "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.4.19.tgz", "integrity": "sha512-3ofp+LL8E+pK/JuPLPggVAIaEuhvIz4qNcf3nA1Xn2o/7fb7s/TYpHhwGDv1ZU3PkBluUVaF8PyCHcm48cKLWQ==", "license": "MIT", - "peer": true, "dependencies": { "@alloc/quick-lru": "^5.2.0", "arg": "^5.0.2", @@ -13056,7 +13040,6 @@ "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", "dev": true, "license": "Apache-2.0", - "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -13286,7 +13269,6 @@ "resolved": "https://registry.npmjs.org/vite/-/vite-7.3.1.tgz", "integrity": "sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==", "license": "MIT", - "peer": true, "dependencies": { "esbuild": "^0.27.0", "fdir": "^6.5.0", @@ -14047,7 +14029,6 @@ "integrity": "sha512-Y5y4jpwHvuduUfup+gXTuCU6AROn/k6qOba3st0laFluKHY+q5SHOpQAJdS8acYLwE8caDQ2dXJhmXyxuJrm0Q==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@types/node": "^20.11.30", "@types/sinonjs__fake-timers": "^8.1.5", diff --git a/docs/sprints/phase-5-evolution/sprint-18.5-rocm-support/ARCHITECTURE.md b/docs/sprints/phase-5-evolution/sprint-18.5-rocm-support/ARCHITECTURE.md new file mode 100644 index 0000000..b63a802 --- /dev/null +++ b/docs/sprints/phase-5-evolution/sprint-18.5-rocm-support/ARCHITECTURE.md @@ -0,0 +1,1139 @@ +# ROCm Support: Architecture & Code Samples + +This document provides architecture diagrams and concrete code samples for the ROCm integration. + +Alignment note: the current codebase already defines `AsrEngine` in +`src/noteflow/infrastructure/asr/protocols.py` and uses `AsrResult` from +`src/noteflow/infrastructure/asr/dto.py`. Prefer extending those rather than +adding parallel protocol/DTO types unless we plan a broader layering refactor. + +--- + +## System Architecture + +### Current State (CUDA Only) + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ gRPC Server │ +│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────────┐ │ +│ │ StreamingMixin │ │ DiarizationMixin│ │ AsrConfigMixin │ │ +│ └────────┬────────┘ └────────┬────────┘ └──────────┬──────────┘ │ +└───────────┼─────────────────────┼─────────────────────┼─────────────┘ + │ │ │ + ▼ ▼ │ +┌───────────────────────┐ ┌───────────────────────┐ │ +│ FasterWhisperEngine │ │ DiarizationEngine │ │ +│ (CUDA/CPU only) │ │ (CUDA/CPU/MPS) │ │ +│ │ │ │ │ +│ - device: "cuda"|"cpu"│ │ - device: auto-detect │ │ +│ - uses: CTranslate2 │ │ - uses: PyTorch │ │ +└───────────────────────┘ └───────────────────────┘ │ + │ │ │ + ▼ ▼ ▼ +┌───────────────────────────────────────────────────────────────────┐ +│ torch.cuda (CUDA or ROCm HIP build) │ +└───────────────────────────────────────────────────────────────────┘ +``` + +### Target State (Multi-Backend) + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ gRPC Server │ +│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────────┐ │ +│ │ StreamingMixin │ │ DiarizationMixin│ │ AsrConfigMixin │ │ +│ └────────┬────────┘ └────────┬────────┘ └──────────┬──────────┘ │ +└───────────┼─────────────────────┼─────────────────────┼─────────────┘ + │ │ │ + ▼ │ │ +┌───────────────────────────────┐ │ ┌──────────────────────────────┐ +│ AsrEngineFactory │ │ │ AsrEngineManager │ +│ ┌─────────────────────────┐ │ │ │ - detect_cuda_available() │ +│ │ create_asr_engine() │ │ │ │ - detect_rocm_available() │ ◄─┘ +│ │ - auto-detect backend │ │ │ │ - build_capabilities() │ +│ │ - fallback logic │ │ │ └──────────────────────────────┘ +│ └─────────────────────────┘ │ │ +└───────────────────────────────┘ │ + │ │ + ┌───────┴───────┬─────────────┤ + │ │ │ + ▼ ▼ ▼ +┌─────────┐ ┌─────────┐ ┌─────────────────────┐ +│Faster │ │Faster │ │WhisperPyTorch │ +│Whisper │ │Whisper │ │Engine │ +│Engine │ │RocmEng │ │(universal fallback) │ +│(CUDA/CPU)│ │(ROCm) │ │ │ +└─────────┘ └─────────┘ └─────────────────────┘ + │ │ │ + │ │ │ + ▼ ▼ ▼ +┌───────────────────────────────────────────────────┐ +│ GPU Detection Layer │ +│ ┌─────────────────────────────────────────────┐ │ +│ │ detect_gpu_backend() -> GpuBackend │ │ +│ │ - CUDA: torch.cuda + no HIP │ │ +│ │ - ROCM: torch.cuda + torch.version.hip │ │ +│ │ - MPS: torch.backends.mps │ │ +│ │ - NONE: no GPU available │ │ +│ └─────────────────────────────────────────────┘ │ +└───────────────────────────────────────────────────┘ + │ │ │ + ▼ ▼ ▼ +┌───────────────┐ ┌───────────────┐ ┌───────────────┐ +│ NVIDIA CUDA │ │ AMD ROCm/HIP │ │ CPU/fallback │ +└───────────────┘ └───────────────┘ └───────────────┘ +``` + +--- + +## Module Structure + +### New Modules + +``` +src/noteflow/ +├── domain/ +│ └── ports/ +│ ├── gpu.py # NEW: GpuBackend enum, GpuInfo +│ └── asr.py # OPTIONAL: only if relocating AsrEngine protocol +│ +├── infrastructure/ +│ ├── gpu/ # NEW: GPU detection module +│ │ ├── __init__.py +│ │ └── detection.py # detect_gpu_backend(), get_gpu_info() +│ │ +│ └── asr/ +│ ├── engine.py # EXISTING: FasterWhisperEngine (refactored) +│ ├── protocols.py # EXISTING: AsrEngine protocol (extend) +│ ├── pytorch_engine.py # NEW: WhisperPyTorchEngine +│ ├── rocm_engine.py # NEW: FasterWhisperRocmEngine +│ └── factory.py # NEW: create_asr_engine() +``` + +--- + +## Code Samples + +### 1. Domain Types (`domain/ports/gpu.py`) + +```python +"""GPU backend types and detection protocol.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Protocol + + +class GpuBackend(str, Enum): + """Detected GPU backend type.""" + + NONE = "none" + CUDA = "cuda" + ROCM = "rocm" + MPS = "mps" + + +@dataclass(frozen=True) +class GpuInfo: + """Information about detected GPU.""" + + backend: GpuBackend + device_name: str + vram_total_mb: int + driver_version: str + architecture: str | None = None # e.g., "gfx1100" for AMD + + +class GpuDetectionProtocol(Protocol): + """Protocol for GPU detection implementations.""" + + def detect_backend(self) -> GpuBackend: + """Detect the available GPU backend.""" + ... + + def get_info(self) -> GpuInfo | None: + """Get detailed GPU information.""" + ... + + def is_supported_for_asr(self) -> bool: + """Check if GPU is supported for ASR workloads.""" + ... +``` + +### 2. ASR Engine Protocol (`infrastructure/asr/protocols.py`) + +```python +"""ASR protocols defining contracts for ASR components.""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + from pathlib import Path + + import numpy as np + from numpy.typing import NDArray + from noteflow.infrastructure.asr.dto import AsrResult + + +class AsrEngine(Protocol): + """Protocol for ASR engine implementations. + + All ASR engines must implement this interface to be used + by the engine manager and gRPC handlers. + + Implementations: + - FasterWhisperEngine: CUDA/CPU via CTranslate2 + - FasterWhisperRocmEngine: ROCm via CTranslate2-ROCm fork + - WhisperPyTorchEngine: Universal via openai-whisper + """ + + @property + def device(self) -> str: + """Return the requested device ("cpu", "cuda", "rocm").""" + ... + + @property + def compute_type(self) -> str: + """Return the compute precision ("int8", "float16", "float32").""" + ... + + @property + def model_size(self) -> str | None: + """Return the loaded model size, or None if not loaded.""" + ... + + @property + def is_loaded(self) -> bool: + """Return True if model is loaded and ready for inference.""" + ... + + def load_model(self, model_size: str = "base") -> None: + """Load the specified Whisper model.""" + ... + + def unload(self) -> None: + """Unload the model and free GPU/CPU resources.""" + ... + + def transcribe( + self, + audio: NDArray[np.float32], + language: str | None = None, + ) -> Iterator[AsrResult]: + """Transcribe audio samples. + + Args: + audio: Audio samples as float32 array, 16kHz mono, normalized to [-1, 1]. + language: Optional BCP-47 language code (auto-detect if None). + + Yields: + AsrResult for each detected segment. + """ + ... + + def transcribe_file( + self, + audio_path: Path, + *, + language: str | None = None, + ) -> Iterator[AsrResult]: + """Transcribe audio file. + + Args: + audio_path: Path to audio file (WAV, MP3, FLAC, etc.) + language: Optional language code. + + Yields: + AsrResult for each detected segment. + """ + ... +``` + +### 3. GPU Detection (`infrastructure/gpu/detection.py`) + +```python +"""GPU backend detection utilities.""" + +from __future__ import annotations + +import os +from functools import cache + +from noteflow.domain.ports.gpu import GpuBackend, GpuInfo +from noteflow.infrastructure.logging import get_logger + +logger = get_logger(__name__) + +# Example AMD GPU architectures; keep in sync with AMD ROCm support docs +SUPPORTED_AMD_ARCHITECTURES: frozenset[str] = frozenset({ + # CDNA (Instinct) + "gfx906", # MI50 + "gfx908", # MI100 + "gfx90a", # MI210, MI250 + "gfx942", # MI300X + # RDNA 2 + "gfx1030", # RX 6800, 6900 + # RDNA 3 + "gfx1100", # RX 7900 XTX + "gfx1101", # RX 7900 XT + "gfx1102", # RX 7600 +}) + + +@cache +def detect_gpu_backend() -> GpuBackend: + """Detect the available GPU backend. + + Results are cached for performance. + + Returns: + GpuBackend enum indicating the detected backend. + """ + try: + import torch + except ImportError: + logger.debug("PyTorch not installed, no GPU backend available") + return GpuBackend.NONE + + # Check CUDA/ROCm availability + if torch.cuda.is_available(): + # Distinguish between CUDA and ROCm via HIP version + if hasattr(torch.version, "hip") and torch.version.hip: + logger.info("ROCm/HIP backend detected", version=torch.version.hip) + return GpuBackend.ROCM + + logger.info("CUDA backend detected", version=torch.version.cuda) + return GpuBackend.CUDA + + # Check Apple Metal Performance Shaders + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + logger.info("MPS backend detected") + return GpuBackend.MPS + + logger.debug("No GPU backend available, using CPU") + return GpuBackend.NONE + + +def get_gpu_info() -> GpuInfo | None: + """Get detailed GPU information. + + Returns: + GpuInfo if a GPU is available, None otherwise. + """ + backend = detect_gpu_backend() + + if backend == GpuBackend.NONE: + return None + + import torch + + if backend in (GpuBackend.CUDA, GpuBackend.ROCM): + try: + props = torch.cuda.get_device_properties(0) + vram_mb = props.total_memory // (1024 * 1024) + + # Get driver version + if backend == GpuBackend.ROCM: + driver_version = str(torch.version.hip) if torch.version.hip else "unknown" + # On ROCm, props.name may include a gfx ID; parse if present. + architecture = props.name if props.name.startswith("gfx") else None + else: + driver_version = torch.version.cuda or "unknown" + architecture = f"sm_{props.major}{props.minor}" + + return GpuInfo( + backend=backend, + device_name=props.name, + vram_total_mb=vram_mb, + driver_version=driver_version, + architecture=architecture, + ) + except RuntimeError as e: + logger.warning("Failed to get GPU properties", error=str(e)) + return None + + if backend == GpuBackend.MPS: + return GpuInfo( + backend=backend, + device_name="Apple Metal", + vram_total_mb=0, # MPS doesn't expose VRAM + driver_version="mps", + architecture=None, + ) + + return None + + +def is_rocm_architecture_supported(architecture: str | None) -> bool: + """Check if AMD GPU architecture is officially supported. + + Args: + architecture: GPU architecture string (e.g., "gfx1100") + + Returns: + True if supported, False otherwise. + """ + if architecture is None: + return False + + # Check for override (allows unofficial GPUs) + if os.environ.get("HSA_OVERRIDE_GFX_VERSION"): + return True + + return architecture in SUPPORTED_AMD_ARCHITECTURES + + +def is_ctranslate2_rocm_available() -> bool: + """Check if CTranslate2-ROCm fork is installed. + + Returns: + True if the ROCm fork is available. + """ + try: + import ctranslate2 + + # The ROCm fork should have HIP support + # Check by attempting to create a HIP allocator + return hasattr(ctranslate2, "get_supported_compute_types") + except ImportError: + return False + + +def get_rocm_environment_info() -> dict[str, str]: + """Get ROCm-related environment variables for debugging. + + Returns: + Dictionary of relevant environment variables. + """ + rocm_vars = [ + "HSA_OVERRIDE_GFX_VERSION", + "HIP_VISIBLE_DEVICES", + "ROCM_PATH", + "MIOPEN_USER_DB_PATH", + "MIOPEN_FIND_MODE", + "AMD_LOG_LEVEL", + ] + + return {var: os.environ.get(var, "") for var in rocm_vars if os.environ.get(var)} +``` + +### 4. Engine Factory (`infrastructure/asr/factory.py`) + +```python +"""ASR engine factory for backend selection.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from noteflow.domain.ports.gpu import GpuBackend +from noteflow.infrastructure.gpu.detection import ( + detect_gpu_backend, + get_gpu_info, + is_ctranslate2_rocm_available, + is_rocm_architecture_supported, +) +from noteflow.infrastructure.logging import get_logger + +if TYPE_CHECKING: + from noteflow.infrastructure.asr.protocols import AsrEngine + +logger = get_logger(__name__) + + +class EngineCreationError(Exception): + """Raised when ASR engine creation fails.""" + + +def create_asr_engine( + device: str = "auto", + compute_type: str = "int8", + *, + prefer_faster_whisper: bool = True, +) -> AsrEngine: + """Create an ASR engine for the specified device. + + This factory handles: + 1. Auto-detection of available GPU backends + 2. Selection of appropriate engine implementation + 3. Fallback to PyTorch Whisper when native engines unavailable + + Args: + device: Target device ("auto", "cpu", "cuda", "rocm"). + compute_type: Compute precision ("int8", "float16", "float32"). + prefer_faster_whisper: If True, prefer faster-whisper over PyTorch Whisper. + faster-whisper uses CTranslate2 and is significantly faster. + + Returns: + An ASR engine implementing AsrEngine. + + Raises: + EngineCreationError: If engine creation fails. + + Example: + >>> engine = create_asr_engine(device="auto") + >>> engine.load_model("base") + >>> for segment in engine.transcribe(audio): + ... print(segment.text) + """ + resolved_device = _resolve_device(device) + + logger.info( + "Creating ASR engine", + requested_device=device, + resolved_device=resolved_device, + compute_type=compute_type, + prefer_faster_whisper=prefer_faster_whisper, + ) + + if resolved_device == "cpu": + return _create_cpu_engine(compute_type) + + if resolved_device == "cuda": + return _create_cuda_engine(compute_type, prefer_faster_whisper) + + if resolved_device == "rocm": + return _create_rocm_engine(compute_type, prefer_faster_whisper) + + msg = f"Unsupported device: {resolved_device}" + raise EngineCreationError(msg) + + +def _resolve_device(device: str) -> str: + """Resolve 'auto' device to actual backend. + + Args: + device: Requested device string. + + Returns: + Resolved device string ("cpu", "cuda", or "rocm"). + """ + if device != "auto": + return device + + backend = detect_gpu_backend() + + if backend == GpuBackend.CUDA: + return "cuda" + + if backend == GpuBackend.ROCM: + # Check if ROCm architecture is supported for ASR + gpu_info = get_gpu_info() + if gpu_info and is_rocm_architecture_supported(gpu_info.architecture): + return "rocm" + else: + logger.warning( + "ROCm detected but architecture may not be supported, falling back to CPU", + architecture=gpu_info.architecture if gpu_info else "unknown", + ) + return "cpu" + + # MPS not supported by faster-whisper; PyTorch Whisper may work but is untested + if backend == GpuBackend.MPS: + logger.info("MPS detected but not supported for ASR, using CPU") + + return "cpu" + + +def _create_cpu_engine(compute_type: str) -> AsrEngine: + """Create CPU engine (always uses faster-whisper). + + Args: + compute_type: Requested compute type. + + Returns: + ASR engine for CPU. + """ + from noteflow.infrastructure.asr.engine import FasterWhisperEngine + + # CPU only supports int8 and float32 + if compute_type == "float16": + logger.debug("float16 not supported on CPU, using float32") + compute_type = "float32" + + return FasterWhisperEngine(device="cpu", compute_type=compute_type) + + +def _create_cuda_engine( + compute_type: str, + prefer_faster_whisper: bool, +) -> AsrEngine: + """Create CUDA engine. + + Args: + compute_type: Compute precision. + prefer_faster_whisper: Whether to prefer faster-whisper. + + Returns: + ASR engine for CUDA. + """ + if prefer_faster_whisper: + from noteflow.infrastructure.asr.engine import FasterWhisperEngine + + return FasterWhisperEngine(device="cuda", compute_type=compute_type) + + return _create_pytorch_engine("cuda", compute_type) + + +def _create_rocm_engine( + compute_type: str, + prefer_faster_whisper: bool, +) -> AsrEngine: + """Create ROCm engine. + + Attempts to use CTranslate2-ROCm fork if available, + falls back to PyTorch Whisper otherwise. + + Args: + compute_type: Compute precision. + prefer_faster_whisper: Whether to prefer faster-whisper. + + Returns: + ASR engine for ROCm. + """ + if prefer_faster_whisper and is_ctranslate2_rocm_available(): + try: + from noteflow.infrastructure.asr.rocm_engine import FasterWhisperRocmEngine + + logger.info("Using CTranslate2-ROCm for ASR") + return FasterWhisperRocmEngine(compute_type=compute_type) + except ImportError as e: + logger.warning( + "CTranslate2-ROCm import failed, falling back to PyTorch Whisper", + error=str(e), + ) + + logger.info("Using PyTorch Whisper for ROCm ASR") + return _create_pytorch_engine("cuda", compute_type) # ROCm uses "cuda" device string + + +def _create_pytorch_engine(device: str, compute_type: str) -> AsrEngine: + """Create PyTorch Whisper engine (universal fallback). + + Args: + device: Target device. + compute_type: Compute precision. + + Returns: + PyTorch-based Whisper engine. + + Raises: + EngineCreationError: If openai-whisper is not installed. + """ + try: + from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine + + return WhisperPyTorchEngine(device=device, compute_type=compute_type) + except ImportError as e: + msg = ( + "Neither CTranslate2 nor openai-whisper is available. " + "Install one of: pip install faster-whisper OR pip install openai-whisper" + ) + raise EngineCreationError(msg) from e +``` + +### 5. PyTorch Whisper Engine (`infrastructure/asr/pytorch_engine.py`) + +```python +"""PyTorch-based Whisper engine (universal fallback).""" + +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path +from typing import TYPE_CHECKING + +from noteflow.infrastructure.asr.dto import AsrResult, WordTiming +from noteflow.infrastructure.logging import get_logger + +if TYPE_CHECKING: + import numpy as np + from numpy.typing import NDArray + +logger = get_logger(__name__) + + +class WhisperPyTorchEngine: + """Pure PyTorch Whisper implementation. + + Uses the official openai-whisper package for transcription. + Works on any PyTorch-supported device (CPU, CUDA, ROCm via HIP). + + This engine is slower than CTranslate2-based engines but provides + universal compatibility across all GPU backends. + """ + + def __init__( + self, + device: str = "cpu", + compute_type: str = "float32", + ) -> None: + """Initialize PyTorch Whisper engine. + + Args: + device: Target device ("cpu" or "cuda"). + For ROCm, use "cuda" - HIP handles the translation. + compute_type: Compute precision. Only "float16" and "float32" + are supported. "int8" will be treated as "float32". + """ + self._device = device + self._compute_type = self._normalize_compute_type(compute_type) + self._model_size: str = "" + self._model: whisper.Whisper | None = None # type: ignore[name-defined] + + @staticmethod + def _normalize_compute_type(compute_type: str) -> str: + """Normalize compute type for PyTorch. + + PyTorch Whisper doesn't support int8, map to float32. + """ + if compute_type == "int8": + logger.debug("int8 not supported in PyTorch Whisper, using float32") + return "float32" + return compute_type + + @property + def device(self) -> str: + """Return the device this engine runs on.""" + return self._device + + @property + def compute_type(self) -> str: + """Return the compute precision.""" + return self._compute_type + + @property + def model_size(self) -> str: + """Return the loaded model size.""" + return self._model_size + + @property + def is_loaded(self) -> bool: + """Return True if model is loaded.""" + return self._model is not None + + def load_model(self, model_size: str) -> None: + """Load the specified Whisper model. + + Args: + model_size: Whisper model size (e.g., "base", "small", "large-v3"). + """ + import torch + import whisper + + logger.info( + "Loading PyTorch Whisper model", + model_size=model_size, + device=self._device, + compute_type=self._compute_type, + ) + + # Load model + self._model = whisper.load_model(model_size, device=self._device) + self._model_size = model_size + + # Apply compute type + if self._compute_type == "float16" and self._device != "cpu": + self._model = self._model.half() + + logger.info("PyTorch Whisper model loaded successfully") + + def unload(self) -> None: + """Unload the model and free resources.""" + if self._model is not None: + import gc + + import torch + + del self._model + self._model = None + self._model_size = "" + + # Force garbage collection and clear GPU cache + gc.collect() + if self._device != "cpu": + torch.cuda.empty_cache() + + logger.debug("PyTorch Whisper model unloaded") + + def transcribe( + self, + audio: NDArray[np.float32], + *, + language: str | None = None, + initial_prompt: str | None = None, + ) -> Iterator[AsrResult]: + """Transcribe audio samples. + + Args: + audio: Audio samples as float32 array, 16kHz mono. + language: Optional language code. + initial_prompt: Optional prompt for context. + + Yields: + AsrResult for each detected segment. + """ + if self._model is None: + msg = "Model not loaded. Call load_model() first." + raise RuntimeError(msg) + + # Build transcription options + options: dict[str, object] = { + "word_timestamps": True, + "fp16": self._compute_type == "float16" and self._device != "cpu", + } + + if language is not None: + options["language"] = language + + if initial_prompt is not None: + options["initial_prompt"] = initial_prompt + + # Transcribe + result = self._model.transcribe(audio, **options) + + # Convert to our segment format + for segment in result["segments"]: + words = tuple( + WordTiming( + word=w["word"], + start=w["start"], + end=w["end"], + probability=w.get("probability", 0.0), + ) + for w in segment.get("words", []) + ) + + yield AsrResult( + text=segment["text"].strip(), + start=segment["start"], + end=segment["end"], + words=words, + language=result.get("language", "en"), + avg_logprob=segment.get("avg_logprob", 0.0), + no_speech_prob=segment.get("no_speech_prob", 0.0), + ) + + def transcribe_file( + self, + audio_path: Path, + *, + language: str | None = None, + ) -> Iterator[AsrResult]: + """Transcribe audio file. + + Args: + audio_path: Path to audio file. + language: Optional language code. + + Yields: + AsrResult for each detected segment. + """ + import whisper + + # Load audio using whisper's utility + audio = whisper.load_audio(str(audio_path)) + + yield from self.transcribe(audio, language=language) +``` + +### 6. Updated ASR Device Types (`application/services/asr_config/types.py`) + +```python +"""Types and constants for ASR configuration.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from enum import Enum +from typing import Final +from uuid import UUID + + +class AsrConfigPhase(str, Enum): + """Phases of ASR reconfiguration.""" + + VALIDATING = "validating" + DOWNLOADING = "downloading" + LOADING = "loading" + COMPLETED = "completed" + FAILED = "failed" + + +class AsrDevice(str, Enum): + """Supported ASR devices.""" + + CPU = "cpu" + CUDA = "cuda" + ROCM = "rocm" # NEW + + +class AsrComputeType(str, Enum): + """Supported compute types.""" + + INT8 = "int8" + FLOAT16 = "float16" + FLOAT32 = "float32" + + +# Compute types available for each device +DEVICE_COMPUTE_TYPES: Final[dict[AsrDevice, tuple[AsrComputeType, ...]]] = { + AsrDevice.CPU: (AsrComputeType.INT8, AsrComputeType.FLOAT32), + AsrDevice.CUDA: ( + AsrComputeType.INT8, + AsrComputeType.FLOAT16, + AsrComputeType.FLOAT32, + ), + AsrDevice.ROCM: ( # NEW + AsrComputeType.INT8, + AsrComputeType.FLOAT16, + AsrComputeType.FLOAT32, + ), +} + + +@dataclass +class AsrConfigJob: + """Tracks ASR reconfiguration job state.""" + + job_id: UUID + status: str + phase: AsrConfigPhase + progress_percent: float + error_message: str + target_model_size: str + target_device: AsrDevice + target_compute_type: AsrComputeType + task: asyncio.Task[None] | None = field(default=None, repr=False) + + +@dataclass(frozen=True) +class AsrCapabilities: + """Current ASR capabilities and configuration.""" + + model_size: str | None + device: AsrDevice + compute_type: AsrComputeType + is_ready: bool + cuda_available: bool + rocm_available: bool # NEW + gpu_backend: str # NEW: "cuda", "rocm", "mps", or "none" + available_model_sizes: tuple[str, ...] + available_compute_types: tuple[AsrComputeType, ...] +``` + +--- + +## Testing Examples + +### Unit Test for GPU Detection + +```python +"""Tests for GPU detection utilities.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from noteflow.domain.ports.gpu import GpuBackend +from noteflow.infrastructure.gpu.detection import ( + SUPPORTED_AMD_ARCHITECTURES, + detect_gpu_backend, + get_gpu_info, + is_rocm_architecture_supported, +) + + +class TestDetectGpuBackend: + """Tests for detect_gpu_backend function.""" + + def test_no_torch_returns_none(self) -> None: + """Return NONE when torch is not installed.""" + with patch.dict("sys.modules", {"torch": None}): + # Clear cache and reimport + detect_gpu_backend.cache_clear() + result = detect_gpu_backend() + assert result == GpuBackend.NONE + + def test_cuda_without_hip_returns_cuda(self) -> None: + """Return CUDA when CUDA available and no HIP.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = True + mock_torch.version.hip = None + mock_torch.version.cuda = "12.1" + + with patch.dict("sys.modules", {"torch": mock_torch}): + detect_gpu_backend.cache_clear() + result = detect_gpu_backend() + assert result == GpuBackend.CUDA + + def test_cuda_with_hip_returns_rocm(self) -> None: + """Return ROCM when HIP version is present.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = True + mock_torch.version.hip = "6.0.0" + + with patch.dict("sys.modules", {"torch": mock_torch}): + detect_gpu_backend.cache_clear() + result = detect_gpu_backend() + assert result == GpuBackend.ROCM + + def test_mps_available_returns_mps(self) -> None: + """Return MPS on Apple Silicon.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = False + mock_torch.backends.mps.is_available.return_value = True + + with patch.dict("sys.modules", {"torch": mock_torch}): + detect_gpu_backend.cache_clear() + result = detect_gpu_backend() + assert result == GpuBackend.MPS + + +class TestIsRocmArchitectureSupported: + """Tests for ROCm architecture support check.""" + + @pytest.mark.parametrize( + "architecture", + list(SUPPORTED_AMD_ARCHITECTURES), + ) + def test_supported_architectures(self, architecture: str) -> None: + """All listed architectures should be supported.""" + assert is_rocm_architecture_supported(architecture) + + @pytest.mark.parametrize( + "architecture", + ["gfx803", "gfx1010", "sm_80", None], + ) + def test_unsupported_architectures(self, architecture: str | None) -> None: + """Unsupported architectures return False.""" + assert not is_rocm_architecture_supported(architecture) + + def test_override_env_allows_unsupported(self) -> None: + """HSA_OVERRIDE_GFX_VERSION allows unsupported GPUs.""" + with patch.dict("os.environ", {"HSA_OVERRIDE_GFX_VERSION": "11.0.0"}): + assert is_rocm_architecture_supported("gfx803") +``` + +### Integration Test for Engine Factory + +```python +"""Integration tests for ASR engine factory.""" + +from __future__ import annotations + +import pytest + +from noteflow.infrastructure.asr.factory import create_asr_engine +from noteflow.infrastructure.gpu.detection import detect_gpu_backend +from noteflow.domain.ports.gpu import GpuBackend + + +class TestEngineFactory: + """Integration tests for engine factory.""" + + def test_cpu_engine_always_works(self) -> None: + """CPU engine should always be creatable.""" + engine = create_asr_engine(device="cpu", compute_type="int8") + assert engine.device == "cpu" + assert engine.compute_type in ("int8", "float32") + + def test_auto_device_selects_available(self) -> None: + """Auto device should select an available backend.""" + engine = create_asr_engine(device="auto") + assert engine.device in ("cpu", "cuda", "rocm") + + @pytest.mark.skipif( + detect_gpu_backend() != GpuBackend.CUDA, + reason="CUDA not available", + ) + def test_cuda_engine_with_cuda(self) -> None: + """CUDA engine works on NVIDIA hardware.""" + engine = create_asr_engine(device="cuda") + assert engine.device == "cuda" + + @pytest.mark.skipif( + detect_gpu_backend() != GpuBackend.ROCM, + reason="ROCm not available", + ) + def test_rocm_engine_with_rocm(self) -> None: + """ROCm engine works on AMD hardware.""" + engine = create_asr_engine(device="rocm") + assert engine.device in ("rocm", "cuda") # May use "cuda" string internally + + def test_pytorch_fallback_when_forced(self) -> None: + """PyTorch fallback engine works when explicitly requested.""" + engine = create_asr_engine( + device="cpu", + prefer_faster_whisper=False, + ) + assert engine.device == "cpu" +``` + +--- + +## Configuration Examples + +### Environment Variables + +```bash +# Force specific device +NOTEFLOW_ASR_DEVICE=rocm + +# Enable ROCm feature (during rollout) +NOTEFLOW_FEATURE_ROCM_ENABLED=true + +# ROCm tuning +HSA_OVERRIDE_GFX_VERSION=11.0.0 # Override for unsupported GPUs +HIP_VISIBLE_DEVICES=0 # Limit to first GPU +MIOPEN_FIND_MODE=3 # Fast kernel selection + +# Debug +AMD_LOG_LEVEL=3 # Verbose ROCm logging +``` + +### Docker Compose with ROCm + +```yaml +version: "3.8" + +services: + noteflow-rocm: + build: + context: . + dockerfile: docker/Dockerfile.rocm + devices: + - /dev/kfd + - /dev/dri + group_add: + - video + - render + environment: + - NOTEFLOW_ASR_DEVICE=auto + - NOTEFLOW_FEATURE_ROCM_ENABLED=true + volumes: + - ./data:/app/data + ports: + - "50051:50051" +``` + +--- + +## Summary + +This architecture enables: + +1. **Transparent ROCm Support**: Pure PyTorch components work unchanged +2. **Swappable ASR Engines**: Protocol pattern allows different backends +3. **Graceful Fallbacks**: PyTorch Whisper when native engines unavailable +4. **Clean Detection**: Clear distinction between CUDA and ROCm +5. **Testability**: Mock-friendly design for CI without GPUs diff --git a/docs/sprints/phase-5-evolution/sprint-18.5-rocm-support/IMPLEMENTATION_CHECKLIST.md b/docs/sprints/phase-5-evolution/sprint-18.5-rocm-support/IMPLEMENTATION_CHECKLIST.md new file mode 100644 index 0000000..89beb82 --- /dev/null +++ b/docs/sprints/phase-5-evolution/sprint-18.5-rocm-support/IMPLEMENTATION_CHECKLIST.md @@ -0,0 +1,282 @@ +# ROCm Support Implementation Checklist + +This checklist tracks the implementation progress for Sprint 18.5. + +--- + +## Phase 1: Device Abstraction Layer + +### 1.1 GPU Detection Module + +- [ ] Create `src/noteflow/infrastructure/gpu/__init__.py` +- [ ] Create `src/noteflow/infrastructure/gpu/detection.py` + - [ ] Implement `GpuBackend` enum (NONE, CUDA, ROCM, MPS) + - [ ] Implement `GpuInfo` dataclass + - [ ] Implement `detect_gpu_backend()` function + - [ ] Implement `get_gpu_info()` function + - [ ] Add ROCm version detection via `torch.version.hip` +- [ ] Create `tests/infrastructure/gpu/test_detection.py` + - [ ] Test no-torch case + - [ ] Test CUDA detection + - [ ] Test ROCm detection (HIP check) + - [ ] Test MPS detection + - [ ] Test CPU fallback + +### 1.2 Domain Types + +- [ ] Create `src/noteflow/domain/ports/gpu.py` + - [ ] Export `GpuBackend` enum + - [ ] Export `GpuInfo` type + - [ ] Define `GpuDetectionProtocol` + +### 1.3 ASR Device Types + +- [ ] Update `src/noteflow/application/services/asr_config/types.py` + - [ ] Add `ROCM = "rocm"` to `AsrDevice` enum + - [ ] Add ROCm entry to `DEVICE_COMPUTE_TYPES` mapping + - [ ] Update `AsrCapabilities` dataclass with `rocm_available` and `gpu_backend` fields + +### 1.4 Diarization Device Mixin + +- [ ] Update `src/noteflow/infrastructure/diarization/engine/_device_mixin.py` + - [ ] Add ROCm detection in `_detect_available_device()` + - [ ] Maintain backward compatibility with "cuda" device string + +### 1.5 System Metrics + +- [ ] Update `src/noteflow/infrastructure/metrics/system_resources.py` + - [ ] Handle ROCm VRAM queries (same API as CUDA via HIP) + - [ ] Add `gpu_backend` field to metrics + +### 1.6 gRPC Proto + +- [ ] Update `src/noteflow/grpc/proto/noteflow.proto` + - [ ] Add `ASR_DEVICE_ROCM = 3` to `AsrDevice` enum + - [ ] Add `rocm_available` field to `AsrConfiguration` + - [ ] Add `gpu_backend` field to `AsrConfiguration` +- [ ] Regenerate Python stubs +- [ ] Run `scripts/patch_grpc_stubs.py` + +### 1.7 Phase 1 Tests + +- [ ] Run `pytest tests/infrastructure/gpu/` +- [ ] Run `make quality-py` +- [ ] Verify no regressions in CUDA detection + +--- + +## Phase 2: ASR Engine Protocol + +### 2.1 Engine Protocol Definition + +- [ ] Extend `src/noteflow/infrastructure/asr/protocols.py` (or relocate to `domain/ports`) + - [ ] Reuse `AsrResult` / `WordTiming` from `infrastructure/asr/dto.py` + - [ ] Add `device` property (logical device: cpu/cuda/rocm) + - [ ] Add `compute_type` property + - [ ] Confirm `model_size` + `is_loaded` already covered + - [ ] Add optional `transcribe_file()` helper (if needed) + +### 2.2 Refactor FasterWhisperEngine + +- [ ] Update `src/noteflow/infrastructure/asr/engine.py` + - [ ] Ensure compliance with `AsrEngine` + - [ ] Add explicit type annotations + - [ ] Document as CUDA/CPU backend +- [ ] Create `tests/infrastructure/asr/test_protocol_compliance.py` + - [ ] Verify `FasterWhisperEngine` implements protocol + +### 2.3 PyTorch Whisper Engine (Fallback) + +- [ ] Create `src/noteflow/infrastructure/asr/pytorch_engine.py` + - [ ] Implement `WhisperPyTorchEngine` class + - [ ] Implement all protocol methods + - [ ] Handle device placement (cuda/rocm/cpu) + - [ ] Support all compute types +- [ ] Create `tests/infrastructure/asr/test_pytorch_engine.py` + - [ ] Test model loading + - [ ] Test transcription + - [ ] Test device handling + +### 2.4 Engine Factory + +- [ ] Create `src/noteflow/infrastructure/asr/factory.py` + - [ ] Implement `create_asr_engine()` function + - [ ] Implement `_resolve_device()` helper + - [ ] Implement `_create_cpu_engine()` helper + - [ ] Implement `_create_cuda_engine()` helper + - [ ] Implement `_create_rocm_engine()` helper + - [ ] Define `EngineCreationError` exception +- [ ] Create `tests/infrastructure/asr/test_factory.py` + - [ ] Test auto device resolution + - [ ] Test explicit device selection + - [ ] Test fallback behavior + - [ ] Test error cases + +### 2.5 Update Engine Manager + +- [ ] Update `src/noteflow/application/services/asr_config/_engine_manager.py` + - [ ] Add `detect_rocm_available()` method + - [ ] Update `build_capabilities()` for ROCm + - [ ] Update `check_configuration()` for ROCm validation + - [ ] Use factory for engine creation in `build_engine_for_job()` +- [ ] Update `tests/application/test_asr_config_service.py` + - [ ] Add ROCm detection tests + - [ ] Add ROCm validation tests + +### 2.6 Phase 2 Tests + +- [ ] Run full ASR test suite +- [ ] Run `make quality-py` +- [ ] Verify CUDA path unchanged + +--- + +## Phase 3: ROCm-Specific Engine + +### 3.1 ROCm Engine Implementation + +- [ ] Create `src/noteflow/infrastructure/asr/rocm_engine.py` + - [ ] Implement `FasterWhisperRocmEngine` class + - [ ] Handle CTranslate2-ROCm import with fallback + - [ ] Implement all protocol methods + - [ ] Add ROCm-specific optimizations +- [ ] Create `tests/infrastructure/asr/test_rocm_engine.py` + - [ ] Test import fallback behavior + - [ ] Test engine creation (mock) + - [ ] Test protocol compliance + +### 3.2 Update Factory for ROCm + +- [ ] Update `src/noteflow/infrastructure/asr/factory.py` + - [ ] Add ROCm engine import with graceful fallback + - [ ] Log warning when falling back to PyTorch +- [ ] Update factory tests for ROCm path + +### 3.3 ROCm Installation Detection + +- [ ] Update `src/noteflow/infrastructure/gpu/detection.py` + - [ ] Add `is_ctranslate2_rocm_available()` function + - [ ] Add `get_rocm_version()` function +- [ ] Add corresponding tests + +### 3.4 Phase 3 Tests + +- [ ] Run ROCm-specific tests (skip if no ROCm) +- [ ] Run `make quality-py` +- [ ] Test on AMD hardware (if available) + +--- + +## Phase 4: Configuration & Distribution + +### 4.1 Feature Flag + +- [ ] Update `src/noteflow/config/settings/_features.py` + - [ ] Add `NOTEFLOW_FEATURE_ROCM_ENABLED` flag + - [ ] Document in settings +- [ ] Update any feature flag guards + +### 4.2 gRPC Config Handlers + +- [ ] Update `src/noteflow/grpc/mixins/asr_config.py` + - [ ] Handle ROCm device in `GetAsrConfiguration()` + - [ ] Handle ROCm device in `UpdateAsrConfiguration()` + - [ ] Add ROCm to capabilities response +- [ ] Update tests in `tests/grpc/test_asr_config.py` + +### 4.3 Dependencies + +- [ ] Update `pyproject.toml` + - [ ] Add `rocm` extras group + - [ ] Add `openai-whisper` as optional dependency + - [ ] Document ROCm installation in comments +- [ ] Create `requirements-rocm.txt` (optional) + +### 4.4 Docker ROCm Image + +- [ ] Create `docker/Dockerfile.rocm` + - [ ] Base on `rocm/pytorch` image + - [ ] Install NoteFlow with ROCm extras + - [ ] Configure for GPU access +- [ ] Update `compose.yaml` (and/or add `compose.rocm.yaml`) with ROCm profile +- [ ] Test Docker image build + +### 4.5 Documentation + +- [ ] Create `docs/installation/rocm.md` + - [ ] System requirements + - [ ] PyTorch ROCm installation + - [ ] CTranslate2-ROCm installation (optional) + - [ ] Docker usage + - [ ] Troubleshooting +- [ ] Update main README with ROCm section +- [ ] Update `CLAUDE.md` with ROCm notes + +### 4.6 Phase 4 Tests + +- [ ] Run full test suite +- [ ] Run `make quality` +- [ ] Build ROCm Docker image +- [ ] Test on AMD hardware + +--- + +## Final Validation + +### Quality Gates + +- [ ] `pytest tests/quality/` passes +- [ ] `make quality-py` passes +- [ ] `make quality` passes (full stack) +- [ ] Proto regenerated correctly +- [ ] No type errors (`basedpyright`) +- [ ] No lint errors (`ruff`) + +### Functional Validation + +- [ ] CUDA path works (no regression) +- [ ] CPU path works (no regression) +- [ ] ROCm detection works +- [ ] PyTorch fallback works +- [ ] gRPC configuration works +- [ ] Device switching works + +### Documentation + +- [ ] Sprint README complete +- [ ] Implementation checklist complete +- [ ] Installation guide complete +- [ ] API documentation updated + +--- + +## Notes + +### Files Created + +| File | Status | +|------|--------| +| `src/noteflow/domain/ports/gpu.py` | ❌ | +| `src/noteflow/domain/ports/asr.py` | optional (only if relocating protocol) | +| `src/noteflow/infrastructure/gpu/__init__.py` | ❌ | +| `src/noteflow/infrastructure/gpu/detection.py` | ❌ | +| `src/noteflow/infrastructure/asr/pytorch_engine.py` | ❌ | +| `src/noteflow/infrastructure/asr/rocm_engine.py` | ❌ | +| `src/noteflow/infrastructure/asr/factory.py` | ❌ | +| `docker/Dockerfile.rocm` | ❌ | +| `docs/installation/rocm.md` | ❌ | + +### Files Modified + +| File | Status | +|------|--------| +| `application/services/asr_config/types.py` | ❌ | +| `application/services/asr_config/_engine_manager.py` | ❌ | +| `infrastructure/diarization/engine/_device_mixin.py` | ❌ | +| `infrastructure/metrics/system_resources.py` | ❌ | +| `infrastructure/asr/engine.py` | ❌ | +| `infrastructure/asr/protocols.py` | ❌ | +| `grpc/proto/noteflow.proto` | ❌ | +| `grpc/mixins/asr_config.py` | ❌ | +| `config/settings/_features.py` | ❌ | +| `pyproject.toml` | ❌ | diff --git a/docs/sprints/phase-5-evolution/sprint-18.5-rocm-support/README.md b/docs/sprints/phase-5-evolution/sprint-18.5-rocm-support/README.md new file mode 100644 index 0000000..4ffb776 --- /dev/null +++ b/docs/sprints/phase-5-evolution/sprint-18.5-rocm-support/README.md @@ -0,0 +1,786 @@ +# Sprint 18.5: ROCm GPU Backend Support + +> **Size**: L | **Owner**: Backend | **Prerequisites**: None +> **Phase**: 5 - Platform Evolution + +--- + +## Validation Status (2025-01-17) + +### Research Complete — Implementation Ready + +Note: Hardware/driver compatibility and ROCm wheel availability are time-sensitive. +Re-verify against AMD ROCm compatibility matrices and PyTorch ROCm install guidance +before implementation. + +### Repo Alignment Notes (current tree) + +- ASR protocol already exists at `src/noteflow/infrastructure/asr/protocols.py` and + returns `AsrResult` from `src/noteflow/infrastructure/asr/dto.py`. Extend these + instead of adding parallel `domain/ports/asr.py` types unless we plan a broader + layering refactor. +- gRPC mixins live under `src/noteflow/grpc/mixins/` (not `_mixins`). +- Tests live under `tests/infrastructure/` and `tests/application/` (no `tests/unit/`). + +| Prerequisite | Status | Impact | +|--------------|--------|--------| +| PyTorch ROCm support | ✅ Available | PyTorch HIP layer works with existing `torch.cuda` API | +| CTranslate2 ROCm support | ⚠️ Community fork | No official support; requires alternative engine strategy | +| pyannote.audio ROCm support | ✅ Available | Pure PyTorch, works out of box | +| diart ROCm support | ✅ Available | Pure PyTorch, works out of box | + +| Component | Status | Notes | +|-----------|--------|-------| +| Device abstraction layer | ❌ Not implemented | Need `GpuBackend` enum | +| ASR engine protocol | ⚠️ Partial | AsrEngine exists; extend with device/compute metadata | +| ROCm detection | ❌ Not implemented | Need `torch.version.hip` check | +| gRPC proto updates | ❌ Not implemented | Need `ASR_DEVICE_ROCM` | +| PyTorch Whisper fallback | ❌ Not implemented | Fallback for universal compatibility | + +**Action required**: Implement device abstraction layer and engine protocol pattern. + +--- + +## Objective + +Enable NoteFlow backend to run on AMD GPUs via ROCm, providing GPU-accelerated ASR and diarization for users without NVIDIA hardware. This extends hardware support while maintaining full CUDA compatibility. + +--- + +## Feasibility Analysis + +### Component Compatibility Matrix + +| Component | Library | CUDA | ROCm | Notes | +|-----------|---------|------|------|-------| +| ASR (Speech-to-Text) | faster-whisper (CTranslate2) | ✅ Official | ⚠️ Fork | [CTranslate2-ROCm](https://github.com/arlo-phoenix/CTranslate2-rocm) | +| Diarization (Streaming) | diart (PyTorch) | ✅ | ✅* | HIP layer transparent; validate in our stack | +| Diarization (Offline) | pyannote.audio (PyTorch) | ✅ | ✅ | [AMD tested](https://rocm.blogs.amd.com/artificial-intelligence/speech_models/README.html) | +| NER | spaCy transformers (PyTorch) | ✅ | ✅ | HIP layer transparent | +| Metrics/VRAM | torch.cuda.* | ✅ | ✅ | Works via HIP | + +### Hardware Support (Source of Truth) + +ROCm GPU support is version- and distro-specific. Use AMD's official compatibility +matrices and the Radeon ROCm guide as the source of truth. + +- Datacenter (Instinct/MI): Refer to the ROCm compatibility matrices for exact + ROCm release + GPU + OS combinations. +- Radeon/Workstation: Official Radeon ROCm support is narrower; check the latest + Radeon ROCm guide for the exact RX series and Linux distro support. Older Radeon + generations are often not listed and may require overrides or community builds. + +### Key Technical Insights + +1. **ROCm HIP Layer**: PyTorch's ROCm build [reuses the `torch.cuda` interface](https://docs.pytorch.org/docs/stable/notes/hip.html). Existing code using `torch.cuda.is_available()` and `torch.device("cuda")` works without modification (there is no `torch.device("rocm")`). + +2. **CTranslate2 Blocker**: faster-whisper's [CTranslate2](https://github.com/OpenNMT/CTranslate2) dependency has no official ROCm support. Options: + - Use [CTranslate2-ROCm fork](https://github.com/arlo-phoenix/CTranslate2-rocm) (community maintained) + - Implement PyTorch Whisper fallback (official, slower) + - Use [insanely-fast-whisper-rocm](https://github.com/beecave-homelab/insanely-fast-whisper-rocm) (community project) + +3. **Detection Logic**: ROCm can be detected via: + ```python + if torch.cuda.is_available() and hasattr(torch.version, 'hip') and torch.version.hip: + # ROCm/HIP backend + ``` + +--- + +## What Already Exists + +### Device Detection Infrastructure + +| Asset | Location | Reuse Strategy | +|-------|----------|----------------| +| CUDA detection | `application/services/asr_config/_engine_manager.py:61-72` | Extend with ROCm check | +| Device preference | `infrastructure/diarization/engine/_device_mixin.py:18-36` | Add ROCm to detection cascade | +| Device types | `application/services/asr_config/types.py:22-27` | Add `ROCM` enum value | +| Compute types | `application/services/asr_config/types.py:37-44` | Add ROCm compute types | +| ASR protocol | `infrastructure/asr/protocols.py` | Extend existing `AsrEngine` (avoid duplicate protocol) | +| ASR DTOs | `infrastructure/asr/dto.py` | Reuse `AsrResult`/`WordTiming` (avoid duplicate types) | + +### ASR Engine Architecture + +| Asset | Location | Reuse Strategy | +|-------|----------|----------------| +| FasterWhisperEngine | `infrastructure/asr/engine.py` | Keep as CUDA/CPU backend | +| AsrEngine protocol (existing) | `infrastructure/asr/protocols.py` | Extend with device/compute metadata | +| Engine manager | `application/services/asr_config/_engine_manager.py` | Extend for engine selection | + +### gRPC Configuration + +| Asset | Location | Reuse Strategy | +|-------|----------|----------------| +| ASR device enum | `grpc/proto/noteflow.proto` | Add `ASR_DEVICE_ROCM` | +| ASR config messages | `grpc/proto/noteflow.proto` | Add `rocm_available` field | +| Config RPC handlers | `grpc/mixins/asr_config.py` | Extend for ROCm | + +--- + +## Architecture Decision: Engine Protocol Pattern + +### Decision + +Extend the existing `AsrEngine` protocol (or move it to `domain/ports` if we want a +cleaner layering boundary) to abstract ASR backends and enable swappable implementations. + +### Rationale + +1. **Separation of Concerns**: Hardware-specific code isolated in engine implementations +2. **Testability**: Mock engines for testing without GPU +3. **Extensibility**: Future backends (TPU, Apple Neural Engine) follow same pattern +4. **Gradual Migration**: Existing code continues working while new engines added + +### Engine Hierarchy + +``` +AsrEngine (Protocol) +├── FasterWhisperEngine (CUDA/CPU) — existing, extend protocol +├── FasterWhisperRocmEngine (ROCm) — new, uses CTranslate2-ROCm fork +└── WhisperPyTorchEngine (Universal) — new, pure PyTorch fallback +``` + +--- + +## Scope + +### Phase 1: Device Abstraction Layer (M) + +| Task | Effort | Files | +|------|--------|-------| +| Add `GpuBackend` enum | S | `domain/ports/gpu.py` (new) | +| Add ROCm detection utility | S | `infrastructure/gpu/detection.py` (new) | +| Extend `AsrDevice` enum | S | `application/services/asr_config/types.py` | +| Update device mixin | S | `infrastructure/diarization/engine/_device_mixin.py` | +| Update gRPC proto | S | `grpc/proto/noteflow.proto` | + +### Phase 2: ASR Engine Protocol (M) + +| Task | Effort | Files | +|------|--------|-------| +| Extend `AsrEngine` protocol (or relocate) | M | `infrastructure/asr/protocols.py` (extend) or `domain/ports/asr.py` (move) | +| Refactor `FasterWhisperEngine` to protocol | M | `infrastructure/asr/engine.py` | +| Create `WhisperPyTorchEngine` fallback | M | `infrastructure/asr/pytorch_engine.py` (new) | +| Create engine factory | S | `infrastructure/asr/factory.py` (new) | +| Update engine manager | M | `application/services/asr_config/_engine_manager.py` | + +### Phase 3: ROCm-Specific Engine (M) + +| Task | Effort | Files | +|------|--------|-------| +| Create `FasterWhisperRocmEngine` | M | `infrastructure/asr/rocm_engine.py` (new) | +| Add conditional CTranslate2-ROCm import | S | `infrastructure/asr/rocm_engine.py` | +| Update engine factory for ROCm | S | `infrastructure/asr/factory.py` | +| Add ROCm installation detection | S | `infrastructure/gpu/detection.py` | + +### Phase 4: Configuration & Distribution (L) + +| Task | Effort | Files | +|------|--------|-------| +| Add ROCm feature flag | S | `config/settings/_features.py` | +| Update gRPC config handlers | M | `grpc/mixins/asr_config.py` | +| Create ROCm Docker image | M | `docker/Dockerfile.rocm` (new) | +| Update pyproject.toml extras | S | `pyproject.toml` | +| Add ROCm installation docs | M | `docs/installation/rocm.md` (new) | + +--- + +## Deliverables + +### New Files + +| File | Purpose | +|------|---------| +| `src/noteflow/domain/ports/gpu.py` | GPU backend types and detection protocol | +| `src/noteflow/infrastructure/gpu/__init__.py` | GPU package init | +| `src/noteflow/infrastructure/gpu/detection.py` | GPU detection utilities | +| `src/noteflow/infrastructure/asr/pytorch_engine.py` | Pure PyTorch Whisper engine | +| `src/noteflow/infrastructure/asr/rocm_engine.py` | ROCm-specific faster-whisper engine | +| `src/noteflow/infrastructure/asr/factory.py` | Engine factory for backend selection | +| `docker/Dockerfile.rocm` | ROCm-enabled Docker image | +| `docs/installation/rocm.md` | ROCm setup documentation | + +### Modified Files + +| File | Changes | +|------|---------| +| `application/services/asr_config/types.py` | Add `ROCM` to `AsrDevice`, ROCm compute types | +| `application/services/asr_config/_engine_manager.py` | Use engine factory, extend detection | +| `infrastructure/diarization/engine/_device_mixin.py` | Add ROCm to detection cascade | +| `infrastructure/metrics/system_resources.py` | Handle ROCm VRAM queries | +| `infrastructure/asr/protocols.py` | Extend `AsrEngine` with device/compute_type metadata | +| `grpc/proto/noteflow.proto` | Add `ASR_DEVICE_ROCM`, `rocm_available`, `gpu_backend` | +| `grpc/mixins/asr_config.py` | Map ROCm device and capabilities | +| `config/settings/_features.py` | Add `NOTEFLOW_FEATURE_ROCM_ENABLED` | +| `pyproject.toml` | Add `rocm` extras group | + +--- + +## Implementation Details + +### 1. GPU Detection Module + +**File**: `src/noteflow/infrastructure/gpu/detection.py` + +```python +"""GPU backend detection utilities.""" + +from __future__ import annotations + +from noteflow.domain.ports.gpu import GpuBackend, GpuInfo + + +def detect_gpu_backend() -> GpuBackend: + """Detect the available GPU backend. + + Returns: + GpuBackend enum indicating the detected backend. + """ + try: + import torch + except ImportError: + return GpuBackend.NONE + + if not torch.cuda.is_available(): + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + return GpuBackend.MPS + return GpuBackend.NONE + + # Check if this is ROCm masquerading as CUDA + if hasattr(torch.version, 'hip') and torch.version.hip: + return GpuBackend.ROCM + + return GpuBackend.CUDA + + +def get_gpu_info() -> GpuInfo | None: + """Get detailed GPU information. + + Returns: + GpuInfo if a GPU is available, None otherwise. + """ + backend = detect_gpu_backend() + + if backend == GpuBackend.NONE: + return None + + import torch + + if backend in (GpuBackend.CUDA, GpuBackend.ROCM): + props = torch.cuda.get_device_properties(0) + vram_mb = props.total_memory // (1024 * 1024) + + driver_version = "" + if backend == GpuBackend.ROCM: + driver_version = getattr(torch.version, 'hip', 'unknown') + else: + driver_version = torch.version.cuda or 'unknown' + + return GpuInfo( + backend=backend, + device_name=props.name, + vram_total_mb=vram_mb, + driver_version=driver_version, + ) + + if backend == GpuBackend.MPS: + return GpuInfo( + backend=backend, + device_name="Apple Metal", + vram_total_mb=0, # MPS doesn't expose VRAM + driver_version="mps", + ) + + return None +``` + +### 2. ASR Engine Protocol + +**File**: `src/noteflow/infrastructure/asr/protocols.py` (extend) + +```python +"""ASR protocols defining contracts for ASR components.""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + from pathlib import Path + + import numpy as np + from numpy.typing import NDArray + from noteflow.infrastructure.asr.dto import AsrResult + + +class AsrEngine(Protocol): + """Protocol for ASR transcription engine implementations.""" + + @property + def device(self) -> str: + """Return the requested device ("cpu", "cuda", "rocm").""" + ... + + @property + def compute_type(self) -> str: + """Return the compute type (int8, float16, float32).""" + ... + + @property + def model_size(self) -> str | None: + """Return the loaded model size, or None if not loaded.""" + ... + + @property + def is_loaded(self) -> bool: + """Return True if model is loaded.""" + ... + + def load_model(self, model_size: str = "base") -> None: + """Load the ASR model.""" + ... + + def transcribe( + self, + audio: NDArray[np.float32], + language: str | None = None, + ) -> Iterator[AsrResult]: + """Transcribe audio data.""" + ... + + def unload(self) -> None: + """Unload the model and free resources.""" + ... + + def transcribe_file( + self, + audio_path: Path, + *, + language: str | None = None, + ) -> Iterator[AsrResult]: + """Optional helper for file-based transcription.""" + ... +``` + +### 3. Engine Factory + +**File**: `src/noteflow/infrastructure/asr/factory.py` + +```python +"""ASR engine factory for backend selection.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from noteflow.infrastructure.gpu.detection import GpuBackend, detect_gpu_backend +from noteflow.infrastructure.logging import get_logger + +if TYPE_CHECKING: + from noteflow.infrastructure.asr.protocols import AsrEngine + +logger = get_logger(__name__) + + +class EngineCreationError(Exception): + """Raised when engine creation fails.""" + + +def create_asr_engine( + device: str = "auto", + compute_type: str = "int8", + *, + prefer_faster_whisper: bool = True, +) -> AsrEngine: + """Create an ASR engine for the specified device. + + Args: + device: Target device ("auto", "cpu", "cuda", "rocm"). + compute_type: Compute precision ("int8", "float16", "float32"). + prefer_faster_whisper: If True, prefer faster-whisper over PyTorch. + + Returns: + An ASR engine implementing AsrEngine. + + Raises: + EngineCreationError: If engine creation fails. + """ + resolved_device = _resolve_device(device) + + logger.info( + "Creating ASR engine", + device=device, + resolved_device=resolved_device, + compute_type=compute_type, + ) + + if resolved_device == "cpu": + return _create_cpu_engine(compute_type) + + if resolved_device == "cuda": + return _create_cuda_engine(compute_type, prefer_faster_whisper) + + if resolved_device == "rocm": + return _create_rocm_engine(compute_type, prefer_faster_whisper) + + msg = f"Unsupported device: {resolved_device}" + raise EngineCreationError(msg) + + +def _resolve_device(device: str) -> str: + """Resolve 'auto' device to actual backend.""" + if device != "auto": + return device + + backend = detect_gpu_backend() + + if backend == GpuBackend.CUDA: + return "cuda" + if backend == GpuBackend.ROCM: + return "rocm" + + return "cpu" + + +def _create_cpu_engine(compute_type: str) -> AsrEngine: + """Create CPU engine (always uses faster-whisper).""" + from noteflow.infrastructure.asr.engine import FasterWhisperEngine + + # CPU only supports int8 and float32 + if compute_type == "float16": + compute_type = "float32" + + return FasterWhisperEngine(device="cpu", compute_type=compute_type) + + +def _create_cuda_engine( + compute_type: str, + prefer_faster_whisper: bool, +) -> AsrEngine: + """Create CUDA engine.""" + if prefer_faster_whisper: + from noteflow.infrastructure.asr.engine import FasterWhisperEngine + return FasterWhisperEngine(device="cuda", compute_type=compute_type) + + from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine + return WhisperPyTorchEngine(device="cuda", compute_type=compute_type) + + +def _create_rocm_engine( + compute_type: str, + prefer_faster_whisper: bool, +) -> AsrEngine: + """Create ROCm engine. + + Attempts to use CTranslate2-ROCm fork if available, + falls back to PyTorch Whisper otherwise. + """ + if prefer_faster_whisper: + try: + from noteflow.infrastructure.asr.rocm_engine import FasterWhisperRocmEngine + return FasterWhisperRocmEngine(compute_type=compute_type) + except ImportError: + logger.warning( + "CTranslate2-ROCm not installed, falling back to PyTorch Whisper" + ) + + from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine + return WhisperPyTorchEngine(device="cuda", compute_type=compute_type) +``` + +### 4. Extended Device Types + +**File**: `src/noteflow/application/services/asr_config/types.py` (modifications) + +```python +class AsrDevice(str, Enum): + """Supported ASR devices.""" + + CPU = "cpu" + CUDA = "cuda" + ROCM = "rocm" + + +DEVICE_COMPUTE_TYPES: Final[dict[AsrDevice, tuple[AsrComputeType, ...]]] = { + AsrDevice.CPU: (AsrComputeType.INT8, AsrComputeType.FLOAT32), + AsrDevice.CUDA: ( + AsrComputeType.INT8, + AsrComputeType.FLOAT16, + AsrComputeType.FLOAT32, + ), + AsrDevice.ROCM: ( + AsrComputeType.INT8, + AsrComputeType.FLOAT16, + AsrComputeType.FLOAT32, + ), +} +``` + +### 5. Updated Device Mixin + +**File**: `src/noteflow/infrastructure/diarization/engine/_device_mixin.py` (modifications) + +```python +def _detect_available_device(self) -> str: + """Detect the best available device for computation.""" + if self._device_preference != "auto": + return self._device_preference + + import torch + + if torch.cuda.is_available(): + # ROCm uses the same torch.cuda device string; keep "cuda" here + # and expose backend (cuda vs rocm) via a separate detection helper. + return "cuda" + + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + return "mps" + + return "cpu" +``` + +### 6. Proto Updates + +**File**: `src/noteflow/grpc/proto/noteflow.proto` (additions) + +```protobuf +enum AsrDevice { + ASR_DEVICE_UNSPECIFIED = 0; + ASR_DEVICE_CPU = 1; + ASR_DEVICE_CUDA = 2; + ASR_DEVICE_ROCM = 3; // NEW +} + +message AsrConfiguration { + // ... existing fields ... + bool rocm_available = 8; // NEW: ROCm/HIP detected + string gpu_backend = 9; // NEW: "cuda", "rocm", "mps", or "none" +} +``` + +Note: `gpu_backend` reflects detected hardware/runtime, while `device` reflects the +configured ASR target (which may still be "cuda" even on ROCm). + +--- + +## Quality Gates + +- [ ] `pytest tests/quality/` passes +- [ ] `make quality-py` passes +- [ ] Proto regenerated and stubs patched +- [ ] ROCm detection works on AMD hardware +- [ ] PyTorch Whisper fallback functional +- [ ] CUDA path unchanged (no regression) +- [ ] Docker ROCm image builds successfully +- [ ] Meets `docs/sprints/QUALITY_STANDARDS.md` + +--- + +## Test Plan + +### Unit Tests + +**File**: `tests/infrastructure/gpu/test_detection.py` + +```python +import pytest +from unittest.mock import patch, MagicMock + +from noteflow.infrastructure.gpu.detection import ( + GpuBackend, + detect_gpu_backend, + get_gpu_info, +) + + +class TestDetectGpuBackend: + """Tests for GPU backend detection.""" + + def test_no_torch_returns_none(self) -> None: + """Return NONE when torch not installed.""" + with patch.dict("sys.modules", {"torch": None}): + # Force reimport + from importlib import reload + from noteflow.infrastructure.gpu import detection + reload(detection) + assert detection.detect_gpu_backend() == GpuBackend.NONE + + def test_cuda_available_returns_cuda(self) -> None: + """Return CUDA when CUDA available.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = True + mock_torch.version.hip = None + + with patch.dict("sys.modules", {"torch": mock_torch}): + assert detect_gpu_backend() == GpuBackend.CUDA + + def test_hip_available_returns_rocm(self) -> None: + """Return ROCM when HIP version present.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = True + mock_torch.version.hip = "6.0.0" + + with patch.dict("sys.modules", {"torch": mock_torch}): + assert detect_gpu_backend() == GpuBackend.ROCM + + def test_mps_available_returns_mps(self) -> None: + """Return MPS on Apple Silicon.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = False + mock_torch.backends.mps.is_available.return_value = True + + with patch.dict("sys.modules", {"torch": mock_torch}): + assert detect_gpu_backend() == GpuBackend.MPS +``` + +### Integration Tests + +**File**: `tests/integration/test_engine_factory.py` + +```python +import pytest +from unittest.mock import patch + +from noteflow.infrastructure.asr.factory import create_asr_engine + + +@pytest.mark.integration +class TestEngineFactory: + """Integration tests for ASR engine factory.""" + + def test_auto_device_selects_available(self) -> None: + """Auto device resolves to available backend.""" + engine = create_asr_engine(device="auto") + # Note: ROCm may still report "cuda" at the torch device level. + assert engine.device in ("cpu", "cuda", "rocm") + + def test_cpu_engine_always_works(self) -> None: + """CPU engine works regardless of GPU.""" + engine = create_asr_engine(device="cpu") + assert engine.device == "cpu" + + @pytest.mark.skipif( + not _rocm_available(), + reason="ROCm not available", + ) + def test_rocm_engine_creation(self) -> None: + """ROCm engine creates on AMD hardware.""" + engine = create_asr_engine(device="rocm") + assert engine.device == "rocm" + + +def _rocm_available() -> bool: + """Check if ROCm is available for testing.""" + try: + import torch + return ( + torch.cuda.is_available() + and hasattr(torch.version, 'hip') + and torch.version.hip is not None + ) + except ImportError: + return False +``` + +--- + +## Failure Modes + +| Failure | Detection | Recovery | +|---------|-----------|----------| +| CTranslate2-ROCm not installed | Import error in factory | Fall back to PyTorch Whisper | +| ROCm driver version mismatch | HIP error at runtime | Log error, suggest driver update | +| GPU OOM on large model | CUDA/HIP OOM error | Reduce model size or use CPU | +| Unsupported GPU architecture | "invalid device function" | Fall back to CPU with warning | +| PyTorch ROCm not installed | Import error | CPU-only mode | + +--- + +## Dependencies + +### Python Packages + +| Package | Version | Purpose | Notes | +|---------|---------|---------|-------| +| `torch` | >=2.0 | PyTorch with ROCm | Version must match ROCm runtime | +| `torchaudio` | >=2.0 | Audio processing | Version must match ROCm runtime | +| `faster-whisper` | >=1.0 | CUDA/CPU ASR | Existing | +| `ctranslate2-rocm` | (fork) | ROCm ASR | Optional, install from GitHub fork | +| `openai-whisper` | >=20231117 | PyTorch fallback | Optional (slower) | + +### System Requirements + +| Requirement | Minimum | Notes | +|-------------|---------|-------| +| ROCm | Per AMD docs | Must match PyTorch ROCm wheel version | +| Linux kernel | Per AMD docs | Depends on ROCm release + distro | +| AMD GPU driver | amdgpu | Included with ROCm packages | + +--- + +## Installation Guide + +### PyTorch ROCm Installation + +```bash +# Example (replace with your installed ROCm version) +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm +``` + +### CTranslate2-ROCm Installation (Optional) + +```bash +# Install from fork for faster-whisper ROCm support +pip install git+https://github.com/arlo-phoenix/CTranslate2-rocm.git +pip install faster-whisper +``` + +### Docker ROCm Image + +```dockerfile +# Example base image (pin to a ROCm/PyTorch combo that matches your host) +FROM rocm/pytorch:rocm6.x_ubuntu22.04_py3.10_pytorch_2.x.x + +WORKDIR /app + +# Install NoteFlow +COPY . . +RUN pip install -e ".[rocm]" + +CMD ["python", "-m", "noteflow.grpc.server"] +``` + +--- + +## Post-Sprint + +- ROCm performance benchmarking vs CUDA +- Multi-GPU support (device index selection) +- ROCm-specific memory optimization +- Apple Neural Engine support investigation +- Windows ROCm support (if/when officially available) + +--- + +## References + +### Official Documentation + +- [PyTorch HIP Semantics](https://docs.pytorch.org/docs/stable/notes/hip.html) +- [ROCm Compatibility Matrix](https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html) +- [AMD ROCm PyTorch Installation](https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/install/installrad/native_linux/install-pytorch.html) + +### AMD Blog Posts + +- [Speech-to-Text on AMD GPU with Whisper](https://rocm.blogs.amd.com/artificial-intelligence/whisper/README.html) +- [CTranslate2 on AMD GPUs](https://rocm.blogs.amd.com/artificial-intelligence/ctranslate2/README.html) +- [Fine-tuning Speech Models on ROCm](https://rocm.blogs.amd.com/artificial-intelligence/speech_models/README.html) + +### Community Projects + +- [CTranslate2-ROCm Fork](https://github.com/arlo-phoenix/CTranslate2-rocm) +- [insanely-fast-whisper-rocm](https://github.com/beecave-homelab/insanely-fast-whisper-rocm) +- [wyoming-faster-whisper-rocm](https://github.com/Donkey545/wyoming-faster-whisper-rocm) +- [ROCM SDK Builder](https://github.com/lamikr/rocm_sdk_builder) diff --git a/docs/sprints/phase-5-evolution/sprint-18.5-rocm-support/TECHNICAL_DEEP_DIVE.md b/docs/sprints/phase-5-evolution/sprint-18.5-rocm-support/TECHNICAL_DEEP_DIVE.md new file mode 100644 index 0000000..7559aff --- /dev/null +++ b/docs/sprints/phase-5-evolution/sprint-18.5-rocm-support/TECHNICAL_DEEP_DIVE.md @@ -0,0 +1,528 @@ +# ROCm Support: Technical Deep Dive + +This document provides detailed technical analysis of the ROCm integration, including architecture decisions, code patterns, and potential challenges. + +Alignment note: the current codebase already defines `AsrEngine` in +`src/noteflow/infrastructure/asr/protocols.py` and uses `AsrResult` from +`src/noteflow/infrastructure/asr/dto.py`. Prefer extending those instead of +adding parallel protocol/DTO types unless we plan a broader refactor. + +--- + +## Table of Contents + +1. [ROCm/HIP Architecture Overview](#rocmhip-architecture-overview) +2. [PyTorch HIP Transparency](#pytorch-hip-transparency) +3. [CTranslate2 Challenge](#ctranslate2-challenge) +4. [Engine Protocol Design](#engine-protocol-design) +5. [Detection Strategy](#detection-strategy) +6. [Memory Management](#memory-management) +7. [Performance Considerations](#performance-considerations) +8. [Migration Path](#migration-path) + +--- + +## ROCm/HIP Architecture Overview + +### What is ROCm? + +ROCm (Radeon Open Compute) is AMD's open-source software platform for GPU computing. It includes: + +- **HIP (Heterogeneous-compute Interface for Portability)**: A C++ runtime API and kernel language that allows developers to write portable code for AMD and NVIDIA GPUs. +- **rocBLAS/rocFFT/etc.**: Optimized math libraries equivalent to cuBLAS/cuFFT. +- **MIOpen**: Deep learning primitives library (equivalent to cuDNN). + +### HIP Compilation Modes + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ HIP Source Code │ +└─────────────────────────────────────────────────────────────────┘ + │ + ┌───────────────┴───────────────┐ + ▼ ▼ +┌─────────────────────────┐ ┌─────────────────────────┐ +│ hipcc (AMD path) │ │ hipcc (NVIDIA path) │ +│ Compiles to ROCm │ │ Compiles to CUDA │ +└─────────────────────────┘ └─────────────────────────┘ + │ │ + ▼ ▼ +┌─────────────────────────┐ ┌─────────────────────────┐ +│ AMD GPU binary │ │ NVIDIA GPU binary │ +│ (runs on Radeon/MI) │ │ (runs on GeForce/A) │ +└─────────────────────────┘ └─────────────────────────┘ +``` + +### Key Insight for NoteFlow + +PyTorch's ROCm build uses HIP to expose the same `torch.cuda.*` API. From Python's perspective, the code is identical: + +```python +# This works on both CUDA and ROCm PyTorch builds +device = torch.device("cuda") +tensor = torch.randn(1000, 1000, device=device) +``` + +--- + +## PyTorch HIP Transparency + +### How PyTorch Achieves Compatibility + +PyTorch's ROCm build maps CUDA APIs to HIP: + +| CUDA API | HIP API (internal) | Python API (unchanged) | +|----------|-------------------|------------------------| +| `cudaMalloc` | `hipMalloc` | `torch.cuda.memory_allocated()` | +| `cudaMemcpy` | `hipMemcpy` | `tensor.cuda()` | +| `cudaDeviceGetAttribute` | `hipDeviceGetAttribute` | `torch.cuda.get_device_properties()` | +| `cudaGetDeviceCount` | `hipGetDeviceCount` | `torch.cuda.device_count()` | + +### Detection via torch.version.hip + +```python +import torch + +# Standard CUDA check (works on both CUDA and ROCm) +if torch.cuda.is_available(): + # Distinguish between CUDA and ROCm + if hasattr(torch.version, 'hip') and torch.version.hip: + print(f"ROCm {torch.version.hip}") + backend = "rocm" + else: + print(f"CUDA {torch.version.cuda}") + backend = "cuda" +``` + +### Implications for NoteFlow + +**Components that work unchanged:** + +| Component | Why It Works | +|-----------|--------------| +| pyannote.audio | Pure PyTorch, uses `torch.cuda` API | +| diart | Pure PyTorch, uses `torch.cuda` API | +| spaCy transformers | Uses PyTorch backend | +| Diarization engine | Uses `torch.device("cuda")` | +| VRAM monitoring | Uses `torch.cuda.mem_get_info()` | + +**Components that need modification:** + +| Component | Issue | Solution | +|-----------|-------|----------| +| faster-whisper | Uses CTranslate2, no ROCm | Engine abstraction + fallback | +| Device detection | Doesn't distinguish ROCm | Add `torch.version.hip` check | +| gRPC config | No ROCM device type | Add to proto enum | + +--- + +## CTranslate2 Challenge + +### Why CTranslate2 Doesn't Support ROCm + +CTranslate2 is a C++ inference engine optimized for transformer models. Its CUDA support is: + +1. **Compiled against CUDA toolkit**: Uses cuBLAS, cuDNN directly +2. **No HIP port**: Would require significant rewrite +3. **Performance-critical paths**: Hand-optimized CUDA kernels + +### Community Fork Analysis + +The [CTranslate2-ROCm fork](https://github.com/arlo-phoenix/CTranslate2-rocm) uses hipify to convert CUDA code: + +```bash +# How the fork was created (simplified) +hipify-perl -inplace -print-stats cuda_file.cu # Convert CUDA to HIP +hipcc -o output hip_file.cpp # Compile with hipcc +``` + +**Fork Status (as of research):** + +| Aspect | Status | +|--------|--------| +| Maintenance | Active, community-driven | +| ROCm versions | 5.7 - 6.x tested | +| Performance | Reported faster than whisper.cpp (verify on our hardware) | +| Whisper support | Full (via faster-whisper) | +| Stability | Community reports vary; validate for our workloads | + +### Why We Need the Protocol Pattern + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ AsrEngine │ +│ - transcribe(audio) -> segments │ +│ - load_model(size) │ +│ - device, compute_type, is_loaded │ +└─────────────────────────────────────────────────────────────────┘ + │ + ┌───────────────────┼───────────────────┐ + │ │ │ + ▼ ▼ ▼ +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│FasterWhisper │ │FasterWhisper │ │WhisperPyTorch │ +│Engine │ │RocmEngine │ │Engine │ +│(CUDA/CPU) │ │(ROCm fork) │ │(Universal) │ +│ │ │ │ │ │ +│Uses: ctranslate2│ │Uses: ctranslate2│ │Uses: openai- │ +│ (official) │ │ -rocm │ │ whisper │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ + │ │ │ + ▼ ▼ ▼ + NVIDIA GPU AMD GPU Any PyTorch + Device +``` + +--- + +## Engine Protocol Design + +### Protocol Definition Rationale + +The `AsrEngine` protocol uses Python's `Protocol` for structural subtyping: + +```python +from typing import Protocol, runtime_checkable + +@runtime_checkable +class AsrEngine(Protocol): + """All ASR engines must implement this interface.""" + + @property + def device(self) -> str: ... + + @property + def is_loaded(self) -> bool: ... + + def transcribe(self, audio: NDArray[np.float32]) -> Iterator[AsrResult]: ... +``` + +**Why Protocol over ABC?** + +| Approach | Pros | Cons | +|----------|------|------| +| `Protocol` | Structural typing, no inheritance required | Harder to enforce at definition time | +| `ABC` | Explicit contract, IDE support | Requires inheritance, tighter coupling | + +We choose `Protocol` because: +1. `FasterWhisperEngine` already exists without base class +2. Third-party engines can implement without importing our types +3. `@runtime_checkable` enables `isinstance()` checks + +### Factory Pattern + +```python +def create_asr_engine(device: str = "auto") -> AsrEngine: + """Factory function for ASR engine creation. + + Encapsulates: + 1. Device detection (CUDA vs ROCm vs CPU) + 2. Library availability (CTranslate2-ROCm installed?) + 3. Fallback logic (PyTorch Whisper if no native engine) + """ +``` + +**Why factory function over class?** + +- Simpler API: `create_asr_engine("auto")` vs `AsrEngineFactory().create("auto")` +- No state needed between creations +- Easier to test with mocks + +--- + +## Detection Strategy + +### Multi-Layer Detection + +```python +def get_full_device_info() -> DeviceInfo: + """Complete device detection with all relevant information.""" + + # Layer 1: GPU Backend + backend = detect_gpu_backend() # CUDA, ROCm, MPS, or NONE + + # Layer 2: Library Availability + has_ctranslate2 = _check_ctranslate2() + has_ctranslate2_rocm = _check_ctranslate2_rocm() + has_openai_whisper = _check_openai_whisper() + + # Layer 3: Hardware Capabilities + if backend in (GpuBackend.CUDA, GpuBackend.ROCM): + props = torch.cuda.get_device_properties(0) + compute_capability = (props.major, props.minor) + vram_gb = props.total_memory / (1024**3) + + return DeviceInfo( + backend=backend, + has_ctranslate2=has_ctranslate2, + has_ctranslate2_rocm=has_ctranslate2_rocm, + has_openai_whisper=has_openai_whisper, + compute_capability=compute_capability, + vram_gb=vram_gb, + ) +``` + +### GPU Architecture Detection for ROCm + +AMD GPU architectures use "gfx" naming: + +| Architecture | gfx ID | Example GPUs | +|--------------|--------|--------------| +| RDNA 3 | gfx1100, gfx1101, gfx1102 | RX 7900 XTX, RX 7600 | +| RDNA 2 | gfx1030, gfx1031, gfx1032 | RX 6800 XT, RX 6600 | +| CDNA 2 | gfx90a | MI210, MI250 | +| CDNA 3 | gfx942 | MI300X | + +```python +def get_amd_gpu_architecture() -> str | None: + """Get AMD GPU architecture string.""" + if not _is_rocm(): + return None + + props = torch.cuda.get_device_properties(0) + # props.name on ROCm returns architecture, e.g., "gfx1100" + return props.name +``` + +### Handling Unsupported GPUs + +ROCm builds ship with kernels for specific architectures. If your GPU's gfx ID isn't supported: + +``` +RuntimeError: HIP error: invalid device function +``` + +**Detection strategy:** + +```python +OFFICIALLY_SUPPORTED_GFX = { + "gfx900", "gfx906", "gfx908", "gfx90a", # MI series + "gfx1030", "gfx1100", "gfx1101", "gfx1102", # Consumer RDNA +} + +def check_rocm_support(gfx_id: str) -> tuple[bool, str]: + """Check if GPU architecture is officially supported.""" + if gfx_id in OFFICIALLY_SUPPORTED_GFX: + return True, "Officially supported" + + # Check for compatible override (HSA_OVERRIDE_GFX_VERSION) + if os.environ.get("HSA_OVERRIDE_GFX_VERSION"): + return True, "Override enabled (may be unstable)" + + return False, f"Architecture {gfx_id} not officially supported" +``` + +--- + +## Memory Management + +### VRAM Monitoring + +ROCm exposes memory through the same PyTorch API: + +```python +def get_gpu_memory() -> tuple[int, int]: + """Get (free, total) GPU memory in bytes.""" + if not torch.cuda.is_available(): + return (0, 0) + + # Works identically for CUDA and ROCm + free, total = torch.cuda.mem_get_info(device=0) + return (free, total) +``` + +### Memory Pools + +ROCm's memory allocator differs from CUDA: + +| Aspect | CUDA | ROCm | +|--------|------|------| +| Allocator | cudaMallocAsync (11.2+) | hipMallocAsync (5.2+) | +| Memory pools | Yes | Yes (but different behavior) | +| Unified memory | Full support | Limited support | + +**Recommendation:** Use PyTorch's caching allocator abstraction rather than direct HIP calls. + +### Model Memory Requirements + +Approximate figures; validate with our model sizes, batch sizes, and backend. + +| Model | VRAM (float16) | VRAM (int8) | +|-------|----------------|-------------| +| tiny | ~1 GB | ~0.5 GB | +| base | ~1.5 GB | ~0.8 GB | +| small | ~2 GB | ~1 GB | +| medium | ~5 GB | ~2.5 GB | +| large-v3 | ~10 GB | ~5 GB | + +--- + +## Performance Considerations + +### Expected Performance Comparison + +Based on community benchmarks (directional only; verify locally): + +| Configuration | Relative Speed | Notes | +|---------------|----------------|-------| +| CUDA + CTranslate2 | 1.0x (baseline) | Fastest | +| ROCm + CTranslate2-fork | 0.9x | ~10% slower | +| ROCm + PyTorch Whisper | 0.5x | Pure PyTorch overhead | +| CPU + CTranslate2 (int8) | 0.3x | Reasonable for short audio | + +### ROCm-Specific Optimizations + +1. **Kernel Compilation Cache:** + ```bash + export MIOPEN_USER_DB_PATH=/path/to/cache + export MIOPEN_FIND_MODE=3 # Fast compilation + ``` + +2. **Memory Optimization:** + ```python + # Force garbage collection between transcriptions + torch.cuda.empty_cache() + gc.collect() + ``` + +3. **Compute Type Selection:** + - Use `float16` on MI-series (native support) + - Use `float32` on RDNA if `float16` shows artifacts + +### Batch Processing Recommendations + +Example values only; tune for each GPU and model size. + +```python +# ROCm benefits from larger batches due to kernel launch overhead +BATCH_SIZES = { + "cuda": 16, + "rocm": 32, # Larger batch amortizes launch cost + "cpu": 8, +} +``` + +--- + +## Migration Path + +### Step 1: Non-Breaking Detection (v1) + +Add ROCm detection without changing behavior: + +```python +# In _engine_manager.py +def detect_gpu_backend(self) -> str: + """Detect GPU backend for logging/telemetry.""" + if not torch.cuda.is_available(): + return "none" + if hasattr(torch.version, 'hip') and torch.version.hip: + return "rocm" + return "cuda" +``` + +### Step 2: Feature-Flagged Engine Selection (v2) + +```python +# Behind NOTEFLOW_FEATURE_ROCM_ENABLED flag +if settings.feature_flags.rocm_enabled: + engine = create_asr_engine(device="auto") # Uses factory +else: + engine = FasterWhisperEngine(device=device) # Legacy path +``` + +### Step 3: Full Integration (v3) + +- Remove feature flag, factory becomes default +- Add ROCm to gRPC configuration +- Document ROCm installation + +### Rollback Strategy + +Each phase can be disabled independently: + +```python +# Emergency rollback +NOTEFLOW_FEATURE_ROCM_ENABLED=false # Disables ROCm path +NOTEFLOW_ASR_DEVICE=cpu # Forces CPU +``` + +--- + +## Testing Strategy + +### Unit Tests (No GPU Required) + +```python +class TestGpuDetection: + """Mock-based tests for detection logic.""" + + def test_rocm_detection_with_hip_version(self, mock_torch): + mock_torch.version.hip = "6.0.0" + mock_torch.cuda.is_available.return_value = True + assert detect_gpu_backend() == GpuBackend.ROCM +``` + +### Integration Tests (GPU Required) + +```python +@pytest.mark.skipif( + not _rocm_available(), + reason="ROCm not available" +) +class TestRocmEngine: + """Tests that require actual ROCm hardware.""" + + def test_model_loading(self): + engine = create_asr_engine(device="rocm") + engine.load_model("tiny") + assert engine.is_loaded +``` + +### CI/CD Considerations + +```yaml +# GitHub Actions example +jobs: + test-rocm: + runs-on: [self-hosted, rocm] # Custom runner with AMD GPU + steps: + - uses: actions/checkout@v4 + - run: pytest -m rocm tests/ +``` + +--- + +## Troubleshooting Guide + +### Common Issues + +| Error | Cause | Solution | +|-------|-------|----------| +| `HIP error: invalid device function` | GPU architecture not supported | Set `HSA_OVERRIDE_GFX_VERSION` or use CPU | +| `ImportError: libamdhip64.so` | ROCm not installed | Install ROCm or use CPU | +| `CUDA out of memory` (on ROCm) | VRAM exhausted | Reduce model size or batch | +| Model loads but transcription fails | Driver mismatch | Update ROCm to match PyTorch version | + +### Environment Variables + +```bash +# Debug ROCm issues +export AMD_LOG_LEVEL=3 # Verbose logging +export HIP_VISIBLE_DEVICES=0 # Limit to first GPU +export HSA_OVERRIDE_GFX_VERSION=11.0.0 # Override gfx check + +# Performance tuning +export MIOPEN_FIND_MODE=3 # Fast kernel selection +export MIOPEN_USER_DB_PATH=/cache # Kernel cache location +``` + +--- + +## References + +- [PyTorch HIP Documentation](https://docs.pytorch.org/docs/stable/notes/hip.html) +- [ROCm GitHub Repository](https://github.com/RadeonOpenCompute/ROCm) +- [CTranslate2 GitHub](https://github.com/OpenNMT/CTranslate2) +- [CTranslate2-ROCm Fork](https://github.com/arlo-phoenix/CTranslate2-rocm) +- [AMD ROCm Documentation](https://rocm.docs.amd.com/)