chore: configure devcontainer Python venv persistence and normalize package-lock peer dependencies
- Added bind mount for .venv directory in devcontainer to persist Python virtual environment across container rebuilds - Enabled updateRemoteUserUID for proper file permissions in devcontainer - Normalized peer dependency flags in package-lock.json (removed inconsistent "peer": true from core dependencies, added to test-only dependencies) - Added empty codex file placeholder - Created comprehensive
This commit is contained in:
@@ -20,6 +20,10 @@
|
||||
"DISPLAY": ":1",
|
||||
"XDG_RUNTIME_DIR": "/tmp/runtime-vscode"
|
||||
},
|
||||
"mounts": [
|
||||
"source=${localWorkspaceFolder}/.venv,target=${containerWorkspaceFolder}/.venv,type=bind,consistency=cached"
|
||||
],
|
||||
"updateRemoteUserUID": true,
|
||||
"postCreateCommand": ".devcontainer/postCreate.sh",
|
||||
"remoteUser": "vscode",
|
||||
"customizations": {
|
||||
|
||||
37
client/package-lock.json
generated
37
client/package-lock.json
generated
@@ -482,7 +482,6 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
@@ -524,7 +523,6 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
@@ -4531,7 +4529,8 @@
|
||||
"version": "5.0.4",
|
||||
"resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.4.tgz",
|
||||
"integrity": "sha512-rfT93uj5s0PRL7EzccGMs3brplhcrghnDoV26NqKhCAS1hVo+WdNsPvE/yb6ilfr5hi2MEk6d5EWJTKdxg8jVw==",
|
||||
"license": "MIT"
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/@types/chai": {
|
||||
"version": "5.2.3",
|
||||
@@ -4665,7 +4664,6 @@
|
||||
"integrity": "sha512-qm+G8HuG6hOHQigsi7VGuLjUVu6TtBo/F05zvX04Mw2uCg9Dv0Qxy3Qw7j41SidlTcl5D/5yg0SEZqOB+EqZnQ==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -4690,7 +4688,6 @@
|
||||
"integrity": "sha512-cisd7gxkzjBKU2GgdYrTdtQx1SORymWyaAFhaxQPK9bYO9ot3Y5OikQRvY0VYQtvwjeQnizCINJAenh/V7MK2w==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@types/prop-types": "*",
|
||||
"csstype": "^3.2.2"
|
||||
@@ -4702,7 +4699,6 @@
|
||||
"integrity": "sha512-MEe3UeoENYVFXzoXEWsvcpg6ZvlrFNlOQ7EOsvhI3CfAXwzPfO8Qwuxd40nepsYKqyyVQnTdEfv68q91yLcKrQ==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"peerDependencies": {
|
||||
"@types/react": "^18.0.0"
|
||||
}
|
||||
@@ -4811,7 +4807,6 @@
|
||||
"integrity": "sha512-npiaib8XzbjtzS2N4HlqPvlpxpmZ14FjSJrteZpPxGUaYPlvhzlzUZ4mZyABo0EFrOWnvyd0Xxroq//hKhtAWg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@typescript-eslint/scope-manager": "8.53.0",
|
||||
"@typescript-eslint/types": "8.53.0",
|
||||
@@ -5276,7 +5271,6 @@
|
||||
"integrity": "sha512-OmwPKV8c5ecLqo+EkytN7oUeYfNmRI4uOXGIR1ybP7AK5Zz+l9R0dGfoadEuwi1aZXAL0vwuhtq3p0OL3dfqHQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=18.20.0"
|
||||
},
|
||||
@@ -5331,7 +5325,6 @@
|
||||
"integrity": "sha512-HdzDrRs+ywAqbXGKqe1i/bLtCv47plz4TvsHFH3j729OooT5VH38ctFn5aLXgECmiAKDkmH/A6kOq2Zh5DIxww==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"chalk": "^5.1.2",
|
||||
"loglevel": "^1.6.0",
|
||||
@@ -5582,7 +5575,6 @@
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -6091,7 +6083,6 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"baseline-browser-mapping": "^2.9.0",
|
||||
"caniuse-lite": "^1.0.30001759",
|
||||
@@ -6852,7 +6843,6 @@
|
||||
"resolved": "https://registry.npmjs.org/date-fns/-/date-fns-3.6.0.tgz",
|
||||
"integrity": "sha512-fRHTG8g/Gif+kSh50gaGEdToemgfj74aRX3swtiouboip5JDLAyDE9F11nHMIcvOaXeOC6D7SpNhi7uFyB7Uww==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/kossnocorp"
|
||||
@@ -6997,7 +6987,8 @@
|
||||
"version": "0.5.16",
|
||||
"resolved": "https://registry.npmjs.org/dom-accessibility-api/-/dom-accessibility-api-0.5.16.tgz",
|
||||
"integrity": "sha512-X7BJ2yElsnOJ30pZF4uIIDfBEVgF4XEBxL9Bxhy6dnrm5hkzqmsWHGTiHqRiITNhMyFLyAiWndIJP7Z1NTteDg==",
|
||||
"license": "MIT"
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/dom-helpers": {
|
||||
"version": "5.2.1",
|
||||
@@ -7208,8 +7199,7 @@
|
||||
"version": "8.6.0",
|
||||
"resolved": "https://registry.npmjs.org/embla-carousel/-/embla-carousel-8.6.0.tgz",
|
||||
"integrity": "sha512-SjWyZBHJPbqxHOzckOfo8lHisEaJWmwd23XppYFYVh10bU66/Pn5tkVkbkCMZVdbUE5eTCI2nD8OyIP4Z+uwkA==",
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/embla-carousel-react": {
|
||||
"version": "8.6.0",
|
||||
@@ -7398,7 +7388,6 @@
|
||||
"integrity": "sha512-LEyamqS7W5HB3ujJyvi0HQK/dtVINZvd5mAAp9eT5S/ujByGjiZLCzPcHVzuXbpJDJF/cxwHlfceVUDZ2lnSTw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@eslint-community/eslint-utils": "^4.8.0",
|
||||
"@eslint-community/regexpp": "^4.12.1",
|
||||
@@ -7735,7 +7724,6 @@
|
||||
"integrity": "sha512-gQHqfI6SmtYBIkTeMizpHThdpXh6ej2Hk68oKZneFM6iu99ZGXvOPnmhd8VDus3xOWhVDDdf4sLsMV2/o+X6Yg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@vitest/snapshot": "^4.0.16",
|
||||
"deep-eql": "^5.0.2",
|
||||
@@ -9133,7 +9121,6 @@
|
||||
"resolved": "https://registry.npmjs.org/jiti/-/jiti-1.21.7.tgz",
|
||||
"integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"jiti": "bin/jiti.js"
|
||||
}
|
||||
@@ -9162,7 +9149,6 @@
|
||||
"resolved": "https://registry.npmjs.org/jsdom/-/jsdom-27.4.0.tgz",
|
||||
"integrity": "sha512-mjzqwWRD9Y1J1KUi7W97Gja1bwOOM5Ug0EZ6UDK3xS7j7mndrkwozHtSblfomlzyB4NepioNt+B2sOSzczVgtQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@acemir/cssom": "^0.9.28",
|
||||
"@asamuzakjp/dom-selector": "^6.7.6",
|
||||
@@ -9608,6 +9594,7 @@
|
||||
"resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.5.0.tgz",
|
||||
"integrity": "sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"lz-string": "bin/bin.js"
|
||||
}
|
||||
@@ -10572,7 +10559,6 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"nanoid": "^3.3.11",
|
||||
"picocolors": "^1.1.1",
|
||||
@@ -10739,6 +10725,7 @@
|
||||
"resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-27.5.1.tgz",
|
||||
"integrity": "sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"ansi-regex": "^5.0.1",
|
||||
"ansi-styles": "^5.0.0",
|
||||
@@ -10907,7 +10894,6 @@
|
||||
"resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz",
|
||||
"integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"loose-envify": "^1.1.0"
|
||||
},
|
||||
@@ -10934,7 +10920,6 @@
|
||||
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz",
|
||||
"integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"loose-envify": "^1.1.0",
|
||||
"scheduler": "^0.23.2"
|
||||
@@ -10948,7 +10933,6 @@
|
||||
"resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.71.1.tgz",
|
||||
"integrity": "sha512-9SUJKCGKo8HUSsCO+y0CtqkqI5nNuaDqTxyqPsZPqIwudpj4rCrAz/jZV+jn57bx5gtZKOh3neQu94DXMc+w5w==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
@@ -10964,7 +10948,8 @@
|
||||
"version": "17.0.2",
|
||||
"resolved": "https://registry.npmjs.org/react-is/-/react-is-17.0.2.tgz",
|
||||
"integrity": "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==",
|
||||
"license": "MIT"
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/react-remove-scroll": {
|
||||
"version": "2.7.2",
|
||||
@@ -12204,7 +12189,6 @@
|
||||
"resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.4.19.tgz",
|
||||
"integrity": "sha512-3ofp+LL8E+pK/JuPLPggVAIaEuhvIz4qNcf3nA1Xn2o/7fb7s/TYpHhwGDv1ZU3PkBluUVaF8PyCHcm48cKLWQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@alloc/quick-lru": "^5.2.0",
|
||||
"arg": "^5.0.2",
|
||||
@@ -13056,7 +13040,6 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -13286,7 +13269,6 @@
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-7.3.1.tgz",
|
||||
"integrity": "sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.27.0",
|
||||
"fdir": "^6.5.0",
|
||||
@@ -14047,7 +14029,6 @@
|
||||
"integrity": "sha512-Y5y4jpwHvuduUfup+gXTuCU6AROn/k6qOba3st0laFluKHY+q5SHOpQAJdS8acYLwE8caDQ2dXJhmXyxuJrm0Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@types/node": "^20.11.30",
|
||||
"@types/sinonjs__fake-timers": "^8.1.5",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,282 @@
|
||||
# ROCm Support Implementation Checklist
|
||||
|
||||
This checklist tracks the implementation progress for Sprint 18.5.
|
||||
|
||||
---
|
||||
|
||||
## Phase 1: Device Abstraction Layer
|
||||
|
||||
### 1.1 GPU Detection Module
|
||||
|
||||
- [ ] Create `src/noteflow/infrastructure/gpu/__init__.py`
|
||||
- [ ] Create `src/noteflow/infrastructure/gpu/detection.py`
|
||||
- [ ] Implement `GpuBackend` enum (NONE, CUDA, ROCM, MPS)
|
||||
- [ ] Implement `GpuInfo` dataclass
|
||||
- [ ] Implement `detect_gpu_backend()` function
|
||||
- [ ] Implement `get_gpu_info()` function
|
||||
- [ ] Add ROCm version detection via `torch.version.hip`
|
||||
- [ ] Create `tests/infrastructure/gpu/test_detection.py`
|
||||
- [ ] Test no-torch case
|
||||
- [ ] Test CUDA detection
|
||||
- [ ] Test ROCm detection (HIP check)
|
||||
- [ ] Test MPS detection
|
||||
- [ ] Test CPU fallback
|
||||
|
||||
### 1.2 Domain Types
|
||||
|
||||
- [ ] Create `src/noteflow/domain/ports/gpu.py`
|
||||
- [ ] Export `GpuBackend` enum
|
||||
- [ ] Export `GpuInfo` type
|
||||
- [ ] Define `GpuDetectionProtocol`
|
||||
|
||||
### 1.3 ASR Device Types
|
||||
|
||||
- [ ] Update `src/noteflow/application/services/asr_config/types.py`
|
||||
- [ ] Add `ROCM = "rocm"` to `AsrDevice` enum
|
||||
- [ ] Add ROCm entry to `DEVICE_COMPUTE_TYPES` mapping
|
||||
- [ ] Update `AsrCapabilities` dataclass with `rocm_available` and `gpu_backend` fields
|
||||
|
||||
### 1.4 Diarization Device Mixin
|
||||
|
||||
- [ ] Update `src/noteflow/infrastructure/diarization/engine/_device_mixin.py`
|
||||
- [ ] Add ROCm detection in `_detect_available_device()`
|
||||
- [ ] Maintain backward compatibility with "cuda" device string
|
||||
|
||||
### 1.5 System Metrics
|
||||
|
||||
- [ ] Update `src/noteflow/infrastructure/metrics/system_resources.py`
|
||||
- [ ] Handle ROCm VRAM queries (same API as CUDA via HIP)
|
||||
- [ ] Add `gpu_backend` field to metrics
|
||||
|
||||
### 1.6 gRPC Proto
|
||||
|
||||
- [ ] Update `src/noteflow/grpc/proto/noteflow.proto`
|
||||
- [ ] Add `ASR_DEVICE_ROCM = 3` to `AsrDevice` enum
|
||||
- [ ] Add `rocm_available` field to `AsrConfiguration`
|
||||
- [ ] Add `gpu_backend` field to `AsrConfiguration`
|
||||
- [ ] Regenerate Python stubs
|
||||
- [ ] Run `scripts/patch_grpc_stubs.py`
|
||||
|
||||
### 1.7 Phase 1 Tests
|
||||
|
||||
- [ ] Run `pytest tests/infrastructure/gpu/`
|
||||
- [ ] Run `make quality-py`
|
||||
- [ ] Verify no regressions in CUDA detection
|
||||
|
||||
---
|
||||
|
||||
## Phase 2: ASR Engine Protocol
|
||||
|
||||
### 2.1 Engine Protocol Definition
|
||||
|
||||
- [ ] Extend `src/noteflow/infrastructure/asr/protocols.py` (or relocate to `domain/ports`)
|
||||
- [ ] Reuse `AsrResult` / `WordTiming` from `infrastructure/asr/dto.py`
|
||||
- [ ] Add `device` property (logical device: cpu/cuda/rocm)
|
||||
- [ ] Add `compute_type` property
|
||||
- [ ] Confirm `model_size` + `is_loaded` already covered
|
||||
- [ ] Add optional `transcribe_file()` helper (if needed)
|
||||
|
||||
### 2.2 Refactor FasterWhisperEngine
|
||||
|
||||
- [ ] Update `src/noteflow/infrastructure/asr/engine.py`
|
||||
- [ ] Ensure compliance with `AsrEngine`
|
||||
- [ ] Add explicit type annotations
|
||||
- [ ] Document as CUDA/CPU backend
|
||||
- [ ] Create `tests/infrastructure/asr/test_protocol_compliance.py`
|
||||
- [ ] Verify `FasterWhisperEngine` implements protocol
|
||||
|
||||
### 2.3 PyTorch Whisper Engine (Fallback)
|
||||
|
||||
- [ ] Create `src/noteflow/infrastructure/asr/pytorch_engine.py`
|
||||
- [ ] Implement `WhisperPyTorchEngine` class
|
||||
- [ ] Implement all protocol methods
|
||||
- [ ] Handle device placement (cuda/rocm/cpu)
|
||||
- [ ] Support all compute types
|
||||
- [ ] Create `tests/infrastructure/asr/test_pytorch_engine.py`
|
||||
- [ ] Test model loading
|
||||
- [ ] Test transcription
|
||||
- [ ] Test device handling
|
||||
|
||||
### 2.4 Engine Factory
|
||||
|
||||
- [ ] Create `src/noteflow/infrastructure/asr/factory.py`
|
||||
- [ ] Implement `create_asr_engine()` function
|
||||
- [ ] Implement `_resolve_device()` helper
|
||||
- [ ] Implement `_create_cpu_engine()` helper
|
||||
- [ ] Implement `_create_cuda_engine()` helper
|
||||
- [ ] Implement `_create_rocm_engine()` helper
|
||||
- [ ] Define `EngineCreationError` exception
|
||||
- [ ] Create `tests/infrastructure/asr/test_factory.py`
|
||||
- [ ] Test auto device resolution
|
||||
- [ ] Test explicit device selection
|
||||
- [ ] Test fallback behavior
|
||||
- [ ] Test error cases
|
||||
|
||||
### 2.5 Update Engine Manager
|
||||
|
||||
- [ ] Update `src/noteflow/application/services/asr_config/_engine_manager.py`
|
||||
- [ ] Add `detect_rocm_available()` method
|
||||
- [ ] Update `build_capabilities()` for ROCm
|
||||
- [ ] Update `check_configuration()` for ROCm validation
|
||||
- [ ] Use factory for engine creation in `build_engine_for_job()`
|
||||
- [ ] Update `tests/application/test_asr_config_service.py`
|
||||
- [ ] Add ROCm detection tests
|
||||
- [ ] Add ROCm validation tests
|
||||
|
||||
### 2.6 Phase 2 Tests
|
||||
|
||||
- [ ] Run full ASR test suite
|
||||
- [ ] Run `make quality-py`
|
||||
- [ ] Verify CUDA path unchanged
|
||||
|
||||
---
|
||||
|
||||
## Phase 3: ROCm-Specific Engine
|
||||
|
||||
### 3.1 ROCm Engine Implementation
|
||||
|
||||
- [ ] Create `src/noteflow/infrastructure/asr/rocm_engine.py`
|
||||
- [ ] Implement `FasterWhisperRocmEngine` class
|
||||
- [ ] Handle CTranslate2-ROCm import with fallback
|
||||
- [ ] Implement all protocol methods
|
||||
- [ ] Add ROCm-specific optimizations
|
||||
- [ ] Create `tests/infrastructure/asr/test_rocm_engine.py`
|
||||
- [ ] Test import fallback behavior
|
||||
- [ ] Test engine creation (mock)
|
||||
- [ ] Test protocol compliance
|
||||
|
||||
### 3.2 Update Factory for ROCm
|
||||
|
||||
- [ ] Update `src/noteflow/infrastructure/asr/factory.py`
|
||||
- [ ] Add ROCm engine import with graceful fallback
|
||||
- [ ] Log warning when falling back to PyTorch
|
||||
- [ ] Update factory tests for ROCm path
|
||||
|
||||
### 3.3 ROCm Installation Detection
|
||||
|
||||
- [ ] Update `src/noteflow/infrastructure/gpu/detection.py`
|
||||
- [ ] Add `is_ctranslate2_rocm_available()` function
|
||||
- [ ] Add `get_rocm_version()` function
|
||||
- [ ] Add corresponding tests
|
||||
|
||||
### 3.4 Phase 3 Tests
|
||||
|
||||
- [ ] Run ROCm-specific tests (skip if no ROCm)
|
||||
- [ ] Run `make quality-py`
|
||||
- [ ] Test on AMD hardware (if available)
|
||||
|
||||
---
|
||||
|
||||
## Phase 4: Configuration & Distribution
|
||||
|
||||
### 4.1 Feature Flag
|
||||
|
||||
- [ ] Update `src/noteflow/config/settings/_features.py`
|
||||
- [ ] Add `NOTEFLOW_FEATURE_ROCM_ENABLED` flag
|
||||
- [ ] Document in settings
|
||||
- [ ] Update any feature flag guards
|
||||
|
||||
### 4.2 gRPC Config Handlers
|
||||
|
||||
- [ ] Update `src/noteflow/grpc/mixins/asr_config.py`
|
||||
- [ ] Handle ROCm device in `GetAsrConfiguration()`
|
||||
- [ ] Handle ROCm device in `UpdateAsrConfiguration()`
|
||||
- [ ] Add ROCm to capabilities response
|
||||
- [ ] Update tests in `tests/grpc/test_asr_config.py`
|
||||
|
||||
### 4.3 Dependencies
|
||||
|
||||
- [ ] Update `pyproject.toml`
|
||||
- [ ] Add `rocm` extras group
|
||||
- [ ] Add `openai-whisper` as optional dependency
|
||||
- [ ] Document ROCm installation in comments
|
||||
- [ ] Create `requirements-rocm.txt` (optional)
|
||||
|
||||
### 4.4 Docker ROCm Image
|
||||
|
||||
- [ ] Create `docker/Dockerfile.rocm`
|
||||
- [ ] Base on `rocm/pytorch` image
|
||||
- [ ] Install NoteFlow with ROCm extras
|
||||
- [ ] Configure for GPU access
|
||||
- [ ] Update `compose.yaml` (and/or add `compose.rocm.yaml`) with ROCm profile
|
||||
- [ ] Test Docker image build
|
||||
|
||||
### 4.5 Documentation
|
||||
|
||||
- [ ] Create `docs/installation/rocm.md`
|
||||
- [ ] System requirements
|
||||
- [ ] PyTorch ROCm installation
|
||||
- [ ] CTranslate2-ROCm installation (optional)
|
||||
- [ ] Docker usage
|
||||
- [ ] Troubleshooting
|
||||
- [ ] Update main README with ROCm section
|
||||
- [ ] Update `CLAUDE.md` with ROCm notes
|
||||
|
||||
### 4.6 Phase 4 Tests
|
||||
|
||||
- [ ] Run full test suite
|
||||
- [ ] Run `make quality`
|
||||
- [ ] Build ROCm Docker image
|
||||
- [ ] Test on AMD hardware
|
||||
|
||||
---
|
||||
|
||||
## Final Validation
|
||||
|
||||
### Quality Gates
|
||||
|
||||
- [ ] `pytest tests/quality/` passes
|
||||
- [ ] `make quality-py` passes
|
||||
- [ ] `make quality` passes (full stack)
|
||||
- [ ] Proto regenerated correctly
|
||||
- [ ] No type errors (`basedpyright`)
|
||||
- [ ] No lint errors (`ruff`)
|
||||
|
||||
### Functional Validation
|
||||
|
||||
- [ ] CUDA path works (no regression)
|
||||
- [ ] CPU path works (no regression)
|
||||
- [ ] ROCm detection works
|
||||
- [ ] PyTorch fallback works
|
||||
- [ ] gRPC configuration works
|
||||
- [ ] Device switching works
|
||||
|
||||
### Documentation
|
||||
|
||||
- [ ] Sprint README complete
|
||||
- [ ] Implementation checklist complete
|
||||
- [ ] Installation guide complete
|
||||
- [ ] API documentation updated
|
||||
|
||||
---
|
||||
|
||||
## Notes
|
||||
|
||||
### Files Created
|
||||
|
||||
| File | Status |
|
||||
|------|--------|
|
||||
| `src/noteflow/domain/ports/gpu.py` | ❌ |
|
||||
| `src/noteflow/domain/ports/asr.py` | optional (only if relocating protocol) |
|
||||
| `src/noteflow/infrastructure/gpu/__init__.py` | ❌ |
|
||||
| `src/noteflow/infrastructure/gpu/detection.py` | ❌ |
|
||||
| `src/noteflow/infrastructure/asr/pytorch_engine.py` | ❌ |
|
||||
| `src/noteflow/infrastructure/asr/rocm_engine.py` | ❌ |
|
||||
| `src/noteflow/infrastructure/asr/factory.py` | ❌ |
|
||||
| `docker/Dockerfile.rocm` | ❌ |
|
||||
| `docs/installation/rocm.md` | ❌ |
|
||||
|
||||
### Files Modified
|
||||
|
||||
| File | Status |
|
||||
|------|--------|
|
||||
| `application/services/asr_config/types.py` | ❌ |
|
||||
| `application/services/asr_config/_engine_manager.py` | ❌ |
|
||||
| `infrastructure/diarization/engine/_device_mixin.py` | ❌ |
|
||||
| `infrastructure/metrics/system_resources.py` | ❌ |
|
||||
| `infrastructure/asr/engine.py` | ❌ |
|
||||
| `infrastructure/asr/protocols.py` | ❌ |
|
||||
| `grpc/proto/noteflow.proto` | ❌ |
|
||||
| `grpc/mixins/asr_config.py` | ❌ |
|
||||
| `config/settings/_features.py` | ❌ |
|
||||
| `pyproject.toml` | ❌ |
|
||||
@@ -0,0 +1,786 @@
|
||||
# Sprint 18.5: ROCm GPU Backend Support
|
||||
|
||||
> **Size**: L | **Owner**: Backend | **Prerequisites**: None
|
||||
> **Phase**: 5 - Platform Evolution
|
||||
|
||||
---
|
||||
|
||||
## Validation Status (2025-01-17)
|
||||
|
||||
### Research Complete — Implementation Ready
|
||||
|
||||
Note: Hardware/driver compatibility and ROCm wheel availability are time-sensitive.
|
||||
Re-verify against AMD ROCm compatibility matrices and PyTorch ROCm install guidance
|
||||
before implementation.
|
||||
|
||||
### Repo Alignment Notes (current tree)
|
||||
|
||||
- ASR protocol already exists at `src/noteflow/infrastructure/asr/protocols.py` and
|
||||
returns `AsrResult` from `src/noteflow/infrastructure/asr/dto.py`. Extend these
|
||||
instead of adding parallel `domain/ports/asr.py` types unless we plan a broader
|
||||
layering refactor.
|
||||
- gRPC mixins live under `src/noteflow/grpc/mixins/` (not `_mixins`).
|
||||
- Tests live under `tests/infrastructure/` and `tests/application/` (no `tests/unit/`).
|
||||
|
||||
| Prerequisite | Status | Impact |
|
||||
|--------------|--------|--------|
|
||||
| PyTorch ROCm support | ✅ Available | PyTorch HIP layer works with existing `torch.cuda` API |
|
||||
| CTranslate2 ROCm support | ⚠️ Community fork | No official support; requires alternative engine strategy |
|
||||
| pyannote.audio ROCm support | ✅ Available | Pure PyTorch, works out of box |
|
||||
| diart ROCm support | ✅ Available | Pure PyTorch, works out of box |
|
||||
|
||||
| Component | Status | Notes |
|
||||
|-----------|--------|-------|
|
||||
| Device abstraction layer | ❌ Not implemented | Need `GpuBackend` enum |
|
||||
| ASR engine protocol | ⚠️ Partial | AsrEngine exists; extend with device/compute metadata |
|
||||
| ROCm detection | ❌ Not implemented | Need `torch.version.hip` check |
|
||||
| gRPC proto updates | ❌ Not implemented | Need `ASR_DEVICE_ROCM` |
|
||||
| PyTorch Whisper fallback | ❌ Not implemented | Fallback for universal compatibility |
|
||||
|
||||
**Action required**: Implement device abstraction layer and engine protocol pattern.
|
||||
|
||||
---
|
||||
|
||||
## Objective
|
||||
|
||||
Enable NoteFlow backend to run on AMD GPUs via ROCm, providing GPU-accelerated ASR and diarization for users without NVIDIA hardware. This extends hardware support while maintaining full CUDA compatibility.
|
||||
|
||||
---
|
||||
|
||||
## Feasibility Analysis
|
||||
|
||||
### Component Compatibility Matrix
|
||||
|
||||
| Component | Library | CUDA | ROCm | Notes |
|
||||
|-----------|---------|------|------|-------|
|
||||
| ASR (Speech-to-Text) | faster-whisper (CTranslate2) | ✅ Official | ⚠️ Fork | [CTranslate2-ROCm](https://github.com/arlo-phoenix/CTranslate2-rocm) |
|
||||
| Diarization (Streaming) | diart (PyTorch) | ✅ | ✅* | HIP layer transparent; validate in our stack |
|
||||
| Diarization (Offline) | pyannote.audio (PyTorch) | ✅ | ✅ | [AMD tested](https://rocm.blogs.amd.com/artificial-intelligence/speech_models/README.html) |
|
||||
| NER | spaCy transformers (PyTorch) | ✅ | ✅ | HIP layer transparent |
|
||||
| Metrics/VRAM | torch.cuda.* | ✅ | ✅ | Works via HIP |
|
||||
|
||||
### Hardware Support (Source of Truth)
|
||||
|
||||
ROCm GPU support is version- and distro-specific. Use AMD's official compatibility
|
||||
matrices and the Radeon ROCm guide as the source of truth.
|
||||
|
||||
- Datacenter (Instinct/MI): Refer to the ROCm compatibility matrices for exact
|
||||
ROCm release + GPU + OS combinations.
|
||||
- Radeon/Workstation: Official Radeon ROCm support is narrower; check the latest
|
||||
Radeon ROCm guide for the exact RX series and Linux distro support. Older Radeon
|
||||
generations are often not listed and may require overrides or community builds.
|
||||
|
||||
### Key Technical Insights
|
||||
|
||||
1. **ROCm HIP Layer**: PyTorch's ROCm build [reuses the `torch.cuda` interface](https://docs.pytorch.org/docs/stable/notes/hip.html). Existing code using `torch.cuda.is_available()` and `torch.device("cuda")` works without modification (there is no `torch.device("rocm")`).
|
||||
|
||||
2. **CTranslate2 Blocker**: faster-whisper's [CTranslate2](https://github.com/OpenNMT/CTranslate2) dependency has no official ROCm support. Options:
|
||||
- Use [CTranslate2-ROCm fork](https://github.com/arlo-phoenix/CTranslate2-rocm) (community maintained)
|
||||
- Implement PyTorch Whisper fallback (official, slower)
|
||||
- Use [insanely-fast-whisper-rocm](https://github.com/beecave-homelab/insanely-fast-whisper-rocm) (community project)
|
||||
|
||||
3. **Detection Logic**: ROCm can be detected via:
|
||||
```python
|
||||
if torch.cuda.is_available() and hasattr(torch.version, 'hip') and torch.version.hip:
|
||||
# ROCm/HIP backend
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## What Already Exists
|
||||
|
||||
### Device Detection Infrastructure
|
||||
|
||||
| Asset | Location | Reuse Strategy |
|
||||
|-------|----------|----------------|
|
||||
| CUDA detection | `application/services/asr_config/_engine_manager.py:61-72` | Extend with ROCm check |
|
||||
| Device preference | `infrastructure/diarization/engine/_device_mixin.py:18-36` | Add ROCm to detection cascade |
|
||||
| Device types | `application/services/asr_config/types.py:22-27` | Add `ROCM` enum value |
|
||||
| Compute types | `application/services/asr_config/types.py:37-44` | Add ROCm compute types |
|
||||
| ASR protocol | `infrastructure/asr/protocols.py` | Extend existing `AsrEngine` (avoid duplicate protocol) |
|
||||
| ASR DTOs | `infrastructure/asr/dto.py` | Reuse `AsrResult`/`WordTiming` (avoid duplicate types) |
|
||||
|
||||
### ASR Engine Architecture
|
||||
|
||||
| Asset | Location | Reuse Strategy |
|
||||
|-------|----------|----------------|
|
||||
| FasterWhisperEngine | `infrastructure/asr/engine.py` | Keep as CUDA/CPU backend |
|
||||
| AsrEngine protocol (existing) | `infrastructure/asr/protocols.py` | Extend with device/compute metadata |
|
||||
| Engine manager | `application/services/asr_config/_engine_manager.py` | Extend for engine selection |
|
||||
|
||||
### gRPC Configuration
|
||||
|
||||
| Asset | Location | Reuse Strategy |
|
||||
|-------|----------|----------------|
|
||||
| ASR device enum | `grpc/proto/noteflow.proto` | Add `ASR_DEVICE_ROCM` |
|
||||
| ASR config messages | `grpc/proto/noteflow.proto` | Add `rocm_available` field |
|
||||
| Config RPC handlers | `grpc/mixins/asr_config.py` | Extend for ROCm |
|
||||
|
||||
---
|
||||
|
||||
## Architecture Decision: Engine Protocol Pattern
|
||||
|
||||
### Decision
|
||||
|
||||
Extend the existing `AsrEngine` protocol (or move it to `domain/ports` if we want a
|
||||
cleaner layering boundary) to abstract ASR backends and enable swappable implementations.
|
||||
|
||||
### Rationale
|
||||
|
||||
1. **Separation of Concerns**: Hardware-specific code isolated in engine implementations
|
||||
2. **Testability**: Mock engines for testing without GPU
|
||||
3. **Extensibility**: Future backends (TPU, Apple Neural Engine) follow same pattern
|
||||
4. **Gradual Migration**: Existing code continues working while new engines added
|
||||
|
||||
### Engine Hierarchy
|
||||
|
||||
```
|
||||
AsrEngine (Protocol)
|
||||
├── FasterWhisperEngine (CUDA/CPU) — existing, extend protocol
|
||||
├── FasterWhisperRocmEngine (ROCm) — new, uses CTranslate2-ROCm fork
|
||||
└── WhisperPyTorchEngine (Universal) — new, pure PyTorch fallback
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Scope
|
||||
|
||||
### Phase 1: Device Abstraction Layer (M)
|
||||
|
||||
| Task | Effort | Files |
|
||||
|------|--------|-------|
|
||||
| Add `GpuBackend` enum | S | `domain/ports/gpu.py` (new) |
|
||||
| Add ROCm detection utility | S | `infrastructure/gpu/detection.py` (new) |
|
||||
| Extend `AsrDevice` enum | S | `application/services/asr_config/types.py` |
|
||||
| Update device mixin | S | `infrastructure/diarization/engine/_device_mixin.py` |
|
||||
| Update gRPC proto | S | `grpc/proto/noteflow.proto` |
|
||||
|
||||
### Phase 2: ASR Engine Protocol (M)
|
||||
|
||||
| Task | Effort | Files |
|
||||
|------|--------|-------|
|
||||
| Extend `AsrEngine` protocol (or relocate) | M | `infrastructure/asr/protocols.py` (extend) or `domain/ports/asr.py` (move) |
|
||||
| Refactor `FasterWhisperEngine` to protocol | M | `infrastructure/asr/engine.py` |
|
||||
| Create `WhisperPyTorchEngine` fallback | M | `infrastructure/asr/pytorch_engine.py` (new) |
|
||||
| Create engine factory | S | `infrastructure/asr/factory.py` (new) |
|
||||
| Update engine manager | M | `application/services/asr_config/_engine_manager.py` |
|
||||
|
||||
### Phase 3: ROCm-Specific Engine (M)
|
||||
|
||||
| Task | Effort | Files |
|
||||
|------|--------|-------|
|
||||
| Create `FasterWhisperRocmEngine` | M | `infrastructure/asr/rocm_engine.py` (new) |
|
||||
| Add conditional CTranslate2-ROCm import | S | `infrastructure/asr/rocm_engine.py` |
|
||||
| Update engine factory for ROCm | S | `infrastructure/asr/factory.py` |
|
||||
| Add ROCm installation detection | S | `infrastructure/gpu/detection.py` |
|
||||
|
||||
### Phase 4: Configuration & Distribution (L)
|
||||
|
||||
| Task | Effort | Files |
|
||||
|------|--------|-------|
|
||||
| Add ROCm feature flag | S | `config/settings/_features.py` |
|
||||
| Update gRPC config handlers | M | `grpc/mixins/asr_config.py` |
|
||||
| Create ROCm Docker image | M | `docker/Dockerfile.rocm` (new) |
|
||||
| Update pyproject.toml extras | S | `pyproject.toml` |
|
||||
| Add ROCm installation docs | M | `docs/installation/rocm.md` (new) |
|
||||
|
||||
---
|
||||
|
||||
## Deliverables
|
||||
|
||||
### New Files
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `src/noteflow/domain/ports/gpu.py` | GPU backend types and detection protocol |
|
||||
| `src/noteflow/infrastructure/gpu/__init__.py` | GPU package init |
|
||||
| `src/noteflow/infrastructure/gpu/detection.py` | GPU detection utilities |
|
||||
| `src/noteflow/infrastructure/asr/pytorch_engine.py` | Pure PyTorch Whisper engine |
|
||||
| `src/noteflow/infrastructure/asr/rocm_engine.py` | ROCm-specific faster-whisper engine |
|
||||
| `src/noteflow/infrastructure/asr/factory.py` | Engine factory for backend selection |
|
||||
| `docker/Dockerfile.rocm` | ROCm-enabled Docker image |
|
||||
| `docs/installation/rocm.md` | ROCm setup documentation |
|
||||
|
||||
### Modified Files
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `application/services/asr_config/types.py` | Add `ROCM` to `AsrDevice`, ROCm compute types |
|
||||
| `application/services/asr_config/_engine_manager.py` | Use engine factory, extend detection |
|
||||
| `infrastructure/diarization/engine/_device_mixin.py` | Add ROCm to detection cascade |
|
||||
| `infrastructure/metrics/system_resources.py` | Handle ROCm VRAM queries |
|
||||
| `infrastructure/asr/protocols.py` | Extend `AsrEngine` with device/compute_type metadata |
|
||||
| `grpc/proto/noteflow.proto` | Add `ASR_DEVICE_ROCM`, `rocm_available`, `gpu_backend` |
|
||||
| `grpc/mixins/asr_config.py` | Map ROCm device and capabilities |
|
||||
| `config/settings/_features.py` | Add `NOTEFLOW_FEATURE_ROCM_ENABLED` |
|
||||
| `pyproject.toml` | Add `rocm` extras group |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### 1. GPU Detection Module
|
||||
|
||||
**File**: `src/noteflow/infrastructure/gpu/detection.py`
|
||||
|
||||
```python
|
||||
"""GPU backend detection utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from noteflow.domain.ports.gpu import GpuBackend, GpuInfo
|
||||
|
||||
|
||||
def detect_gpu_backend() -> GpuBackend:
|
||||
"""Detect the available GPU backend.
|
||||
|
||||
Returns:
|
||||
GpuBackend enum indicating the detected backend.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
return GpuBackend.NONE
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
return GpuBackend.MPS
|
||||
return GpuBackend.NONE
|
||||
|
||||
# Check if this is ROCm masquerading as CUDA
|
||||
if hasattr(torch.version, 'hip') and torch.version.hip:
|
||||
return GpuBackend.ROCM
|
||||
|
||||
return GpuBackend.CUDA
|
||||
|
||||
|
||||
def get_gpu_info() -> GpuInfo | None:
|
||||
"""Get detailed GPU information.
|
||||
|
||||
Returns:
|
||||
GpuInfo if a GPU is available, None otherwise.
|
||||
"""
|
||||
backend = detect_gpu_backend()
|
||||
|
||||
if backend == GpuBackend.NONE:
|
||||
return None
|
||||
|
||||
import torch
|
||||
|
||||
if backend in (GpuBackend.CUDA, GpuBackend.ROCM):
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
vram_mb = props.total_memory // (1024 * 1024)
|
||||
|
||||
driver_version = ""
|
||||
if backend == GpuBackend.ROCM:
|
||||
driver_version = getattr(torch.version, 'hip', 'unknown')
|
||||
else:
|
||||
driver_version = torch.version.cuda or 'unknown'
|
||||
|
||||
return GpuInfo(
|
||||
backend=backend,
|
||||
device_name=props.name,
|
||||
vram_total_mb=vram_mb,
|
||||
driver_version=driver_version,
|
||||
)
|
||||
|
||||
if backend == GpuBackend.MPS:
|
||||
return GpuInfo(
|
||||
backend=backend,
|
||||
device_name="Apple Metal",
|
||||
vram_total_mb=0, # MPS doesn't expose VRAM
|
||||
driver_version="mps",
|
||||
)
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### 2. ASR Engine Protocol
|
||||
|
||||
**File**: `src/noteflow/infrastructure/asr/protocols.py` (extend)
|
||||
|
||||
```python
|
||||
"""ASR protocols defining contracts for ASR components."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from noteflow.infrastructure.asr.dto import AsrResult
|
||||
|
||||
|
||||
class AsrEngine(Protocol):
|
||||
"""Protocol for ASR transcription engine implementations."""
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
"""Return the requested device ("cpu", "cuda", "rocm")."""
|
||||
...
|
||||
|
||||
@property
|
||||
def compute_type(self) -> str:
|
||||
"""Return the compute type (int8, float16, float32)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def model_size(self) -> str | None:
|
||||
"""Return the loaded model size, or None if not loaded."""
|
||||
...
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Return True if model is loaded."""
|
||||
...
|
||||
|
||||
def load_model(self, model_size: str = "base") -> None:
|
||||
"""Load the ASR model."""
|
||||
...
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio: NDArray[np.float32],
|
||||
language: str | None = None,
|
||||
) -> Iterator[AsrResult]:
|
||||
"""Transcribe audio data."""
|
||||
...
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model and free resources."""
|
||||
...
|
||||
|
||||
def transcribe_file(
|
||||
self,
|
||||
audio_path: Path,
|
||||
*,
|
||||
language: str | None = None,
|
||||
) -> Iterator[AsrResult]:
|
||||
"""Optional helper for file-based transcription."""
|
||||
...
|
||||
```
|
||||
|
||||
### 3. Engine Factory
|
||||
|
||||
**File**: `src/noteflow/infrastructure/asr/factory.py`
|
||||
|
||||
```python
|
||||
"""ASR engine factory for backend selection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from noteflow.infrastructure.gpu.detection import GpuBackend, detect_gpu_backend
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.infrastructure.asr.protocols import AsrEngine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EngineCreationError(Exception):
|
||||
"""Raised when engine creation fails."""
|
||||
|
||||
|
||||
def create_asr_engine(
|
||||
device: str = "auto",
|
||||
compute_type: str = "int8",
|
||||
*,
|
||||
prefer_faster_whisper: bool = True,
|
||||
) -> AsrEngine:
|
||||
"""Create an ASR engine for the specified device.
|
||||
|
||||
Args:
|
||||
device: Target device ("auto", "cpu", "cuda", "rocm").
|
||||
compute_type: Compute precision ("int8", "float16", "float32").
|
||||
prefer_faster_whisper: If True, prefer faster-whisper over PyTorch.
|
||||
|
||||
Returns:
|
||||
An ASR engine implementing AsrEngine.
|
||||
|
||||
Raises:
|
||||
EngineCreationError: If engine creation fails.
|
||||
"""
|
||||
resolved_device = _resolve_device(device)
|
||||
|
||||
logger.info(
|
||||
"Creating ASR engine",
|
||||
device=device,
|
||||
resolved_device=resolved_device,
|
||||
compute_type=compute_type,
|
||||
)
|
||||
|
||||
if resolved_device == "cpu":
|
||||
return _create_cpu_engine(compute_type)
|
||||
|
||||
if resolved_device == "cuda":
|
||||
return _create_cuda_engine(compute_type, prefer_faster_whisper)
|
||||
|
||||
if resolved_device == "rocm":
|
||||
return _create_rocm_engine(compute_type, prefer_faster_whisper)
|
||||
|
||||
msg = f"Unsupported device: {resolved_device}"
|
||||
raise EngineCreationError(msg)
|
||||
|
||||
|
||||
def _resolve_device(device: str) -> str:
|
||||
"""Resolve 'auto' device to actual backend."""
|
||||
if device != "auto":
|
||||
return device
|
||||
|
||||
backend = detect_gpu_backend()
|
||||
|
||||
if backend == GpuBackend.CUDA:
|
||||
return "cuda"
|
||||
if backend == GpuBackend.ROCM:
|
||||
return "rocm"
|
||||
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _create_cpu_engine(compute_type: str) -> AsrEngine:
|
||||
"""Create CPU engine (always uses faster-whisper)."""
|
||||
from noteflow.infrastructure.asr.engine import FasterWhisperEngine
|
||||
|
||||
# CPU only supports int8 and float32
|
||||
if compute_type == "float16":
|
||||
compute_type = "float32"
|
||||
|
||||
return FasterWhisperEngine(device="cpu", compute_type=compute_type)
|
||||
|
||||
|
||||
def _create_cuda_engine(
|
||||
compute_type: str,
|
||||
prefer_faster_whisper: bool,
|
||||
) -> AsrEngine:
|
||||
"""Create CUDA engine."""
|
||||
if prefer_faster_whisper:
|
||||
from noteflow.infrastructure.asr.engine import FasterWhisperEngine
|
||||
return FasterWhisperEngine(device="cuda", compute_type=compute_type)
|
||||
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
return WhisperPyTorchEngine(device="cuda", compute_type=compute_type)
|
||||
|
||||
|
||||
def _create_rocm_engine(
|
||||
compute_type: str,
|
||||
prefer_faster_whisper: bool,
|
||||
) -> AsrEngine:
|
||||
"""Create ROCm engine.
|
||||
|
||||
Attempts to use CTranslate2-ROCm fork if available,
|
||||
falls back to PyTorch Whisper otherwise.
|
||||
"""
|
||||
if prefer_faster_whisper:
|
||||
try:
|
||||
from noteflow.infrastructure.asr.rocm_engine import FasterWhisperRocmEngine
|
||||
return FasterWhisperRocmEngine(compute_type=compute_type)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"CTranslate2-ROCm not installed, falling back to PyTorch Whisper"
|
||||
)
|
||||
|
||||
from noteflow.infrastructure.asr.pytorch_engine import WhisperPyTorchEngine
|
||||
return WhisperPyTorchEngine(device="cuda", compute_type=compute_type)
|
||||
```
|
||||
|
||||
### 4. Extended Device Types
|
||||
|
||||
**File**: `src/noteflow/application/services/asr_config/types.py` (modifications)
|
||||
|
||||
```python
|
||||
class AsrDevice(str, Enum):
|
||||
"""Supported ASR devices."""
|
||||
|
||||
CPU = "cpu"
|
||||
CUDA = "cuda"
|
||||
ROCM = "rocm"
|
||||
|
||||
|
||||
DEVICE_COMPUTE_TYPES: Final[dict[AsrDevice, tuple[AsrComputeType, ...]]] = {
|
||||
AsrDevice.CPU: (AsrComputeType.INT8, AsrComputeType.FLOAT32),
|
||||
AsrDevice.CUDA: (
|
||||
AsrComputeType.INT8,
|
||||
AsrComputeType.FLOAT16,
|
||||
AsrComputeType.FLOAT32,
|
||||
),
|
||||
AsrDevice.ROCM: (
|
||||
AsrComputeType.INT8,
|
||||
AsrComputeType.FLOAT16,
|
||||
AsrComputeType.FLOAT32,
|
||||
),
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Updated Device Mixin
|
||||
|
||||
**File**: `src/noteflow/infrastructure/diarization/engine/_device_mixin.py` (modifications)
|
||||
|
||||
```python
|
||||
def _detect_available_device(self) -> str:
|
||||
"""Detect the best available device for computation."""
|
||||
if self._device_preference != "auto":
|
||||
return self._device_preference
|
||||
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# ROCm uses the same torch.cuda device string; keep "cuda" here
|
||||
# and expose backend (cuda vs rocm) via a separate detection helper.
|
||||
return "cuda"
|
||||
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
|
||||
return "cpu"
|
||||
```
|
||||
|
||||
### 6. Proto Updates
|
||||
|
||||
**File**: `src/noteflow/grpc/proto/noteflow.proto` (additions)
|
||||
|
||||
```protobuf
|
||||
enum AsrDevice {
|
||||
ASR_DEVICE_UNSPECIFIED = 0;
|
||||
ASR_DEVICE_CPU = 1;
|
||||
ASR_DEVICE_CUDA = 2;
|
||||
ASR_DEVICE_ROCM = 3; // NEW
|
||||
}
|
||||
|
||||
message AsrConfiguration {
|
||||
// ... existing fields ...
|
||||
bool rocm_available = 8; // NEW: ROCm/HIP detected
|
||||
string gpu_backend = 9; // NEW: "cuda", "rocm", "mps", or "none"
|
||||
}
|
||||
```
|
||||
|
||||
Note: `gpu_backend` reflects detected hardware/runtime, while `device` reflects the
|
||||
configured ASR target (which may still be "cuda" even on ROCm).
|
||||
|
||||
---
|
||||
|
||||
## Quality Gates
|
||||
|
||||
- [ ] `pytest tests/quality/` passes
|
||||
- [ ] `make quality-py` passes
|
||||
- [ ] Proto regenerated and stubs patched
|
||||
- [ ] ROCm detection works on AMD hardware
|
||||
- [ ] PyTorch Whisper fallback functional
|
||||
- [ ] CUDA path unchanged (no regression)
|
||||
- [ ] Docker ROCm image builds successfully
|
||||
- [ ] Meets `docs/sprints/QUALITY_STANDARDS.md`
|
||||
|
||||
---
|
||||
|
||||
## Test Plan
|
||||
|
||||
### Unit Tests
|
||||
|
||||
**File**: `tests/infrastructure/gpu/test_detection.py`
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from noteflow.infrastructure.gpu.detection import (
|
||||
GpuBackend,
|
||||
detect_gpu_backend,
|
||||
get_gpu_info,
|
||||
)
|
||||
|
||||
|
||||
class TestDetectGpuBackend:
|
||||
"""Tests for GPU backend detection."""
|
||||
|
||||
def test_no_torch_returns_none(self) -> None:
|
||||
"""Return NONE when torch not installed."""
|
||||
with patch.dict("sys.modules", {"torch": None}):
|
||||
# Force reimport
|
||||
from importlib import reload
|
||||
from noteflow.infrastructure.gpu import detection
|
||||
reload(detection)
|
||||
assert detection.detect_gpu_backend() == GpuBackend.NONE
|
||||
|
||||
def test_cuda_available_returns_cuda(self) -> None:
|
||||
"""Return CUDA when CUDA available."""
|
||||
mock_torch = MagicMock()
|
||||
mock_torch.cuda.is_available.return_value = True
|
||||
mock_torch.version.hip = None
|
||||
|
||||
with patch.dict("sys.modules", {"torch": mock_torch}):
|
||||
assert detect_gpu_backend() == GpuBackend.CUDA
|
||||
|
||||
def test_hip_available_returns_rocm(self) -> None:
|
||||
"""Return ROCM when HIP version present."""
|
||||
mock_torch = MagicMock()
|
||||
mock_torch.cuda.is_available.return_value = True
|
||||
mock_torch.version.hip = "6.0.0"
|
||||
|
||||
with patch.dict("sys.modules", {"torch": mock_torch}):
|
||||
assert detect_gpu_backend() == GpuBackend.ROCM
|
||||
|
||||
def test_mps_available_returns_mps(self) -> None:
|
||||
"""Return MPS on Apple Silicon."""
|
||||
mock_torch = MagicMock()
|
||||
mock_torch.cuda.is_available.return_value = False
|
||||
mock_torch.backends.mps.is_available.return_value = True
|
||||
|
||||
with patch.dict("sys.modules", {"torch": mock_torch}):
|
||||
assert detect_gpu_backend() == GpuBackend.MPS
|
||||
```
|
||||
|
||||
### Integration Tests
|
||||
|
||||
**File**: `tests/integration/test_engine_factory.py`
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from noteflow.infrastructure.asr.factory import create_asr_engine
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestEngineFactory:
|
||||
"""Integration tests for ASR engine factory."""
|
||||
|
||||
def test_auto_device_selects_available(self) -> None:
|
||||
"""Auto device resolves to available backend."""
|
||||
engine = create_asr_engine(device="auto")
|
||||
# Note: ROCm may still report "cuda" at the torch device level.
|
||||
assert engine.device in ("cpu", "cuda", "rocm")
|
||||
|
||||
def test_cpu_engine_always_works(self) -> None:
|
||||
"""CPU engine works regardless of GPU."""
|
||||
engine = create_asr_engine(device="cpu")
|
||||
assert engine.device == "cpu"
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _rocm_available(),
|
||||
reason="ROCm not available",
|
||||
)
|
||||
def test_rocm_engine_creation(self) -> None:
|
||||
"""ROCm engine creates on AMD hardware."""
|
||||
engine = create_asr_engine(device="rocm")
|
||||
assert engine.device == "rocm"
|
||||
|
||||
|
||||
def _rocm_available() -> bool:
|
||||
"""Check if ROCm is available for testing."""
|
||||
try:
|
||||
import torch
|
||||
return (
|
||||
torch.cuda.is_available()
|
||||
and hasattr(torch.version, 'hip')
|
||||
and torch.version.hip is not None
|
||||
)
|
||||
except ImportError:
|
||||
return False
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Failure Modes
|
||||
|
||||
| Failure | Detection | Recovery |
|
||||
|---------|-----------|----------|
|
||||
| CTranslate2-ROCm not installed | Import error in factory | Fall back to PyTorch Whisper |
|
||||
| ROCm driver version mismatch | HIP error at runtime | Log error, suggest driver update |
|
||||
| GPU OOM on large model | CUDA/HIP OOM error | Reduce model size or use CPU |
|
||||
| Unsupported GPU architecture | "invalid device function" | Fall back to CPU with warning |
|
||||
| PyTorch ROCm not installed | Import error | CPU-only mode |
|
||||
|
||||
---
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Python Packages
|
||||
|
||||
| Package | Version | Purpose | Notes |
|
||||
|---------|---------|---------|-------|
|
||||
| `torch` | >=2.0 | PyTorch with ROCm | Version must match ROCm runtime |
|
||||
| `torchaudio` | >=2.0 | Audio processing | Version must match ROCm runtime |
|
||||
| `faster-whisper` | >=1.0 | CUDA/CPU ASR | Existing |
|
||||
| `ctranslate2-rocm` | (fork) | ROCm ASR | Optional, install from GitHub fork |
|
||||
| `openai-whisper` | >=20231117 | PyTorch fallback | Optional (slower) |
|
||||
|
||||
### System Requirements
|
||||
|
||||
| Requirement | Minimum | Notes |
|
||||
|-------------|---------|-------|
|
||||
| ROCm | Per AMD docs | Must match PyTorch ROCm wheel version |
|
||||
| Linux kernel | Per AMD docs | Depends on ROCm release + distro |
|
||||
| AMD GPU driver | amdgpu | Included with ROCm packages |
|
||||
|
||||
---
|
||||
|
||||
## Installation Guide
|
||||
|
||||
### PyTorch ROCm Installation
|
||||
|
||||
```bash
|
||||
# Example (replace <rocm-version> with your installed ROCm version)
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm<rocm-version>
|
||||
```
|
||||
|
||||
### CTranslate2-ROCm Installation (Optional)
|
||||
|
||||
```bash
|
||||
# Install from fork for faster-whisper ROCm support
|
||||
pip install git+https://github.com/arlo-phoenix/CTranslate2-rocm.git
|
||||
pip install faster-whisper
|
||||
```
|
||||
|
||||
### Docker ROCm Image
|
||||
|
||||
```dockerfile
|
||||
# Example base image (pin to a ROCm/PyTorch combo that matches your host)
|
||||
FROM rocm/pytorch:rocm6.x_ubuntu22.04_py3.10_pytorch_2.x.x
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install NoteFlow
|
||||
COPY . .
|
||||
RUN pip install -e ".[rocm]"
|
||||
|
||||
CMD ["python", "-m", "noteflow.grpc.server"]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Post-Sprint
|
||||
|
||||
- ROCm performance benchmarking vs CUDA
|
||||
- Multi-GPU support (device index selection)
|
||||
- ROCm-specific memory optimization
|
||||
- Apple Neural Engine support investigation
|
||||
- Windows ROCm support (if/when officially available)
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
### Official Documentation
|
||||
|
||||
- [PyTorch HIP Semantics](https://docs.pytorch.org/docs/stable/notes/hip.html)
|
||||
- [ROCm Compatibility Matrix](https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html)
|
||||
- [AMD ROCm PyTorch Installation](https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/install/installrad/native_linux/install-pytorch.html)
|
||||
|
||||
### AMD Blog Posts
|
||||
|
||||
- [Speech-to-Text on AMD GPU with Whisper](https://rocm.blogs.amd.com/artificial-intelligence/whisper/README.html)
|
||||
- [CTranslate2 on AMD GPUs](https://rocm.blogs.amd.com/artificial-intelligence/ctranslate2/README.html)
|
||||
- [Fine-tuning Speech Models on ROCm](https://rocm.blogs.amd.com/artificial-intelligence/speech_models/README.html)
|
||||
|
||||
### Community Projects
|
||||
|
||||
- [CTranslate2-ROCm Fork](https://github.com/arlo-phoenix/CTranslate2-rocm)
|
||||
- [insanely-fast-whisper-rocm](https://github.com/beecave-homelab/insanely-fast-whisper-rocm)
|
||||
- [wyoming-faster-whisper-rocm](https://github.com/Donkey545/wyoming-faster-whisper-rocm)
|
||||
- [ROCM SDK Builder](https://github.com/lamikr/rocm_sdk_builder)
|
||||
@@ -0,0 +1,528 @@
|
||||
# ROCm Support: Technical Deep Dive
|
||||
|
||||
This document provides detailed technical analysis of the ROCm integration, including architecture decisions, code patterns, and potential challenges.
|
||||
|
||||
Alignment note: the current codebase already defines `AsrEngine` in
|
||||
`src/noteflow/infrastructure/asr/protocols.py` and uses `AsrResult` from
|
||||
`src/noteflow/infrastructure/asr/dto.py`. Prefer extending those instead of
|
||||
adding parallel protocol/DTO types unless we plan a broader refactor.
|
||||
|
||||
---
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [ROCm/HIP Architecture Overview](#rocmhip-architecture-overview)
|
||||
2. [PyTorch HIP Transparency](#pytorch-hip-transparency)
|
||||
3. [CTranslate2 Challenge](#ctranslate2-challenge)
|
||||
4. [Engine Protocol Design](#engine-protocol-design)
|
||||
5. [Detection Strategy](#detection-strategy)
|
||||
6. [Memory Management](#memory-management)
|
||||
7. [Performance Considerations](#performance-considerations)
|
||||
8. [Migration Path](#migration-path)
|
||||
|
||||
---
|
||||
|
||||
## ROCm/HIP Architecture Overview
|
||||
|
||||
### What is ROCm?
|
||||
|
||||
ROCm (Radeon Open Compute) is AMD's open-source software platform for GPU computing. It includes:
|
||||
|
||||
- **HIP (Heterogeneous-compute Interface for Portability)**: A C++ runtime API and kernel language that allows developers to write portable code for AMD and NVIDIA GPUs.
|
||||
- **rocBLAS/rocFFT/etc.**: Optimized math libraries equivalent to cuBLAS/cuFFT.
|
||||
- **MIOpen**: Deep learning primitives library (equivalent to cuDNN).
|
||||
|
||||
### HIP Compilation Modes
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ HIP Source Code │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
┌───────────────┴───────────────┐
|
||||
▼ ▼
|
||||
┌─────────────────────────┐ ┌─────────────────────────┐
|
||||
│ hipcc (AMD path) │ │ hipcc (NVIDIA path) │
|
||||
│ Compiles to ROCm │ │ Compiles to CUDA │
|
||||
└─────────────────────────┘ └─────────────────────────┘
|
||||
│ │
|
||||
▼ ▼
|
||||
┌─────────────────────────┐ ┌─────────────────────────┐
|
||||
│ AMD GPU binary │ │ NVIDIA GPU binary │
|
||||
│ (runs on Radeon/MI) │ │ (runs on GeForce/A) │
|
||||
└─────────────────────────┘ └─────────────────────────┘
|
||||
```
|
||||
|
||||
### Key Insight for NoteFlow
|
||||
|
||||
PyTorch's ROCm build uses HIP to expose the same `torch.cuda.*` API. From Python's perspective, the code is identical:
|
||||
|
||||
```python
|
||||
# This works on both CUDA and ROCm PyTorch builds
|
||||
device = torch.device("cuda")
|
||||
tensor = torch.randn(1000, 1000, device=device)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## PyTorch HIP Transparency
|
||||
|
||||
### How PyTorch Achieves Compatibility
|
||||
|
||||
PyTorch's ROCm build maps CUDA APIs to HIP:
|
||||
|
||||
| CUDA API | HIP API (internal) | Python API (unchanged) |
|
||||
|----------|-------------------|------------------------|
|
||||
| `cudaMalloc` | `hipMalloc` | `torch.cuda.memory_allocated()` |
|
||||
| `cudaMemcpy` | `hipMemcpy` | `tensor.cuda()` |
|
||||
| `cudaDeviceGetAttribute` | `hipDeviceGetAttribute` | `torch.cuda.get_device_properties()` |
|
||||
| `cudaGetDeviceCount` | `hipGetDeviceCount` | `torch.cuda.device_count()` |
|
||||
|
||||
### Detection via torch.version.hip
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Standard CUDA check (works on both CUDA and ROCm)
|
||||
if torch.cuda.is_available():
|
||||
# Distinguish between CUDA and ROCm
|
||||
if hasattr(torch.version, 'hip') and torch.version.hip:
|
||||
print(f"ROCm {torch.version.hip}")
|
||||
backend = "rocm"
|
||||
else:
|
||||
print(f"CUDA {torch.version.cuda}")
|
||||
backend = "cuda"
|
||||
```
|
||||
|
||||
### Implications for NoteFlow
|
||||
|
||||
**Components that work unchanged:**
|
||||
|
||||
| Component | Why It Works |
|
||||
|-----------|--------------|
|
||||
| pyannote.audio | Pure PyTorch, uses `torch.cuda` API |
|
||||
| diart | Pure PyTorch, uses `torch.cuda` API |
|
||||
| spaCy transformers | Uses PyTorch backend |
|
||||
| Diarization engine | Uses `torch.device("cuda")` |
|
||||
| VRAM monitoring | Uses `torch.cuda.mem_get_info()` |
|
||||
|
||||
**Components that need modification:**
|
||||
|
||||
| Component | Issue | Solution |
|
||||
|-----------|-------|----------|
|
||||
| faster-whisper | Uses CTranslate2, no ROCm | Engine abstraction + fallback |
|
||||
| Device detection | Doesn't distinguish ROCm | Add `torch.version.hip` check |
|
||||
| gRPC config | No ROCM device type | Add to proto enum |
|
||||
|
||||
---
|
||||
|
||||
## CTranslate2 Challenge
|
||||
|
||||
### Why CTranslate2 Doesn't Support ROCm
|
||||
|
||||
CTranslate2 is a C++ inference engine optimized for transformer models. Its CUDA support is:
|
||||
|
||||
1. **Compiled against CUDA toolkit**: Uses cuBLAS, cuDNN directly
|
||||
2. **No HIP port**: Would require significant rewrite
|
||||
3. **Performance-critical paths**: Hand-optimized CUDA kernels
|
||||
|
||||
### Community Fork Analysis
|
||||
|
||||
The [CTranslate2-ROCm fork](https://github.com/arlo-phoenix/CTranslate2-rocm) uses hipify to convert CUDA code:
|
||||
|
||||
```bash
|
||||
# How the fork was created (simplified)
|
||||
hipify-perl -inplace -print-stats cuda_file.cu # Convert CUDA to HIP
|
||||
hipcc -o output hip_file.cpp # Compile with hipcc
|
||||
```
|
||||
|
||||
**Fork Status (as of research):**
|
||||
|
||||
| Aspect | Status |
|
||||
|--------|--------|
|
||||
| Maintenance | Active, community-driven |
|
||||
| ROCm versions | 5.7 - 6.x tested |
|
||||
| Performance | Reported faster than whisper.cpp (verify on our hardware) |
|
||||
| Whisper support | Full (via faster-whisper) |
|
||||
| Stability | Community reports vary; validate for our workloads |
|
||||
|
||||
### Why We Need the Protocol Pattern
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ AsrEngine │
|
||||
│ - transcribe(audio) -> segments │
|
||||
│ - load_model(size) │
|
||||
│ - device, compute_type, is_loaded │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
┌───────────────────┼───────────────────┐
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│FasterWhisper │ │FasterWhisper │ │WhisperPyTorch │
|
||||
│Engine │ │RocmEngine │ │Engine │
|
||||
│(CUDA/CPU) │ │(ROCm fork) │ │(Universal) │
|
||||
│ │ │ │ │ │
|
||||
│Uses: ctranslate2│ │Uses: ctranslate2│ │Uses: openai- │
|
||||
│ (official) │ │ -rocm │ │ whisper │
|
||||
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
NVIDIA GPU AMD GPU Any PyTorch
|
||||
Device
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Engine Protocol Design
|
||||
|
||||
### Protocol Definition Rationale
|
||||
|
||||
The `AsrEngine` protocol uses Python's `Protocol` for structural subtyping:
|
||||
|
||||
```python
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
@runtime_checkable
|
||||
class AsrEngine(Protocol):
|
||||
"""All ASR engines must implement this interface."""
|
||||
|
||||
@property
|
||||
def device(self) -> str: ...
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool: ...
|
||||
|
||||
def transcribe(self, audio: NDArray[np.float32]) -> Iterator[AsrResult]: ...
|
||||
```
|
||||
|
||||
**Why Protocol over ABC?**
|
||||
|
||||
| Approach | Pros | Cons |
|
||||
|----------|------|------|
|
||||
| `Protocol` | Structural typing, no inheritance required | Harder to enforce at definition time |
|
||||
| `ABC` | Explicit contract, IDE support | Requires inheritance, tighter coupling |
|
||||
|
||||
We choose `Protocol` because:
|
||||
1. `FasterWhisperEngine` already exists without base class
|
||||
2. Third-party engines can implement without importing our types
|
||||
3. `@runtime_checkable` enables `isinstance()` checks
|
||||
|
||||
### Factory Pattern
|
||||
|
||||
```python
|
||||
def create_asr_engine(device: str = "auto") -> AsrEngine:
|
||||
"""Factory function for ASR engine creation.
|
||||
|
||||
Encapsulates:
|
||||
1. Device detection (CUDA vs ROCm vs CPU)
|
||||
2. Library availability (CTranslate2-ROCm installed?)
|
||||
3. Fallback logic (PyTorch Whisper if no native engine)
|
||||
"""
|
||||
```
|
||||
|
||||
**Why factory function over class?**
|
||||
|
||||
- Simpler API: `create_asr_engine("auto")` vs `AsrEngineFactory().create("auto")`
|
||||
- No state needed between creations
|
||||
- Easier to test with mocks
|
||||
|
||||
---
|
||||
|
||||
## Detection Strategy
|
||||
|
||||
### Multi-Layer Detection
|
||||
|
||||
```python
|
||||
def get_full_device_info() -> DeviceInfo:
|
||||
"""Complete device detection with all relevant information."""
|
||||
|
||||
# Layer 1: GPU Backend
|
||||
backend = detect_gpu_backend() # CUDA, ROCm, MPS, or NONE
|
||||
|
||||
# Layer 2: Library Availability
|
||||
has_ctranslate2 = _check_ctranslate2()
|
||||
has_ctranslate2_rocm = _check_ctranslate2_rocm()
|
||||
has_openai_whisper = _check_openai_whisper()
|
||||
|
||||
# Layer 3: Hardware Capabilities
|
||||
if backend in (GpuBackend.CUDA, GpuBackend.ROCM):
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
compute_capability = (props.major, props.minor)
|
||||
vram_gb = props.total_memory / (1024**3)
|
||||
|
||||
return DeviceInfo(
|
||||
backend=backend,
|
||||
has_ctranslate2=has_ctranslate2,
|
||||
has_ctranslate2_rocm=has_ctranslate2_rocm,
|
||||
has_openai_whisper=has_openai_whisper,
|
||||
compute_capability=compute_capability,
|
||||
vram_gb=vram_gb,
|
||||
)
|
||||
```
|
||||
|
||||
### GPU Architecture Detection for ROCm
|
||||
|
||||
AMD GPU architectures use "gfx" naming:
|
||||
|
||||
| Architecture | gfx ID | Example GPUs |
|
||||
|--------------|--------|--------------|
|
||||
| RDNA 3 | gfx1100, gfx1101, gfx1102 | RX 7900 XTX, RX 7600 |
|
||||
| RDNA 2 | gfx1030, gfx1031, gfx1032 | RX 6800 XT, RX 6600 |
|
||||
| CDNA 2 | gfx90a | MI210, MI250 |
|
||||
| CDNA 3 | gfx942 | MI300X |
|
||||
|
||||
```python
|
||||
def get_amd_gpu_architecture() -> str | None:
|
||||
"""Get AMD GPU architecture string."""
|
||||
if not _is_rocm():
|
||||
return None
|
||||
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
# props.name on ROCm returns architecture, e.g., "gfx1100"
|
||||
return props.name
|
||||
```
|
||||
|
||||
### Handling Unsupported GPUs
|
||||
|
||||
ROCm builds ship with kernels for specific architectures. If your GPU's gfx ID isn't supported:
|
||||
|
||||
```
|
||||
RuntimeError: HIP error: invalid device function
|
||||
```
|
||||
|
||||
**Detection strategy:**
|
||||
|
||||
```python
|
||||
OFFICIALLY_SUPPORTED_GFX = {
|
||||
"gfx900", "gfx906", "gfx908", "gfx90a", # MI series
|
||||
"gfx1030", "gfx1100", "gfx1101", "gfx1102", # Consumer RDNA
|
||||
}
|
||||
|
||||
def check_rocm_support(gfx_id: str) -> tuple[bool, str]:
|
||||
"""Check if GPU architecture is officially supported."""
|
||||
if gfx_id in OFFICIALLY_SUPPORTED_GFX:
|
||||
return True, "Officially supported"
|
||||
|
||||
# Check for compatible override (HSA_OVERRIDE_GFX_VERSION)
|
||||
if os.environ.get("HSA_OVERRIDE_GFX_VERSION"):
|
||||
return True, "Override enabled (may be unstable)"
|
||||
|
||||
return False, f"Architecture {gfx_id} not officially supported"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Memory Management
|
||||
|
||||
### VRAM Monitoring
|
||||
|
||||
ROCm exposes memory through the same PyTorch API:
|
||||
|
||||
```python
|
||||
def get_gpu_memory() -> tuple[int, int]:
|
||||
"""Get (free, total) GPU memory in bytes."""
|
||||
if not torch.cuda.is_available():
|
||||
return (0, 0)
|
||||
|
||||
# Works identically for CUDA and ROCm
|
||||
free, total = torch.cuda.mem_get_info(device=0)
|
||||
return (free, total)
|
||||
```
|
||||
|
||||
### Memory Pools
|
||||
|
||||
ROCm's memory allocator differs from CUDA:
|
||||
|
||||
| Aspect | CUDA | ROCm |
|
||||
|--------|------|------|
|
||||
| Allocator | cudaMallocAsync (11.2+) | hipMallocAsync (5.2+) |
|
||||
| Memory pools | Yes | Yes (but different behavior) |
|
||||
| Unified memory | Full support | Limited support |
|
||||
|
||||
**Recommendation:** Use PyTorch's caching allocator abstraction rather than direct HIP calls.
|
||||
|
||||
### Model Memory Requirements
|
||||
|
||||
Approximate figures; validate with our model sizes, batch sizes, and backend.
|
||||
|
||||
| Model | VRAM (float16) | VRAM (int8) |
|
||||
|-------|----------------|-------------|
|
||||
| tiny | ~1 GB | ~0.5 GB |
|
||||
| base | ~1.5 GB | ~0.8 GB |
|
||||
| small | ~2 GB | ~1 GB |
|
||||
| medium | ~5 GB | ~2.5 GB |
|
||||
| large-v3 | ~10 GB | ~5 GB |
|
||||
|
||||
---
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Expected Performance Comparison
|
||||
|
||||
Based on community benchmarks (directional only; verify locally):
|
||||
|
||||
| Configuration | Relative Speed | Notes |
|
||||
|---------------|----------------|-------|
|
||||
| CUDA + CTranslate2 | 1.0x (baseline) | Fastest |
|
||||
| ROCm + CTranslate2-fork | 0.9x | ~10% slower |
|
||||
| ROCm + PyTorch Whisper | 0.5x | Pure PyTorch overhead |
|
||||
| CPU + CTranslate2 (int8) | 0.3x | Reasonable for short audio |
|
||||
|
||||
### ROCm-Specific Optimizations
|
||||
|
||||
1. **Kernel Compilation Cache:**
|
||||
```bash
|
||||
export MIOPEN_USER_DB_PATH=/path/to/cache
|
||||
export MIOPEN_FIND_MODE=3 # Fast compilation
|
||||
```
|
||||
|
||||
2. **Memory Optimization:**
|
||||
```python
|
||||
# Force garbage collection between transcriptions
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
```
|
||||
|
||||
3. **Compute Type Selection:**
|
||||
- Use `float16` on MI-series (native support)
|
||||
- Use `float32` on RDNA if `float16` shows artifacts
|
||||
|
||||
### Batch Processing Recommendations
|
||||
|
||||
Example values only; tune for each GPU and model size.
|
||||
|
||||
```python
|
||||
# ROCm benefits from larger batches due to kernel launch overhead
|
||||
BATCH_SIZES = {
|
||||
"cuda": 16,
|
||||
"rocm": 32, # Larger batch amortizes launch cost
|
||||
"cpu": 8,
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Migration Path
|
||||
|
||||
### Step 1: Non-Breaking Detection (v1)
|
||||
|
||||
Add ROCm detection without changing behavior:
|
||||
|
||||
```python
|
||||
# In _engine_manager.py
|
||||
def detect_gpu_backend(self) -> str:
|
||||
"""Detect GPU backend for logging/telemetry."""
|
||||
if not torch.cuda.is_available():
|
||||
return "none"
|
||||
if hasattr(torch.version, 'hip') and torch.version.hip:
|
||||
return "rocm"
|
||||
return "cuda"
|
||||
```
|
||||
|
||||
### Step 2: Feature-Flagged Engine Selection (v2)
|
||||
|
||||
```python
|
||||
# Behind NOTEFLOW_FEATURE_ROCM_ENABLED flag
|
||||
if settings.feature_flags.rocm_enabled:
|
||||
engine = create_asr_engine(device="auto") # Uses factory
|
||||
else:
|
||||
engine = FasterWhisperEngine(device=device) # Legacy path
|
||||
```
|
||||
|
||||
### Step 3: Full Integration (v3)
|
||||
|
||||
- Remove feature flag, factory becomes default
|
||||
- Add ROCm to gRPC configuration
|
||||
- Document ROCm installation
|
||||
|
||||
### Rollback Strategy
|
||||
|
||||
Each phase can be disabled independently:
|
||||
|
||||
```python
|
||||
# Emergency rollback
|
||||
NOTEFLOW_FEATURE_ROCM_ENABLED=false # Disables ROCm path
|
||||
NOTEFLOW_ASR_DEVICE=cpu # Forces CPU
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Unit Tests (No GPU Required)
|
||||
|
||||
```python
|
||||
class TestGpuDetection:
|
||||
"""Mock-based tests for detection logic."""
|
||||
|
||||
def test_rocm_detection_with_hip_version(self, mock_torch):
|
||||
mock_torch.version.hip = "6.0.0"
|
||||
mock_torch.cuda.is_available.return_value = True
|
||||
assert detect_gpu_backend() == GpuBackend.ROCM
|
||||
```
|
||||
|
||||
### Integration Tests (GPU Required)
|
||||
|
||||
```python
|
||||
@pytest.mark.skipif(
|
||||
not _rocm_available(),
|
||||
reason="ROCm not available"
|
||||
)
|
||||
class TestRocmEngine:
|
||||
"""Tests that require actual ROCm hardware."""
|
||||
|
||||
def test_model_loading(self):
|
||||
engine = create_asr_engine(device="rocm")
|
||||
engine.load_model("tiny")
|
||||
assert engine.is_loaded
|
||||
```
|
||||
|
||||
### CI/CD Considerations
|
||||
|
||||
```yaml
|
||||
# GitHub Actions example
|
||||
jobs:
|
||||
test-rocm:
|
||||
runs-on: [self-hosted, rocm] # Custom runner with AMD GPU
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- run: pytest -m rocm tests/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting Guide
|
||||
|
||||
### Common Issues
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `HIP error: invalid device function` | GPU architecture not supported | Set `HSA_OVERRIDE_GFX_VERSION` or use CPU |
|
||||
| `ImportError: libamdhip64.so` | ROCm not installed | Install ROCm or use CPU |
|
||||
| `CUDA out of memory` (on ROCm) | VRAM exhausted | Reduce model size or batch |
|
||||
| Model loads but transcription fails | Driver mismatch | Update ROCm to match PyTorch version |
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Debug ROCm issues
|
||||
export AMD_LOG_LEVEL=3 # Verbose logging
|
||||
export HIP_VISIBLE_DEVICES=0 # Limit to first GPU
|
||||
export HSA_OVERRIDE_GFX_VERSION=11.0.0 # Override gfx check
|
||||
|
||||
# Performance tuning
|
||||
export MIOPEN_FIND_MODE=3 # Fast kernel selection
|
||||
export MIOPEN_USER_DB_PATH=/cache # Kernel cache location
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
- [PyTorch HIP Documentation](https://docs.pytorch.org/docs/stable/notes/hip.html)
|
||||
- [ROCm GitHub Repository](https://github.com/RadeonOpenCompute/ROCm)
|
||||
- [CTranslate2 GitHub](https://github.com/OpenNMT/CTranslate2)
|
||||
- [CTranslate2-ROCm Fork](https://github.com/arlo-phoenix/CTranslate2-rocm)
|
||||
- [AMD ROCm Documentation](https://rocm.docs.amd.com/)
|
||||
Reference in New Issue
Block a user