536 lines
18 KiB
Python
536 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""Unified quality hook for Claude Code supporting both PreToolUse and PostToolUse.
|
|
|
|
Prevents writing duplicate, complex, or non-modernized code and verifies quality
|
|
after writes.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
from contextlib import suppress
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
from tempfile import NamedTemporaryFile
|
|
from typing import Any
|
|
|
|
# Import internal duplicate detector
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
from internal_duplicate_detector import detect_internal_duplicates
|
|
|
|
|
|
@dataclass
|
|
class QualityConfig:
|
|
"""Configuration for quality checks."""
|
|
|
|
# Core settings
|
|
duplicate_threshold: float = 0.7
|
|
duplicate_enabled: bool = True
|
|
complexity_threshold: int = 10
|
|
complexity_enabled: bool = True
|
|
modernization_enabled: bool = True
|
|
require_type_hints: bool = True
|
|
enforcement_mode: str = "strict" # strict/warn/permissive
|
|
|
|
# PostToolUse features
|
|
state_tracking_enabled: bool = False
|
|
cross_file_check_enabled: bool = False
|
|
verify_naming: bool = True
|
|
show_success: bool = False
|
|
|
|
# File patterns
|
|
skip_patterns: list[str] = None
|
|
|
|
def __post_init__(self):
|
|
if self.skip_patterns is None:
|
|
self.skip_patterns = ["test_", "_test.py", "/tests/", "/fixtures/"]
|
|
|
|
@classmethod
|
|
def from_env(cls) -> "QualityConfig":
|
|
"""Load config from environment variables."""
|
|
return cls(
|
|
duplicate_threshold=float(os.getenv("QUALITY_DUP_THRESHOLD", "0.7")),
|
|
duplicate_enabled=os.getenv("QUALITY_DUP_ENABLED", "true").lower()
|
|
== "true",
|
|
complexity_threshold=int(os.getenv("QUALITY_COMPLEXITY_THRESHOLD", "10")),
|
|
complexity_enabled=os.getenv("QUALITY_COMPLEXITY_ENABLED", "true").lower()
|
|
== "true",
|
|
modernization_enabled=os.getenv("QUALITY_MODERN_ENABLED", "true").lower()
|
|
== "true",
|
|
require_type_hints=os.getenv("QUALITY_REQUIRE_TYPES", "true").lower()
|
|
== "true",
|
|
enforcement_mode=os.getenv("QUALITY_ENFORCEMENT", "strict"),
|
|
state_tracking_enabled=os.getenv("QUALITY_STATE_TRACKING", "false").lower()
|
|
== "true",
|
|
cross_file_check_enabled=os.getenv(
|
|
"QUALITY_CROSS_FILE_CHECK",
|
|
"false",
|
|
).lower()
|
|
== "true",
|
|
verify_naming=os.getenv("QUALITY_VERIFY_NAMING", "true").lower() == "true",
|
|
show_success=os.getenv("QUALITY_SHOW_SUCCESS", "false").lower() == "true",
|
|
)
|
|
|
|
|
|
def should_skip_file(file_path: str, config: QualityConfig) -> bool:
|
|
"""Check if file should be skipped based on patterns."""
|
|
return any(pattern in file_path for pattern in config.skip_patterns)
|
|
|
|
|
|
def get_claude_quality_path() -> str:
|
|
"""Get claude-quality binary path."""
|
|
claude_quality = Path(__file__).parent.parent / ".venv/bin/claude-quality"
|
|
return str(claude_quality) if claude_quality.exists() else "claude-quality"
|
|
|
|
|
|
def analyze_code_quality(
|
|
content: str,
|
|
_file_path: str,
|
|
config: QualityConfig,
|
|
) -> dict[str, Any]:
|
|
"""Analyze code content using claude-quality toolkit."""
|
|
with NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp:
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
results = {}
|
|
claude_quality = get_claude_quality_path()
|
|
|
|
# First check for internal duplicates within the file
|
|
if config.duplicate_enabled:
|
|
internal_duplicates = detect_internal_duplicates(
|
|
content,
|
|
threshold=config.duplicate_threshold,
|
|
min_lines=4,
|
|
)
|
|
if internal_duplicates.get("duplicates"):
|
|
results["internal_duplicates"] = internal_duplicates
|
|
|
|
# Run complexity analysis
|
|
if config.complexity_enabled:
|
|
cmd = [
|
|
claude_quality,
|
|
"complexity",
|
|
tmp_path,
|
|
"--threshold",
|
|
str(config.complexity_threshold),
|
|
"--format",
|
|
"json",
|
|
]
|
|
try:
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
)
|
|
if result.returncode == 0:
|
|
with suppress(json.JSONDecodeError):
|
|
results["complexity"] = json.loads(result.stdout)
|
|
except subprocess.TimeoutExpired:
|
|
pass # Command timed out
|
|
|
|
# Run modernization analysis
|
|
if config.modernization_enabled:
|
|
cmd = [
|
|
claude_quality,
|
|
"modernization",
|
|
tmp_path,
|
|
"--include-type-hints" if config.require_type_hints else "",
|
|
"--format",
|
|
"json",
|
|
]
|
|
cmd = [c for c in cmd if c] # Remove empty strings
|
|
try:
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
)
|
|
if result.returncode == 0:
|
|
with suppress(json.JSONDecodeError):
|
|
results["modernization"] = json.loads(result.stdout)
|
|
except subprocess.TimeoutExpired:
|
|
pass # Command timed out
|
|
|
|
return results
|
|
|
|
finally:
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|
|
|
|
|
def check_code_issues(
|
|
results: dict[str, Any],
|
|
config: QualityConfig,
|
|
) -> tuple[bool, list[str]]:
|
|
"""Check analysis results for issues that should block the operation."""
|
|
issues = []
|
|
|
|
# Check for internal duplicates (within the same file)
|
|
if "internal_duplicates" in results:
|
|
duplicates = results["internal_duplicates"].get("duplicates", [])
|
|
for dup in duplicates[:3]: # Show first 3
|
|
locations = ", ".join(
|
|
f"{loc['name']} ({loc['lines']})" for loc in dup.get("locations", [])
|
|
)
|
|
issues.append(
|
|
f"Internal duplication ({dup.get('similarity', 0):.0%} similar): "
|
|
f"{dup.get('description')} - {locations}",
|
|
)
|
|
|
|
# Check for complexity issues
|
|
if "complexity" in results:
|
|
summary = results["complexity"].get("summary", {})
|
|
avg_cc = summary.get("average_cyclomatic_complexity", 0)
|
|
if avg_cc > config.complexity_threshold:
|
|
issues.append(
|
|
f"High average complexity: CC={avg_cc:.1f} "
|
|
f"(threshold: {config.complexity_threshold})",
|
|
)
|
|
|
|
distribution = results["complexity"].get("distribution", {})
|
|
high_count = (
|
|
distribution.get("High", 0)
|
|
+ distribution.get("Very High", 0)
|
|
+ distribution.get("Extreme", 0)
|
|
)
|
|
if high_count > 0:
|
|
issues.append(f"Found {high_count} function(s) with high complexity")
|
|
|
|
# Check for modernization issues
|
|
if "modernization" in results:
|
|
files = results["modernization"].get("files", {})
|
|
total_issues = 0
|
|
issue_types = set()
|
|
|
|
for _file_path, file_issues in files.items():
|
|
if isinstance(file_issues, list):
|
|
total_issues += len(file_issues)
|
|
for issue in file_issues:
|
|
if isinstance(issue, dict):
|
|
issue_types.add(issue.get("issue_type", "unknown"))
|
|
|
|
# Only flag if there are non-type-hint issues or many type hint issues
|
|
non_type_issues = len(
|
|
[t for t in issue_types if "type" not in t and "typing" not in t],
|
|
)
|
|
type_issues = total_issues - non_type_issues
|
|
|
|
if non_type_issues > 0:
|
|
non_type_list = [
|
|
t for t in issue_types if "type" not in t and "typing" not in t
|
|
]
|
|
issues.append(
|
|
f"Modernization needed: {non_type_issues} non-type issues "
|
|
f"({', '.join(non_type_list)})",
|
|
)
|
|
elif config.require_type_hints and type_issues > 10:
|
|
issues.append(
|
|
f"Many missing type hints: {type_issues} functions/parameters "
|
|
"lacking annotations",
|
|
)
|
|
|
|
return len(issues) > 0, issues
|
|
|
|
|
|
def store_pre_state(file_path: str, content: str) -> None:
|
|
"""Store file state before modification for later comparison."""
|
|
import tempfile
|
|
|
|
cache_dir = Path(tempfile.gettempdir()) / ".quality_state"
|
|
cache_dir.mkdir(exist_ok=True, mode=0o700)
|
|
|
|
state = {
|
|
"file_path": file_path,
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
"content_hash": hashlib.sha256(content.encode()).hexdigest(),
|
|
"lines": len(content.split("\n")),
|
|
"functions": content.count("def "),
|
|
"classes": content.count("class "),
|
|
}
|
|
|
|
cache_key = hashlib.sha256(file_path.encode()).hexdigest()[:8]
|
|
cache_file = cache_dir / f"{cache_key}_pre.json"
|
|
cache_file.write_text(json.dumps(state, indent=2))
|
|
|
|
|
|
def check_state_changes(file_path: str) -> list[str]:
|
|
"""Check for quality changes between pre and post states."""
|
|
import tempfile
|
|
|
|
issues = []
|
|
cache_dir = Path(tempfile.gettempdir()) / ".quality_state"
|
|
cache_key = hashlib.sha256(file_path.encode()).hexdigest()[:8]
|
|
|
|
pre_file = cache_dir / f"{cache_key}_pre.json"
|
|
if not pre_file.exists():
|
|
return issues
|
|
|
|
try:
|
|
pre_state = json.loads(pre_file.read_text())
|
|
try:
|
|
current_content = Path(file_path).read_text()
|
|
except (PermissionError, FileNotFoundError, OSError):
|
|
return issues # Can't compare if can't read file
|
|
|
|
current_lines = len(current_content.split("\n"))
|
|
current_functions = current_content.count("def ")
|
|
_current_classes = current_content.count("class ") # Future use
|
|
|
|
# Check for significant changes
|
|
if current_functions < pre_state.get("functions", 0):
|
|
issues.append(
|
|
f"⚠️ Reduced functions: {pre_state['functions']} → {current_functions}",
|
|
)
|
|
|
|
if current_lines > pre_state.get("lines", 0) * 1.5:
|
|
issues.append(
|
|
"⚠️ File size increased significantly: "
|
|
f"{pre_state['lines']} → {current_lines} lines",
|
|
)
|
|
except Exception: # noqa: BLE001
|
|
logging.debug("Could not analyze state changes for %s", file_path)
|
|
|
|
return issues
|
|
|
|
|
|
def check_cross_file_duplicates(file_path: str, config: QualityConfig) -> list[str]:
|
|
"""Check for duplicates across project files."""
|
|
issues = []
|
|
|
|
# Get project root
|
|
project_root = Path(file_path).parent
|
|
while project_root.parent != project_root:
|
|
if (project_root / ".git").exists() or (
|
|
project_root / "pyproject.toml"
|
|
).exists():
|
|
break
|
|
project_root = project_root.parent
|
|
|
|
claude_quality = get_claude_quality_path()
|
|
|
|
try:
|
|
result = subprocess.run( # noqa: S603
|
|
[
|
|
claude_quality,
|
|
"duplicates",
|
|
str(project_root),
|
|
"--threshold",
|
|
str(config.duplicate_threshold),
|
|
"--format",
|
|
"json",
|
|
],
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=60,
|
|
)
|
|
if result.returncode == 0:
|
|
data = json.loads(result.stdout)
|
|
duplicates = data.get("duplicates", [])
|
|
if any(str(file_path) in str(d) for d in duplicates):
|
|
issues.append("⚠️ Cross-file duplication detected")
|
|
except Exception: # noqa: BLE001
|
|
logging.debug("Could not check cross-file duplicates for %s", file_path)
|
|
|
|
return issues
|
|
|
|
|
|
def verify_naming_conventions(file_path: str) -> list[str]:
|
|
"""Verify PEP8 naming conventions."""
|
|
issues = []
|
|
try:
|
|
content = Path(file_path).read_text()
|
|
except (PermissionError, FileNotFoundError, OSError):
|
|
return issues # Can't check naming if can't read file
|
|
|
|
# Check function names (should be snake_case)
|
|
if bad_funcs := re.findall(
|
|
r"def\s+([A-Z][a-zA-Z0-9_]*|[a-z]+[A-Z][a-zA-Z0-9_]*)\s*\(",
|
|
content,
|
|
):
|
|
issues.append(f"⚠️ Non-PEP8 function names: {', '.join(bad_funcs[:3])}")
|
|
|
|
# Check class names (should be PascalCase)
|
|
if bad_classes := re.findall(r"class\s+([a-z][a-z0-9_]*)\s*[\(:]", content):
|
|
issues.append(f"⚠️ Non-PEP8 class names: {', '.join(bad_classes[:3])}")
|
|
|
|
return issues
|
|
|
|
|
|
def pretooluse_hook(hook_data: dict, config: QualityConfig) -> dict:
|
|
"""Handle PreToolUse hook - analyze content before write/edit."""
|
|
tool_name = hook_data.get("tool_name", "")
|
|
tool_input = hook_data.get("tool_input", {})
|
|
|
|
# Only analyze for write/edit tools
|
|
if tool_name not in ["Write", "Edit", "MultiEdit"]:
|
|
return {"decision": "allow"}
|
|
|
|
# Extract content based on tool type
|
|
content = None
|
|
file_path = tool_input.get("file_path", "")
|
|
|
|
if tool_name == "Write":
|
|
content = tool_input.get("content", "")
|
|
elif tool_name == "Edit":
|
|
content = tool_input.get("new_string", "")
|
|
elif tool_name == "MultiEdit":
|
|
edits = tool_input.get("edits", [])
|
|
content = "\n".join(edit.get("new_string", "") for edit in edits)
|
|
|
|
# Only analyze Python files
|
|
if not file_path or not file_path.endswith(".py") or not content:
|
|
return {"decision": "allow"}
|
|
|
|
# Skip analysis for configured patterns
|
|
if should_skip_file(file_path, config):
|
|
return {"decision": "allow"}
|
|
|
|
try:
|
|
# Store state if tracking enabled
|
|
if config.state_tracking_enabled:
|
|
store_pre_state(file_path, content)
|
|
|
|
# Run quality analysis
|
|
results = analyze_code_quality(content, file_path, config)
|
|
has_issues, issues = check_code_issues(results, config)
|
|
|
|
if has_issues:
|
|
# Prepare denial message
|
|
message = (
|
|
f"Code quality check failed for {Path(file_path).name}:\n"
|
|
+ "\n".join(f" • {issue}" for issue in issues)
|
|
+ "\n\nFix these issues before writing the code."
|
|
)
|
|
|
|
# Make decision based on enforcement mode
|
|
if config.enforcement_mode == "strict":
|
|
return {"decision": "deny", "message": message}
|
|
if config.enforcement_mode == "warn":
|
|
return {"decision": "ask", "message": message}
|
|
# permissive
|
|
return {
|
|
"decision": "allow",
|
|
"message": f"⚠️ Quality Warning:\n{message}",
|
|
}
|
|
return {"decision": "allow"} # noqa: TRY300
|
|
|
|
except Exception as e: # noqa: BLE001
|
|
return {
|
|
"decision": "allow",
|
|
"message": f"Warning: Code quality check failed with error: {e}",
|
|
}
|
|
|
|
|
|
def posttooluse_hook(hook_data: dict, config: QualityConfig) -> dict:
|
|
"""Handle PostToolUse hook - verify quality after write/edit."""
|
|
tool_name = hook_data.get("tool_name", "")
|
|
tool_output = hook_data.get("tool_output", {})
|
|
|
|
# Only process write/edit tools
|
|
if tool_name not in ["Write", "Edit", "MultiEdit"]:
|
|
return {"decision": "allow"}
|
|
|
|
# Extract file path from output
|
|
file_path = None
|
|
if isinstance(tool_output, dict):
|
|
file_path = tool_output.get("file_path", "") or tool_output.get("path", "")
|
|
elif isinstance(tool_output, str):
|
|
match = re.search(r"([/\w\-_.]+\.py)", tool_output)
|
|
if match:
|
|
file_path = match[1]
|
|
|
|
if not file_path or not file_path.endswith(".py"):
|
|
return {"decision": "allow"}
|
|
|
|
if not Path(file_path).exists():
|
|
return {"decision": "allow"}
|
|
|
|
issues = []
|
|
|
|
# Check state changes if tracking enabled
|
|
if config.state_tracking_enabled:
|
|
delta_issues = check_state_changes(file_path)
|
|
issues.extend(delta_issues)
|
|
|
|
# Run cross-file duplicate detection if enabled
|
|
if config.cross_file_check_enabled:
|
|
cross_file_issues = check_cross_file_duplicates(file_path, config)
|
|
issues.extend(cross_file_issues)
|
|
|
|
# Verify naming conventions if enabled
|
|
if config.verify_naming:
|
|
naming_issues = verify_naming_conventions(file_path)
|
|
issues.extend(naming_issues)
|
|
|
|
# Format response
|
|
if issues:
|
|
message = (
|
|
f"📝 Post-write quality notes for {Path(file_path).name}:\n"
|
|
+ "\n".join(issues)
|
|
)
|
|
return {"decision": "allow", "message": message}
|
|
if config.show_success:
|
|
return {
|
|
"decision": "allow",
|
|
"message": f"✅ {Path(file_path).name} passed post-write verification",
|
|
}
|
|
|
|
return {"decision": "allow"}
|
|
|
|
|
|
def main() -> None:
|
|
"""Main hook entry point."""
|
|
try:
|
|
# Load configuration
|
|
config = QualityConfig.from_env()
|
|
|
|
# Read hook input from stdin
|
|
try:
|
|
hook_data = json.load(sys.stdin)
|
|
except json.JSONDecodeError:
|
|
print(json.dumps({"decision": "allow"})) # noqa: T201
|
|
return
|
|
|
|
# Detect hook type based on tool_output (PostToolUse) vs tool_input (PreToolUse)
|
|
if "tool_output" in hook_data:
|
|
# PostToolUse hook
|
|
response = posttooluse_hook(hook_data, config)
|
|
else:
|
|
# PreToolUse hook
|
|
response = pretooluse_hook(hook_data, config)
|
|
|
|
print(json.dumps(response)) # noqa: T201
|
|
|
|
# Handle exit codes according to Claude Code spec
|
|
if response.get("decision") == "deny":
|
|
# Exit code 2: Blocking error - stderr fed back to Claude
|
|
if "message" in response:
|
|
sys.stderr.write(response["message"])
|
|
sys.exit(2)
|
|
elif response.get("decision") == "ask":
|
|
# Also use exit code 2 for ask decisions to ensure Claude sees the message
|
|
if "message" in response:
|
|
sys.stderr.write(response["message"])
|
|
sys.exit(2)
|
|
# Exit code 0: Success (default)
|
|
|
|
except Exception as e: # noqa: BLE001
|
|
# Unexpected error - use exit code 1 (non-blocking error)
|
|
sys.stderr.write(f"Hook error: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|