Files
claude-scripts/hooks/code_quality_guard.py
2025-09-17 18:55:08 +00:00

1046 lines
33 KiB
Python

"""Unified quality hook for Claude Code supporting both PreToolUse and PostToolUse.
Prevents writing duplicate, complex, or non-modernized code and verifies quality
after writes.
"""
import ast
import hashlib
import json
import logging
import os
import re
import subprocess
import sys
from collections.abc import Callable
from contextlib import suppress
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import TypedDict, cast
# Import internal duplicate detector
sys.path.insert(0, str(Path(__file__).parent))
from internal_duplicate_detector import detect_internal_duplicates
class ToolConfig(TypedDict):
"""Configuration for a type checking tool."""
args: list[str]
error_check: Callable[[subprocess.CompletedProcess[str]], bool]
error_message: str | Callable[[subprocess.CompletedProcess[str]], str]
class DuplicateLocation(TypedDict):
"""Location information for a duplicate code block."""
name: str
lines: str
class Duplicate(TypedDict):
"""Duplicate code detection result."""
similarity: float
description: str
locations: list[DuplicateLocation]
class DuplicateResults(TypedDict):
"""Results from duplicate detection analysis."""
duplicates: list[Duplicate]
class ComplexitySummary(TypedDict):
"""Summary of complexity analysis."""
average_cyclomatic_complexity: float
class ComplexityDistribution(TypedDict):
"""Distribution of complexity levels."""
High: int
Very_High: int
Extreme: int
class ComplexityResults(TypedDict):
"""Results from complexity analysis."""
summary: ComplexitySummary
distribution: ComplexityDistribution
class TypeCheckingResults(TypedDict):
"""Results from type checking analysis."""
issues: list[str]
class AnalysisResults(TypedDict, total=False):
"""Complete analysis results from quality checks."""
internal_duplicates: DuplicateResults
complexity: ComplexityResults
type_checking: TypeCheckingResults
modernization: dict[str, object] # JSON structure varies
# Type aliases for JSON-like structures
JsonObject = dict[str, object]
JsonValue = str | int | float | bool | None | list[object] | JsonObject
@dataclass
class QualityConfig:
"""Configuration for quality checks."""
# Core settings
duplicate_threshold: float = 0.7
duplicate_enabled: bool = True
complexity_threshold: int = 10
complexity_enabled: bool = True
modernization_enabled: bool = True
require_type_hints: bool = True
enforcement_mode: str = "strict" # strict/warn/permissive
# Type checking tools
sourcery_enabled: bool = True
basedpyright_enabled: bool = True
pyrefly_enabled: bool = True
type_check_exit_code: int = 2
# PostToolUse features
state_tracking_enabled: bool = False
cross_file_check_enabled: bool = False
verify_naming: bool = True
show_success: bool = False
# File patterns
skip_patterns: list[str] | None = None
def __post_init__(self) -> None:
if self.skip_patterns is None:
self.skip_patterns = ["test_", "_test.py", "/tests/", "/fixtures/"]
@classmethod
def from_env(cls) -> "QualityConfig":
"""Load config from environment variables."""
return cls(
duplicate_threshold=float(os.getenv("QUALITY_DUP_THRESHOLD", "0.7")),
duplicate_enabled=os.getenv("QUALITY_DUP_ENABLED", "true").lower()
== "true",
complexity_threshold=int(os.getenv("QUALITY_COMPLEXITY_THRESHOLD", "10")),
complexity_enabled=os.getenv("QUALITY_COMPLEXITY_ENABLED", "true").lower()
== "true",
modernization_enabled=os.getenv("QUALITY_MODERN_ENABLED", "true").lower()
== "true",
require_type_hints=os.getenv("QUALITY_REQUIRE_TYPES", "true").lower()
== "true",
enforcement_mode=os.getenv("QUALITY_ENFORCEMENT", "strict"),
state_tracking_enabled=os.getenv("QUALITY_STATE_TRACKING", "false").lower()
== "true",
cross_file_check_enabled=os.getenv(
"QUALITY_CROSS_FILE_CHECK",
"false",
).lower()
== "true",
verify_naming=os.getenv("QUALITY_VERIFY_NAMING", "true").lower() == "true",
show_success=os.getenv("QUALITY_SHOW_SUCCESS", "false").lower() == "true",
sourcery_enabled=os.getenv("QUALITY_SOURCERY_ENABLED", "true").lower()
== "true",
basedpyright_enabled=os.getenv(
"QUALITY_BASEDPYRIGHT_ENABLED",
"true",
).lower()
== "true",
pyrefly_enabled=os.getenv("QUALITY_PYREFLY_ENABLED", "true").lower()
== "true",
type_check_exit_code=int(os.getenv("QUALITY_TYPE_CHECK_EXIT_CODE", "2")),
)
def should_skip_file(file_path: str, config: QualityConfig) -> bool:
"""Check if file should be skipped based on patterns."""
if config.skip_patterns is None:
return False
return any(pattern in file_path for pattern in config.skip_patterns)
def get_claude_quality_path() -> str:
"""Get claude-quality binary path."""
claude_quality = Path(__file__).parent.parent / ".venv/bin/claude-quality"
return str(claude_quality) if claude_quality.exists() else "claude-quality"
def _ensure_tool_installed(tool_name: str) -> bool:
"""Ensure a type checking tool is installed in the virtual environment."""
venv_bin = Path(__file__).parent.parent / ".venv/bin"
tool_path = venv_bin / tool_name
if tool_path.exists():
return True
# Try to install using uv if available
try:
result = subprocess.run( # noqa: S603
[str(venv_bin / "uv"), "pip", "install", tool_name],
check=False,
capture_output=True,
text=True,
timeout=60,
)
except (subprocess.TimeoutExpired, OSError):
return False
else:
return result.returncode == 0
def _run_type_checker(
tool_name: str,
file_path: str,
_config: QualityConfig,
) -> tuple[bool, str]:
"""Run a type checking tool and return success status and output."""
venv_bin = Path(__file__).parent.parent / ".venv/bin"
tool_path = venv_bin / tool_name
if not tool_path.exists() and not _ensure_tool_installed(tool_name):
return True, f"Warning: {tool_name} not available"
# Tool configuration mapping
tool_configs: dict[str, ToolConfig] = {
"basedpyright": ToolConfig(
args=["--outputjson", file_path],
error_check=lambda result: result.returncode == 1,
error_message="Type errors found",
),
"pyrefly": ToolConfig(
args=["check", file_path],
error_check=lambda result: result.returncode == 1,
error_message=lambda result: str(result.stdout).strip(),
),
"sourcery": ToolConfig(
args=["review", file_path],
error_check=lambda result: (
"issues detected" in str(result.stdout)
and "0 issues detected" not in str(result.stdout)
),
error_message=lambda result: str(result.stdout).strip(),
),
}
tool_config = tool_configs.get(tool_name)
if not tool_config:
return True, f"Warning: Unknown tool {tool_name}"
try:
cmd = [str(tool_path)] + tool_config["args"]
result = subprocess.run( # noqa: S603
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
)
# Check for tool-specific errors
error_check = tool_config["error_check"]
if error_check(result):
error_msg = tool_config["error_message"]
message = str(error_msg(result) if callable(error_msg) else error_msg)
return False, message
# Return success or warning
if result.returncode == 0:
return True, ""
except subprocess.TimeoutExpired:
return True, f"Warning: {tool_name} timeout"
except OSError:
return True, f"Warning: {tool_name} execution error"
else:
return True, f"Warning: {tool_name} error (exit {result.returncode})"
def _initialize_analysis() -> tuple[AnalysisResults, str]:
"""Initialize analysis with empty results and claude-quality path."""
results: AnalysisResults = {}
claude_quality: str = get_claude_quality_path()
return results, claude_quality
def run_type_checks(file_path: str, config: QualityConfig) -> list[str]:
"""Run all enabled type checking tools and return any issues."""
issues: list[str] = []
# Run Sourcery
if config.sourcery_enabled:
success, output = _run_type_checker("sourcery", file_path, config)
if not success and output:
issues.append(f"Sourcery: {output.strip()}")
# Run BasedPyright
if config.basedpyright_enabled:
success, output = _run_type_checker("basedpyright", file_path, config)
if not success and output:
issues.append(f"BasedPyright: {output.strip()}")
# Run Pyrefly
if config.pyrefly_enabled:
success, output = _run_type_checker("pyrefly", file_path, config)
if not success and output:
issues.append(f"Pyrefly: {output.strip()}")
return issues
def _run_quality_analyses(
content: str,
tmp_path: str,
config: QualityConfig,
enable_type_checks: bool,
) -> AnalysisResults:
"""Run all quality analysis checks and return results."""
results, claude_quality = _initialize_analysis()
# First check for internal duplicates within the file
if config.duplicate_enabled:
internal_duplicates_raw = detect_internal_duplicates(
content,
threshold=config.duplicate_threshold,
min_lines=4,
)
# Cast after runtime validation - function returns compatible structure
internal_duplicates = cast("DuplicateResults", internal_duplicates_raw)
if internal_duplicates.get("duplicates"):
results["internal_duplicates"] = internal_duplicates
# Run complexity analysis
if config.complexity_enabled:
cmd = [
claude_quality,
"complexity",
tmp_path,
"--threshold",
str(config.complexity_threshold),
"--format",
"json",
]
with suppress(subprocess.TimeoutExpired):
result = subprocess.run( # noqa: S603
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
)
if result.returncode == 0:
with suppress(json.JSONDecodeError):
results["complexity"] = json.loads(result.stdout)
# Run type checking if any tool is enabled
if enable_type_checks and any(
[
config.sourcery_enabled,
config.basedpyright_enabled,
config.pyrefly_enabled,
],
):
try:
if type_issues := run_type_checks(tmp_path, config):
results["type_checking"] = {"issues": type_issues}
except Exception as e: # noqa: BLE001
logging.debug("Type checking failed: %s", e)
# Run modernization analysis
if config.modernization_enabled:
cmd = [
claude_quality,
"modernization",
tmp_path,
"--include-type-hints" if config.require_type_hints else "",
"--format",
"json",
]
cmd = [c for c in cmd if c] # Remove empty strings
with suppress(subprocess.TimeoutExpired):
result = subprocess.run( # noqa: S603
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
)
if result.returncode == 0:
with suppress(json.JSONDecodeError):
results["modernization"] = json.loads(result.stdout)
return results
def analyze_code_quality(
content: str,
file_path: str,
config: QualityConfig,
*,
enable_type_checks: bool = True,
) -> AnalysisResults:
"""Analyze code content using claude-quality toolkit."""
suffix = Path(file_path).suffix or ".py"
with NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
return _run_quality_analyses(content, tmp_path, config, enable_type_checks)
finally:
Path(tmp_path).unlink(missing_ok=True)
def _check_internal_duplicates(results: AnalysisResults) -> list[str]:
"""Check for internal duplicate code within the same file."""
issues: list[str] = []
if "internal_duplicates" not in results:
return issues
duplicates: list[Duplicate] = results["internal_duplicates"].get(
"duplicates",
[],
)
for dup in duplicates[:3]: # Show first 3
locations = ", ".join(
f"{loc['name']} ({loc['lines']})" for loc in dup.get("locations", [])
)
issues.append(
f"Internal duplication ({dup.get('similarity', 0):.0%} similar): "
f"{dup.get('description')} - {locations}",
)
return issues
def _check_complexity_issues(
results: AnalysisResults,
config: QualityConfig,
) -> list[str]:
"""Check for code complexity issues."""
issues: list[str] = []
if "complexity" not in results:
return issues
complexity_data = results["complexity"]
summary = complexity_data.get("summary", {})
avg_cc = summary.get("average_cyclomatic_complexity", 0.0)
if avg_cc > config.complexity_threshold:
issues.append(
f"High average complexity: CC={avg_cc:.1f} "
f"(threshold: {config.complexity_threshold})",
)
distribution = complexity_data.get("distribution", {})
high_count: int = (
distribution.get("High", 0)
+ distribution.get("Very High", 0)
+ distribution.get("Extreme", 0)
)
if high_count > 0:
issues.append(f"Found {high_count} function(s) with high complexity")
return issues
def _check_modernization_issues(
results: AnalysisResults,
config: QualityConfig,
) -> list[str]:
"""Check for code modernization issues."""
issues: list[str] = []
if "modernization" not in results:
return issues
try:
modernization_data = results["modernization"]
files = modernization_data.get("files", {})
if not isinstance(files, dict):
return issues
except (AttributeError, TypeError):
return issues
total_issues: int = 0
issue_types: set[str] = set()
files_values = cast("dict[str, object]", files).values()
for file_issues_raw in files_values:
if not isinstance(file_issues_raw, list):
continue
# Cast to proper type after runtime check
file_issues_list = cast("list[object]", file_issues_raw)
total_issues += len(file_issues_list)
for issue_raw in file_issues_list:
if not isinstance(issue_raw, dict):
continue
# Cast to proper type after runtime check
issue_dict = cast("dict[str, object]", issue_raw)
issue_type_raw = issue_dict.get("issue_type", "unknown")
if isinstance(issue_type_raw, str):
issue_types.add(issue_type_raw)
# Only flag if there are non-type-hint issues or many type hint issues
non_type_issues: int = len(
[t for t in issue_types if "type" not in t and "typing" not in t],
)
type_issues: int = total_issues - non_type_issues
if non_type_issues > 0:
non_type_list: list[str] = [
t for t in issue_types if "type" not in t and "typing" not in t
]
issues.append(
f"Modernization needed: {non_type_issues} non-type issues "
f"({', '.join(non_type_list)})",
)
elif config.require_type_hints and type_issues > 10:
issues.append(
f"Many missing type hints: {type_issues} functions/parameters "
"lacking annotations",
)
return issues
def _check_type_checking_issues(results: AnalysisResults) -> list[str]:
"""Check for type checking issues from Sourcery, BasedPyright, and Pyrefly."""
issues: list[str] = []
if "type_checking" not in results:
return issues
with suppress(AttributeError, TypeError):
type_checking_data = results["type_checking"]
type_issues = type_checking_data.get("issues", [])
issues.extend(str(issue_raw) for issue_raw in type_issues[:5])
return issues
def check_code_issues(
results: AnalysisResults,
config: QualityConfig,
) -> tuple[bool, list[str]]:
"""Check analysis results for issues that should block the operation."""
issues: list[str] = []
issues.extend(_check_internal_duplicates(results))
issues.extend(_check_complexity_issues(results, config))
issues.extend(_check_modernization_issues(results, config))
issues.extend(_check_type_checking_issues(results))
return len(issues) > 0, issues
def store_pre_state(file_path: str, content: str) -> None:
"""Store file state before modification for later comparison."""
import tempfile
cache_dir = Path(tempfile.gettempdir()) / ".quality_state"
cache_dir.mkdir(exist_ok=True, mode=0o700)
state: dict[str, str | int] = {
"file_path": file_path,
"timestamp": datetime.now(UTC).isoformat(),
"content_hash": hashlib.sha256(content.encode()).hexdigest(),
"lines": len(content.split("\n")),
"functions": content.count("def "),
"classes": content.count("class "),
}
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
cache_file: Path = cache_dir / f"{cache_key}_pre.json"
cache_file.write_text(json.dumps(state, indent=2))
def check_state_changes(file_path: str) -> list[str]:
"""Check for quality changes between pre and post states."""
import tempfile
issues: list[str] = []
cache_dir: Path = Path(tempfile.gettempdir()) / ".quality_state"
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
pre_file: Path = cache_dir / f"{cache_key}_pre.json"
if not pre_file.exists():
return issues
try:
pre_state = json.loads(pre_file.read_text())
try:
current_content = Path(file_path).read_text()
except OSError:
return issues # Can't compare if can't read file
current_lines = len(current_content.split("\n"))
current_functions = current_content.count("def ")
# Check for significant changes
if current_functions < pre_state.get("functions", 0):
issues.append(
f"⚠️ Reduced functions: {pre_state['functions']}{current_functions}",
)
if current_lines > pre_state.get("lines", 0) * 1.5:
issues.append(
"⚠️ File size increased significantly: "
f"{pre_state['lines']}{current_lines} lines",
)
except Exception: # noqa: BLE001
logging.debug("Could not analyze state changes for %s", file_path)
return issues
def check_cross_file_duplicates(file_path: str, config: QualityConfig) -> list[str]:
"""Check for duplicates across project files."""
issues: list[str] = []
# Get project root
project_root: Path = Path(file_path).parent
while (
project_root.parent != project_root
and not (project_root / ".git").exists()
and not (project_root / "pyproject.toml").exists()
):
project_root = project_root.parent
try:
claude_quality: str = get_claude_quality_path()
result = subprocess.run( # noqa: S603
[
claude_quality,
"duplicates",
str(project_root),
"--threshold",
str(config.duplicate_threshold),
"--format",
"json",
],
check=False,
capture_output=True,
text=True,
timeout=60,
)
if result.returncode == 0:
data = json.loads(result.stdout)
duplicates = data.get("duplicates", [])
if any(file_path in str(d) for d in duplicates):
issues.append("⚠️ Cross-file duplication detected")
except Exception: # noqa: BLE001
logging.debug("Could not check cross-file duplicates for %s", file_path)
return issues
def verify_naming_conventions(file_path: str) -> list[str]:
"""Verify PEP8 naming conventions."""
issues: list[str] = []
try:
content = Path(file_path).read_text()
except OSError:
return issues # Can't check naming if can't read file
# Check function names (should be snake_case)
if bad_funcs := re.findall(
r"def\s+([A-Z][a-zA-Z0-9_]*|[a-z]+[A-Z][a-zA-Z0-9_]*)\s*\(",
content,
):
issues.append(f"⚠️ Non-PEP8 function names: {', '.join(bad_funcs[:3])}")
# Check class names (should be PascalCase)
if bad_classes := re.findall(r"class\s+([a-z][a-z0-9_]*)\s*[\(:]", content):
issues.append(f"⚠️ Non-PEP8 class names: {', '.join(bad_classes[:3])}")
return issues
def _detect_any_usage(content: str) -> list[str]:
"""Detect forbidden typing.Any usage in proposed content."""
class _AnyUsageVisitor(ast.NodeVisitor):
"""Collect line numbers where typing.Any is referenced."""
def __init__(self) -> None:
self.lines: set[int] = set()
def visit_Name(self, node: ast.Name) -> None:
if node.id == "Any":
self.lines.add(node.lineno)
self.generic_visit(node)
def visit_Attribute(self, node: ast.Attribute) -> None:
if node.attr == "Any":
self.lines.add(node.lineno)
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
for alias in node.names:
if alias.name == "Any" or alias.asname == "Any":
self.lines.add(node.lineno)
self.generic_visit(node)
def visit_Import(self, node: ast.Import) -> None:
for alias in node.names:
if alias.name == "Any" or alias.asname == "Any":
self.lines.add(node.lineno)
self.generic_visit(node)
lines_with_any: set[int] = set()
try:
tree = ast.parse(content)
except SyntaxError:
for index, line in enumerate(content.splitlines(), start=1):
code_portion = line.split("#", 1)[0]
if re.search(r"\bAny\b", code_portion):
lines_with_any.add(index)
else:
visitor = _AnyUsageVisitor()
visitor.visit(tree)
lines_with_any = visitor.lines
if not lines_with_any:
return []
sorted_lines = sorted(lines_with_any)
display_lines = ", ".join(str(num) for num in sorted_lines[:5])
if len(sorted_lines) > 5:
display_lines += ", …"
return [
"⚠️ Forbidden typing.Any usage at line(s) "
f"{display_lines}; replace with specific types",
]
def _perform_quality_check(
file_path: str,
content: str,
config: QualityConfig,
enable_type_checks: bool = True,
) -> tuple[bool, list[str]]:
"""Perform quality analysis and return issues."""
# Store state if tracking enabled
if config.state_tracking_enabled:
store_pre_state(file_path, content)
# Run quality analysis
results = analyze_code_quality(
content,
file_path,
config,
enable_type_checks=enable_type_checks,
)
return check_code_issues(results, config)
def _handle_quality_issues(
file_path: str,
issues: list[str],
config: QualityConfig,
*,
forced_permission: str | None = None,
) -> JsonObject:
"""Handle quality issues based on enforcement mode."""
# Prepare denial message
message = (
f"Code quality check failed for {Path(file_path).name}:\n"
+ "\n".join(f"{issue}" for issue in issues)
+ "\n\nFix these issues before writing the code."
)
# Make decision based on enforcement mode
if forced_permission:
return _create_hook_response("PreToolUse", forced_permission, message)
if config.enforcement_mode == "strict":
return _create_hook_response("PreToolUse", "deny", message)
if config.enforcement_mode == "warn":
return _create_hook_response("PreToolUse", "ask", message)
# permissive
warning_message = f"⚠️ Quality Warning:\n{message}"
return _create_hook_response(
"PreToolUse",
"allow",
warning_message,
warning_message,
)
def _exit_with_reason(reason: str, exit_code: int = 2) -> None:
"""Write reason to stderr and exit with specified code."""
sys.stderr.write(reason)
sys.exit(exit_code)
def _create_hook_response(
event_name: str,
permission: str = "",
reason: str = "",
system_message: str = "",
additional_context: str = "",
*,
decision: str | None = None,
) -> JsonObject:
"""Create standardized hook response."""
hook_output: dict[str, object] = {
"hookEventName": event_name,
}
if permission:
hook_output["permissionDecision"] = permission
if reason:
hook_output["permissionDecisionReason"] = reason
if additional_context:
hook_output["additionalContext"] = additional_context
response: JsonObject = {
"hookSpecificOutput": hook_output,
}
if permission:
response["permissionDecision"] = permission
if decision:
response["decision"] = decision
if reason:
response["reason"] = reason
if system_message:
response["systemMessage"] = system_message
return response
def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject:
"""Handle PreToolUse hook - analyze content before write/edit."""
tool_name = str(hook_data.get("tool_name", ""))
tool_input_raw = hook_data.get("tool_input", {})
if not isinstance(tool_input_raw, dict):
return _create_hook_response("PreToolUse", "allow")
tool_input = cast("dict[str, object]", tool_input_raw)
# Only analyze for write/edit tools
if tool_name not in ["Write", "Edit", "MultiEdit"]:
return _create_hook_response("PreToolUse", "allow")
# Extract content based on tool type
file_path = str(tool_input.get("file_path", ""))
content = ""
if tool_name == "Write":
raw_content = tool_input.get("content", "")
content = "" if raw_content is None else str(raw_content)
elif tool_name == "Edit":
new_string = tool_input.get("new_string", "")
content = "" if new_string is None else str(new_string)
elif tool_name == "MultiEdit":
edits = tool_input.get("edits", [])
if isinstance(edits, list):
edits_list = cast("list[object]", edits)
parts: list[str] = []
for edit in edits_list:
if not isinstance(edit, dict):
continue
edit_dict = cast("dict[str, object]", edit)
new_str = edit_dict.get("new_string")
parts.append("") if new_str is None else parts.append(str(new_str))
content = "\n".join(parts)
# Only analyze Python files
if not file_path or not file_path.endswith(".py") or not content:
return _create_hook_response("PreToolUse", "allow")
# Skip analysis for configured patterns
if should_skip_file(file_path, config):
return _create_hook_response("PreToolUse", "allow")
enable_type_checks = tool_name == "Write"
any_usage_issues = _detect_any_usage(content)
try:
_has_issues, issues = _perform_quality_check(
file_path,
content,
config,
enable_type_checks=enable_type_checks,
)
all_issues = any_usage_issues + issues
if not all_issues:
return _create_hook_response("PreToolUse", "allow")
if any_usage_issues:
return _handle_quality_issues(
file_path,
all_issues,
config,
forced_permission="deny",
)
return _handle_quality_issues(file_path, all_issues, config)
except Exception as e: # noqa: BLE001
if any_usage_issues:
return _handle_quality_issues(
file_path,
any_usage_issues,
config,
forced_permission="deny",
)
return _create_hook_response(
"PreToolUse",
"allow",
f"Warning: Code quality check failed with error: {e}",
f"Warning: Code quality check failed with error: {e}",
)
def posttooluse_hook(
hook_data: JsonObject,
config: QualityConfig,
) -> JsonObject:
"""Handle PostToolUse hook - verify quality after write/edit."""
tool_name: str = str(hook_data.get("tool_name", ""))
tool_output = hook_data.get("tool_response", hook_data.get("tool_output", {}))
# Only process write/edit tools
if tool_name not in ["Write", "Edit", "MultiEdit"]:
return _create_hook_response("PostToolUse")
# Extract file path from output
file_path: str = ""
if isinstance(tool_output, dict):
tool_output_dict = cast("dict[str, object]", tool_output)
file_path = str(
tool_output_dict.get("file_path", "") or tool_output_dict.get("path", ""),
)
elif isinstance(tool_output, str) and (
match := re.search(r"([/\w\-_.]+\.py)", tool_output)
):
file_path = match[1]
if not file_path or not file_path.endswith(".py"):
return _create_hook_response("PostToolUse")
if not Path(file_path).exists():
return _create_hook_response("PostToolUse")
issues: list[str] = []
# Check state changes if tracking enabled
if config.state_tracking_enabled:
delta_issues = check_state_changes(file_path)
issues.extend(delta_issues)
# Run cross-file duplicate detection if enabled
if config.cross_file_check_enabled:
cross_file_issues = check_cross_file_duplicates(file_path, config)
issues.extend(cross_file_issues)
# Verify naming conventions if enabled
if config.verify_naming:
naming_issues = verify_naming_conventions(file_path)
issues.extend(naming_issues)
# Format response
if issues:
message = (
f"📝 Post-write quality notes for {Path(file_path).name}:\n"
+ "\n".join(issues)
)
return _create_hook_response(
"PostToolUse",
"",
message,
message,
message,
decision="block",
)
if config.show_success:
message = f"{Path(file_path).name} passed post-write verification"
return _create_hook_response(
"PostToolUse",
"",
"",
message,
"",
decision="approve",
)
return _create_hook_response("PostToolUse")
def main() -> None:
"""Main hook entry point."""
try:
# Load configuration
config = QualityConfig.from_env()
# Read hook input from stdin
try:
hook_data: JsonObject = json.load(sys.stdin)
except json.JSONDecodeError:
fallback_response: JsonObject = {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow",
},
}
sys.stdout.write(json.dumps(fallback_response))
return
# Detect hook type: tool_response=PostToolUse, tool_input=PreToolUse
response: JsonObject
if "tool_response" in hook_data or "tool_output" in hook_data:
# PostToolUse hook
response = posttooluse_hook(hook_data, config)
else:
# PreToolUse hook
response = pretooluse_hook(hook_data, config)
print(json.dumps(response)) # noqa: T201
# Handle exit codes based on hook output
hook_output_raw = response.get("hookSpecificOutput", {})
if not isinstance(hook_output_raw, dict):
return
hook_output = cast("dict[str, object]", hook_output_raw)
permission_decision = hook_output.get("permissionDecision")
if permission_decision == "deny":
# Exit code 2: Blocking error - stderr fed back to Claude
reason = str(
hook_output.get("permissionDecisionReason", "Permission denied"),
)
_exit_with_reason(reason)
elif permission_decision == "ask":
# Also use exit code 2 for ask decisions to ensure Claude sees the message
reason = str(
hook_output.get("permissionDecisionReason", "Permission request"),
)
_exit_with_reason(reason)
# Exit code 0: Success (default)
except Exception as e: # noqa: BLE001
# Unexpected error - use exit code 1 (non-blocking error)
sys.stderr.write(f"Hook error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()