1046 lines
33 KiB
Python
1046 lines
33 KiB
Python
"""Unified quality hook for Claude Code supporting both PreToolUse and PostToolUse.
|
|
|
|
Prevents writing duplicate, complex, or non-modernized code and verifies quality
|
|
after writes.
|
|
"""
|
|
|
|
import ast
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
from collections.abc import Callable
|
|
from contextlib import suppress
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
from tempfile import NamedTemporaryFile
|
|
from typing import TypedDict, cast
|
|
|
|
# Import internal duplicate detector
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
from internal_duplicate_detector import detect_internal_duplicates
|
|
|
|
|
|
class ToolConfig(TypedDict):
|
|
"""Configuration for a type checking tool."""
|
|
|
|
args: list[str]
|
|
error_check: Callable[[subprocess.CompletedProcess[str]], bool]
|
|
error_message: str | Callable[[subprocess.CompletedProcess[str]], str]
|
|
|
|
|
|
class DuplicateLocation(TypedDict):
|
|
"""Location information for a duplicate code block."""
|
|
|
|
name: str
|
|
lines: str
|
|
|
|
|
|
class Duplicate(TypedDict):
|
|
"""Duplicate code detection result."""
|
|
|
|
similarity: float
|
|
description: str
|
|
locations: list[DuplicateLocation]
|
|
|
|
|
|
class DuplicateResults(TypedDict):
|
|
"""Results from duplicate detection analysis."""
|
|
|
|
duplicates: list[Duplicate]
|
|
|
|
|
|
class ComplexitySummary(TypedDict):
|
|
"""Summary of complexity analysis."""
|
|
|
|
average_cyclomatic_complexity: float
|
|
|
|
|
|
class ComplexityDistribution(TypedDict):
|
|
"""Distribution of complexity levels."""
|
|
|
|
High: int
|
|
Very_High: int
|
|
Extreme: int
|
|
|
|
|
|
class ComplexityResults(TypedDict):
|
|
"""Results from complexity analysis."""
|
|
|
|
summary: ComplexitySummary
|
|
distribution: ComplexityDistribution
|
|
|
|
|
|
class TypeCheckingResults(TypedDict):
|
|
"""Results from type checking analysis."""
|
|
|
|
issues: list[str]
|
|
|
|
|
|
class AnalysisResults(TypedDict, total=False):
|
|
"""Complete analysis results from quality checks."""
|
|
|
|
internal_duplicates: DuplicateResults
|
|
complexity: ComplexityResults
|
|
type_checking: TypeCheckingResults
|
|
modernization: dict[str, object] # JSON structure varies
|
|
|
|
|
|
# Type aliases for JSON-like structures
|
|
JsonObject = dict[str, object]
|
|
JsonValue = str | int | float | bool | None | list[object] | JsonObject
|
|
|
|
|
|
@dataclass
|
|
class QualityConfig:
|
|
"""Configuration for quality checks."""
|
|
|
|
# Core settings
|
|
duplicate_threshold: float = 0.7
|
|
duplicate_enabled: bool = True
|
|
complexity_threshold: int = 10
|
|
complexity_enabled: bool = True
|
|
modernization_enabled: bool = True
|
|
require_type_hints: bool = True
|
|
enforcement_mode: str = "strict" # strict/warn/permissive
|
|
|
|
# Type checking tools
|
|
sourcery_enabled: bool = True
|
|
basedpyright_enabled: bool = True
|
|
pyrefly_enabled: bool = True
|
|
type_check_exit_code: int = 2
|
|
|
|
# PostToolUse features
|
|
state_tracking_enabled: bool = False
|
|
cross_file_check_enabled: bool = False
|
|
verify_naming: bool = True
|
|
show_success: bool = False
|
|
|
|
# File patterns
|
|
skip_patterns: list[str] | None = None
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.skip_patterns is None:
|
|
self.skip_patterns = ["test_", "_test.py", "/tests/", "/fixtures/"]
|
|
|
|
@classmethod
|
|
def from_env(cls) -> "QualityConfig":
|
|
"""Load config from environment variables."""
|
|
return cls(
|
|
duplicate_threshold=float(os.getenv("QUALITY_DUP_THRESHOLD", "0.7")),
|
|
duplicate_enabled=os.getenv("QUALITY_DUP_ENABLED", "true").lower()
|
|
== "true",
|
|
complexity_threshold=int(os.getenv("QUALITY_COMPLEXITY_THRESHOLD", "10")),
|
|
complexity_enabled=os.getenv("QUALITY_COMPLEXITY_ENABLED", "true").lower()
|
|
== "true",
|
|
modernization_enabled=os.getenv("QUALITY_MODERN_ENABLED", "true").lower()
|
|
== "true",
|
|
require_type_hints=os.getenv("QUALITY_REQUIRE_TYPES", "true").lower()
|
|
== "true",
|
|
enforcement_mode=os.getenv("QUALITY_ENFORCEMENT", "strict"),
|
|
state_tracking_enabled=os.getenv("QUALITY_STATE_TRACKING", "false").lower()
|
|
== "true",
|
|
cross_file_check_enabled=os.getenv(
|
|
"QUALITY_CROSS_FILE_CHECK",
|
|
"false",
|
|
).lower()
|
|
== "true",
|
|
verify_naming=os.getenv("QUALITY_VERIFY_NAMING", "true").lower() == "true",
|
|
show_success=os.getenv("QUALITY_SHOW_SUCCESS", "false").lower() == "true",
|
|
sourcery_enabled=os.getenv("QUALITY_SOURCERY_ENABLED", "true").lower()
|
|
== "true",
|
|
basedpyright_enabled=os.getenv(
|
|
"QUALITY_BASEDPYRIGHT_ENABLED",
|
|
"true",
|
|
).lower()
|
|
== "true",
|
|
pyrefly_enabled=os.getenv("QUALITY_PYREFLY_ENABLED", "true").lower()
|
|
== "true",
|
|
type_check_exit_code=int(os.getenv("QUALITY_TYPE_CHECK_EXIT_CODE", "2")),
|
|
)
|
|
|
|
|
|
def should_skip_file(file_path: str, config: QualityConfig) -> bool:
|
|
"""Check if file should be skipped based on patterns."""
|
|
if config.skip_patterns is None:
|
|
return False
|
|
return any(pattern in file_path for pattern in config.skip_patterns)
|
|
|
|
|
|
def get_claude_quality_path() -> str:
|
|
"""Get claude-quality binary path."""
|
|
claude_quality = Path(__file__).parent.parent / ".venv/bin/claude-quality"
|
|
return str(claude_quality) if claude_quality.exists() else "claude-quality"
|
|
|
|
|
|
def _ensure_tool_installed(tool_name: str) -> bool:
|
|
"""Ensure a type checking tool is installed in the virtual environment."""
|
|
venv_bin = Path(__file__).parent.parent / ".venv/bin"
|
|
tool_path = venv_bin / tool_name
|
|
|
|
if tool_path.exists():
|
|
return True
|
|
|
|
# Try to install using uv if available
|
|
try:
|
|
result = subprocess.run( # noqa: S603
|
|
[str(venv_bin / "uv"), "pip", "install", tool_name],
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=60,
|
|
)
|
|
except (subprocess.TimeoutExpired, OSError):
|
|
return False
|
|
else:
|
|
return result.returncode == 0
|
|
|
|
|
|
def _run_type_checker(
|
|
tool_name: str,
|
|
file_path: str,
|
|
_config: QualityConfig,
|
|
) -> tuple[bool, str]:
|
|
"""Run a type checking tool and return success status and output."""
|
|
venv_bin = Path(__file__).parent.parent / ".venv/bin"
|
|
tool_path = venv_bin / tool_name
|
|
|
|
if not tool_path.exists() and not _ensure_tool_installed(tool_name):
|
|
return True, f"Warning: {tool_name} not available"
|
|
|
|
# Tool configuration mapping
|
|
tool_configs: dict[str, ToolConfig] = {
|
|
"basedpyright": ToolConfig(
|
|
args=["--outputjson", file_path],
|
|
error_check=lambda result: result.returncode == 1,
|
|
error_message="Type errors found",
|
|
),
|
|
"pyrefly": ToolConfig(
|
|
args=["check", file_path],
|
|
error_check=lambda result: result.returncode == 1,
|
|
error_message=lambda result: str(result.stdout).strip(),
|
|
),
|
|
"sourcery": ToolConfig(
|
|
args=["review", file_path],
|
|
error_check=lambda result: (
|
|
"issues detected" in str(result.stdout)
|
|
and "0 issues detected" not in str(result.stdout)
|
|
),
|
|
error_message=lambda result: str(result.stdout).strip(),
|
|
),
|
|
}
|
|
|
|
tool_config = tool_configs.get(tool_name)
|
|
if not tool_config:
|
|
return True, f"Warning: Unknown tool {tool_name}"
|
|
|
|
try:
|
|
cmd = [str(tool_path)] + tool_config["args"]
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
)
|
|
|
|
# Check for tool-specific errors
|
|
error_check = tool_config["error_check"]
|
|
if error_check(result):
|
|
error_msg = tool_config["error_message"]
|
|
message = str(error_msg(result) if callable(error_msg) else error_msg)
|
|
return False, message
|
|
|
|
# Return success or warning
|
|
if result.returncode == 0:
|
|
return True, ""
|
|
|
|
except subprocess.TimeoutExpired:
|
|
return True, f"Warning: {tool_name} timeout"
|
|
except OSError:
|
|
return True, f"Warning: {tool_name} execution error"
|
|
else:
|
|
return True, f"Warning: {tool_name} error (exit {result.returncode})"
|
|
|
|
|
|
def _initialize_analysis() -> tuple[AnalysisResults, str]:
|
|
"""Initialize analysis with empty results and claude-quality path."""
|
|
results: AnalysisResults = {}
|
|
claude_quality: str = get_claude_quality_path()
|
|
return results, claude_quality
|
|
|
|
|
|
def run_type_checks(file_path: str, config: QualityConfig) -> list[str]:
|
|
"""Run all enabled type checking tools and return any issues."""
|
|
issues: list[str] = []
|
|
|
|
# Run Sourcery
|
|
if config.sourcery_enabled:
|
|
success, output = _run_type_checker("sourcery", file_path, config)
|
|
if not success and output:
|
|
issues.append(f"Sourcery: {output.strip()}")
|
|
|
|
# Run BasedPyright
|
|
if config.basedpyright_enabled:
|
|
success, output = _run_type_checker("basedpyright", file_path, config)
|
|
if not success and output:
|
|
issues.append(f"BasedPyright: {output.strip()}")
|
|
|
|
# Run Pyrefly
|
|
if config.pyrefly_enabled:
|
|
success, output = _run_type_checker("pyrefly", file_path, config)
|
|
if not success and output:
|
|
issues.append(f"Pyrefly: {output.strip()}")
|
|
|
|
return issues
|
|
|
|
|
|
def _run_quality_analyses(
|
|
content: str,
|
|
tmp_path: str,
|
|
config: QualityConfig,
|
|
enable_type_checks: bool,
|
|
) -> AnalysisResults:
|
|
"""Run all quality analysis checks and return results."""
|
|
results, claude_quality = _initialize_analysis()
|
|
|
|
# First check for internal duplicates within the file
|
|
if config.duplicate_enabled:
|
|
internal_duplicates_raw = detect_internal_duplicates(
|
|
content,
|
|
threshold=config.duplicate_threshold,
|
|
min_lines=4,
|
|
)
|
|
# Cast after runtime validation - function returns compatible structure
|
|
internal_duplicates = cast("DuplicateResults", internal_duplicates_raw)
|
|
if internal_duplicates.get("duplicates"):
|
|
results["internal_duplicates"] = internal_duplicates
|
|
|
|
# Run complexity analysis
|
|
if config.complexity_enabled:
|
|
cmd = [
|
|
claude_quality,
|
|
"complexity",
|
|
tmp_path,
|
|
"--threshold",
|
|
str(config.complexity_threshold),
|
|
"--format",
|
|
"json",
|
|
]
|
|
with suppress(subprocess.TimeoutExpired):
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
)
|
|
if result.returncode == 0:
|
|
with suppress(json.JSONDecodeError):
|
|
results["complexity"] = json.loads(result.stdout)
|
|
|
|
# Run type checking if any tool is enabled
|
|
if enable_type_checks and any(
|
|
[
|
|
config.sourcery_enabled,
|
|
config.basedpyright_enabled,
|
|
config.pyrefly_enabled,
|
|
],
|
|
):
|
|
try:
|
|
if type_issues := run_type_checks(tmp_path, config):
|
|
results["type_checking"] = {"issues": type_issues}
|
|
except Exception as e: # noqa: BLE001
|
|
logging.debug("Type checking failed: %s", e)
|
|
|
|
# Run modernization analysis
|
|
if config.modernization_enabled:
|
|
cmd = [
|
|
claude_quality,
|
|
"modernization",
|
|
tmp_path,
|
|
"--include-type-hints" if config.require_type_hints else "",
|
|
"--format",
|
|
"json",
|
|
]
|
|
cmd = [c for c in cmd if c] # Remove empty strings
|
|
with suppress(subprocess.TimeoutExpired):
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
)
|
|
if result.returncode == 0:
|
|
with suppress(json.JSONDecodeError):
|
|
results["modernization"] = json.loads(result.stdout)
|
|
|
|
return results
|
|
|
|
|
|
def analyze_code_quality(
|
|
content: str,
|
|
file_path: str,
|
|
config: QualityConfig,
|
|
*,
|
|
enable_type_checks: bool = True,
|
|
) -> AnalysisResults:
|
|
"""Analyze code content using claude-quality toolkit."""
|
|
suffix = Path(file_path).suffix or ".py"
|
|
with NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp:
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
return _run_quality_analyses(content, tmp_path, config, enable_type_checks)
|
|
|
|
finally:
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|
|
|
|
|
def _check_internal_duplicates(results: AnalysisResults) -> list[str]:
|
|
"""Check for internal duplicate code within the same file."""
|
|
issues: list[str] = []
|
|
if "internal_duplicates" not in results:
|
|
return issues
|
|
|
|
duplicates: list[Duplicate] = results["internal_duplicates"].get(
|
|
"duplicates",
|
|
[],
|
|
)
|
|
for dup in duplicates[:3]: # Show first 3
|
|
locations = ", ".join(
|
|
f"{loc['name']} ({loc['lines']})" for loc in dup.get("locations", [])
|
|
)
|
|
issues.append(
|
|
f"Internal duplication ({dup.get('similarity', 0):.0%} similar): "
|
|
f"{dup.get('description')} - {locations}",
|
|
)
|
|
return issues
|
|
|
|
|
|
def _check_complexity_issues(
|
|
results: AnalysisResults,
|
|
config: QualityConfig,
|
|
) -> list[str]:
|
|
"""Check for code complexity issues."""
|
|
issues: list[str] = []
|
|
if "complexity" not in results:
|
|
return issues
|
|
|
|
complexity_data = results["complexity"]
|
|
summary = complexity_data.get("summary", {})
|
|
avg_cc = summary.get("average_cyclomatic_complexity", 0.0)
|
|
if avg_cc > config.complexity_threshold:
|
|
issues.append(
|
|
f"High average complexity: CC={avg_cc:.1f} "
|
|
f"(threshold: {config.complexity_threshold})",
|
|
)
|
|
|
|
distribution = complexity_data.get("distribution", {})
|
|
high_count: int = (
|
|
distribution.get("High", 0)
|
|
+ distribution.get("Very High", 0)
|
|
+ distribution.get("Extreme", 0)
|
|
)
|
|
if high_count > 0:
|
|
issues.append(f"Found {high_count} function(s) with high complexity")
|
|
return issues
|
|
|
|
|
|
def _check_modernization_issues(
|
|
results: AnalysisResults,
|
|
config: QualityConfig,
|
|
) -> list[str]:
|
|
"""Check for code modernization issues."""
|
|
issues: list[str] = []
|
|
if "modernization" not in results:
|
|
return issues
|
|
|
|
try:
|
|
modernization_data = results["modernization"]
|
|
files = modernization_data.get("files", {})
|
|
if not isinstance(files, dict):
|
|
return issues
|
|
except (AttributeError, TypeError):
|
|
return issues
|
|
total_issues: int = 0
|
|
issue_types: set[str] = set()
|
|
|
|
files_values = cast("dict[str, object]", files).values()
|
|
for file_issues_raw in files_values:
|
|
if not isinstance(file_issues_raw, list):
|
|
continue
|
|
|
|
# Cast to proper type after runtime check
|
|
file_issues_list = cast("list[object]", file_issues_raw)
|
|
total_issues += len(file_issues_list)
|
|
|
|
for issue_raw in file_issues_list:
|
|
if not isinstance(issue_raw, dict):
|
|
continue
|
|
|
|
# Cast to proper type after runtime check
|
|
issue_dict = cast("dict[str, object]", issue_raw)
|
|
issue_type_raw = issue_dict.get("issue_type", "unknown")
|
|
if isinstance(issue_type_raw, str):
|
|
issue_types.add(issue_type_raw)
|
|
|
|
# Only flag if there are non-type-hint issues or many type hint issues
|
|
non_type_issues: int = len(
|
|
[t for t in issue_types if "type" not in t and "typing" not in t],
|
|
)
|
|
type_issues: int = total_issues - non_type_issues
|
|
|
|
if non_type_issues > 0:
|
|
non_type_list: list[str] = [
|
|
t for t in issue_types if "type" not in t and "typing" not in t
|
|
]
|
|
issues.append(
|
|
f"Modernization needed: {non_type_issues} non-type issues "
|
|
f"({', '.join(non_type_list)})",
|
|
)
|
|
elif config.require_type_hints and type_issues > 10:
|
|
issues.append(
|
|
f"Many missing type hints: {type_issues} functions/parameters "
|
|
"lacking annotations",
|
|
)
|
|
return issues
|
|
|
|
|
|
def _check_type_checking_issues(results: AnalysisResults) -> list[str]:
|
|
"""Check for type checking issues from Sourcery, BasedPyright, and Pyrefly."""
|
|
issues: list[str] = []
|
|
if "type_checking" not in results:
|
|
return issues
|
|
|
|
with suppress(AttributeError, TypeError):
|
|
type_checking_data = results["type_checking"]
|
|
type_issues = type_checking_data.get("issues", [])
|
|
issues.extend(str(issue_raw) for issue_raw in type_issues[:5])
|
|
|
|
return issues
|
|
|
|
|
|
def check_code_issues(
|
|
results: AnalysisResults,
|
|
config: QualityConfig,
|
|
) -> tuple[bool, list[str]]:
|
|
"""Check analysis results for issues that should block the operation."""
|
|
issues: list[str] = []
|
|
|
|
issues.extend(_check_internal_duplicates(results))
|
|
issues.extend(_check_complexity_issues(results, config))
|
|
issues.extend(_check_modernization_issues(results, config))
|
|
issues.extend(_check_type_checking_issues(results))
|
|
|
|
return len(issues) > 0, issues
|
|
|
|
|
|
def store_pre_state(file_path: str, content: str) -> None:
|
|
"""Store file state before modification for later comparison."""
|
|
import tempfile
|
|
|
|
cache_dir = Path(tempfile.gettempdir()) / ".quality_state"
|
|
cache_dir.mkdir(exist_ok=True, mode=0o700)
|
|
|
|
state: dict[str, str | int] = {
|
|
"file_path": file_path,
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
"content_hash": hashlib.sha256(content.encode()).hexdigest(),
|
|
"lines": len(content.split("\n")),
|
|
"functions": content.count("def "),
|
|
"classes": content.count("class "),
|
|
}
|
|
|
|
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
|
|
cache_file: Path = cache_dir / f"{cache_key}_pre.json"
|
|
cache_file.write_text(json.dumps(state, indent=2))
|
|
|
|
|
|
def check_state_changes(file_path: str) -> list[str]:
|
|
"""Check for quality changes between pre and post states."""
|
|
import tempfile
|
|
|
|
issues: list[str] = []
|
|
cache_dir: Path = Path(tempfile.gettempdir()) / ".quality_state"
|
|
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
|
|
|
|
pre_file: Path = cache_dir / f"{cache_key}_pre.json"
|
|
if not pre_file.exists():
|
|
return issues
|
|
|
|
try:
|
|
pre_state = json.loads(pre_file.read_text())
|
|
try:
|
|
current_content = Path(file_path).read_text()
|
|
except OSError:
|
|
return issues # Can't compare if can't read file
|
|
|
|
current_lines = len(current_content.split("\n"))
|
|
current_functions = current_content.count("def ")
|
|
|
|
# Check for significant changes
|
|
if current_functions < pre_state.get("functions", 0):
|
|
issues.append(
|
|
f"⚠️ Reduced functions: {pre_state['functions']} → {current_functions}",
|
|
)
|
|
|
|
if current_lines > pre_state.get("lines", 0) * 1.5:
|
|
issues.append(
|
|
"⚠️ File size increased significantly: "
|
|
f"{pre_state['lines']} → {current_lines} lines",
|
|
)
|
|
except Exception: # noqa: BLE001
|
|
logging.debug("Could not analyze state changes for %s", file_path)
|
|
|
|
return issues
|
|
|
|
|
|
def check_cross_file_duplicates(file_path: str, config: QualityConfig) -> list[str]:
|
|
"""Check for duplicates across project files."""
|
|
issues: list[str] = []
|
|
|
|
# Get project root
|
|
project_root: Path = Path(file_path).parent
|
|
while (
|
|
project_root.parent != project_root
|
|
and not (project_root / ".git").exists()
|
|
and not (project_root / "pyproject.toml").exists()
|
|
):
|
|
project_root = project_root.parent
|
|
|
|
try:
|
|
claude_quality: str = get_claude_quality_path()
|
|
result = subprocess.run( # noqa: S603
|
|
[
|
|
claude_quality,
|
|
"duplicates",
|
|
str(project_root),
|
|
"--threshold",
|
|
str(config.duplicate_threshold),
|
|
"--format",
|
|
"json",
|
|
],
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=60,
|
|
)
|
|
if result.returncode == 0:
|
|
data = json.loads(result.stdout)
|
|
duplicates = data.get("duplicates", [])
|
|
if any(file_path in str(d) for d in duplicates):
|
|
issues.append("⚠️ Cross-file duplication detected")
|
|
except Exception: # noqa: BLE001
|
|
logging.debug("Could not check cross-file duplicates for %s", file_path)
|
|
|
|
return issues
|
|
|
|
|
|
def verify_naming_conventions(file_path: str) -> list[str]:
|
|
"""Verify PEP8 naming conventions."""
|
|
issues: list[str] = []
|
|
try:
|
|
content = Path(file_path).read_text()
|
|
except OSError:
|
|
return issues # Can't check naming if can't read file
|
|
|
|
# Check function names (should be snake_case)
|
|
if bad_funcs := re.findall(
|
|
r"def\s+([A-Z][a-zA-Z0-9_]*|[a-z]+[A-Z][a-zA-Z0-9_]*)\s*\(",
|
|
content,
|
|
):
|
|
issues.append(f"⚠️ Non-PEP8 function names: {', '.join(bad_funcs[:3])}")
|
|
|
|
# Check class names (should be PascalCase)
|
|
if bad_classes := re.findall(r"class\s+([a-z][a-z0-9_]*)\s*[\(:]", content):
|
|
issues.append(f"⚠️ Non-PEP8 class names: {', '.join(bad_classes[:3])}")
|
|
|
|
return issues
|
|
|
|
|
|
def _detect_any_usage(content: str) -> list[str]:
|
|
"""Detect forbidden typing.Any usage in proposed content."""
|
|
|
|
class _AnyUsageVisitor(ast.NodeVisitor):
|
|
"""Collect line numbers where typing.Any is referenced."""
|
|
|
|
def __init__(self) -> None:
|
|
self.lines: set[int] = set()
|
|
|
|
def visit_Name(self, node: ast.Name) -> None:
|
|
if node.id == "Any":
|
|
self.lines.add(node.lineno)
|
|
self.generic_visit(node)
|
|
|
|
def visit_Attribute(self, node: ast.Attribute) -> None:
|
|
if node.attr == "Any":
|
|
self.lines.add(node.lineno)
|
|
self.generic_visit(node)
|
|
|
|
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
|
for alias in node.names:
|
|
if alias.name == "Any" or alias.asname == "Any":
|
|
self.lines.add(node.lineno)
|
|
self.generic_visit(node)
|
|
|
|
def visit_Import(self, node: ast.Import) -> None:
|
|
for alias in node.names:
|
|
if alias.name == "Any" or alias.asname == "Any":
|
|
self.lines.add(node.lineno)
|
|
self.generic_visit(node)
|
|
|
|
lines_with_any: set[int] = set()
|
|
try:
|
|
tree = ast.parse(content)
|
|
except SyntaxError:
|
|
for index, line in enumerate(content.splitlines(), start=1):
|
|
code_portion = line.split("#", 1)[0]
|
|
if re.search(r"\bAny\b", code_portion):
|
|
lines_with_any.add(index)
|
|
else:
|
|
visitor = _AnyUsageVisitor()
|
|
visitor.visit(tree)
|
|
lines_with_any = visitor.lines
|
|
|
|
if not lines_with_any:
|
|
return []
|
|
|
|
sorted_lines = sorted(lines_with_any)
|
|
display_lines = ", ".join(str(num) for num in sorted_lines[:5])
|
|
if len(sorted_lines) > 5:
|
|
display_lines += ", …"
|
|
|
|
return [
|
|
"⚠️ Forbidden typing.Any usage at line(s) "
|
|
f"{display_lines}; replace with specific types",
|
|
]
|
|
|
|
|
|
def _perform_quality_check(
|
|
file_path: str,
|
|
content: str,
|
|
config: QualityConfig,
|
|
enable_type_checks: bool = True,
|
|
) -> tuple[bool, list[str]]:
|
|
"""Perform quality analysis and return issues."""
|
|
# Store state if tracking enabled
|
|
if config.state_tracking_enabled:
|
|
store_pre_state(file_path, content)
|
|
|
|
# Run quality analysis
|
|
results = analyze_code_quality(
|
|
content,
|
|
file_path,
|
|
config,
|
|
enable_type_checks=enable_type_checks,
|
|
)
|
|
return check_code_issues(results, config)
|
|
|
|
|
|
def _handle_quality_issues(
|
|
file_path: str,
|
|
issues: list[str],
|
|
config: QualityConfig,
|
|
*,
|
|
forced_permission: str | None = None,
|
|
) -> JsonObject:
|
|
"""Handle quality issues based on enforcement mode."""
|
|
# Prepare denial message
|
|
message = (
|
|
f"Code quality check failed for {Path(file_path).name}:\n"
|
|
+ "\n".join(f" • {issue}" for issue in issues)
|
|
+ "\n\nFix these issues before writing the code."
|
|
)
|
|
|
|
# Make decision based on enforcement mode
|
|
if forced_permission:
|
|
return _create_hook_response("PreToolUse", forced_permission, message)
|
|
|
|
if config.enforcement_mode == "strict":
|
|
return _create_hook_response("PreToolUse", "deny", message)
|
|
if config.enforcement_mode == "warn":
|
|
return _create_hook_response("PreToolUse", "ask", message)
|
|
# permissive
|
|
warning_message = f"⚠️ Quality Warning:\n{message}"
|
|
return _create_hook_response(
|
|
"PreToolUse",
|
|
"allow",
|
|
warning_message,
|
|
warning_message,
|
|
)
|
|
|
|
|
|
def _exit_with_reason(reason: str, exit_code: int = 2) -> None:
|
|
"""Write reason to stderr and exit with specified code."""
|
|
sys.stderr.write(reason)
|
|
sys.exit(exit_code)
|
|
|
|
|
|
def _create_hook_response(
|
|
event_name: str,
|
|
permission: str = "",
|
|
reason: str = "",
|
|
system_message: str = "",
|
|
additional_context: str = "",
|
|
*,
|
|
decision: str | None = None,
|
|
) -> JsonObject:
|
|
"""Create standardized hook response."""
|
|
hook_output: dict[str, object] = {
|
|
"hookEventName": event_name,
|
|
}
|
|
|
|
if permission:
|
|
hook_output["permissionDecision"] = permission
|
|
if reason:
|
|
hook_output["permissionDecisionReason"] = reason
|
|
|
|
if additional_context:
|
|
hook_output["additionalContext"] = additional_context
|
|
|
|
response: JsonObject = {
|
|
"hookSpecificOutput": hook_output,
|
|
}
|
|
|
|
if permission:
|
|
response["permissionDecision"] = permission
|
|
|
|
if decision:
|
|
response["decision"] = decision
|
|
|
|
if reason:
|
|
response["reason"] = reason
|
|
|
|
if system_message:
|
|
response["systemMessage"] = system_message
|
|
|
|
return response
|
|
|
|
|
|
def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject:
|
|
"""Handle PreToolUse hook - analyze content before write/edit."""
|
|
tool_name = str(hook_data.get("tool_name", ""))
|
|
tool_input_raw = hook_data.get("tool_input", {})
|
|
if not isinstance(tool_input_raw, dict):
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
tool_input = cast("dict[str, object]", tool_input_raw)
|
|
|
|
# Only analyze for write/edit tools
|
|
if tool_name not in ["Write", "Edit", "MultiEdit"]:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
# Extract content based on tool type
|
|
file_path = str(tool_input.get("file_path", ""))
|
|
content = ""
|
|
|
|
if tool_name == "Write":
|
|
raw_content = tool_input.get("content", "")
|
|
content = "" if raw_content is None else str(raw_content)
|
|
elif tool_name == "Edit":
|
|
new_string = tool_input.get("new_string", "")
|
|
content = "" if new_string is None else str(new_string)
|
|
elif tool_name == "MultiEdit":
|
|
edits = tool_input.get("edits", [])
|
|
if isinstance(edits, list):
|
|
edits_list = cast("list[object]", edits)
|
|
parts: list[str] = []
|
|
for edit in edits_list:
|
|
if not isinstance(edit, dict):
|
|
continue
|
|
edit_dict = cast("dict[str, object]", edit)
|
|
new_str = edit_dict.get("new_string")
|
|
parts.append("") if new_str is None else parts.append(str(new_str))
|
|
content = "\n".join(parts)
|
|
|
|
# Only analyze Python files
|
|
if not file_path or not file_path.endswith(".py") or not content:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
# Skip analysis for configured patterns
|
|
if should_skip_file(file_path, config):
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
enable_type_checks = tool_name == "Write"
|
|
any_usage_issues = _detect_any_usage(content)
|
|
|
|
try:
|
|
_has_issues, issues = _perform_quality_check(
|
|
file_path,
|
|
content,
|
|
config,
|
|
enable_type_checks=enable_type_checks,
|
|
)
|
|
|
|
all_issues = any_usage_issues + issues
|
|
|
|
if not all_issues:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
if any_usage_issues:
|
|
return _handle_quality_issues(
|
|
file_path,
|
|
all_issues,
|
|
config,
|
|
forced_permission="deny",
|
|
)
|
|
|
|
return _handle_quality_issues(file_path, all_issues, config)
|
|
except Exception as e: # noqa: BLE001
|
|
if any_usage_issues:
|
|
return _handle_quality_issues(
|
|
file_path,
|
|
any_usage_issues,
|
|
config,
|
|
forced_permission="deny",
|
|
)
|
|
return _create_hook_response(
|
|
"PreToolUse",
|
|
"allow",
|
|
f"Warning: Code quality check failed with error: {e}",
|
|
f"Warning: Code quality check failed with error: {e}",
|
|
)
|
|
|
|
|
|
def posttooluse_hook(
|
|
hook_data: JsonObject,
|
|
config: QualityConfig,
|
|
) -> JsonObject:
|
|
"""Handle PostToolUse hook - verify quality after write/edit."""
|
|
tool_name: str = str(hook_data.get("tool_name", ""))
|
|
tool_output = hook_data.get("tool_response", hook_data.get("tool_output", {}))
|
|
|
|
# Only process write/edit tools
|
|
if tool_name not in ["Write", "Edit", "MultiEdit"]:
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
# Extract file path from output
|
|
file_path: str = ""
|
|
if isinstance(tool_output, dict):
|
|
tool_output_dict = cast("dict[str, object]", tool_output)
|
|
file_path = str(
|
|
tool_output_dict.get("file_path", "") or tool_output_dict.get("path", ""),
|
|
)
|
|
elif isinstance(tool_output, str) and (
|
|
match := re.search(r"([/\w\-_.]+\.py)", tool_output)
|
|
):
|
|
file_path = match[1]
|
|
|
|
if not file_path or not file_path.endswith(".py"):
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
if not Path(file_path).exists():
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
issues: list[str] = []
|
|
|
|
# Check state changes if tracking enabled
|
|
if config.state_tracking_enabled:
|
|
delta_issues = check_state_changes(file_path)
|
|
issues.extend(delta_issues)
|
|
|
|
# Run cross-file duplicate detection if enabled
|
|
if config.cross_file_check_enabled:
|
|
cross_file_issues = check_cross_file_duplicates(file_path, config)
|
|
issues.extend(cross_file_issues)
|
|
|
|
# Verify naming conventions if enabled
|
|
if config.verify_naming:
|
|
naming_issues = verify_naming_conventions(file_path)
|
|
issues.extend(naming_issues)
|
|
|
|
# Format response
|
|
if issues:
|
|
message = (
|
|
f"📝 Post-write quality notes for {Path(file_path).name}:\n"
|
|
+ "\n".join(issues)
|
|
)
|
|
return _create_hook_response(
|
|
"PostToolUse",
|
|
"",
|
|
message,
|
|
message,
|
|
message,
|
|
decision="block",
|
|
)
|
|
if config.show_success:
|
|
message = f"✅ {Path(file_path).name} passed post-write verification"
|
|
return _create_hook_response(
|
|
"PostToolUse",
|
|
"",
|
|
"",
|
|
message,
|
|
"",
|
|
decision="approve",
|
|
)
|
|
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
|
|
def main() -> None:
|
|
"""Main hook entry point."""
|
|
try:
|
|
# Load configuration
|
|
config = QualityConfig.from_env()
|
|
|
|
# Read hook input from stdin
|
|
try:
|
|
hook_data: JsonObject = json.load(sys.stdin)
|
|
except json.JSONDecodeError:
|
|
fallback_response: JsonObject = {
|
|
"hookSpecificOutput": {
|
|
"hookEventName": "PreToolUse",
|
|
"permissionDecision": "allow",
|
|
},
|
|
}
|
|
sys.stdout.write(json.dumps(fallback_response))
|
|
return
|
|
|
|
# Detect hook type: tool_response=PostToolUse, tool_input=PreToolUse
|
|
response: JsonObject
|
|
if "tool_response" in hook_data or "tool_output" in hook_data:
|
|
# PostToolUse hook
|
|
response = posttooluse_hook(hook_data, config)
|
|
else:
|
|
# PreToolUse hook
|
|
response = pretooluse_hook(hook_data, config)
|
|
|
|
print(json.dumps(response)) # noqa: T201
|
|
|
|
# Handle exit codes based on hook output
|
|
hook_output_raw = response.get("hookSpecificOutput", {})
|
|
if not isinstance(hook_output_raw, dict):
|
|
return
|
|
|
|
hook_output = cast("dict[str, object]", hook_output_raw)
|
|
permission_decision = hook_output.get("permissionDecision")
|
|
|
|
if permission_decision == "deny":
|
|
# Exit code 2: Blocking error - stderr fed back to Claude
|
|
reason = str(
|
|
hook_output.get("permissionDecisionReason", "Permission denied"),
|
|
)
|
|
_exit_with_reason(reason)
|
|
elif permission_decision == "ask":
|
|
# Also use exit code 2 for ask decisions to ensure Claude sees the message
|
|
reason = str(
|
|
hook_output.get("permissionDecisionReason", "Permission request"),
|
|
)
|
|
_exit_with_reason(reason)
|
|
# Exit code 0: Success (default)
|
|
|
|
except Exception as e: # noqa: BLE001
|
|
# Unexpected error - use exit code 1 (non-blocking error)
|
|
sys.stderr.write(f"Hook error: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|