Files
claude-scripts/hooks/code_quality_guard.py
2025-09-16 21:33:47 +00:00

608 lines
20 KiB
Python

#!/usr/bin/env python3
"""Unified quality hook for Claude Code supporting both PreToolUse and PostToolUse.
Prevents writing duplicate, complex, or non-modernized code and verifies quality
after writes.
"""
import hashlib
import json
import logging
import os
import re
import subprocess
import sys
from contextlib import suppress
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any
# Import internal duplicate detector
sys.path.insert(0, str(Path(__file__).parent))
from internal_duplicate_detector import detect_internal_duplicates
@dataclass
class QualityConfig:
"""Configuration for quality checks."""
# Core settings
duplicate_threshold: float = 0.7
duplicate_enabled: bool = True
complexity_threshold: int = 10
complexity_enabled: bool = True
modernization_enabled: bool = True
require_type_hints: bool = True
enforcement_mode: str = "strict" # strict/warn/permissive
# PostToolUse features
state_tracking_enabled: bool = False
cross_file_check_enabled: bool = False
verify_naming: bool = True
show_success: bool = False
# File patterns
skip_patterns: list[str] = None
def __post_init__(self):
if self.skip_patterns is None:
self.skip_patterns = ["test_", "_test.py", "/tests/", "/fixtures/"]
@classmethod
def from_env(cls) -> "QualityConfig":
"""Load config from environment variables."""
return cls(
duplicate_threshold=float(os.getenv("QUALITY_DUP_THRESHOLD", "0.7")),
duplicate_enabled=os.getenv("QUALITY_DUP_ENABLED", "true").lower()
== "true",
complexity_threshold=int(os.getenv("QUALITY_COMPLEXITY_THRESHOLD", "10")),
complexity_enabled=os.getenv("QUALITY_COMPLEXITY_ENABLED", "true").lower()
== "true",
modernization_enabled=os.getenv("QUALITY_MODERN_ENABLED", "true").lower()
== "true",
require_type_hints=os.getenv("QUALITY_REQUIRE_TYPES", "true").lower()
== "true",
enforcement_mode=os.getenv("QUALITY_ENFORCEMENT", "strict"),
state_tracking_enabled=os.getenv("QUALITY_STATE_TRACKING", "false").lower()
== "true",
cross_file_check_enabled=os.getenv(
"QUALITY_CROSS_FILE_CHECK",
"false",
).lower()
== "true",
verify_naming=os.getenv("QUALITY_VERIFY_NAMING", "true").lower() == "true",
show_success=os.getenv("QUALITY_SHOW_SUCCESS", "false").lower() == "true",
)
def should_skip_file(file_path: str, config: QualityConfig) -> bool:
"""Check if file should be skipped based on patterns."""
return any(pattern in file_path for pattern in config.skip_patterns)
def get_claude_quality_path() -> str:
"""Get claude-quality binary path."""
claude_quality = Path(__file__).parent.parent / ".venv/bin/claude-quality"
return str(claude_quality) if claude_quality.exists() else "claude-quality"
def analyze_code_quality(
content: str,
_file_path: str,
config: QualityConfig,
) -> dict[str, Any]:
"""Analyze code content using claude-quality toolkit."""
with NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
results = {}
claude_quality = get_claude_quality_path()
# First check for internal duplicates within the file
if config.duplicate_enabled:
internal_duplicates = detect_internal_duplicates(
content,
threshold=config.duplicate_threshold,
min_lines=4,
)
if internal_duplicates.get("duplicates"):
results["internal_duplicates"] = internal_duplicates
# Run complexity analysis
if config.complexity_enabled:
cmd = [
claude_quality,
"complexity",
tmp_path,
"--threshold",
str(config.complexity_threshold),
"--format",
"json",
]
try:
result = subprocess.run( # noqa: S603
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
)
if result.returncode == 0:
with suppress(json.JSONDecodeError):
results["complexity"] = json.loads(result.stdout)
except subprocess.TimeoutExpired:
pass # Command timed out
# Run modernization analysis
if config.modernization_enabled:
cmd = [
claude_quality,
"modernization",
tmp_path,
"--include-type-hints" if config.require_type_hints else "",
"--format",
"json",
]
cmd = [c for c in cmd if c] # Remove empty strings
try:
result = subprocess.run( # noqa: S603
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
)
if result.returncode == 0:
with suppress(json.JSONDecodeError):
results["modernization"] = json.loads(result.stdout)
except subprocess.TimeoutExpired:
pass # Command timed out
return results
finally:
Path(tmp_path).unlink(missing_ok=True)
def check_code_issues(
results: dict[str, Any],
config: QualityConfig,
) -> tuple[bool, list[str]]:
"""Check analysis results for issues that should block the operation."""
issues = []
# Check for internal duplicates (within the same file)
if "internal_duplicates" in results:
duplicates = results["internal_duplicates"].get("duplicates", [])
for dup in duplicates[:3]: # Show first 3
locations = ", ".join(
f"{loc['name']} ({loc['lines']})" for loc in dup.get("locations", [])
)
issues.append(
f"Internal duplication ({dup.get('similarity', 0):.0%} similar): "
f"{dup.get('description')} - {locations}",
)
# Check for complexity issues
if "complexity" in results:
summary = results["complexity"].get("summary", {})
avg_cc = summary.get("average_cyclomatic_complexity", 0)
if avg_cc > config.complexity_threshold:
issues.append(
f"High average complexity: CC={avg_cc:.1f} "
f"(threshold: {config.complexity_threshold})",
)
distribution = results["complexity"].get("distribution", {})
high_count = (
distribution.get("High", 0)
+ distribution.get("Very High", 0)
+ distribution.get("Extreme", 0)
)
if high_count > 0:
issues.append(f"Found {high_count} function(s) with high complexity")
# Check for modernization issues
if "modernization" in results:
files = results["modernization"].get("files", {})
total_issues = 0
issue_types = set()
for _file_path, file_issues in files.items():
if isinstance(file_issues, list):
total_issues += len(file_issues)
for issue in file_issues:
if isinstance(issue, dict):
issue_types.add(issue.get("issue_type", "unknown"))
# Only flag if there are non-type-hint issues or many type hint issues
non_type_issues = len(
[t for t in issue_types if "type" not in t and "typing" not in t],
)
type_issues = total_issues - non_type_issues
if non_type_issues > 0:
non_type_list = [
t for t in issue_types if "type" not in t and "typing" not in t
]
issues.append(
f"Modernization needed: {non_type_issues} non-type issues "
f"({', '.join(non_type_list)})",
)
elif config.require_type_hints and type_issues > 10:
issues.append(
f"Many missing type hints: {type_issues} functions/parameters "
"lacking annotations",
)
return len(issues) > 0, issues
def store_pre_state(file_path: str, content: str) -> None:
"""Store file state before modification for later comparison."""
import tempfile
cache_dir = Path(tempfile.gettempdir()) / ".quality_state"
cache_dir.mkdir(exist_ok=True, mode=0o700)
state = {
"file_path": file_path,
"timestamp": datetime.now(UTC).isoformat(),
"content_hash": hashlib.sha256(content.encode()).hexdigest(),
"lines": len(content.split("\n")),
"functions": content.count("def "),
"classes": content.count("class "),
}
cache_key = hashlib.sha256(file_path.encode()).hexdigest()[:8]
cache_file = cache_dir / f"{cache_key}_pre.json"
cache_file.write_text(json.dumps(state, indent=2))
def check_state_changes(file_path: str) -> list[str]:
"""Check for quality changes between pre and post states."""
import tempfile
issues = []
cache_dir = Path(tempfile.gettempdir()) / ".quality_state"
cache_key = hashlib.sha256(file_path.encode()).hexdigest()[:8]
pre_file = cache_dir / f"{cache_key}_pre.json"
if not pre_file.exists():
return issues
try:
pre_state = json.loads(pre_file.read_text())
try:
current_content = Path(file_path).read_text()
except (PermissionError, FileNotFoundError, OSError):
return issues # Can't compare if can't read file
current_lines = len(current_content.split("\n"))
current_functions = current_content.count("def ")
_current_classes = current_content.count("class ") # Future use
# Check for significant changes
if current_functions < pre_state.get("functions", 0):
issues.append(
f"⚠️ Reduced functions: {pre_state['functions']}{current_functions}",
)
if current_lines > pre_state.get("lines", 0) * 1.5:
issues.append(
"⚠️ File size increased significantly: "
f"{pre_state['lines']}{current_lines} lines",
)
except Exception: # noqa: BLE001
logging.debug("Could not analyze state changes for %s", file_path)
return issues
def check_cross_file_duplicates(file_path: str, config: QualityConfig) -> list[str]:
"""Check for duplicates across project files."""
issues = []
# Get project root
project_root = Path(file_path).parent
while project_root.parent != project_root:
if (project_root / ".git").exists() or (
project_root / "pyproject.toml"
).exists():
break
project_root = project_root.parent
claude_quality = get_claude_quality_path()
try:
result = subprocess.run( # noqa: S603
[
claude_quality,
"duplicates",
str(project_root),
"--threshold",
str(config.duplicate_threshold),
"--format",
"json",
],
check=False,
capture_output=True,
text=True,
timeout=60,
)
if result.returncode == 0:
data = json.loads(result.stdout)
duplicates = data.get("duplicates", [])
if any(str(file_path) in str(d) for d in duplicates):
issues.append("⚠️ Cross-file duplication detected")
except Exception: # noqa: BLE001
logging.debug("Could not check cross-file duplicates for %s", file_path)
return issues
def verify_naming_conventions(file_path: str) -> list[str]:
"""Verify PEP8 naming conventions."""
issues = []
try:
content = Path(file_path).read_text()
except (PermissionError, FileNotFoundError, OSError):
return issues # Can't check naming if can't read file
# Check function names (should be snake_case)
if bad_funcs := re.findall(
r"def\s+([A-Z][a-zA-Z0-9_]*|[a-z]+[A-Z][a-zA-Z0-9_]*)\s*\(",
content,
):
issues.append(f"⚠️ Non-PEP8 function names: {', '.join(bad_funcs[:3])}")
# Check class names (should be PascalCase)
if bad_classes := re.findall(r"class\s+([a-z][a-z0-9_]*)\s*[\(:]", content):
issues.append(f"⚠️ Non-PEP8 class names: {', '.join(bad_classes[:3])}")
return issues
def pretooluse_hook(hook_data: dict, config: QualityConfig) -> dict:
"""Handle PreToolUse hook - analyze content before write/edit."""
tool_name = hook_data.get("tool_name", "")
tool_input = hook_data.get("tool_input", {})
# Only analyze for write/edit tools
if tool_name not in ["Write", "Edit", "MultiEdit"]:
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow"
}
}
# Extract content based on tool type
content = None
file_path = tool_input.get("file_path", "")
if tool_name == "Write":
content = tool_input.get("content", "")
elif tool_name == "Edit":
content = tool_input.get("new_string", "")
elif tool_name == "MultiEdit":
edits = tool_input.get("edits", [])
content = "\n".join(edit.get("new_string", "") for edit in edits)
# Only analyze Python files
if not file_path or not file_path.endswith(".py") or not content:
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow"
}
}
# Skip analysis for configured patterns
if should_skip_file(file_path, config):
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow"
}
}
try:
# Store state if tracking enabled
if config.state_tracking_enabled:
store_pre_state(file_path, content)
# Run quality analysis
results = analyze_code_quality(content, file_path, config)
has_issues, issues = check_code_issues(results, config)
if has_issues:
# Prepare denial message
message = (
f"Code quality check failed for {Path(file_path).name}:\n"
+ "\n".join(f"{issue}" for issue in issues)
+ "\n\nFix these issues before writing the code."
)
# Make decision based on enforcement mode
if config.enforcement_mode == "strict":
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": message
}
}
if config.enforcement_mode == "warn":
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "ask",
"permissionDecisionReason": message
}
}
# permissive
return {
"systemMessage": f"⚠️ Quality Warning:\n{message}",
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow"
}
}
else:
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow"
}
}
except Exception as e: # noqa: BLE001
return {
"systemMessage": f"Warning: Code quality check failed with error: {e}",
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow"
}
}
def posttooluse_hook(hook_data: dict, config: QualityConfig) -> dict:
"""Handle PostToolUse hook - verify quality after write/edit."""
tool_name = hook_data.get("tool_name", "")
tool_output = hook_data.get("tool_response", {})
# Only process write/edit tools
if tool_name not in ["Write", "Edit", "MultiEdit"]:
return {
"hookSpecificOutput": {
"hookEventName": "PostToolUse"
}
}
# Extract file path from output
file_path = None
if isinstance(tool_output, dict):
file_path = tool_output.get("file_path", "") or tool_output.get("path", "")
elif isinstance(tool_output, str):
match = re.search(r"([/\w\-_.]+\.py)", tool_output)
if match:
file_path = match[1]
if not file_path or not file_path.endswith(".py"):
return {
"hookSpecificOutput": {
"hookEventName": "PostToolUse"
}
}
if not Path(file_path).exists():
return {
"hookSpecificOutput": {
"hookEventName": "PostToolUse"
}
}
issues = []
# Check state changes if tracking enabled
if config.state_tracking_enabled:
delta_issues = check_state_changes(file_path)
issues.extend(delta_issues)
# Run cross-file duplicate detection if enabled
if config.cross_file_check_enabled:
cross_file_issues = check_cross_file_duplicates(file_path, config)
issues.extend(cross_file_issues)
# Verify naming conventions if enabled
if config.verify_naming:
naming_issues = verify_naming_conventions(file_path)
issues.extend(naming_issues)
# Format response
if issues:
message = (
f"📝 Post-write quality notes for {Path(file_path).name}:\n"
+ "\n".join(issues)
)
return {
"systemMessage": message,
"hookSpecificOutput": {
"hookEventName": "PostToolUse",
"additionalContext": message
}
}
if config.show_success:
message = f"{Path(file_path).name} passed post-write verification"
return {
"systemMessage": message,
"hookSpecificOutput": {
"hookEventName": "PostToolUse"
}
}
return {
"hookSpecificOutput": {
"hookEventName": "PostToolUse"
}
}
def main() -> None:
"""Main hook entry point."""
try:
# Load configuration
config = QualityConfig.from_env()
# Read hook input from stdin
try:
hook_data = json.load(sys.stdin)
except json.JSONDecodeError:
print(json.dumps({
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow"
}
})) # noqa: T201
return
# Detect hook type based on tool_response (PostToolUse) vs tool_input (PreToolUse)
if "tool_response" in hook_data:
# PostToolUse hook
response = posttooluse_hook(hook_data, config)
else:
# PreToolUse hook
response = pretooluse_hook(hook_data, config)
print(json.dumps(response)) # noqa: T201
# Handle exit codes based on hook output
hook_output = response.get("hookSpecificOutput", {})
permission_decision = hook_output.get("permissionDecision")
if permission_decision == "deny":
# Exit code 2: Blocking error - stderr fed back to Claude
reason = hook_output.get("permissionDecisionReason", "Permission denied")
sys.stderr.write(reason)
sys.exit(2)
elif permission_decision == "ask":
# Also use exit code 2 for ask decisions to ensure Claude sees the message
reason = hook_output.get("permissionDecisionReason", "Permission request")
sys.stderr.write(reason)
sys.exit(2)
# Exit code 0: Success (default)
except Exception as e: # noqa: BLE001
# Unexpected error - use exit code 1 (non-blocking error)
sys.stderr.write(f"Hook error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()