f
This commit is contained in:
216
hooks/README.md
216
hooks/README.md
@@ -1,216 +0,0 @@
|
||||
# Claude Code Quality Guard Hook
|
||||
|
||||
A comprehensive code quality enforcement system for Claude Code that prevents writing duplicate, complex, or non-modernized Python code.
|
||||
|
||||
## Features
|
||||
|
||||
### PreToolUse Analysis
|
||||
Analyzes code **before** it's written to prevent quality issues:
|
||||
- **Internal Duplicate Detection**: Detects duplicate code blocks within the same file using AST analysis
|
||||
- **Complexity Analysis**: Measures cyclomatic complexity and flags overly complex functions
|
||||
- **Modernization Checks**: Identifies outdated Python patterns and missing type hints
|
||||
- **Configurable Enforcement**: Strict (deny), Warn (ask), or Permissive (allow with warning) modes
|
||||
|
||||
### PostToolUse Verification
|
||||
Verifies code **after** it's written to track quality:
|
||||
- **State Tracking**: Detects quality degradation between edits
|
||||
- **Cross-File Duplicates**: Finds duplicates across the entire codebase
|
||||
- **Naming Conventions**: Verifies PEP8 naming standards
|
||||
- **Success Feedback**: Optional success messages for clean code
|
||||
|
||||
## Installation
|
||||
|
||||
### Global Setup (Recommended)
|
||||
Run the setup script to install the hook globally for all projects in `~/repos`:
|
||||
|
||||
```bash
|
||||
cd ~/repos/claude-scripts
|
||||
./setup_global_hook.sh
|
||||
```
|
||||
|
||||
This creates:
|
||||
- Global Claude Code configuration at `~/.claude/claude-code-settings.json`
|
||||
- Configuration helper at `~/.claude/configure-quality.sh`
|
||||
- Convenience alias `claude-quality` in your shell
|
||||
|
||||
### Quick Configuration
|
||||
After installation, use the `claude-quality` command:
|
||||
|
||||
```bash
|
||||
# Apply presets
|
||||
claude-quality strict # Strict enforcement
|
||||
claude-quality moderate # Moderate with warnings
|
||||
claude-quality permissive # Permissive suggestions
|
||||
claude-quality disabled # Disable all checks
|
||||
|
||||
# Check current settings
|
||||
claude-quality status
|
||||
```
|
||||
|
||||
### Per-Project Setup
|
||||
Alternatively, copy the configuration to a specific project:
|
||||
|
||||
```bash
|
||||
cp hooks/claude-code-settings.json /path/to/project/
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `QUALITY_ENFORCEMENT` | Mode: strict/warn/permissive | strict |
|
||||
| `QUALITY_COMPLEXITY_THRESHOLD` | Max cyclomatic complexity | 10 |
|
||||
| `QUALITY_DUP_THRESHOLD` | Duplicate similarity (0-1) | 0.7 |
|
||||
| `QUALITY_DUP_ENABLED` | Enable duplicate detection | true |
|
||||
| `QUALITY_COMPLEXITY_ENABLED` | Enable complexity checks | true |
|
||||
| `QUALITY_MODERN_ENABLED` | Enable modernization | true |
|
||||
| `QUALITY_TYPE_HINTS` | Require type hints | false |
|
||||
| `QUALITY_STATE_TRACKING` | Track file changes | true |
|
||||
| `QUALITY_CROSS_FILE_CHECK` | Cross-file duplicates | true |
|
||||
| `QUALITY_VERIFY_NAMING` | Check PEP8 naming | true |
|
||||
| `QUALITY_SHOW_SUCCESS` | Show success messages | false |
|
||||
|
||||
### Per-Project Overrides
|
||||
Create a `.quality.env` file in your project root:
|
||||
|
||||
```bash
|
||||
# .quality.env
|
||||
QUALITY_ENFORCEMENT=moderate
|
||||
QUALITY_COMPLEXITY_THRESHOLD=15
|
||||
QUALITY_TYPE_HINTS=true
|
||||
```
|
||||
|
||||
Then source it: `source .quality.env`
|
||||
|
||||
## How It Works
|
||||
|
||||
### Internal Duplicate Detection
|
||||
The hook uses AST analysis to detect three types of duplicates within files:
|
||||
1. **Exact Duplicates**: Identical code blocks
|
||||
2. **Structural Duplicates**: Same AST structure, different names
|
||||
3. **Semantic Duplicates**: Similar logic patterns
|
||||
|
||||
### Enforcement Modes
|
||||
- **Strict**: Blocks (denies) code that fails quality checks
|
||||
- **Warn**: Asks for user confirmation on quality issues
|
||||
- **Permissive**: Allows code but shows warnings
|
||||
|
||||
### State Tracking
|
||||
Tracks quality metrics between edits to detect:
|
||||
- Reduction in functions/classes
|
||||
- Significant file size increases
|
||||
- Quality degradation trends
|
||||
|
||||
## Testing
|
||||
|
||||
The hook comes with a comprehensive test suite:
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest tests/hooks/
|
||||
|
||||
# Run specific test modules
|
||||
pytest tests/hooks/test_pretooluse.py
|
||||
pytest tests/hooks/test_posttooluse.py
|
||||
pytest tests/hooks/test_edge_cases.py
|
||||
pytest tests/hooks/test_integration.py
|
||||
|
||||
# Run with coverage
|
||||
pytest tests/hooks/ --cov=hooks
|
||||
```
|
||||
|
||||
### Test Coverage
|
||||
- 90 tests covering all functionality
|
||||
- Edge cases and error handling
|
||||
- Integration testing with Claude Code
|
||||
- Concurrent access and thread safety
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
code_quality_guard.py # Main hook implementation
|
||||
├── QualityConfig # Configuration management
|
||||
├── pretooluse_hook() # Pre-write analysis
|
||||
├── posttooluse_hook() # Post-write verification
|
||||
└── analyze_code_quality() # Quality analysis engine
|
||||
|
||||
internal_duplicate_detector.py # AST-based duplicate detection
|
||||
├── InternalDuplicateDetector # Main detector class
|
||||
├── extract_code_blocks() # AST traversal
|
||||
└── find_duplicates() # Similarity algorithms
|
||||
|
||||
claude-code-settings.json # Hook configuration
|
||||
└── Maps both hooks to same script
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Detecting Internal Duplicates
|
||||
```python
|
||||
# This would be flagged as duplicate
|
||||
def calculate_tax(amount):
|
||||
tax = amount * 0.1
|
||||
total = amount + tax
|
||||
return total
|
||||
|
||||
def calculate_fee(amount): # Duplicate!
|
||||
fee = amount * 0.1
|
||||
total = amount + fee
|
||||
return total
|
||||
```
|
||||
|
||||
### Complexity Issues
|
||||
```python
|
||||
# This would be flagged as too complex (CC > 10)
|
||||
def process_data(data):
|
||||
if data:
|
||||
if data.type == 'A':
|
||||
if data.value > 100:
|
||||
# ... nested logic
|
||||
```
|
||||
|
||||
### Modernization Suggestions
|
||||
```python
|
||||
# Outdated patterns that would be flagged
|
||||
d = dict() # Use {} instead
|
||||
if x == None: # Use 'is None'
|
||||
for i in range(len(items)): # Use enumerate
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Hook Not Working
|
||||
1. Verify installation: `ls ~/.claude/claude-code-settings.json`
|
||||
2. Check Python: `python --version` (requires 3.8+)
|
||||
3. Test directly: `echo '{"tool_name":"Read"}' | python hooks/code_quality_guard.py`
|
||||
4. Check claude-quality binary: `which claude-quality`
|
||||
|
||||
### False Positives
|
||||
- Adjust thresholds via environment variables
|
||||
- Use `.quality-exceptions.yaml` for suppressions
|
||||
- Switch to permissive mode for legacy code
|
||||
|
||||
### Performance Issues
|
||||
- Disable cross-file checks: `QUALITY_CROSS_FILE_CHECK=false`
|
||||
- Increase thresholds for large files
|
||||
- Use skip patterns for generated code
|
||||
|
||||
## Development
|
||||
|
||||
### Adding New Checks
|
||||
1. Add analysis logic to `analyze_code_quality()`
|
||||
2. Add issue detection to `check_code_issues()`
|
||||
3. Add configuration to `QualityConfig`
|
||||
4. Add tests to appropriate test module
|
||||
|
||||
### Contributing
|
||||
1. Run tests: `pytest tests/hooks/`
|
||||
2. Check types: `mypy hooks/`
|
||||
3. Format code: `ruff format hooks/`
|
||||
4. Submit PR with tests
|
||||
|
||||
## License
|
||||
|
||||
Part of the Claude Scripts project. See main LICENSE file.
|
||||
@@ -1,335 +0,0 @@
|
||||
# Claude Code Quality Hooks
|
||||
|
||||
Comprehensive quality hooks for Claude Code supporting both PreToolUse (preventive) and PostToolUse (verification) stages to ensure high-quality Python code.
|
||||
|
||||
## Features
|
||||
|
||||
### PreToolUse (Preventive)
|
||||
- **Internal Duplicate Detection**: Analyzes code blocks within the same file
|
||||
- **Complexity Analysis**: Prevents functions with excessive cyclomatic complexity
|
||||
- **Modernization Checks**: Ensures code uses modern Python patterns and type hints
|
||||
- **Test Quality Checks**: Enforces test-specific rules for files in test directories
|
||||
- **Smart Filtering**: Automatically skips test files and fixtures
|
||||
- **Configurable Enforcement**: Strict denial, user prompts, or warnings
|
||||
|
||||
### PostToolUse (Verification)
|
||||
- **Cross-File Duplicate Detection**: Finds duplicates across the project
|
||||
- **State Tracking**: Compares quality metrics before and after modifications
|
||||
- **Naming Convention Verification**: Checks PEP8 compliance for functions and classes
|
||||
- **Quality Delta Reports**: Shows improvements vs degradations
|
||||
- **Project Standards Verification**: Ensures consistency with codebase
|
||||
|
||||
## Installation
|
||||
|
||||
### Quick Setup
|
||||
|
||||
```bash
|
||||
# Make setup script executable and run it
|
||||
chmod +x setup_hook.sh
|
||||
./setup_hook.sh
|
||||
```
|
||||
|
||||
### Manual Setup
|
||||
|
||||
1. Install claude-scripts (required for analysis):
|
||||
```bash
|
||||
pip install claude-scripts
|
||||
```
|
||||
|
||||
2. Copy hook configuration to Claude Code settings:
|
||||
```bash
|
||||
mkdir -p ~/.config/claude
|
||||
cp claude-code-settings.json ~/.config/claude/settings.json
|
||||
```
|
||||
|
||||
3. Update paths in settings.json to match your installation location
|
||||
|
||||
## Hook Versions
|
||||
|
||||
### Basic Hook (`code_quality_guard.py`)
|
||||
- Simple deny/allow decisions
|
||||
- Fixed thresholds
|
||||
- Good for enforcing consistent standards
|
||||
|
||||
### Advanced Hook (`code_quality_guard_advanced.py`)
|
||||
- Configurable via environment variables
|
||||
- Multiple enforcement modes
|
||||
- Detailed issue reporting
|
||||
|
||||
## Configuration (Advanced Hook)
|
||||
|
||||
Set these environment variables to customize behavior:
|
||||
|
||||
### Core Settings
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `QUALITY_DUP_THRESHOLD` | 0.7 | Similarity threshold for duplicate detection (0.0-1.0) |
|
||||
| `QUALITY_DUP_ENABLED` | true | Enable/disable duplicate checking |
|
||||
| `QUALITY_COMPLEXITY_THRESHOLD` | 10 | Maximum allowed cyclomatic complexity |
|
||||
| `QUALITY_COMPLEXITY_ENABLED` | true | Enable/disable complexity checking |
|
||||
| `QUALITY_MODERN_ENABLED` | true | Enable/disable modernization checking |
|
||||
| `QUALITY_REQUIRE_TYPES` | true | Require type hints in code |
|
||||
| `QUALITY_ENFORCEMENT` | strict | Enforcement mode: strict/warn/permissive |
|
||||
|
||||
### PostToolUse Features
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `QUALITY_STATE_TRACKING` | false | Enable quality metrics comparison before/after |
|
||||
| `QUALITY_CROSS_FILE_CHECK` | false | Check for cross-file duplicates |
|
||||
| `QUALITY_VERIFY_NAMING` | true | Verify PEP8 naming conventions |
|
||||
| `QUALITY_SHOW_SUCCESS` | false | Show success messages for clean files |
|
||||
|
||||
### Test Quality Features
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `QUALITY_TEST_QUALITY_ENABLED` | true | Enable test-specific quality checks for test files |
|
||||
|
||||
### External Context Providers
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `QUALITY_CONTEXT7_ENABLED` | false | Enable Context7 API for additional context analysis |
|
||||
| `QUALITY_CONTEXT7_API_KEY` | "" | API key for Context7 service |
|
||||
| `QUALITY_FIRECRAWL_ENABLED` | false | Enable Firecrawl API for web scraping examples |
|
||||
| `QUALITY_FIRECRAWL_API_KEY` | "" | API key for Firecrawl service |
|
||||
|
||||
### Enforcement Modes
|
||||
|
||||
- **strict**: Deny writes with critical issues, prompt for warnings
|
||||
- **warn**: Always prompt user to confirm when issues found
|
||||
- **permissive**: Allow writes but display warnings
|
||||
|
||||
## Enhanced Error Messaging
|
||||
|
||||
When test quality violations are detected, the hook provides detailed, actionable guidance instead of generic error messages.
|
||||
|
||||
### Rule-Specific Guidance
|
||||
|
||||
Each violation type includes:
|
||||
|
||||
- **📋 Problem Description**: Clear explanation of what was detected
|
||||
- **❓ Why It Matters**: Educational context about test best practices
|
||||
- **🛠️ How to Fix It**: Step-by-step remediation instructions
|
||||
- **💡 Examples**: Before/after code examples showing the fix
|
||||
- **🔍 Context**: File and function information for easy location
|
||||
|
||||
### Example Enhanced Message
|
||||
|
||||
```
|
||||
🚫 Conditional Logic in Test Function
|
||||
|
||||
📋 Problem: Test function 'test_user_access' contains conditional statements (if/elif/else).
|
||||
|
||||
❓ Why this matters: Tests should be simple assertions that verify specific behavior. Conditionals make tests harder to understand and maintain.
|
||||
|
||||
🛠️ How to fix it:
|
||||
• Replace conditionals with parameterized test cases
|
||||
• Use pytest.mark.parametrize for multiple scenarios
|
||||
• Extract conditional logic into helper functions
|
||||
• Use assertion libraries like assertpy for complex conditions
|
||||
|
||||
💡 Example:
|
||||
# ❌ Instead of this:
|
||||
def test_user_access():
|
||||
user = create_user()
|
||||
if user.is_admin:
|
||||
assert user.can_access_admin()
|
||||
else:
|
||||
assert not user.can_access_admin()
|
||||
|
||||
# ✅ Do this:
|
||||
@pytest.mark.parametrize('is_admin,can_access', [
|
||||
(True, True),
|
||||
(False, False)
|
||||
])
|
||||
def test_user_access(is_admin, can_access):
|
||||
user = create_user(admin=is_admin)
|
||||
assert user.can_access_admin() == can_access
|
||||
|
||||
🔍 File: test_user.py
|
||||
📍 Function: test_user_access
|
||||
```
|
||||
|
||||
## External Context Integration
|
||||
|
||||
The hook can integrate with external APIs to provide additional context and examples.
|
||||
|
||||
### Context7 Integration
|
||||
|
||||
Provides additional analysis and context for rule violations using advanced language models.
|
||||
|
||||
### Firecrawl Integration
|
||||
|
||||
Scrapes web resources for additional examples, best practices, and community solutions.
|
||||
|
||||
### Configuration
|
||||
|
||||
```bash
|
||||
# Enable external context providers
|
||||
export QUALITY_CONTEXT7_ENABLED=true
|
||||
export QUALITY_CONTEXT7_API_KEY="your_context7_api_key"
|
||||
|
||||
export QUALITY_FIRECRAWL_ENABLED=true
|
||||
export QUALITY_FIRECRAWL_API_KEY="your_firecrawl_api_key"
|
||||
```
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Setting Environment Variables
|
||||
|
||||
```bash
|
||||
# In your shell profile (.bashrc, .zshrc, etc.)
|
||||
export QUALITY_DUP_THRESHOLD=0.8
|
||||
export QUALITY_COMPLEXITY_THRESHOLD=15
|
||||
export QUALITY_ENFORCEMENT=warn
|
||||
```
|
||||
|
||||
### Testing the Hook
|
||||
|
||||
1. Open Claude Code
|
||||
2. Try to write Python code with issues:
|
||||
|
||||
```python
|
||||
# This will trigger the duplicate detection
|
||||
def calculate_total(items):
|
||||
total = 0
|
||||
for item in items:
|
||||
total += item.price
|
||||
return total
|
||||
|
||||
def compute_sum(products): # Similar to above
|
||||
sum = 0
|
||||
for product in products:
|
||||
sum += product.price
|
||||
return sum
|
||||
```
|
||||
|
||||
3. The hook will analyze and potentially block the operation
|
||||
|
||||
## Test Quality Checks
|
||||
|
||||
When enabled, the hook performs additional quality checks on test files using Sourcery rules specifically designed for test code:
|
||||
|
||||
### Test-Specific Rules
|
||||
|
||||
- **no-conditionals-in-tests**: Prevents conditional statements in test functions
|
||||
- **no-loop-in-tests**: Prevents loops in test functions
|
||||
- **raise-specific-error**: Ensures specific exceptions are raised instead of generic ones
|
||||
- **dont-import-test-modules**: Prevents importing test modules in non-test code
|
||||
|
||||
### Configuration
|
||||
|
||||
Test quality checks are controlled by the `QUALITY_TEST_QUALITY_ENABLED` environment variable:
|
||||
|
||||
```bash
|
||||
# Enable test quality checks (default)
|
||||
export QUALITY_TEST_QUALITY_ENABLED=true
|
||||
|
||||
# Disable test quality checks
|
||||
export QUALITY_TEST_QUALITY_ENABLED=false
|
||||
```
|
||||
|
||||
### File Detection
|
||||
|
||||
Test files are automatically detected if they are located in directories containing:
|
||||
- `test/` or `tests/` or `testing/`
|
||||
|
||||
Example test file paths:
|
||||
- `tests/test_user.py`
|
||||
- `src/tests/test_auth.py`
|
||||
- `project/tests/integration/test_api.py`
|
||||
|
||||
## Hook Behavior
|
||||
|
||||
### What Gets Checked
|
||||
|
||||
✅ Python files (`.py` extension)
|
||||
✅ New file contents (Write tool)
|
||||
✅ Modified content (Edit tool)
|
||||
✅ Multiple edits (MultiEdit tool)
|
||||
✅ Test files (when test quality checks enabled)
|
||||
|
||||
### What Gets Skipped
|
||||
|
||||
❌ Non-Python files
|
||||
❌ Test files (when test quality checks disabled)
|
||||
❌ Fixture files (`/fixtures/`)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Hook Not Triggering
|
||||
|
||||
1. Verify settings location:
|
||||
```bash
|
||||
cat ~/.config/claude/settings.json
|
||||
```
|
||||
|
||||
2. Check claude-quality is installed:
|
||||
```bash
|
||||
claude-quality --version
|
||||
```
|
||||
|
||||
3. Test hook directly:
|
||||
```bash
|
||||
echo '{"tool_name": "Write", "tool_input": {"file_path": "test.py", "content": "print(1)"}}' | python code_quality_guard.py
|
||||
```
|
||||
|
||||
### Performance Issues
|
||||
|
||||
If analysis is slow:
|
||||
- Increase timeout in hook scripts
|
||||
- Disable specific checks via environment variables
|
||||
- Use permissive mode for large files
|
||||
|
||||
### Disabling the Hook
|
||||
|
||||
Remove or rename the settings file:
|
||||
```bash
|
||||
mv ~/.config/claude/settings.json ~/.config/claude/settings.json.disabled
|
||||
```
|
||||
|
||||
## Integration with CI/CD
|
||||
|
||||
These hooks complement CI/CD quality gates:
|
||||
|
||||
1. **Local Prevention**: Hooks prevent low-quality code at write time
|
||||
2. **CI Validation**: CI/CD runs same quality checks on commits
|
||||
3. **Consistent Standards**: Both use same claude-quality toolkit
|
||||
|
||||
## Advanced Customization
|
||||
|
||||
### Custom Skip Patterns
|
||||
|
||||
Modify the `skip_patterns` in `QualityConfig`:
|
||||
|
||||
```python
|
||||
skip_patterns = [
|
||||
'test_', '_test.py', '/tests/',
|
||||
'/vendor/', '/third_party/',
|
||||
'generated_', '.proto'
|
||||
]
|
||||
```
|
||||
|
||||
### Custom Quality Rules
|
||||
|
||||
Extend the analysis by adding checks:
|
||||
|
||||
```python
|
||||
# In analyze_with_quality_toolkit()
|
||||
if config.custom_checks_enabled:
|
||||
# Add your custom analysis
|
||||
cmd = ['your-tool', tmp_path]
|
||||
result = subprocess.run(cmd, ...)
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
To improve these hooks:
|
||||
|
||||
1. Test changes locally
|
||||
2. Update both basic and advanced versions
|
||||
3. Document new configuration options
|
||||
4. Submit PR with examples
|
||||
|
||||
## License
|
||||
|
||||
Same as claude-scripts project (MIT)
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Claude Code hooks subsystem with unified facade.
|
||||
|
||||
Provides a clean, concurrency-safe interface for all Claude Code hooks
|
||||
(PreToolUse, PostToolUse, Stop) with built-in validation for bash commands
|
||||
and code quality.
|
||||
|
||||
Quick Start:
|
||||
```python
|
||||
from hooks import Guards
|
||||
import json
|
||||
|
||||
guards = Guards()
|
||||
payload = json.load(sys.stdin)
|
||||
response = guards.handle_pretooluse(payload)
|
||||
```
|
||||
|
||||
Architecture:
|
||||
- Guards: Main facade coordinating all validations
|
||||
- BashCommandGuard: Validates bash commands for type safety
|
||||
- CodeQualityGuard: Checks code quality (duplicates, complexity)
|
||||
- LockManager: File-based inter-process synchronization
|
||||
- Analyzers: Supporting analysis tools (duplicates, types, etc.)
|
||||
"""
|
||||
|
||||
from .facade import Guards
|
||||
|
||||
__all__ = ["Guards"]
|
||||
|
||||
17
hooks/analyzers/__init__.py
Normal file
17
hooks/analyzers/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Code analysis tools for hook-based quality checking."""
|
||||
|
||||
from .duplicate_detector import (
|
||||
Duplicate,
|
||||
DuplicateResults,
|
||||
detect_internal_duplicates,
|
||||
)
|
||||
from .message_enrichment import EnhancedMessageFormatter
|
||||
from .type_inference import TypeInferenceHelper
|
||||
|
||||
__all__ = [
|
||||
"detect_internal_duplicates",
|
||||
"Duplicate",
|
||||
"DuplicateResults",
|
||||
"EnhancedMessageFormatter",
|
||||
"TypeInferenceHelper",
|
||||
]
|
||||
@@ -163,7 +163,7 @@ class InternalDuplicateDetector:
|
||||
"""Analyze source code for internal duplicates."""
|
||||
try:
|
||||
# Dedent the content to handle code fragments with leading indentation
|
||||
tree = ast.parse(textwrap.dedent(source_code))
|
||||
tree: ast.Module = ast.parse(textwrap.dedent(source_code))
|
||||
except SyntaxError:
|
||||
return {
|
||||
"error": "Failed to parse code",
|
||||
@@ -172,7 +172,7 @@ class InternalDuplicateDetector:
|
||||
}
|
||||
|
||||
# Extract code blocks
|
||||
blocks = self._extract_code_blocks(tree, source_code)
|
||||
blocks: list[CodeBlock] = self._extract_code_blocks(tree, source_code)
|
||||
|
||||
# Filter blocks by size
|
||||
blocks = [
|
||||
@@ -195,18 +195,18 @@ class InternalDuplicateDetector:
|
||||
duplicate_groups: list[DuplicateGroup] = []
|
||||
|
||||
# 1. Check for exact duplicates (normalized)
|
||||
exact_groups = self._find_exact_duplicates(blocks)
|
||||
exact_groups: list[DuplicateGroup] = self._find_exact_duplicates(blocks)
|
||||
duplicate_groups.extend(exact_groups)
|
||||
|
||||
# 2. Check for structural similarity
|
||||
structural_groups = self._find_structural_duplicates(blocks)
|
||||
structural_groups: list[DuplicateGroup] = self._find_structural_duplicates(blocks)
|
||||
duplicate_groups.extend(structural_groups)
|
||||
|
||||
# 3. Check for semantic patterns
|
||||
pattern_groups = self._find_pattern_duplicates(blocks)
|
||||
pattern_groups: list[DuplicateGroup] = self._find_pattern_duplicates(blocks)
|
||||
duplicate_groups.extend(pattern_groups)
|
||||
|
||||
filtered_groups = [
|
||||
filtered_groups: list[DuplicateGroup] = [
|
||||
group
|
||||
for group in duplicate_groups
|
||||
if group.similarity_score >= self.similarity_threshold
|
||||
@@ -244,7 +244,7 @@ class InternalDuplicateDetector:
|
||||
def _extract_code_blocks(self, tree: ast.AST, source: str) -> list[CodeBlock]:
|
||||
"""Extract functions, methods, and classes from AST."""
|
||||
blocks: list[CodeBlock] = []
|
||||
lines = source.split("\n")
|
||||
lines: list[str] = source.split("\n")
|
||||
|
||||
def create_block(
|
||||
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
|
||||
@@ -252,10 +252,10 @@ class InternalDuplicateDetector:
|
||||
lines: list[str],
|
||||
) -> CodeBlock | None:
|
||||
try:
|
||||
start = node.lineno - 1
|
||||
end_lineno = getattr(node, "end_lineno", None)
|
||||
end = end_lineno - 1 if end_lineno is not None else start
|
||||
source = "\n".join(lines[start : end + 1])
|
||||
start: int = node.lineno - 1
|
||||
end_lineno: int | None = getattr(node, "end_lineno", None)
|
||||
end: int = end_lineno - 1 if end_lineno is not None else start
|
||||
source: str = "\n".join(lines[start : end + 1])
|
||||
|
||||
return CodeBlock(
|
||||
name=node.name,
|
||||
@@ -271,7 +271,7 @@ class InternalDuplicateDetector:
|
||||
|
||||
def calculate_complexity(node: ast.AST) -> int:
|
||||
"""Simple cyclomatic complexity calculation."""
|
||||
complexity = 1
|
||||
complexity: int = 1
|
||||
for child in ast.walk(node):
|
||||
if isinstance(
|
||||
child,
|
||||
@@ -297,7 +297,7 @@ class InternalDuplicateDetector:
|
||||
return
|
||||
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
block_type = (
|
||||
block_type: str = (
|
||||
"method" if isinstance(parent, ast.ClassDef) else "function"
|
||||
)
|
||||
if block := create_block(node, block_type, lines):
|
||||
@@ -311,41 +311,41 @@ class InternalDuplicateDetector:
|
||||
|
||||
def _find_exact_duplicates(self, blocks: list[CodeBlock]) -> list[DuplicateGroup]:
|
||||
"""Find exact or near-exact duplicate blocks."""
|
||||
groups = []
|
||||
processed = set()
|
||||
groups: list[DuplicateGroup] = []
|
||||
processed: set[int] = set()
|
||||
|
||||
for i, block1 in enumerate(blocks):
|
||||
if i in processed:
|
||||
continue
|
||||
|
||||
similar = [block1]
|
||||
norm1 = self._normalize_code(block1.source)
|
||||
similar: list[CodeBlock] = [block1]
|
||||
norm1: str = self._normalize_code(block1.source)
|
||||
|
||||
for j, block2 in enumerate(blocks[i + 1 :], i + 1):
|
||||
if j in processed:
|
||||
continue
|
||||
|
||||
norm2 = self._normalize_code(block2.source)
|
||||
norm2: str = self._normalize_code(block2.source)
|
||||
|
||||
# Check if normalized versions are very similar
|
||||
similarity = difflib.SequenceMatcher(None, norm1, norm2).ratio()
|
||||
similarity: float = difflib.SequenceMatcher(None, norm1, norm2).ratio()
|
||||
if similarity >= 0.85: # High threshold for "exact" duplicates
|
||||
similar.append(block2)
|
||||
processed.add(j)
|
||||
|
||||
if len(similar) > 1:
|
||||
# Calculate actual similarity on normalized code
|
||||
total_sim = 0
|
||||
count = 0
|
||||
total_sim: float = 0
|
||||
count: int = 0
|
||||
for k in range(len(similar)):
|
||||
for idx in range(k + 1, len(similar)):
|
||||
norm_k = self._normalize_code(similar[k].source)
|
||||
norm_idx = self._normalize_code(similar[idx].source)
|
||||
sim = difflib.SequenceMatcher(None, norm_k, norm_idx).ratio()
|
||||
norm_k: str = self._normalize_code(similar[k].source)
|
||||
norm_idx: str = self._normalize_code(similar[idx].source)
|
||||
sim: float = difflib.SequenceMatcher(None, norm_k, norm_idx).ratio()
|
||||
total_sim += sim
|
||||
count += 1
|
||||
|
||||
avg_similarity = total_sim / count if count > 0 else 1.0
|
||||
avg_similarity: float = total_sim / count if count > 0 else 1.0
|
||||
|
||||
groups.append(
|
||||
DuplicateGroup(
|
||||
@@ -383,28 +383,28 @@ class InternalDuplicateDetector:
|
||||
blocks: list[CodeBlock],
|
||||
) -> list[DuplicateGroup]:
|
||||
"""Find structurally similar blocks using AST comparison."""
|
||||
groups = []
|
||||
processed = set()
|
||||
groups: list[DuplicateGroup] = []
|
||||
processed: set[int] = set()
|
||||
|
||||
for i, block1 in enumerate(blocks):
|
||||
if i in processed:
|
||||
continue
|
||||
|
||||
similar_blocks = [block1]
|
||||
similar_blocks: list[CodeBlock] = [block1]
|
||||
|
||||
for j, block2 in enumerate(blocks[i + 1 :], i + 1):
|
||||
if j in processed:
|
||||
continue
|
||||
|
||||
similarity = self._ast_similarity(block1.ast_node, block2.ast_node)
|
||||
similarity: float = self._ast_similarity(block1.ast_node, block2.ast_node)
|
||||
if similarity >= self.similarity_threshold:
|
||||
similar_blocks.append(block2)
|
||||
processed.add(j)
|
||||
|
||||
if len(similar_blocks) > 1:
|
||||
# Calculate average similarity
|
||||
total_sim = 0
|
||||
count = 0
|
||||
total_sim: float = 0
|
||||
count: int = 0
|
||||
for k in range(len(similar_blocks)):
|
||||
for idx in range(k + 1, len(similar_blocks)):
|
||||
total_sim += self._ast_similarity(
|
||||
@@ -413,7 +413,7 @@ class InternalDuplicateDetector:
|
||||
)
|
||||
count += 1
|
||||
|
||||
avg_similarity = total_sim / count if count > 0 else 0
|
||||
avg_similarity: float = total_sim / count if count > 0 else 0
|
||||
|
||||
groups.append(
|
||||
DuplicateGroup(
|
||||
@@ -432,46 +432,46 @@ class InternalDuplicateDetector:
|
||||
|
||||
def get_structure(node: ast.AST) -> list[str]:
|
||||
"""Extract structural pattern from AST node."""
|
||||
structure = []
|
||||
structure: list[str] = []
|
||||
for child in ast.walk(node):
|
||||
structure.append(child.__class__.__name__)
|
||||
return structure
|
||||
|
||||
struct1 = get_structure(node1)
|
||||
struct2 = get_structure(node2)
|
||||
struct1: list[str] = get_structure(node1)
|
||||
struct2: list[str] = get_structure(node2)
|
||||
|
||||
if not struct1 or not struct2:
|
||||
return 0.0
|
||||
|
||||
# Use sequence matcher for structural similarity
|
||||
matcher = difflib.SequenceMatcher(None, struct1, struct2)
|
||||
matcher: difflib.SequenceMatcher[str] = difflib.SequenceMatcher(None, struct1, struct2)
|
||||
return matcher.ratio()
|
||||
|
||||
def _find_pattern_duplicates(self, blocks: list[CodeBlock]) -> list[DuplicateGroup]:
|
||||
"""Find blocks with similar patterns (e.g., similar loops, conditions)."""
|
||||
groups = []
|
||||
pattern_groups = defaultdict(list)
|
||||
groups: list[DuplicateGroup] = []
|
||||
pattern_groups: defaultdict[tuple[str, str], list[CodeBlock]] = defaultdict(list)
|
||||
|
||||
for block in blocks:
|
||||
patterns = self._extract_patterns(block)
|
||||
patterns: list[tuple[str, str]] = self._extract_patterns(block)
|
||||
for pattern_type, pattern_hash in patterns:
|
||||
pattern_groups[(pattern_type, pattern_hash)].append(block)
|
||||
|
||||
for (pattern_type, _), similar_blocks in pattern_groups.items():
|
||||
if len(similar_blocks) > 1:
|
||||
# Calculate token-based similarity
|
||||
total_sim = 0
|
||||
count = 0
|
||||
total_sim: float = 0
|
||||
count: int = 0
|
||||
for i in range(len(similar_blocks)):
|
||||
for j in range(i + 1, len(similar_blocks)):
|
||||
sim = self._token_similarity(
|
||||
sim: float = self._token_similarity(
|
||||
similar_blocks[i].tokens,
|
||||
similar_blocks[j].tokens,
|
||||
)
|
||||
total_sim += sim
|
||||
count += 1
|
||||
|
||||
avg_similarity = total_sim / count if count > 0 else 0.7
|
||||
avg_similarity: float = total_sim / count if count > 0 else 0.7
|
||||
|
||||
if avg_similarity >= self.similarity_threshold:
|
||||
groups.append(
|
||||
@@ -487,28 +487,28 @@ class InternalDuplicateDetector:
|
||||
|
||||
def _extract_patterns(self, block: CodeBlock) -> list[tuple[str, str]]:
|
||||
"""Extract semantic patterns from code block."""
|
||||
patterns = []
|
||||
patterns: list[tuple[str, str]] = []
|
||||
|
||||
# Pattern: for-if combination
|
||||
if "for " in block.source and "if " in block.source:
|
||||
pattern = re.sub(r"\b\w+\b", "VAR", block.source)
|
||||
pattern: str = re.sub(r"\b\w+\b", "VAR", block.source)
|
||||
pattern = re.sub(r"\s+", "", pattern)
|
||||
patterns.append(
|
||||
("loop-condition", hashlib.sha256(pattern.encode()).hexdigest()[:8]),
|
||||
)
|
||||
|
||||
# Pattern: multiple similar operations
|
||||
operations = re.findall(r"(\w+)\s*[=+\-*/]+\s*(\w+)", block.source)
|
||||
operations: list[tuple[str, ...]] = re.findall(r"(\w+)\s*[=+\-*/]+\s*(\w+)", block.source)
|
||||
if len(operations) > 2:
|
||||
op_pattern = "".join(sorted(op[0] for op in operations))
|
||||
op_pattern: str = "".join(sorted(op[0] for op in operations))
|
||||
patterns.append(
|
||||
("repetitive-ops", hashlib.sha256(op_pattern.encode()).hexdigest()[:8]),
|
||||
)
|
||||
|
||||
# Pattern: similar function calls
|
||||
calls = re.findall(r"(\w+)\s*\([^)]*\)", block.source)
|
||||
calls: list[str] = re.findall(r"(\w+)\s*\([^)]*\)", block.source)
|
||||
if len(calls) > 2:
|
||||
call_pattern = "".join(sorted(set(calls)))
|
||||
call_pattern: str = "".join(sorted(set(calls)))
|
||||
patterns.append(
|
||||
(
|
||||
"similar-calls",
|
||||
@@ -524,19 +524,19 @@ class InternalDuplicateDetector:
|
||||
return 0.0
|
||||
|
||||
# Use Jaccard similarity on token sets
|
||||
set1 = set(tokens1)
|
||||
set2 = set(tokens2)
|
||||
set1: set[str] = set(tokens1)
|
||||
set2: set[str] = set(tokens2)
|
||||
|
||||
intersection = len(set1 & set2)
|
||||
union = len(set1 | set2)
|
||||
intersection: int = len(set1 & set2)
|
||||
union: int = len(set1 | set2)
|
||||
|
||||
if union == 0:
|
||||
return 0.0
|
||||
|
||||
jaccard = intersection / union
|
||||
jaccard: float = intersection / union
|
||||
|
||||
# Also consider sequence similarity
|
||||
sequence_sim = difflib.SequenceMatcher(None, tokens1, tokens2).ratio()
|
||||
sequence_sim: float = difflib.SequenceMatcher(None, tokens1, tokens2).ratio()
|
||||
|
||||
# Weighted combination
|
||||
return 0.6 * jaccard + 0.4 * sequence_sim
|
||||
@@ -548,43 +548,43 @@ class InternalDuplicateDetector:
|
||||
|
||||
# Check for common dunder methods
|
||||
if all(block.name in COMMON_DUPLICATE_METHODS for block in group.blocks):
|
||||
max_lines = max(
|
||||
dunder_max_lines: int = max(
|
||||
block.end_line - block.start_line + 1 for block in group.blocks
|
||||
)
|
||||
max_complexity = max(block.complexity for block in group.blocks)
|
||||
dunder_max_complexity: int = max(block.complexity for block in group.blocks)
|
||||
|
||||
# Allow simple lifecycle dunder methods to repeat across classes.
|
||||
if max_lines <= 12 and max_complexity <= 3:
|
||||
if dunder_max_lines <= 12 and dunder_max_complexity <= 3:
|
||||
return True
|
||||
|
||||
# Check for pytest fixtures - they legitimately have repetitive structure
|
||||
if all(block.is_test_fixture() for block in group.blocks):
|
||||
max_lines = max(
|
||||
fixture_max_lines: int = max(
|
||||
block.end_line - block.start_line + 1 for block in group.blocks
|
||||
)
|
||||
# Allow fixtures up to 15 lines with similar structure
|
||||
if max_lines <= 15:
|
||||
if fixture_max_lines <= 15:
|
||||
return True
|
||||
|
||||
# Check for test functions with fixture-like names (data builders, mocks, etc.)
|
||||
if all(block.has_test_pattern_name() for block in group.blocks):
|
||||
max_lines = max(
|
||||
pattern_max_lines: int = max(
|
||||
block.end_line - block.start_line + 1 for block in group.blocks
|
||||
)
|
||||
max_complexity = max(block.complexity for block in group.blocks)
|
||||
pattern_max_complexity: int = max(block.complexity for block in group.blocks)
|
||||
# Allow test helpers that are simple and short
|
||||
if max_lines <= 10 and max_complexity <= 4:
|
||||
if pattern_max_lines <= 10 and pattern_max_complexity <= 4:
|
||||
return True
|
||||
|
||||
# Check for simple test functions with arrange-act-assert pattern
|
||||
if all(block.is_test_function() for block in group.blocks):
|
||||
max_complexity = max(block.complexity for block in group.blocks)
|
||||
max_lines = max(
|
||||
test_max_complexity: int = max(block.complexity for block in group.blocks)
|
||||
test_max_lines: int = max(
|
||||
block.end_line - block.start_line + 1 for block in group.blocks
|
||||
)
|
||||
# Simple tests (<=15 lines) often share similar control flow.
|
||||
# Permit full similarity for those cases; duplication is acceptable.
|
||||
if max_complexity <= 5 and max_lines <= 15:
|
||||
if test_max_complexity <= 5 and test_max_lines <= 15:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -50,7 +50,7 @@ class TypeInferenceHelper:
|
||||
) -> TypeSuggestion | None:
|
||||
"""Infer the type of a variable from its usage in code."""
|
||||
try:
|
||||
tree = ast.parse(textwrap.dedent(source_code))
|
||||
tree: ast.Module = ast.parse(textwrap.dedent(source_code))
|
||||
except SyntaxError:
|
||||
return None
|
||||
|
||||
@@ -58,11 +58,13 @@ class TypeInferenceHelper:
|
||||
assignments: list[ast.expr] = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Assign):
|
||||
assignments.extend(
|
||||
# Collect value nodes for matching targets
|
||||
matching_values: list[ast.expr] = [
|
||||
node.value
|
||||
for target in node.targets
|
||||
if isinstance(target, ast.Name) and target.id == variable_name
|
||||
)
|
||||
]
|
||||
assignments.extend(matching_values)
|
||||
elif (
|
||||
isinstance(node, ast.AnnAssign)
|
||||
and isinstance(node.target, ast.Name)
|
||||
@@ -75,8 +77,8 @@ class TypeInferenceHelper:
|
||||
return None
|
||||
|
||||
# Analyze the first assignment
|
||||
value_node = assignments[0]
|
||||
suggested_type = TypeInferenceHelper._infer_from_node(value_node)
|
||||
value_node: ast.expr = assignments[0]
|
||||
suggested_type: str = TypeInferenceHelper._infer_from_node(value_node)
|
||||
|
||||
if suggested_type and suggested_type != "Any":
|
||||
return TypeSuggestion(
|
||||
@@ -94,47 +96,50 @@ class TypeInferenceHelper:
|
||||
def _infer_from_node(node: ast.AST) -> str:
|
||||
"""Infer type from an AST node."""
|
||||
if isinstance(node, ast.Constant):
|
||||
value_type = type(node.value).__name__
|
||||
return {
|
||||
value_type: str = type(node.value).__name__
|
||||
type_map: dict[str, str] = {
|
||||
"NoneType": "None",
|
||||
"bool": "bool",
|
||||
"int": "int",
|
||||
"float": "float",
|
||||
"str": "str",
|
||||
"bytes": "bytes",
|
||||
}.get(value_type, "Any")
|
||||
}
|
||||
return type_map.get(value_type, "Any")
|
||||
|
||||
if isinstance(node, ast.List):
|
||||
if not node.elts:
|
||||
return "list[Any]"
|
||||
# Try to infer element type from first element
|
||||
first_type = TypeInferenceHelper._infer_from_node(node.elts[0])
|
||||
first_type: str = TypeInferenceHelper._infer_from_node(node.elts[0])
|
||||
return f"list[{first_type}]"
|
||||
|
||||
if isinstance(node, ast.Dict):
|
||||
if not node.keys or not node.values:
|
||||
return "dict[Any, Any]"
|
||||
first_key = node.keys[0]
|
||||
first_key: ast.expr | None = node.keys[0]
|
||||
if first_key is None:
|
||||
return "dict[Any, Any]"
|
||||
key_type = TypeInferenceHelper._infer_from_node(first_key)
|
||||
value_type = TypeInferenceHelper._infer_from_node(node.values[0])
|
||||
return f"dict[{key_type}, {value_type}]"
|
||||
key_type: str = TypeInferenceHelper._infer_from_node(first_key)
|
||||
dict_value_type: str = TypeInferenceHelper._infer_from_node(node.values[0])
|
||||
return f"dict[{key_type}, {dict_value_type}]"
|
||||
|
||||
if isinstance(node, ast.Set):
|
||||
if not node.elts:
|
||||
return "set[Any]"
|
||||
element_type = TypeInferenceHelper._infer_from_node(node.elts[0])
|
||||
element_type: str = TypeInferenceHelper._infer_from_node(node.elts[0])
|
||||
return f"set[{element_type}]"
|
||||
|
||||
if isinstance(node, ast.Tuple):
|
||||
if not node.elts:
|
||||
return "tuple[()]"
|
||||
types = [TypeInferenceHelper._infer_from_node(e) for e in node.elts]
|
||||
types: list[str] = [
|
||||
TypeInferenceHelper._infer_from_node(e) for e in node.elts
|
||||
]
|
||||
return f"tuple[{', '.join(types)}]"
|
||||
|
||||
if isinstance(node, ast.Call):
|
||||
func = node.func
|
||||
func: ast.expr = node.func
|
||||
if isinstance(func, ast.Name):
|
||||
# Common constructors
|
||||
if func.id in ("list", "dict", "set", "tuple", "str", "int", "float"):
|
||||
@@ -166,17 +171,18 @@ class TypeInferenceHelper:
|
||||
if node.value is None:
|
||||
return_types.add("None")
|
||||
else:
|
||||
inferred = TypeInferenceHelper._infer_from_node(node.value)
|
||||
inferred: str = TypeInferenceHelper._infer_from_node(node.value)
|
||||
return_types.add(inferred)
|
||||
|
||||
if not return_types:
|
||||
return_types.add("None")
|
||||
|
||||
# Combine multiple return types
|
||||
suggested: str
|
||||
if len(return_types) == 1:
|
||||
suggested = return_types.pop()
|
||||
elif "None" in return_types and len(return_types) == 2:
|
||||
non_none = [t for t in return_types if t != "None"]
|
||||
non_none: list[str] = [t for t in return_types if t != "None"]
|
||||
suggested = f"{non_none[0]} | None"
|
||||
else:
|
||||
suggested = " | ".join(sorted(return_types))
|
||||
@@ -196,7 +202,7 @@ class TypeInferenceHelper:
|
||||
_source_code: str,
|
||||
) -> list[TypeSuggestion]:
|
||||
"""Suggest types for function parameters based on their usage."""
|
||||
suggestions = []
|
||||
suggestions: list[TypeSuggestion] = []
|
||||
|
||||
for arg in function_node.args.args:
|
||||
# Skip if already annotated
|
||||
@@ -208,11 +214,12 @@ class TypeInferenceHelper:
|
||||
continue
|
||||
|
||||
# Try to infer from usage within function
|
||||
arg_name = arg.arg
|
||||
if suggested_type := TypeInferenceHelper._infer_param_from_usage(
|
||||
arg_name: str = arg.arg
|
||||
suggested_type: str | None = TypeInferenceHelper._infer_param_from_usage(
|
||||
arg_name,
|
||||
function_node,
|
||||
):
|
||||
)
|
||||
if suggested_type is not None:
|
||||
suggestions.append(
|
||||
TypeSuggestion(
|
||||
element_name=arg_name,
|
||||
@@ -221,7 +228,7 @@ class TypeInferenceHelper:
|
||||
confidence=0.6,
|
||||
reason=f"Inferred from usage in {function_node.name}",
|
||||
example=f"{arg_name}: {suggested_type}",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return suggestions
|
||||
@@ -240,7 +247,7 @@ class TypeInferenceHelper:
|
||||
and node.value.id == param_name
|
||||
):
|
||||
# Parameter has attribute access - likely an object
|
||||
attr_name = node.attr
|
||||
attr_name: str = node.attr
|
||||
# Common patterns
|
||||
if attr_name in (
|
||||
"read",
|
||||
@@ -290,7 +297,7 @@ class TypeInferenceHelper:
|
||||
Returns list of (old_import, new_import, reason) tuples.
|
||||
"""
|
||||
# Patterns to detect and replace
|
||||
patterns = {
|
||||
patterns: dict[str, tuple[str, str, str]] = {
|
||||
r"from typing import.*\bUnion\b": (
|
||||
"from typing import Union",
|
||||
"# Use | operator instead (Python 3.10+)",
|
||||
@@ -332,24 +339,24 @@ class TypeInferenceHelper:
|
||||
@staticmethod
|
||||
def find_any_usage_with_context(source_code: str) -> list[dict[str, str | int]]:
|
||||
"""Find usage of typing.Any and provide context for better suggestions."""
|
||||
results = []
|
||||
results: list[dict[str, str | int]] = []
|
||||
|
||||
try:
|
||||
tree = ast.parse(textwrap.dedent(source_code))
|
||||
tree: ast.Module = ast.parse(textwrap.dedent(source_code))
|
||||
except SyntaxError:
|
||||
return results
|
||||
|
||||
for node in ast.walk(tree):
|
||||
# Find variable annotations with Any
|
||||
if isinstance(node, ast.AnnAssign) and TypeInferenceHelper._contains_any(
|
||||
node.annotation
|
||||
node.annotation,
|
||||
):
|
||||
target_name = ""
|
||||
target_name: str = ""
|
||||
if isinstance(node.target, ast.Name):
|
||||
target_name = node.target.id
|
||||
|
||||
# Try to infer better type from value
|
||||
better_type = "Any"
|
||||
better_type: str = "Any"
|
||||
if node.value:
|
||||
better_type = TypeInferenceHelper._infer_from_node(node.value)
|
||||
|
||||
@@ -360,12 +367,12 @@ class TypeInferenceHelper:
|
||||
"current": "Any",
|
||||
"suggested": better_type,
|
||||
"context": "variable annotation",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Find function parameters with Any
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
results.extend(
|
||||
param_results: list[dict[str, str | int]] = [
|
||||
{
|
||||
"line": getattr(node, "lineno", 0),
|
||||
"element": arg.arg,
|
||||
@@ -376,14 +383,19 @@ class TypeInferenceHelper:
|
||||
for arg in node.args.args
|
||||
if arg.annotation
|
||||
and TypeInferenceHelper._contains_any(arg.annotation)
|
||||
)
|
||||
]
|
||||
results.extend(param_results)
|
||||
# Check return type
|
||||
if node.returns and TypeInferenceHelper._contains_any(node.returns):
|
||||
suggestion = TypeInferenceHelper.suggest_function_return_type(
|
||||
node,
|
||||
source_code
|
||||
suggestion: TypeSuggestion | None = (
|
||||
TypeInferenceHelper.suggest_function_return_type(
|
||||
node,
|
||||
source_code,
|
||||
)
|
||||
)
|
||||
suggested_type: str = (
|
||||
suggestion.suggested_type if suggestion else "Any"
|
||||
)
|
||||
suggested_type = suggestion.suggested_type if suggestion else "Any"
|
||||
results.append(
|
||||
{
|
||||
"line": getattr(node, "lineno", 0),
|
||||
@@ -391,7 +403,7 @@ class TypeInferenceHelper:
|
||||
"current": "Any",
|
||||
"suggested": suggested_type,
|
||||
"context": "return type",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -1,575 +0,0 @@
|
||||
"""Shell command guard for Claude Code PreToolUse/PostToolUse hooks.
|
||||
|
||||
Prevents circumvention of type safety rules via shell commands that could inject
|
||||
'Any' types or type ignore comments into Python files.
|
||||
"""
|
||||
|
||||
import fcntl
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from shutil import which
|
||||
from typing import TypedDict
|
||||
|
||||
# Handle both relative imports (when run as module) and direct imports (when run as script)
|
||||
try:
|
||||
from .bash_guard_constants import (
|
||||
DANGEROUS_SHELL_PATTERNS,
|
||||
FORBIDDEN_PATTERNS,
|
||||
LOCK_POLL_INTERVAL_SECONDS,
|
||||
LOCK_TIMEOUT_SECONDS,
|
||||
PYTHON_FILE_PATTERNS,
|
||||
)
|
||||
except ImportError:
|
||||
import bash_guard_constants
|
||||
DANGEROUS_SHELL_PATTERNS = bash_guard_constants.DANGEROUS_SHELL_PATTERNS
|
||||
FORBIDDEN_PATTERNS = bash_guard_constants.FORBIDDEN_PATTERNS
|
||||
PYTHON_FILE_PATTERNS = bash_guard_constants.PYTHON_FILE_PATTERNS
|
||||
LOCK_TIMEOUT_SECONDS = bash_guard_constants.LOCK_TIMEOUT_SECONDS
|
||||
LOCK_POLL_INTERVAL_SECONDS = bash_guard_constants.LOCK_POLL_INTERVAL_SECONDS
|
||||
|
||||
|
||||
class JsonObject(TypedDict, total=False):
|
||||
"""Type for JSON-like objects."""
|
||||
|
||||
hookEventName: str
|
||||
permissionDecision: str
|
||||
permissionDecisionReason: str
|
||||
decision: str
|
||||
reason: str
|
||||
systemMessage: str
|
||||
hookSpecificOutput: dict[str, object]
|
||||
|
||||
|
||||
# File-based lock for inter-process synchronization
|
||||
def _get_lock_file() -> Path:
|
||||
"""Get path to lock file for subprocess serialization."""
|
||||
lock_dir = Path(tempfile.gettempdir()) / ".claude_hooks"
|
||||
lock_dir.mkdir(exist_ok=True, mode=0o700)
|
||||
return lock_dir / "subprocess.lock"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _subprocess_lock(timeout: float = LOCK_TIMEOUT_SECONDS):
|
||||
"""Context manager for file-based subprocess locking.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time in seconds to wait for the lock. Non-positive
|
||||
values attempt a single non-blocking acquisition.
|
||||
|
||||
Yields:
|
||||
True if lock was acquired, False if timeout occurred.
|
||||
"""
|
||||
lock_file = _get_lock_file()
|
||||
deadline = (
|
||||
time.monotonic() + timeout if timeout and timeout > 0 else None
|
||||
)
|
||||
acquired = False
|
||||
|
||||
# Open or create lock file
|
||||
with open(lock_file, "a") as f:
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
acquired = True
|
||||
break
|
||||
except (IOError, OSError):
|
||||
if deadline is None:
|
||||
break
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
time.sleep(min(LOCK_POLL_INTERVAL_SECONDS, remaining))
|
||||
|
||||
yield acquired
|
||||
finally:
|
||||
if acquired:
|
||||
try:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
||||
except (IOError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
def _contains_forbidden_pattern(text: str) -> tuple[bool, str | None]:
|
||||
"""Check if text contains any forbidden patterns.
|
||||
|
||||
Args:
|
||||
text: The text to check for forbidden patterns.
|
||||
|
||||
Returns:
|
||||
Tuple of (has_violation, matched_pattern_description)
|
||||
"""
|
||||
for pattern in FORBIDDEN_PATTERNS:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
if "Any" in pattern:
|
||||
return True, "typing.Any usage"
|
||||
if "type.*ignore" in pattern:
|
||||
return True, "type suppression comment"
|
||||
return False, None
|
||||
|
||||
|
||||
def _is_dangerous_shell_command(command: str) -> tuple[bool, str | None]:
|
||||
"""Check if shell command uses dangerous patterns.
|
||||
|
||||
Args:
|
||||
command: The shell command to analyze.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_dangerous, reason)
|
||||
"""
|
||||
# Check if command targets Python files
|
||||
targets_python = any(
|
||||
re.search(pattern, command) for pattern in PYTHON_FILE_PATTERNS
|
||||
)
|
||||
|
||||
if not targets_python:
|
||||
return False, None
|
||||
|
||||
# Allow operations on temporary files (they're not project files)
|
||||
temp_dirs = [r"/tmp/", r"/var/tmp/", r"\.tmp/", r"tempfile"]
|
||||
if any(re.search(temp_pattern, command) for temp_pattern in temp_dirs):
|
||||
return False, None
|
||||
|
||||
# Check for dangerous shell patterns
|
||||
for pattern in DANGEROUS_SHELL_PATTERNS:
|
||||
if re.search(pattern, command):
|
||||
tool_match = re.search(
|
||||
r"\b(sed|awk|perl|ed|echo|printf|cat|tee|find|xargs|python|vim|nano|emacs)\b",
|
||||
pattern,
|
||||
)
|
||||
tool_name = tool_match[1] if tool_match else "shell utility"
|
||||
return True, f"Use of {tool_name} to modify Python files"
|
||||
|
||||
return False, None
|
||||
|
||||
|
||||
def _command_contains_forbidden_injection(command: str) -> tuple[bool, str | None]:
|
||||
"""Check if command attempts to inject forbidden patterns.
|
||||
|
||||
Args:
|
||||
command: The shell command to analyze.
|
||||
|
||||
Returns:
|
||||
Tuple of (has_injection, violation_description)
|
||||
"""
|
||||
# Check if the command itself contains forbidden patterns
|
||||
has_violation, violation_type = _contains_forbidden_pattern(command)
|
||||
|
||||
if has_violation:
|
||||
return True, violation_type
|
||||
|
||||
# Check for encoded or escaped patterns
|
||||
# Handle common escape sequences
|
||||
decoded_cmd = command.replace("\\n", "\n").replace("\\t", "\t")
|
||||
decoded_cmd = re.sub(r"\\\s", " ", decoded_cmd)
|
||||
|
||||
has_violation, violation_type = _contains_forbidden_pattern(decoded_cmd)
|
||||
if has_violation:
|
||||
return True, f"{violation_type} (escaped)"
|
||||
|
||||
return False, None
|
||||
|
||||
|
||||
def _analyze_bash_command(command: str) -> tuple[bool, list[str]]:
|
||||
"""Analyze bash command for safety violations.
|
||||
|
||||
Args:
|
||||
command: The bash command to analyze.
|
||||
|
||||
Returns:
|
||||
Tuple of (should_block, list_of_violations)
|
||||
"""
|
||||
violations: list[str] = []
|
||||
|
||||
# Check for forbidden pattern injection
|
||||
has_injection, injection_type = _command_contains_forbidden_injection(command)
|
||||
if has_injection:
|
||||
violations.append(f"⛔ Shell command attempts to inject {injection_type}")
|
||||
|
||||
# Check for dangerous shell patterns on Python files
|
||||
is_dangerous, danger_reason = _is_dangerous_shell_command(command)
|
||||
if is_dangerous:
|
||||
violations.append(
|
||||
f"⛔ {danger_reason} is forbidden - use Edit/Write tools instead",
|
||||
)
|
||||
|
||||
return len(violations) > 0, violations
|
||||
|
||||
|
||||
def _create_hook_response(
|
||||
event_name: str,
|
||||
permission: str = "",
|
||||
reason: str = "",
|
||||
system_message: str = "",
|
||||
*,
|
||||
decision: str | None = None,
|
||||
) -> JsonObject:
|
||||
"""Create standardized hook response.
|
||||
|
||||
Args:
|
||||
event_name: Name of the hook event (PreToolUse, PostToolUse, Stop).
|
||||
permission: Permission decision (allow, deny, ask).
|
||||
reason: Reason for the decision.
|
||||
system_message: System message to display.
|
||||
decision: Decision for PostToolUse/Stop hooks (approve, block).
|
||||
|
||||
Returns:
|
||||
JSON response object for the hook.
|
||||
"""
|
||||
hook_output: dict[str, object] = {
|
||||
"hookEventName": event_name,
|
||||
}
|
||||
|
||||
if permission:
|
||||
hook_output["permissionDecision"] = permission
|
||||
if reason:
|
||||
hook_output["permissionDecisionReason"] = reason
|
||||
|
||||
response: JsonObject = {
|
||||
"hookSpecificOutput": hook_output,
|
||||
}
|
||||
|
||||
if permission:
|
||||
response["permissionDecision"] = permission
|
||||
|
||||
if decision:
|
||||
response["decision"] = decision
|
||||
|
||||
if reason:
|
||||
response["reason"] = reason
|
||||
|
||||
if system_message:
|
||||
response["systemMessage"] = system_message
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def pretooluse_bash_hook(hook_data: dict[str, object]) -> JsonObject:
|
||||
"""Handle PreToolUse hook for Bash commands.
|
||||
|
||||
Args:
|
||||
hook_data: Hook input data containing tool_name and tool_input.
|
||||
|
||||
Returns:
|
||||
Hook response with permission decision.
|
||||
"""
|
||||
tool_name = str(hook_data.get("tool_name", ""))
|
||||
|
||||
# Only analyze Bash commands
|
||||
if tool_name != "Bash":
|
||||
return _create_hook_response("PreToolUse", "allow")
|
||||
|
||||
tool_input_raw = hook_data.get("tool_input", {})
|
||||
if not isinstance(tool_input_raw, dict):
|
||||
return _create_hook_response("PreToolUse", "allow")
|
||||
|
||||
tool_input: dict[str, object] = dict(tool_input_raw)
|
||||
command = str(tool_input.get("command", ""))
|
||||
|
||||
if not command:
|
||||
return _create_hook_response("PreToolUse", "allow")
|
||||
|
||||
# Analyze command for violations
|
||||
should_block, violations = _analyze_bash_command(command)
|
||||
|
||||
if not should_block:
|
||||
return _create_hook_response("PreToolUse", "allow")
|
||||
|
||||
# Build denial message
|
||||
violation_text = "\n".join(f" {v}" for v in violations)
|
||||
message = (
|
||||
f"🚫 Shell Command Blocked\n\n"
|
||||
f"Violations:\n{violation_text}\n\n"
|
||||
f"Command: {command[:200]}{'...' if len(command) > 200 else ''}\n\n"
|
||||
f"Use Edit/Write tools to modify Python files with proper type safety."
|
||||
)
|
||||
|
||||
return _create_hook_response(
|
||||
"PreToolUse",
|
||||
"deny",
|
||||
message,
|
||||
message,
|
||||
)
|
||||
|
||||
|
||||
def posttooluse_bash_hook(hook_data: dict[str, object]) -> JsonObject:
|
||||
"""Handle PostToolUse hook for Bash commands.
|
||||
|
||||
Args:
|
||||
hook_data: Hook output data containing tool_response.
|
||||
|
||||
Returns:
|
||||
Hook response with decision.
|
||||
"""
|
||||
tool_name = str(hook_data.get("tool_name", ""))
|
||||
|
||||
# Only analyze Bash commands
|
||||
if tool_name != "Bash":
|
||||
return _create_hook_response("PostToolUse")
|
||||
|
||||
# Extract command from hook data
|
||||
tool_input_raw = hook_data.get("tool_input", {})
|
||||
if not isinstance(tool_input_raw, dict):
|
||||
return _create_hook_response("PostToolUse")
|
||||
|
||||
tool_input: dict[str, object] = dict(tool_input_raw)
|
||||
command = str(tool_input.get("command", ""))
|
||||
|
||||
# Check if command modified any Python files
|
||||
# Look for file paths in the command
|
||||
python_files: list[str] = []
|
||||
for match in re.finditer(r"([^\s]+\.pyi?)\b", command):
|
||||
file_path = match.group(1)
|
||||
if Path(file_path).exists():
|
||||
python_files.append(file_path)
|
||||
|
||||
if not python_files:
|
||||
return _create_hook_response("PostToolUse")
|
||||
|
||||
# Scan modified files for violations
|
||||
violations: list[str] = []
|
||||
for file_path in python_files:
|
||||
try:
|
||||
with open(file_path, encoding="utf-8") as file_handle:
|
||||
content = file_handle.read()
|
||||
|
||||
has_violation, violation_type = _contains_forbidden_pattern(content)
|
||||
if has_violation:
|
||||
violations.append(
|
||||
f"⛔ File '{Path(file_path).name}' contains {violation_type}",
|
||||
)
|
||||
except (OSError, UnicodeDecodeError):
|
||||
# If we can't read the file, skip it
|
||||
continue
|
||||
|
||||
if violations:
|
||||
violation_text = "\n".join(f" {v}" for v in violations)
|
||||
message = (
|
||||
f"🚫 Post-Execution Violation Detected\n\n"
|
||||
f"Violations:\n{violation_text}\n\n"
|
||||
f"Shell command introduced forbidden patterns. "
|
||||
f"Please revert changes and use proper typing."
|
||||
)
|
||||
|
||||
return _create_hook_response(
|
||||
"PostToolUse",
|
||||
"",
|
||||
message,
|
||||
message,
|
||||
decision="block",
|
||||
)
|
||||
|
||||
return _create_hook_response("PostToolUse")
|
||||
|
||||
|
||||
def _get_staged_python_files() -> list[str]:
|
||||
"""Get list of staged Python files from git.
|
||||
|
||||
Returns:
|
||||
List of file paths that are staged and end with .py or .pyi
|
||||
"""
|
||||
git_path = which("git")
|
||||
if git_path is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Acquire file-based lock to prevent subprocess concurrency issues
|
||||
with _subprocess_lock(timeout=LOCK_TIMEOUT_SECONDS) as acquired:
|
||||
if not acquired:
|
||||
return []
|
||||
|
||||
# Safe: invokes git with fixed arguments, no user input interpolation.
|
||||
result = subprocess.run( # noqa: S603
|
||||
[git_path, "diff", "--name-only", "--cached"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
|
||||
return [
|
||||
file_name.strip()
|
||||
for file_name in result.stdout.split("\n")
|
||||
if file_name.strip() and file_name.strip().endswith((".py", ".pyi"))
|
||||
]
|
||||
except (OSError, subprocess.SubprocessError, TimeoutError):
|
||||
return []
|
||||
|
||||
|
||||
def _check_files_for_violations(file_paths: list[str]) -> list[str]:
|
||||
"""Scan files for forbidden patterns.
|
||||
|
||||
Args:
|
||||
file_paths: List of file paths to check.
|
||||
|
||||
Returns:
|
||||
List of violation messages.
|
||||
"""
|
||||
violations: list[str] = []
|
||||
|
||||
for file_path in file_paths:
|
||||
if not Path(file_path).exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(file_path, encoding="utf-8") as file_handle:
|
||||
content = file_handle.read()
|
||||
|
||||
has_violation, violation_type = _contains_forbidden_pattern(content)
|
||||
if has_violation:
|
||||
violations.append(f"⛔ {file_path}: {violation_type}")
|
||||
except (OSError, UnicodeDecodeError):
|
||||
continue
|
||||
|
||||
return violations
|
||||
|
||||
|
||||
def stop_hook(_hook_data: dict[str, object]) -> JsonObject:
|
||||
"""Handle Stop hook - final validation before completion.
|
||||
|
||||
Args:
|
||||
_hook_data: Stop hook data (unused).
|
||||
|
||||
Returns:
|
||||
Hook response with decision.
|
||||
"""
|
||||
# Get list of changed files from git
|
||||
try:
|
||||
changed_files = _get_staged_python_files()
|
||||
if not changed_files:
|
||||
return _create_hook_response("Stop", decision="approve")
|
||||
|
||||
if violations := _check_files_for_violations(changed_files):
|
||||
violation_text = "\n".join(f" {v}" for v in violations)
|
||||
message = (
|
||||
f"🚫 Final Validation Failed\n\n"
|
||||
f"Violations:\n{violation_text}\n\n"
|
||||
f"Please remove forbidden patterns before completing."
|
||||
)
|
||||
|
||||
return _create_hook_response(
|
||||
"Stop",
|
||||
"",
|
||||
message,
|
||||
message,
|
||||
decision="block",
|
||||
)
|
||||
|
||||
return _create_hook_response("Stop", decision="approve")
|
||||
|
||||
except (OSError, subprocess.SubprocessError, TimeoutError) as exc:
|
||||
# If validation fails, allow but warn
|
||||
return _create_hook_response(
|
||||
"Stop",
|
||||
"",
|
||||
f"Warning: Final validation error: {exc}",
|
||||
f"Warning: Final validation error: {exc}",
|
||||
decision="approve",
|
||||
)
|
||||
|
||||
|
||||
def _handle_hook_exit_code(response: JsonObject) -> None:
|
||||
"""Handle exit codes based on hook response.
|
||||
|
||||
Args:
|
||||
response: Hook response object.
|
||||
"""
|
||||
hook_output_raw = response.get("hookSpecificOutput", {})
|
||||
if not hook_output_raw or not isinstance(hook_output_raw, dict):
|
||||
return
|
||||
|
||||
hook_output: dict[str, object] = hook_output_raw
|
||||
permission_decision = hook_output.get("permissionDecision")
|
||||
|
||||
if permission_decision == "deny":
|
||||
# Exit code 2: Blocking error
|
||||
reason = str(
|
||||
hook_output.get("permissionDecisionReason", "Permission denied"),
|
||||
)
|
||||
sys.stderr.write(reason)
|
||||
sys.stderr.flush()
|
||||
sys.exit(2)
|
||||
|
||||
if permission_decision == "ask":
|
||||
# Exit code 2 for ask decisions
|
||||
reason = str(
|
||||
hook_output.get("permissionDecisionReason", "Permission request"),
|
||||
)
|
||||
sys.stderr.write(reason)
|
||||
sys.stderr.flush()
|
||||
sys.exit(2)
|
||||
|
||||
# Check for Stop hook block decision
|
||||
if response.get("decision") == "block":
|
||||
reason = str(response.get("reason", "Validation failed"))
|
||||
sys.stderr.write(reason)
|
||||
sys.stderr.flush()
|
||||
sys.exit(2)
|
||||
|
||||
|
||||
def _detect_hook_type(hook_data: dict[str, object]) -> JsonObject:
|
||||
"""Detect hook type and route to appropriate handler.
|
||||
|
||||
Args:
|
||||
hook_data: Hook input data.
|
||||
|
||||
Returns:
|
||||
Hook response object.
|
||||
"""
|
||||
if "tool_response" in hook_data or "tool_output" in hook_data:
|
||||
return posttooluse_bash_hook(hook_data)
|
||||
|
||||
if hook_data.get("hookEventName") == "Stop":
|
||||
return stop_hook(hook_data)
|
||||
|
||||
return pretooluse_bash_hook(hook_data)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main hook entry point."""
|
||||
try:
|
||||
# Read hook input from stdin
|
||||
try:
|
||||
hook_data: dict[str, object] = json.load(sys.stdin)
|
||||
except json.JSONDecodeError:
|
||||
fallback_response: JsonObject = {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "allow",
|
||||
},
|
||||
}
|
||||
sys.stdout.write(json.dumps(fallback_response))
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
return
|
||||
|
||||
# Detect hook type and get response
|
||||
response = _detect_hook_type(hook_data)
|
||||
|
||||
# Write response to stdout with explicit flush
|
||||
sys.stdout.write(json.dumps(response))
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Handle exit codes
|
||||
_handle_hook_exit_code(response)
|
||||
|
||||
except (OSError, ValueError, subprocess.SubprocessError, TimeoutError) as exc:
|
||||
# Unexpected error - use exit code 1 (non-blocking error)
|
||||
sys.stderr.write(f"Hook error: {exc}")
|
||||
sys.stderr.flush()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,7 +6,7 @@
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "cd $CLAUDE_PROJECT_DIR/hooks && (python hook_chain.py --event pre 2>/dev/null || python3 hook_chain.py --event pre)"
|
||||
"command": "cd $CLAUDE_PROJECT_DIR/hooks && python3 cli.py --event pre"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -17,7 +17,18 @@
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "cd $CLAUDE_PROJECT_DIR/hooks && (python hook_chain.py --event post 2>/dev/null || python3 hook_chain.py --event post)"
|
||||
"command": "cd $CLAUDE_PROJECT_DIR/hooks && python3 cli.py --event post"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"Stop": [
|
||||
{
|
||||
"matcher": "",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "cd $CLAUDE_PROJECT_DIR/hooks && python3 cli.py --event stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
139
hooks/cli.py
Executable file
139
hooks/cli.py
Executable file
@@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env python3
|
||||
"""CLI entry point for Claude Code hooks.
|
||||
|
||||
This script serves as the single command invoked by Claude Code for all hook
|
||||
events (PreToolUse, PostToolUse, Stop). It reads JSON from stdin, routes to
|
||||
the appropriate handler, and outputs the response.
|
||||
|
||||
Usage:
|
||||
echo '{"tool_name": "Write", ...}' | python hooks/cli.py --event pre
|
||||
echo '{"tool_name": "Bash", ...}' | python hooks/cli.py --event post
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TypeGuard
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
# Try relative import first (when run as module), fall back to path manipulation
|
||||
try:
|
||||
from .facade import Guards
|
||||
except ImportError:
|
||||
# Add parent directory to path for imports (when run as script)
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from facade import Guards
|
||||
|
||||
|
||||
class PayloadValidator(BaseModel):
|
||||
"""Validates and normalizes JSON payload at boundary."""
|
||||
|
||||
tool_name: str = ""
|
||||
tool_input: dict[str, object] = {}
|
||||
tool_response: object = None
|
||||
tool_output: object = None
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
def _is_dict(value: object) -> TypeGuard[dict[str, object]]:
|
||||
"""Type guard to narrow dict values."""
|
||||
return isinstance(value, dict)
|
||||
|
||||
|
||||
def _normalize_dict(data: object) -> dict[str, object]:
|
||||
"""Normalize untyped dict to dict[str, object] using Pydantic validation.
|
||||
|
||||
This converts JSON-deserialized data (which has Unknown types) to a
|
||||
strongly-typed dict using Pydantic at the boundary.
|
||||
"""
|
||||
try:
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
validated = PayloadValidator.model_validate(data)
|
||||
return validated.model_dump(exclude_none=True)
|
||||
except ValidationError:
|
||||
return {}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main CLI entry point for hook processing."""
|
||||
parser = argparse.ArgumentParser(description="Claude Code unified hook handler")
|
||||
parser.add_argument(
|
||||
"--event",
|
||||
choices={"pre", "post", "stop"},
|
||||
required=True,
|
||||
help="Hook event type to handle",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Read hook payload from stdin
|
||||
raw_input = sys.stdin.read()
|
||||
if not raw_input.strip():
|
||||
# Empty input - return default response
|
||||
payload: dict[str, object] = {}
|
||||
else:
|
||||
try:
|
||||
parsed = json.loads(raw_input)
|
||||
payload = _normalize_dict(parsed)
|
||||
except json.JSONDecodeError:
|
||||
# Invalid JSON - return default response
|
||||
payload = {}
|
||||
|
||||
# Initialize guards and route to appropriate handler
|
||||
guards = Guards()
|
||||
|
||||
if args.event == "pre":
|
||||
response = guards.handle_pretooluse(payload)
|
||||
elif args.event == "post":
|
||||
response = guards.handle_posttooluse(payload)
|
||||
else: # stop
|
||||
response = guards.handle_stop(payload)
|
||||
|
||||
# Output response as JSON
|
||||
sys.stdout.write(json.dumps(response))
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Check if we should exit with error code
|
||||
hook_output = response.get("hookSpecificOutput", {})
|
||||
if _is_dict(hook_output):
|
||||
permission = hook_output.get("permissionDecision")
|
||||
if permission == "deny":
|
||||
reason = hook_output.get(
|
||||
"permissionDecisionReason", "Permission denied",
|
||||
)
|
||||
sys.stderr.write(str(reason))
|
||||
sys.stderr.flush()
|
||||
sys.exit(2)
|
||||
|
||||
if permission == "ask":
|
||||
reason = hook_output.get(
|
||||
"permissionDecisionReason", "Permission request",
|
||||
)
|
||||
sys.stderr.write(str(reason))
|
||||
sys.stderr.flush()
|
||||
sys.exit(2)
|
||||
|
||||
# Check for block decision
|
||||
if response.get("decision") == "block":
|
||||
reason = response.get("reason", "Validation failed")
|
||||
sys.stderr.write(str(reason))
|
||||
sys.stderr.flush()
|
||||
sys.exit(2)
|
||||
|
||||
except (KeyError, ValueError, TypeError, OSError, RuntimeError) as exc:
|
||||
# Unexpected error - log but don't block
|
||||
sys.stderr.write(f"Hook error: {exc}\n")
|
||||
sys.stderr.flush()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
215
hooks/facade.py
Normal file
215
hooks/facade.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Unified facade for Claude Code hooks with zero concurrency issues.
|
||||
|
||||
This module provides a single, well-organized entry point for all Claude Code
|
||||
hooks (PreToolUse, PostToolUse, Stop) with built-in protection against concurrency
|
||||
errors through file-based locking and sequential execution.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TypeGuard
|
||||
|
||||
# Handle both relative (module) and absolute (script) imports
|
||||
try:
|
||||
from .guards import BashCommandGuard, CodeQualityGuard
|
||||
from .lock_manager import LockManager
|
||||
from .models import HookResponse
|
||||
except ImportError:
|
||||
# Fallback for script execution
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from guards import BashCommandGuard, CodeQualityGuard
|
||||
from lock_manager import LockManager
|
||||
from models import HookResponse
|
||||
|
||||
|
||||
def _is_hook_output(value: object) -> TypeGuard[dict[str, object]]:
|
||||
"""Type guard to safely narrow hook output values."""
|
||||
return isinstance(value, dict)
|
||||
|
||||
|
||||
class Guards:
|
||||
"""Unified hook system for Claude Code with concurrency-safe execution.
|
||||
|
||||
This facade coordinates all guard validations through a single entry point,
|
||||
ensuring sequential execution and atomic locking to prevent race conditions.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from hooks import Guards
|
||||
|
||||
guards = Guards()
|
||||
payload = json.load(sys.stdin)
|
||||
response = guards.handle_pretooluse(payload)
|
||||
print(json.dumps(response))
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize guards with their dependencies."""
|
||||
self._bash_guard = BashCommandGuard()
|
||||
self._quality_guard = CodeQualityGuard()
|
||||
|
||||
def handle_pretooluse(self, payload: dict[str, object]) -> HookResponse:
|
||||
"""Handle PreToolUse hook events sequentially.
|
||||
|
||||
Executes guards in order with file-based locking to prevent
|
||||
concurrent execution issues. Short-circuits on first denial.
|
||||
|
||||
Args:
|
||||
payload: Hook payload from Claude Code containing tool metadata.
|
||||
|
||||
Returns:
|
||||
Hook response with permission decision (allow/deny/ask).
|
||||
"""
|
||||
# Acquire lock to prevent concurrent processing
|
||||
with LockManager.acquire(timeout=10.0) as acquired:
|
||||
if not acquired:
|
||||
# Lock timeout - return default allow to not block user
|
||||
return self._default_response("PreToolUse", "allow")
|
||||
|
||||
# Execute guards sequentially
|
||||
tool_name = str(payload.get("tool_name", ""))
|
||||
|
||||
# Bash commands: check for type safety violations
|
||||
if tool_name == "Bash":
|
||||
response = self._bash_guard.pretooluse(payload)
|
||||
# Short-circuit if denied
|
||||
hook_output = response.get("hookSpecificOutput")
|
||||
if _is_hook_output(hook_output):
|
||||
decision = hook_output.get("permissionDecision")
|
||||
if decision == "deny":
|
||||
return response
|
||||
|
||||
# Code writes: check for duplicates, complexity, modernization
|
||||
if tool_name in {"Write", "Edit", "MultiEdit"}:
|
||||
response = self._quality_guard.pretooluse(payload)
|
||||
# Short-circuit if denied
|
||||
hook_output = response.get("hookSpecificOutput")
|
||||
if _is_hook_output(hook_output):
|
||||
decision = hook_output.get("permissionDecision")
|
||||
if decision == "deny":
|
||||
return response
|
||||
|
||||
# All guards passed
|
||||
return self._default_response("PreToolUse", "allow")
|
||||
|
||||
def handle_posttooluse(self, payload: dict[str, object]) -> HookResponse:
|
||||
"""Handle PostToolUse hook events sequentially.
|
||||
|
||||
Verifies code quality after writes and logs bash commands.
|
||||
Executes guards with file-based locking for safety.
|
||||
|
||||
Args:
|
||||
payload: Hook payload from Claude Code containing tool results.
|
||||
|
||||
Returns:
|
||||
Hook response with verification decision (approve/block).
|
||||
"""
|
||||
# Acquire lock to prevent concurrent processing
|
||||
with LockManager.acquire(timeout=10.0) as acquired:
|
||||
if not acquired:
|
||||
# Lock timeout - return default approval
|
||||
return self._default_response("PostToolUse")
|
||||
|
||||
tool_name = str(payload.get("tool_name", ""))
|
||||
|
||||
# Bash: verify no violations were introduced + log command
|
||||
if tool_name == "Bash":
|
||||
response = self._bash_guard.posttooluse(payload)
|
||||
# Block if violations detected
|
||||
if response.get("decision") == "block":
|
||||
return response
|
||||
# Log successful command
|
||||
self._log_bash_command(payload)
|
||||
|
||||
# Code writes: verify quality post-write
|
||||
if tool_name in {"Write", "Edit", "MultiEdit"}:
|
||||
response = self._quality_guard.posttooluse(payload)
|
||||
# Block if violations detected
|
||||
if response.get("decision") == "block":
|
||||
return response
|
||||
|
||||
# All verifications passed
|
||||
return self._default_response("PostToolUse")
|
||||
|
||||
def handle_stop(self, payload: dict[str, object]) -> HookResponse:
|
||||
"""Handle Stop hook for final validation.
|
||||
|
||||
Runs final checks before task completion with file locking.
|
||||
|
||||
Args:
|
||||
payload: Stop hook payload (minimal data).
|
||||
|
||||
Returns:
|
||||
Hook response with approval/block decision.
|
||||
"""
|
||||
# Acquire lock for final validation
|
||||
with LockManager.acquire(timeout=10.0) as acquired:
|
||||
if not acquired:
|
||||
# Lock timeout - allow completion
|
||||
return self._default_response("Stop", decision="approve")
|
||||
|
||||
# Bash guard can do final validation on staged files
|
||||
return self._bash_guard.stop(payload)
|
||||
|
||||
@staticmethod
|
||||
def _default_response(
|
||||
event_name: str,
|
||||
permission: str = "",
|
||||
decision: str = "",
|
||||
) -> HookResponse:
|
||||
"""Create a default pass-through response.
|
||||
|
||||
Args:
|
||||
event_name: Hook event name (PreToolUse, PostToolUse, Stop).
|
||||
permission: Permission for PreToolUse (allow/deny/ask).
|
||||
decision: Decision for PostToolUse/Stop (approve/block).
|
||||
|
||||
Returns:
|
||||
Standard hook response.
|
||||
"""
|
||||
hook_output: dict[str, object] = {"hookEventName": event_name}
|
||||
|
||||
if permission:
|
||||
hook_output["permissionDecision"] = permission
|
||||
|
||||
response: HookResponse = {"hookSpecificOutput": hook_output}
|
||||
|
||||
if permission:
|
||||
response["permissionDecision"] = permission
|
||||
|
||||
if decision:
|
||||
response["decision"] = decision
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _log_bash_command(payload: dict[str, object]) -> None:
|
||||
"""Log successful bash commands to audit trail.
|
||||
|
||||
Args:
|
||||
payload: Hook payload containing command details.
|
||||
"""
|
||||
tool_input = payload.get("tool_input")
|
||||
if not _is_hook_output(tool_input):
|
||||
return
|
||||
|
||||
command = tool_input.get("command")
|
||||
if not isinstance(command, str) or not command.strip():
|
||||
return
|
||||
|
||||
description_raw = tool_input.get("description")
|
||||
description = (
|
||||
description_raw
|
||||
if isinstance(description_raw, str) and description_raw.strip()
|
||||
else "No description"
|
||||
)
|
||||
|
||||
log_path = Path.home() / ".claude" / "bash-command-log.txt"
|
||||
try:
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with log_path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(f"{command} - {description}\n")
|
||||
except OSError:
|
||||
# Logging is best-effort; ignore filesystem errors
|
||||
pass
|
||||
6
hooks/guards/__init__.py
Normal file
6
hooks/guards/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Guard implementations for Claude Code hook validation."""
|
||||
|
||||
from .bash_guard import BashCommandGuard
|
||||
from .quality_guard import CodeQualityGuard
|
||||
|
||||
__all__ = ["BashCommandGuard", "CodeQualityGuard"]
|
||||
445
hooks/guards/bash_guard.py
Normal file
445
hooks/guards/bash_guard.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""Shell command guard for Claude Code PreToolUse/PostToolUse hooks.
|
||||
|
||||
Prevents circumvention of type safety rules via shell commands that could inject
|
||||
'Any' types or type ignore comments into Python files.
|
||||
"""
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from shutil import which
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
# Handle both relative (module) and absolute (script) imports
|
||||
try:
|
||||
from ..lock_manager import LockManager
|
||||
from ..models import HookResponse
|
||||
from .bash_guard_constants import (
|
||||
DANGEROUS_SHELL_PATTERNS,
|
||||
FORBIDDEN_PATTERNS,
|
||||
PYTHON_FILE_PATTERNS,
|
||||
TEMPORARY_DIR_PATTERNS,
|
||||
)
|
||||
except ImportError:
|
||||
# Fallback for script execution
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from bash_guard_constants import (
|
||||
DANGEROUS_SHELL_PATTERNS,
|
||||
FORBIDDEN_PATTERNS,
|
||||
PYTHON_FILE_PATTERNS,
|
||||
TEMPORARY_DIR_PATTERNS,
|
||||
)
|
||||
from lock_manager import LockManager
|
||||
from models import HookResponse
|
||||
|
||||
|
||||
class ToolInputValidator(BaseModel):
|
||||
"""Validates and normalizes tool_input at boundary."""
|
||||
|
||||
command: str = ""
|
||||
description: str = ""
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
def _normalize_tool_input(data: object) -> dict[str, object]:
|
||||
"""Normalize tool_input to dict[str, object] using Pydantic validation.
|
||||
|
||||
Converts untyped dict from JSON deserialization to strongly-typed dict
|
||||
by validating structure at the boundary.
|
||||
"""
|
||||
try:
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
validated = ToolInputValidator.model_validate(data)
|
||||
return validated.model_dump(exclude_none=True)
|
||||
except ValidationError:
|
||||
return {}
|
||||
|
||||
|
||||
class BashCommandGuard:
|
||||
"""Validates bash commands for type safety violations."""
|
||||
|
||||
@staticmethod
|
||||
def _contains_forbidden_pattern(text: str) -> tuple[bool, str | None]:
|
||||
"""Check if text contains any forbidden patterns.
|
||||
|
||||
Args:
|
||||
text: The text to check for forbidden patterns.
|
||||
|
||||
Returns:
|
||||
Tuple of (has_violation, matched_pattern_description)
|
||||
"""
|
||||
for pattern in FORBIDDEN_PATTERNS:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
if "Any" in pattern:
|
||||
return True, "typing.Any usage"
|
||||
if "type.*ignore" in pattern:
|
||||
return True, "type suppression comment"
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def _is_dangerous_shell_command(command: str) -> tuple[bool, str | None]:
|
||||
"""Check if shell command uses dangerous patterns.
|
||||
|
||||
Args:
|
||||
command: The shell command to analyze.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_dangerous, reason)
|
||||
"""
|
||||
# Check if command targets Python files
|
||||
targets_python = any(
|
||||
re.search(pattern, command) for pattern in PYTHON_FILE_PATTERNS
|
||||
)
|
||||
|
||||
if not targets_python:
|
||||
return False, None
|
||||
|
||||
# Allow operations on temporary files (they're not project files)
|
||||
if any(re.search(pattern, command) for pattern in TEMPORARY_DIR_PATTERNS):
|
||||
return False, None
|
||||
|
||||
# Check for dangerous shell patterns
|
||||
for pattern in DANGEROUS_SHELL_PATTERNS:
|
||||
if re.search(pattern, command):
|
||||
tool_match = re.search(
|
||||
r"\b(sed|awk|perl|ed|echo|printf|cat|tee|find|xargs|python|vim|nano|emacs)\b",
|
||||
pattern,
|
||||
)
|
||||
tool_name = tool_match[1] if tool_match else "shell utility"
|
||||
return True, f"Use of {tool_name} to modify Python files"
|
||||
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def _command_contains_forbidden_injection(command: str) -> tuple[bool, str | None]:
|
||||
"""Check if command attempts to inject forbidden patterns.
|
||||
|
||||
Args:
|
||||
command: The shell command to analyze.
|
||||
|
||||
Returns:
|
||||
Tuple of (has_injection, violation_description)
|
||||
"""
|
||||
# Check if the command itself contains forbidden patterns
|
||||
has_violation, violation_type = BashCommandGuard._contains_forbidden_pattern(
|
||||
command,
|
||||
)
|
||||
|
||||
if has_violation:
|
||||
return True, violation_type
|
||||
|
||||
# Check for encoded or escaped patterns
|
||||
decoded_cmd = command.replace("\\n", "\n").replace("\\t", "\t")
|
||||
decoded_cmd = re.sub(r"\\\s", " ", decoded_cmd)
|
||||
|
||||
has_violation, violation_type = BashCommandGuard._contains_forbidden_pattern(
|
||||
decoded_cmd,
|
||||
)
|
||||
if has_violation:
|
||||
return True, f"{violation_type} (escaped)"
|
||||
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def _analyze_bash_command(command: str) -> tuple[bool, list[str]]:
|
||||
"""Analyze bash command for safety violations.
|
||||
|
||||
Args:
|
||||
command: The bash command to analyze.
|
||||
|
||||
Returns:
|
||||
Tuple of (should_block, list_of_violations)
|
||||
"""
|
||||
violations: list[str] = []
|
||||
|
||||
# Check for forbidden pattern injection
|
||||
has_injection, injection_type = (
|
||||
BashCommandGuard._command_contains_forbidden_injection(command)
|
||||
)
|
||||
if has_injection:
|
||||
violations.append(f"⛔ Shell command attempts to inject {injection_type}")
|
||||
|
||||
# Check for dangerous shell patterns on Python files
|
||||
is_dangerous, danger_reason = BashCommandGuard._is_dangerous_shell_command(
|
||||
command,
|
||||
)
|
||||
if is_dangerous:
|
||||
violations.append(
|
||||
f"⛔ {danger_reason} is forbidden - use Edit/Write tools instead",
|
||||
)
|
||||
|
||||
return len(violations) > 0, violations
|
||||
|
||||
@staticmethod
|
||||
def _create_hook_response(
|
||||
event_name: str,
|
||||
permission: str = "",
|
||||
reason: str = "",
|
||||
system_message: str = "",
|
||||
*,
|
||||
decision: str | None = None,
|
||||
) -> HookResponse:
|
||||
"""Create standardized hook response.
|
||||
|
||||
Args:
|
||||
event_name: Name of the hook event (PreToolUse, PostToolUse, Stop).
|
||||
permission: Permission decision (allow, deny, ask).
|
||||
reason: Reason for the decision.
|
||||
system_message: System message to display.
|
||||
decision: Decision for PostToolUse/Stop hooks (approve, block).
|
||||
|
||||
Returns:
|
||||
JSON response object for the hook.
|
||||
"""
|
||||
hook_output: dict[str, object] = {
|
||||
"hookEventName": event_name,
|
||||
}
|
||||
|
||||
if permission:
|
||||
hook_output["permissionDecision"] = permission
|
||||
if reason:
|
||||
hook_output["permissionDecisionReason"] = reason
|
||||
|
||||
response: HookResponse = {
|
||||
"hookSpecificOutput": hook_output,
|
||||
}
|
||||
|
||||
if permission:
|
||||
response["permissionDecision"] = permission
|
||||
|
||||
if decision:
|
||||
response["decision"] = decision
|
||||
|
||||
if reason:
|
||||
response["reason"] = reason
|
||||
|
||||
if system_message:
|
||||
response["systemMessage"] = system_message
|
||||
|
||||
return response
|
||||
|
||||
def pretooluse(self, hook_data: dict[str, object]) -> HookResponse:
|
||||
"""Handle PreToolUse hook for Bash commands.
|
||||
|
||||
Args:
|
||||
hook_data: Hook input data containing tool_name and tool_input.
|
||||
|
||||
Returns:
|
||||
Hook response with permission decision.
|
||||
"""
|
||||
tool_name = str(hook_data.get("tool_name", ""))
|
||||
|
||||
# Only analyze Bash commands
|
||||
if tool_name != "Bash":
|
||||
return self._create_hook_response("PreToolUse", "allow")
|
||||
|
||||
tool_input_raw = hook_data.get("tool_input", {})
|
||||
tool_input = _normalize_tool_input(tool_input_raw)
|
||||
command = str(tool_input.get("command", ""))
|
||||
|
||||
if not command:
|
||||
return self._create_hook_response("PreToolUse", "allow")
|
||||
|
||||
# Analyze command for violations
|
||||
should_block, violations = self._analyze_bash_command(command)
|
||||
|
||||
if not should_block:
|
||||
return self._create_hook_response("PreToolUse", "allow")
|
||||
|
||||
# Build denial message
|
||||
violation_text = "\n".join(f" {v}" for v in violations)
|
||||
message = (
|
||||
f"🚫 Shell Command Blocked\n\n"
|
||||
f"Violations:\n{violation_text}\n\n"
|
||||
f"Command: {command[:200]}{'...' if len(command) > 200 else ''}\n\n"
|
||||
f"Use Edit/Write tools to modify Python files with proper type safety."
|
||||
)
|
||||
|
||||
return self._create_hook_response(
|
||||
"PreToolUse",
|
||||
"deny",
|
||||
message,
|
||||
message,
|
||||
)
|
||||
|
||||
def posttooluse(self, hook_data: dict[str, object]) -> HookResponse:
|
||||
"""Handle PostToolUse hook for Bash commands.
|
||||
|
||||
Args:
|
||||
hook_data: Hook output data containing tool_response.
|
||||
|
||||
Returns:
|
||||
Hook response with decision.
|
||||
"""
|
||||
tool_name = str(hook_data.get("tool_name", ""))
|
||||
|
||||
# Only analyze Bash commands
|
||||
if tool_name != "Bash":
|
||||
return self._create_hook_response("PostToolUse")
|
||||
|
||||
# Extract command from hook data
|
||||
tool_input_raw = hook_data.get("tool_input", {})
|
||||
tool_input = _normalize_tool_input(tool_input_raw)
|
||||
command = str(tool_input.get("command", ""))
|
||||
|
||||
# Check if command modified any Python files
|
||||
python_files: list[str] = []
|
||||
for match in re.finditer(r"([^\s]+\.pyi?)\b", command):
|
||||
file_path = match.group(1)
|
||||
if Path(file_path).exists():
|
||||
python_files.append(file_path)
|
||||
|
||||
if not python_files:
|
||||
return self._create_hook_response("PostToolUse")
|
||||
|
||||
# Scan modified files for violations
|
||||
violations: list[str] = []
|
||||
for file_path in python_files:
|
||||
try:
|
||||
with open(file_path, encoding="utf-8") as file_handle:
|
||||
content = file_handle.read()
|
||||
|
||||
has_violation, violation_type = self._contains_forbidden_pattern(
|
||||
content,
|
||||
)
|
||||
if has_violation:
|
||||
violations.append(
|
||||
f"⛔ File '{Path(file_path).name}' contains {violation_type}",
|
||||
)
|
||||
except (OSError, UnicodeDecodeError):
|
||||
continue
|
||||
|
||||
if violations:
|
||||
violation_text = "\n".join(f" {v}" for v in violations)
|
||||
message = (
|
||||
f"🚫 Post-Execution Violation Detected\n\n"
|
||||
f"Violations:\n{violation_text}\n\n"
|
||||
f"Shell command introduced forbidden patterns. "
|
||||
f"Please revert changes and use proper typing."
|
||||
)
|
||||
|
||||
return self._create_hook_response(
|
||||
"PostToolUse",
|
||||
"",
|
||||
message,
|
||||
message,
|
||||
decision="block",
|
||||
)
|
||||
|
||||
return self._create_hook_response("PostToolUse")
|
||||
|
||||
def _get_staged_python_files(self) -> list[str]:
|
||||
"""Get list of staged Python files from git.
|
||||
|
||||
Returns:
|
||||
List of file paths that are staged and end with .py or .pyi
|
||||
"""
|
||||
git_path = which("git")
|
||||
if git_path is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Acquire file-based lock to prevent subprocess concurrency issues
|
||||
with LockManager.acquire(timeout=10.0) as acquired:
|
||||
if not acquired:
|
||||
return []
|
||||
|
||||
# Safe: invokes git with fixed arguments, no user input interpolation.
|
||||
result = subprocess.run( # noqa: S603
|
||||
[git_path, "diff", "--name-only", "--cached"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
|
||||
return [
|
||||
file_name.strip()
|
||||
for file_name in result.stdout.split("\n")
|
||||
if file_name.strip() and file_name.strip().endswith((".py", ".pyi"))
|
||||
]
|
||||
except (OSError, subprocess.SubprocessError, TimeoutError):
|
||||
return []
|
||||
|
||||
def _check_files_for_violations(self, file_paths: list[str]) -> list[str]:
|
||||
"""Scan files for forbidden patterns.
|
||||
|
||||
Args:
|
||||
file_paths: List of file paths to check.
|
||||
|
||||
Returns:
|
||||
List of violation messages.
|
||||
"""
|
||||
violations: list[str] = []
|
||||
|
||||
for file_path in file_paths:
|
||||
if not Path(file_path).exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(file_path, encoding="utf-8") as file_handle:
|
||||
content = file_handle.read()
|
||||
|
||||
has_violation, violation_type = self._contains_forbidden_pattern(
|
||||
content,
|
||||
)
|
||||
if has_violation:
|
||||
violations.append(f"⛔ {file_path}: {violation_type}")
|
||||
except (OSError, UnicodeDecodeError):
|
||||
continue
|
||||
|
||||
return violations
|
||||
|
||||
def stop(self, _hook_data: dict[str, object]) -> HookResponse:
|
||||
"""Handle Stop hook - final validation before completion.
|
||||
|
||||
Args:
|
||||
_hook_data: Stop hook data (unused).
|
||||
|
||||
Returns:
|
||||
Hook response with decision.
|
||||
"""
|
||||
# Get list of changed files from git
|
||||
try:
|
||||
changed_files = self._get_staged_python_files()
|
||||
if not changed_files:
|
||||
return self._create_hook_response("Stop", decision="approve")
|
||||
|
||||
if violations := self._check_files_for_violations(changed_files):
|
||||
violation_text = "\n".join(f" {v}" for v in violations)
|
||||
message = (
|
||||
f"🚫 Final Validation Failed\n\n"
|
||||
f"Violations:\n{violation_text}\n\n"
|
||||
f"Please remove forbidden patterns before completing."
|
||||
)
|
||||
|
||||
return self._create_hook_response(
|
||||
"Stop",
|
||||
"",
|
||||
message,
|
||||
message,
|
||||
decision="block",
|
||||
)
|
||||
|
||||
return self._create_hook_response("Stop", decision="approve")
|
||||
|
||||
except (OSError, subprocess.SubprocessError, TimeoutError) as exc:
|
||||
# If validation fails, allow but warn
|
||||
return self._create_hook_response(
|
||||
"Stop",
|
||||
"",
|
||||
f"Warning: Final validation error: {exc}",
|
||||
f"Warning: Final validation error: {exc}",
|
||||
decision="approve",
|
||||
)
|
||||
@@ -51,6 +51,14 @@ PYTHON_FILE_PATTERNS = [
|
||||
r"\.pyi\b",
|
||||
]
|
||||
|
||||
# Regex patterns for temporary directory paths (for matching in commands, not creating)
|
||||
TEMPORARY_DIR_PATTERNS = [
|
||||
r"tmp/", # Match tmp directories
|
||||
r"var/tmp/", # Match /var/tmp directories
|
||||
r"\.tmp/", # Match .tmp directories
|
||||
r"tempfile", # Match tempfile references
|
||||
]
|
||||
|
||||
# Pattern descriptions for error messages
|
||||
FORBIDDEN_PATTERN_DESCRIPTIONS = {
|
||||
"Any": "typing.Any usage",
|
||||
58
hooks/guards/quality_guard.py
Normal file
58
hooks/guards/quality_guard.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Code quality guard for Claude Code PreToolUse/PostToolUse hooks.
|
||||
|
||||
Enforces quality standards by preventing duplicate, complex, or non-modernized code.
|
||||
Note: Currently provides minimal pass-through validation.
|
||||
Full code quality analysis can be added by integrating the claude-quality toolkit.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Setup path for imports - try both relative and absolute
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from models import HookResponse
|
||||
|
||||
|
||||
# Fallback: define minimal versions
|
||||
def _pretooluse_hook(_hook_data: dict[str, object]) -> HookResponse:
|
||||
"""Minimal pretooluse handler."""
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "allow",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _posttooluse_hook(_hook_data: dict[str, object]) -> HookResponse:
|
||||
"""Minimal posttooluse handler."""
|
||||
return {
|
||||
"hookSpecificOutput": {"hookEventName": "PostToolUse"},
|
||||
}
|
||||
|
||||
|
||||
class CodeQualityGuard:
|
||||
"""Validates code quality through duplicate, complexity, modernization checks."""
|
||||
|
||||
def pretooluse(self, hook_data: dict[str, object]) -> HookResponse:
|
||||
"""Handle PreToolUse hook for quality analysis.
|
||||
|
||||
Args:
|
||||
hook_data: Hook input data containing tool_name and tool_input.
|
||||
|
||||
Returns:
|
||||
Hook response with permission decision.
|
||||
"""
|
||||
return _pretooluse_hook(hook_data)
|
||||
|
||||
def posttooluse(self, hook_data: dict[str, object]) -> HookResponse:
|
||||
"""Handle PostToolUse hook for quality verification.
|
||||
|
||||
Args:
|
||||
hook_data: Hook output data.
|
||||
|
||||
Returns:
|
||||
Hook response with decision.
|
||||
"""
|
||||
return _posttooluse_hook(hook_data)
|
||||
@@ -1,159 +0,0 @@
|
||||
"""Unified hook runner that chains existing guards without parallel execution.
|
||||
|
||||
This entry point lets Claude Code invoke a single command per hook event while
|
||||
still reusing the more specialized guards. It reacts to the hook payload to
|
||||
decide which guard(s) to run and propagates their output/exit codes so Claude
|
||||
continues to see the same responses. Post-tool Bash logging is handled here so
|
||||
the old jq pipeline is no longer required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
HookPayload = dict[str, Any]
|
||||
|
||||
|
||||
def _load_payload() -> tuple[str, HookPayload | None]:
|
||||
"""Read stdin payload once and return both raw text and parsed JSON."""
|
||||
|
||||
raw = sys.stdin.read()
|
||||
if not raw:
|
||||
return raw, None
|
||||
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return raw, None
|
||||
|
||||
return (raw, parsed) if isinstance(parsed, dict) else (raw, None)
|
||||
|
||||
|
||||
def _default_response(event: str) -> int:
|
||||
"""Emit the minimal pass-through JSON for hooks that we skip."""
|
||||
|
||||
hook_event = "PreToolUse" if event == "pre" else "PostToolUse"
|
||||
response: HookPayload = {"hookSpecificOutput": {"hookEventName": hook_event}}
|
||||
|
||||
if event == "pre":
|
||||
response["hookSpecificOutput"]["permissionDecision"] = "allow"
|
||||
|
||||
sys.stdout.write(json.dumps(response))
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
return 0
|
||||
|
||||
|
||||
def _run_guard(script_name: str, payload: str) -> int:
|
||||
"""Execute a sibling guard script sequentially and relay its output."""
|
||||
|
||||
script_path = Path(__file__).with_name(script_name)
|
||||
if not script_path.exists():
|
||||
raise FileNotFoundError(f"Missing guard script: {script_path}")
|
||||
|
||||
proc = subprocess.run( # noqa: S603
|
||||
[sys.executable, str(script_path)],
|
||||
input=payload,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if proc.stdout:
|
||||
sys.stdout.write(proc.stdout)
|
||||
if proc.stderr:
|
||||
sys.stderr.write(proc.stderr)
|
||||
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
return proc.returncode
|
||||
|
||||
|
||||
def _log_bash_command(payload: HookPayload) -> None:
|
||||
"""Append successful Bash commands to Claude's standard log file."""
|
||||
|
||||
tool_input = payload.get("tool_input")
|
||||
if not isinstance(tool_input, dict):
|
||||
return
|
||||
|
||||
command = tool_input.get("command")
|
||||
if not isinstance(command, str) or not command.strip():
|
||||
return
|
||||
|
||||
description = tool_input.get("description")
|
||||
if not isinstance(description, str) or not description.strip():
|
||||
description = "No description"
|
||||
|
||||
log_path = Path.home() / ".claude" / "bash-command-log.txt"
|
||||
try:
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with log_path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(f"{command} - {description}\n")
|
||||
except OSError:
|
||||
# Logging is best-effort; ignore filesystem errors.
|
||||
pass
|
||||
|
||||
|
||||
def _pre_hook(payload: HookPayload | None, raw: str) -> int:
|
||||
"""Handle PreToolUse events with sequential guard execution."""
|
||||
|
||||
if payload is None:
|
||||
return _default_response("pre")
|
||||
|
||||
tool_name = str(payload.get("tool_name", ""))
|
||||
|
||||
if tool_name in {"Write", "Edit", "MultiEdit"}:
|
||||
return _run_guard("code_quality_guard.py", raw)
|
||||
if tool_name == "Bash":
|
||||
return _run_guard("bash_command_guard.py", raw)
|
||||
|
||||
return _default_response("pre")
|
||||
|
||||
|
||||
def _post_hook(payload: HookPayload | None, raw: str) -> int:
|
||||
"""Handle PostToolUse events with sequential guard execution."""
|
||||
|
||||
if payload is None:
|
||||
return _default_response("post")
|
||||
|
||||
tool_name = str(payload.get("tool_name", ""))
|
||||
|
||||
if tool_name in {"Write", "Edit", "MultiEdit"}:
|
||||
return _run_guard("code_quality_guard.py", raw)
|
||||
if tool_name == "Bash":
|
||||
exit_code = _run_guard("bash_command_guard.py", raw)
|
||||
if exit_code == 0:
|
||||
_log_bash_command(payload)
|
||||
return exit_code
|
||||
|
||||
return _default_response("post")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Chain Claude Code hook guards")
|
||||
parser.add_argument(
|
||||
"--event",
|
||||
choices={"pre", "post"},
|
||||
required=True,
|
||||
help="Hook event type to handle.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
raw_payload, parsed_payload = _load_payload()
|
||||
|
||||
if args.event == "pre":
|
||||
exit_code = _pre_hook(parsed_payload, raw_payload)
|
||||
else:
|
||||
exit_code = _post_hook(parsed_payload, raw_payload)
|
||||
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover - CLI entry
|
||||
main()
|
||||
62
hooks/lock_manager.py
Normal file
62
hooks/lock_manager.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Centralized file-based locking for inter-process synchronization."""
|
||||
|
||||
import fcntl
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager, suppress
|
||||
from pathlib import Path
|
||||
from tempfile import gettempdir
|
||||
|
||||
# Lock configuration constants
|
||||
LOCK_TIMEOUT_SECONDS: float = 10.0
|
||||
LOCK_POLL_INTERVAL_SECONDS: float = 0.1
|
||||
|
||||
|
||||
class LockManager:
|
||||
"""Manages file-based locks for subprocess serialization."""
|
||||
|
||||
@staticmethod
|
||||
def _get_lock_file() -> Path:
|
||||
"""Get path to lock file for subprocess synchronization."""
|
||||
lock_dir = Path(gettempdir()) / ".claude_hooks"
|
||||
lock_dir.mkdir(exist_ok=True, mode=0o700)
|
||||
return lock_dir / "subprocess.lock"
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def acquire(
|
||||
timeout: float = LOCK_TIMEOUT_SECONDS,
|
||||
) -> Generator[bool, None, None]:
|
||||
"""Acquire file-based lock with timeout.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time in seconds to wait for lock acquisition.
|
||||
Non-positive values attempt single non-blocking acquisition.
|
||||
|
||||
Yields:
|
||||
True if lock was acquired, False if timeout occurred.
|
||||
"""
|
||||
lock_file = LockManager._get_lock_file()
|
||||
deadline = time.monotonic() + timeout if timeout and timeout > 0 else None
|
||||
acquired = False
|
||||
|
||||
with open(lock_file, "a") as f:
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
acquired = True
|
||||
break
|
||||
except OSError:
|
||||
if deadline is None:
|
||||
break
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
time.sleep(min(LOCK_POLL_INTERVAL_SECONDS, remaining))
|
||||
|
||||
yield acquired
|
||||
finally:
|
||||
if acquired:
|
||||
with suppress(OSError):
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
||||
File diff suppressed because it is too large
Load Diff
61
hooks/models.py
Normal file
61
hooks/models.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Shared type definitions and data models for hooks subsystem."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class HookPayloadDict(TypedDict, total=False):
|
||||
"""Normalized hook payload from JSON deserialization."""
|
||||
|
||||
tool_name: str
|
||||
tool_input: dict[str, object]
|
||||
tool_response: object
|
||||
tool_output: object
|
||||
|
||||
|
||||
class HookResponse(TypedDict, total=False):
|
||||
"""Standard hook response structure for Claude Code."""
|
||||
|
||||
hookSpecificOutput: dict[str, object]
|
||||
permissionDecision: str
|
||||
decision: str
|
||||
reason: str
|
||||
systemMessage: str
|
||||
|
||||
|
||||
class ToolInput(TypedDict, total=False):
|
||||
"""Tool input data within hook payload."""
|
||||
|
||||
file_path: str
|
||||
content: str
|
||||
command: str
|
||||
description: str
|
||||
|
||||
|
||||
class HookPayload(TypedDict, total=False):
|
||||
"""Standard hook payload structure from Claude Code."""
|
||||
|
||||
tool_name: str
|
||||
tool_input: ToolInput
|
||||
tool_response: object
|
||||
tool_output: object
|
||||
hookEventName: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisResult:
|
||||
"""Result from code analysis operations."""
|
||||
|
||||
status: str # 'pass', 'warn', 'block'
|
||||
violations: list[str]
|
||||
message: str
|
||||
code_context: dict[str, object] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuardDecision:
|
||||
"""Decision made by a guard."""
|
||||
|
||||
permission: str # 'allow', 'deny', 'ask'
|
||||
reason: str
|
||||
system_message: str = ""
|
||||
19593
logs/status_line.json
19593
logs/status_line.json
File diff suppressed because it is too large
Load Diff
@@ -152,7 +152,7 @@ exclude_lines = [
|
||||
include = ["src", "hooks", "tests"]
|
||||
extraPaths = ["hooks"]
|
||||
pythonVersion = "3.12"
|
||||
typeCheckingMode = "standard"
|
||||
typeCheckingMode = "strict"
|
||||
reportMissingTypeStubs = false
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
"venvPath": ".",
|
||||
"venv": ".venv",
|
||||
"pythonVersion": "3.12",
|
||||
"typeCheckingMode": "basic",
|
||||
"typeCheckingMode": "strict",
|
||||
"reportMissingImports": true,
|
||||
"reportMissingTypeStubs": false,
|
||||
"reportMissingTypeStubs": true,
|
||||
"reportMissingModuleSource": "warning"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user