Files
claude-scripts/hooks/code_quality_guard.py
Travis Vasceannie 6a164be2e3 fix: switch to file-based locks for inter-process subprocess synchronization
- Replace threading locks with fcntl file-based locks for proper inter-process synchronization
- Hooks run as separate processes, so threading locks don't work across invocations
- Implement non-blocking lock acquisition to prevent hook deadlocks
- Use fcntl.flock on a shared lock file in /tmp/.claude_hooks/subprocess.lock
- Simplify lock usage with context manager pattern in both hooks
- Ensure graceful fallback if lock can't be acquired (e.g., due to concurrent hooks)

This properly fixes the API Error 400 concurrency issues by serializing subprocess
operations across all hook invocations, not just within a single process.
2025-10-21 04:59:02 +00:00

2126 lines
70 KiB
Python

"""Unified quality hook for Claude Code supporting both PreToolUse and PostToolUse.
Prevents writing duplicate, complex, or non-modernized code and verifies quality
after writes.
"""
import ast
import fcntl
import hashlib
import json
import logging
import os
import re
import shutil
import subprocess
import sys
import textwrap
import tokenize
from collections.abc import Callable
from contextlib import contextmanager, suppress
from dataclasses import dataclass
from datetime import UTC, datetime
from importlib import import_module
from io import StringIO
from pathlib import Path
from tempfile import NamedTemporaryFile, gettempdir
from typing import TYPE_CHECKING, TypedDict, cast
# Import message enrichment helpers
try:
from .message_enrichment import EnhancedMessageFormatter
from .type_inference import TypeInferenceHelper
except ImportError:
# Fallback for direct execution
sys.path.insert(0, str(Path(__file__).parent))
from message_enrichment import EnhancedMessageFormatter
from type_inference import TypeInferenceHelper
# Import internal duplicate detector; fall back to local path when executed directly
if TYPE_CHECKING:
from .internal_duplicate_detector import (
Duplicate,
DuplicateResults,
detect_internal_duplicates,
)
else:
try:
from .internal_duplicate_detector import (
Duplicate,
DuplicateResults,
detect_internal_duplicates,
)
except ImportError:
sys.path.insert(0, str(Path(__file__).parent))
module = import_module("internal_duplicate_detector")
Duplicate = module.Duplicate
DuplicateResults = module.DuplicateResults
detect_internal_duplicates = module.detect_internal_duplicates
# File-based lock for inter-process synchronization
def _get_lock_file() -> Path:
"""Get path to lock file for subprocess serialization."""
lock_dir = Path(gettempdir()) / ".claude_hooks"
lock_dir.mkdir(exist_ok=True, mode=0o700)
return lock_dir / "subprocess.lock"
@contextmanager
def _subprocess_lock(timeout: float = 10.0):
"""Context manager for file-based subprocess locking.
Args:
timeout: Timeout in seconds for acquiring lock (not used, non-blocking).
Yields:
True if lock was acquired, False if timeout occurred.
"""
lock_file = _get_lock_file()
# Open or create lock file
with open(lock_file, "a") as f:
try:
# Try to acquire exclusive lock (non-blocking)
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
yield True
except (IOError, OSError):
# Lock is held by another process, skip to avoid blocking
yield False
finally:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except (IOError, OSError):
pass
SUSPICIOUS_SUFFIXES: tuple[str, ...] = (
"enhanced",
"improved",
"better",
"new",
"updated",
"modified",
"refactored",
"optimized",
"fixed",
"clean",
"simple",
"advanced",
"basic",
"complete",
"final",
"latest",
"current",
"temp",
"temporary",
"backup",
"old",
"legacy",
"unified",
"merged",
"combined",
"integrated",
"consolidated",
"extended",
"enriched",
"augmented",
"upgraded",
"revised",
"polished",
"streamlined",
"simplified",
"modernized",
"normalized",
"sanitized",
"validated",
"verified",
"corrected",
"patched",
"stable",
"experimental",
"alpha",
"beta",
"draft",
"preliminary",
"prototype",
"working",
"test",
"debug",
"custom",
"special",
"generic",
"specific",
"general",
"detailed",
"minimal",
"full",
"partial",
"quick",
"fast",
"slow",
"smart",
"intelligent",
"auto",
"manual",
"secure",
"safe",
"robust",
"flexible",
"dynamic",
"static",
"reactive",
"async",
"sync",
"parallel",
"serial",
"distributed",
"centralized",
"decentralized",
)
FILE_SUFFIX_DUPLICATE_MSG = (
"⚠️ File '{current}' appears to be a suffixed duplicate of '{original}'. "
"Consider refactoring instead of creating variations with adjective "
"suffixes."
)
EXISTING_FILE_DUPLICATE_MSG = (
"⚠️ Creating '{current}' when '{existing}' already exists suggests "
"duplication. Consider consolidating or using a more descriptive name."
)
NAME_SUFFIX_DUPLICATE_MSG = (
"⚠️ {kind} '{name}' appears to be a suffixed duplicate of '{base}'. "
"Consider refactoring instead of creating variations."
)
def get_external_context(
rule_id: str,
content: str,
_file_path: str,
config: "QualityConfig",
) -> str:
"""Fetch additional guidance from optional integrations."""
context_parts: list[str] = []
if config.context7_enabled and config.context7_api_key:
try:
context7_context = _get_context7_analysis(
rule_id,
content,
config.context7_api_key,
)
except (OSError, RuntimeError, ValueError, TimeoutError) as exc:
logging.debug("Context7 API call failed: %s", exc)
else:
if context7_context:
context_parts.append(f"📊 Context7 Analysis: {context7_context}")
if config.firecrawl_enabled and config.firecrawl_api_key:
try:
firecrawl_examples = _get_firecrawl_examples(
rule_id,
config.firecrawl_api_key,
)
except (OSError, RuntimeError, ValueError, TimeoutError) as exc:
logging.debug("Firecrawl API call failed: %s", exc)
else:
if firecrawl_examples:
context_parts.append(f"🔗 Additional Examples: {firecrawl_examples}")
return "\n\n".join(context_parts)
def _get_context7_analysis(
rule_id: str,
_content: str,
_api_key: str,
) -> str:
"""Placeholder for Context7 API integration."""
context_map = {
"no-conditionals-in-tests": (
"Use pytest.mark.parametrize instead of branching inside tests."
),
"no-loop-in-tests": (
"Split loops into individual tests or parameterize the data for clarity."
),
"raise-specific-error": (
"Raise precise exception types (ValueError, custom errors, etc.) "
"so failures highlight the right behaviour."
),
"dont-import-test-modules": (
"Production code should not depend on test helpers; "
"extract shared logic into utility modules."
),
}
return context_map.get(
rule_id,
"Context7 guidance will appear once the integration is available.",
)
def _get_firecrawl_examples(rule_id: str, _api_key: str) -> str:
"""Placeholder for Firecrawl API integration."""
examples_map = {
"no-conditionals-in-tests": (
"Pytest parameterization tutorial: "
"docs.pytest.org/en/latest/how-to/parametrize.html"
),
"no-loop-in-tests": (
"Testing best practices: keep scenarios in separate "
"tests for precise failures."
),
"raise-specific-error": (
"Python docs on exceptions: docs.python.org/3/tutorial/errors.html"
),
"dont-import-test-modules": (
"Clean architecture tip: production modules should not import from tests."
),
}
return examples_map.get(
rule_id,
"Firecrawl examples will appear once the integration is available.",
)
def generate_test_quality_guidance(
rule_id: str,
content: str,
_file_path: str,
_config: "QualityConfig",
) -> str:
"""Return enriched guidance for test quality rule violations."""
function_name = "test_function"
match = re.search(r"def\s+(\w+)\s*\(", content)
if match:
function_name = match.group(1)
# Extract a small snippet of the violating code
code_snippet = ""
if match:
# Try to get a few lines around the function definition
lines = content.splitlines()
for i, line in enumerate(lines):
if f"def {function_name}" in line:
snippet_start = max(0, i)
snippet_end = min(len(lines), i + 10) # Show first 10 lines
code_snippet = "\n".join(lines[snippet_start:snippet_end])
break
# Use enhanced formatter for rich test quality messages
return EnhancedMessageFormatter.format_test_quality_message(
rule_id=rule_id,
function_name=function_name,
code_snippet=code_snippet,
include_examples=True,
)
class ToolConfig(TypedDict):
"""Configuration for a type checking tool."""
args: list[str]
error_check: Callable[[subprocess.CompletedProcess[str]], bool]
error_message: str | Callable[[subprocess.CompletedProcess[str]], str]
class ComplexitySummary(TypedDict):
"""Summary of complexity analysis."""
average_cyclomatic_complexity: float
class ComplexityDistribution(TypedDict):
"""Distribution of complexity levels."""
High: int
Very_High: int
Extreme: int
class ComplexityResults(TypedDict):
"""Results from complexity analysis."""
summary: ComplexitySummary
distribution: ComplexityDistribution
class TypeCheckingResults(TypedDict):
"""Results from type checking analysis."""
issues: list[str]
class AnalysisResults(TypedDict, total=False):
"""Complete analysis results from quality checks."""
internal_duplicates: DuplicateResults
complexity: ComplexityResults
type_checking: TypeCheckingResults
modernization: dict[str, object] # JSON structure varies
# Type aliases for JSON-like structures
JsonObject = dict[str, object]
JsonValue = str | int | float | bool | None | list[object] | JsonObject
@dataclass
class QualityConfig:
"""Configuration for quality checks."""
# Core settings
duplicate_threshold: float = 0.7
duplicate_enabled: bool = True
complexity_threshold: int = 10
complexity_enabled: bool = True
modernization_enabled: bool = True
require_type_hints: bool = True
enforcement_mode: str = "strict" # strict/warn/permissive
# Type checking tools
sourcery_enabled: bool = True
basedpyright_enabled: bool = True
pyrefly_enabled: bool = True
type_check_exit_code: int = 2
# PostToolUse features
state_tracking_enabled: bool = False
cross_file_check_enabled: bool = False
verify_naming: bool = True
show_success: bool = False
# Test quality checks
test_quality_enabled: bool = True
# External context providers
context7_enabled: bool = False
context7_api_key: str = ""
firecrawl_enabled: bool = False
firecrawl_api_key: str = ""
# File patterns
skip_patterns: list[str] | None = None
def __post_init__(self) -> None:
if self.skip_patterns is None:
self.skip_patterns = ["test_", "_test.py", "/tests/", "/fixtures/"]
@classmethod
def from_env(cls) -> "QualityConfig":
"""Load config from environment variables."""
return cls(
duplicate_threshold=float(os.getenv("QUALITY_DUP_THRESHOLD", "0.7")),
duplicate_enabled=os.getenv("QUALITY_DUP_ENABLED", "true").lower()
== "true",
complexity_threshold=int(os.getenv("QUALITY_COMPLEXITY_THRESHOLD", "10")),
complexity_enabled=os.getenv("QUALITY_COMPLEXITY_ENABLED", "true").lower()
== "true",
modernization_enabled=os.getenv("QUALITY_MODERN_ENABLED", "true").lower()
== "true",
require_type_hints=os.getenv("QUALITY_REQUIRE_TYPES", "true").lower()
== "true",
enforcement_mode=os.getenv("QUALITY_ENFORCEMENT", "strict"),
state_tracking_enabled=os.getenv("QUALITY_STATE_TRACKING", "false").lower()
== "true",
cross_file_check_enabled=os.getenv(
"QUALITY_CROSS_FILE_CHECK",
"false",
).lower()
== "true",
verify_naming=os.getenv("QUALITY_VERIFY_NAMING", "true").lower() == "true",
show_success=os.getenv("QUALITY_SHOW_SUCCESS", "false").lower() == "true",
sourcery_enabled=os.getenv("QUALITY_SOURCERY_ENABLED", "true").lower()
== "true",
basedpyright_enabled=os.getenv(
"QUALITY_BASEDPYRIGHT_ENABLED",
"true",
).lower()
== "true",
pyrefly_enabled=os.getenv("QUALITY_PYREFLY_ENABLED", "true").lower()
== "true",
type_check_exit_code=int(os.getenv("QUALITY_TYPE_CHECK_EXIT_CODE", "2")),
test_quality_enabled=(
os.getenv("QUALITY_TEST_QUALITY_ENABLED", "true").lower() == "true"
),
context7_enabled=(
os.getenv("QUALITY_CONTEXT7_ENABLED", "false").lower() == "true"
),
context7_api_key=os.getenv("QUALITY_CONTEXT7_API_KEY", ""),
firecrawl_enabled=(
os.getenv("QUALITY_FIRECRAWL_ENABLED", "false").lower() == "true"
),
firecrawl_api_key=os.getenv("QUALITY_FIRECRAWL_API_KEY", ""),
)
def should_skip_file(file_path: str, config: QualityConfig) -> bool:
"""Check if file should be skipped based on patterns."""
if config.skip_patterns is None:
return False
return any(pattern in file_path for pattern in config.skip_patterns)
def _module_candidate(path: Path) -> tuple[Path, list[str]]:
"""Build a module invocation candidate for a python executable."""
return path, [str(path), "-m", "quality.cli.main"]
def _cli_candidate(path: Path) -> tuple[Path, list[str]]:
"""Build a direct CLI invocation candidate."""
return path, [str(path)]
def get_claude_quality_command(repo_root: Path | None = None) -> list[str]:
"""Return a path-resilient command for invoking claude-quality."""
repo_root = repo_root or Path(__file__).resolve().parent.parent
platform_name = sys.platform
is_windows = platform_name.startswith("win")
scripts_dir = repo_root / ".venv" / ("Scripts" if is_windows else "bin")
python_names = (
["python.exe", "python3.exe"] if is_windows else ["python", "python3"]
)
cli_names = (
["claude-quality.exe", "claude-quality"] if is_windows else ["claude-quality"]
)
candidates: list[tuple[Path, list[str]]] = []
for name in python_names:
candidates.append(_module_candidate(scripts_dir / name))
for name in cli_names:
candidates.append(_cli_candidate(scripts_dir / name))
for candidate_path, command in candidates:
if candidate_path.exists():
return command
interpreter_fallbacks = ["python"] if is_windows else ["python3", "python"]
for interpreter in interpreter_fallbacks:
if shutil.which(interpreter):
return [interpreter, "-m", "quality.cli.main"]
if shutil.which("claude-quality"):
return ["claude-quality"]
message = (
"'claude-quality' was not found on PATH. Please ensure it is installed and "
"available."
)
raise RuntimeError(message)
def _get_project_venv_bin(file_path: str | None = None) -> Path:
"""Get the virtual environment bin directory for the current project.
Args:
file_path: Optional file path to determine project root from.
If not provided, uses current working directory.
"""
# Start from the file's directory if provided, otherwise from cwd
temp_root = Path(gettempdir()).resolve()
if file_path:
resolved_path = Path(file_path).resolve()
try:
if resolved_path.is_relative_to(temp_root):
start_path = Path.cwd()
else:
start_path = resolved_path.parent
except ValueError:
start_path = resolved_path.parent
else:
start_path = Path.cwd()
current = start_path
while current != current.parent:
venv_candidate = current / ".venv"
if venv_candidate.exists() and venv_candidate.is_dir():
bin_dir = venv_candidate / "bin"
if bin_dir.exists():
return bin_dir
current = current.parent
# Fallback to claude-scripts venv if no project venv found
return Path(__file__).parent.parent / ".venv" / "bin"
def _format_basedpyright_errors(json_output: str) -> str:
"""Format basedpyright JSON output into readable error messages."""
try:
data = json.loads(json_output)
diagnostics = data.get("generalDiagnostics", [])
if not diagnostics:
return "Type errors found (no details available)"
# Group by severity and format
errors: list[str] = []
for diag in diagnostics[:10]: # Limit to first 10 errors
severity = diag.get("severity", "error").upper()
message = diag.get("message", "Unknown error")
rule = diag.get("rule", "")
range_info = diag.get("range", {})
start = range_info.get("start", {})
line = start.get("line", 0) + 1 # Convert 0-indexed to 1-indexed
rule_text = f" [{rule}]" if rule else ""
errors.append(f" [{severity}] Line {line}: {message}{rule_text}")
count = len(diagnostics)
summary = f"Found {count} type error{'s' if count != 1 else ''}"
if count > 10:
summary += " (showing first 10)"
return f"{summary}:\n" + "\n".join(errors)
except (json.JSONDecodeError, KeyError, TypeError):
return "Type errors found (failed to parse details)"
def _format_pyrefly_errors(output: str) -> str:
"""Format pyrefly output into readable error messages."""
if not output or not output.strip():
return "Type errors found (no details available)"
# Pyrefly already has pretty good formatting, but let's clean it up
lines = output.strip().split("\n")
# Count ERROR lines to provide summary
error_count = sum(1 for line in lines if line.strip().startswith("ERROR"))
if error_count == 0:
return output.strip()
summary = f"Found {error_count} type error{'s' if error_count != 1 else ''}"
return f"{summary}:\n{output.strip()}"
def _format_sourcery_errors(output: str) -> str:
"""Format sourcery output into readable error messages."""
if not output or not output.strip():
return "Code quality issues found (no details available)"
# Extract issue count if present
lines = output.strip().split("\n")
# Sourcery typically outputs: "✖ X issues detected"
issue_count = 0
for line in lines:
if "issue" in line.lower() and "detected" in line.lower():
# Try to extract the number
import re
match = re.search(r"(\d+)\s+issue", line)
if match:
issue_count = int(match.group(1))
break
# Format the output, removing redundant summary lines
formatted_lines: list[str] = []
for line in lines:
# Skip the summary line as we'll add our own
if "issue" in line.lower() and "detected" in line.lower():
continue
# Skip empty lines at start/end
if line.strip():
formatted_lines.append(line)
if issue_count > 0:
plural = "issue" if issue_count == 1 else "issues"
summary = f"Found {issue_count} code quality {plural}"
return f"{summary}:\n" + "\n".join(formatted_lines)
return output.strip()
def _ensure_tool_installed(tool_name: str) -> bool:
"""Ensure a type checking tool is installed in the virtual environment."""
venv_bin = _get_project_venv_bin()
tool_path = venv_bin / tool_name
if tool_path.exists():
return True
# Try to install using uv if available
try:
result = subprocess.run( # noqa: S603
[str(venv_bin / "uv"), "pip", "install", tool_name],
check=False,
capture_output=True,
text=True,
timeout=60,
)
except (subprocess.TimeoutExpired, OSError):
return False
else:
return result.returncode == 0
def _run_type_checker(
tool_name: str,
file_path: str,
_config: QualityConfig,
*,
original_file_path: str | None = None,
) -> tuple[bool, str]:
"""Run a type checking tool and return success status and output."""
venv_bin = _get_project_venv_bin(original_file_path or file_path)
tool_path = venv_bin / tool_name
if not tool_path.exists() and not _ensure_tool_installed(tool_name):
return True, f"Warning: {tool_name} not available"
# Tool configuration mapping
tool_configs: dict[str, ToolConfig] = {
"basedpyright": ToolConfig(
args=["--outputjson", file_path],
error_check=lambda result: result.returncode == 1,
error_message=lambda result: _format_basedpyright_errors(result.stdout),
),
"pyrefly": ToolConfig(
args=["check", file_path],
error_check=lambda result: result.returncode == 1,
error_message=lambda result: _format_pyrefly_errors(result.stdout),
),
"sourcery": ToolConfig(
args=["review", file_path],
error_check=lambda result: (
"issues detected" in str(result.stdout)
and "0 issues detected" not in str(result.stdout)
),
error_message=lambda result: _format_sourcery_errors(result.stdout),
),
}
tool_config = tool_configs.get(tool_name)
if not tool_config:
return True, f"Warning: Unknown tool {tool_name}"
# Acquire file-based lock to prevent subprocess concurrency issues
with _subprocess_lock(timeout=10.0) as acquired:
if not acquired:
return True, f"Warning: {tool_name} lock timeout"
try:
cmd = [str(tool_path)] + tool_config["args"]
# Activate virtual environment for the subprocess
env = os.environ.copy()
env["VIRTUAL_ENV"] = str(venv_bin.parent)
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
# Remove any PYTHONHOME that might interfere
env.pop("PYTHONHOME", None)
# Add PYTHONPATH=src if src directory exists in project root
# This allows type checkers to resolve imports from src/
project_root = venv_bin.parent.parent # Go from .venv/bin to project root
src_dir = project_root / "src"
if src_dir.exists() and src_dir.is_dir():
existing_pythonpath = env.get("PYTHONPATH", "")
if existing_pythonpath:
env["PYTHONPATH"] = f"{src_dir}:{existing_pythonpath}"
else:
env["PYTHONPATH"] = str(src_dir)
# Run type checker from project root to pick up project configuration files
result = subprocess.run( # noqa: S603
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
env=env,
cwd=str(project_root),
)
# Check for tool-specific errors
error_check = tool_config["error_check"]
if error_check(result):
error_msg = tool_config["error_message"]
message = str(error_msg(result) if callable(error_msg) else error_msg)
return False, message
# Return success or warning
if result.returncode == 0:
return True, ""
return True, f"Warning: {tool_name} error (exit {result.returncode})"
except subprocess.TimeoutExpired:
return True, f"Warning: {tool_name} timeout"
except OSError:
return True, f"Warning: {tool_name} execution error"
def _initialize_analysis() -> tuple[AnalysisResults, list[str]]:
"""Initialize analysis with empty results and claude-quality command."""
results: AnalysisResults = {}
claude_quality_cmd: list[str] = get_claude_quality_command()
return results, claude_quality_cmd
def run_type_checks(
file_path: str,
config: QualityConfig,
*,
original_file_path: str | None = None,
) -> list[str]:
"""Run all enabled type checking tools and return any issues."""
issues: list[str] = []
# Run Sourcery
if config.sourcery_enabled:
success, output = _run_type_checker(
"sourcery",
file_path,
config,
original_file_path=original_file_path,
)
if not success and output:
issues.append(f"Sourcery: {output.strip()}")
# Run BasedPyright
if config.basedpyright_enabled:
success, output = _run_type_checker(
"basedpyright",
file_path,
config,
original_file_path=original_file_path,
)
if not success and output:
issues.append(f"BasedPyright: {output.strip()}")
# Run Pyrefly
if config.pyrefly_enabled:
success, output = _run_type_checker(
"pyrefly",
file_path,
config,
original_file_path=original_file_path,
)
if not success and output:
issues.append(f"Pyrefly: {output.strip()}")
return issues
def _run_quality_analyses(
content: str,
tmp_path: str,
config: QualityConfig,
enable_type_checks: bool,
*,
original_file_path: str | None = None,
) -> AnalysisResults:
"""Run all quality analysis checks and return results."""
results, claude_quality_cmd = _initialize_analysis()
# First check for internal duplicates within the file
if config.duplicate_enabled:
internal_duplicates = detect_internal_duplicates(
content,
threshold=config.duplicate_threshold,
min_lines=4,
)
if internal_duplicates.get("duplicates"):
results["internal_duplicates"] = internal_duplicates
# Run complexity analysis
if config.complexity_enabled:
cmd = [
*claude_quality_cmd,
"complexity",
tmp_path,
"--threshold",
str(config.complexity_threshold),
"--format",
"json",
]
# Prepare virtual environment for subprocess
venv_bin = _get_project_venv_bin(original_file_path)
env = os.environ.copy()
env["VIRTUAL_ENV"] = str(venv_bin.parent)
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
env.pop("PYTHONHOME", None)
# Acquire lock to prevent subprocess concurrency issues
with _subprocess_lock(timeout=10.0) as acquired:
if acquired:
with suppress(subprocess.TimeoutExpired):
result = subprocess.run( # noqa: S603
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
env=env,
)
if result.returncode == 0:
with suppress(json.JSONDecodeError):
results["complexity"] = json.loads(result.stdout)
# Run type checking if any tool is enabled
if enable_type_checks and any(
[
config.sourcery_enabled,
config.basedpyright_enabled,
config.pyrefly_enabled,
],
):
try:
if type_issues := run_type_checks(
tmp_path,
config,
original_file_path=original_file_path,
):
results["type_checking"] = {"issues": type_issues}
except Exception as e: # noqa: BLE001
logging.debug("Type checking failed: %s", e)
# Run modernization analysis
if config.modernization_enabled:
cmd = [
*claude_quality_cmd,
"modernization",
tmp_path,
"--include-type-hints" if config.require_type_hints else "",
"--format",
"json",
]
cmd = [c for c in cmd if c] # Remove empty strings
# Prepare virtual environment for subprocess
venv_bin = _get_project_venv_bin(original_file_path)
env = os.environ.copy()
env["VIRTUAL_ENV"] = str(venv_bin.parent)
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
env.pop("PYTHONHOME", None)
# Acquire lock to prevent subprocess concurrency issues
with _subprocess_lock(timeout=10.0) as acquired:
if acquired:
with suppress(subprocess.TimeoutExpired):
result = subprocess.run( # noqa: S603
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
env=env,
)
if result.returncode == 0:
with suppress(json.JSONDecodeError):
results["modernization"] = json.loads(result.stdout)
return results
def _find_project_root(file_path: str) -> Path:
"""Find project root by looking for common markers."""
file_path_obj = Path(file_path).resolve()
current = file_path_obj.parent
# Look for common project markers
while current != current.parent:
if any(
(current / marker).exists()
for marker in [
".git",
"pyrightconfig.json",
"pyproject.toml",
".venv",
"setup.py",
]
):
return current
current = current.parent
# Fallback to parent directory
return file_path_obj.parent
def _get_project_tmp_dir(file_path: str) -> Path:
"""Get or create .tmp directory in project root."""
project_root = _find_project_root(file_path)
tmp_dir = project_root / ".tmp"
tmp_dir.mkdir(exist_ok=True)
# Ensure .tmp is gitignored
gitignore = project_root / ".gitignore"
if gitignore.exists():
content = gitignore.read_text()
if ".tmp/" not in content and ".tmp" not in content:
# Add .tmp/ to .gitignore
with gitignore.open("a") as f:
f.write("\n# Temporary files created by code quality hooks\n.tmp/\n")
return tmp_dir
def analyze_code_quality(
content: str,
file_path: str,
config: QualityConfig,
*,
enable_type_checks: bool = True,
) -> AnalysisResults:
"""Analyze code content using claude-quality toolkit."""
suffix = Path(file_path).suffix or ".py"
# Create temp file in project directory, not /tmp, so it inherits config files
# like pyrightconfig.json, pyproject.toml, etc.
tmp_dir = _get_project_tmp_dir(file_path)
# Create temp file in project's .tmp directory
with NamedTemporaryFile(
mode="w",
suffix=suffix,
delete=False,
dir=str(tmp_dir),
prefix="hook_validation_",
) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
return _run_quality_analyses(
content,
tmp_path,
config,
enable_type_checks,
original_file_path=file_path,
)
finally:
Path(tmp_path).unlink(missing_ok=True)
def _check_internal_duplicates(
results: AnalysisResults,
source_code: str = "",
file_path: str = "",
) -> list[str]:
"""Check for internal duplicate code within the same file."""
issues: list[str] = []
if "internal_duplicates" not in results:
return issues
duplicates: list[Duplicate] = results["internal_duplicates"].get(
"duplicates",
[],
)
for dup in duplicates[:3]: # Show first 3
# Use enhanced formatter for rich duplicate messages
duplicate_type = dup.get("type", "unknown")
similarity = dup.get("similarity", 0.0)
locations_raw = dup.get("locations", [])
# Cast to list of dicts for the formatter
locations_dicts: list[dict[str, str]] = [
{
"name": str(loc.get("name", "unknown")),
"type": str(loc.get("type", "code")),
"lines": str(loc.get("lines", "?")),
}
for loc in locations_raw
]
enriched_message = EnhancedMessageFormatter.format_duplicate_message(
duplicate_type=str(duplicate_type),
similarity=float(similarity),
locations=locations_dicts,
source_code=source_code,
include_refactoring=True,
file_path=file_path,
)
issues.append(enriched_message)
return issues
def _check_complexity_issues(
results: AnalysisResults,
config: QualityConfig,
) -> list[str]:
"""Check for code complexity issues."""
issues: list[str] = []
if "complexity" not in results:
return issues
complexity_data = results["complexity"]
summary = complexity_data.get("summary", {})
avg_cc = summary.get("average_cyclomatic_complexity", 0.0)
distribution = complexity_data.get("distribution", {})
# Only block on Very High (21-50) and Extreme (50+), not High (11-20)
# Radon's "High" category (CC 11-20) is acceptable for moderately complex functions
critical_count: int = (
distribution.get("Very High", 0)
+ distribution.get("Extreme", 0)
)
high_count: int = distribution.get("High", 0)
# Block only if average is too high OR if any functions are critically complex (CC > 20)
# This allows individual functions with CC 11-20 (Radon's "High" category)
if avg_cc > config.complexity_threshold or critical_count > 0:
# Use enhanced formatter for rich complexity messages
enriched_message = EnhancedMessageFormatter.format_complexity_message(
avg_complexity=avg_cc,
threshold=config.complexity_threshold,
high_count=critical_count + high_count, # Show total for context
)
issues.append(enriched_message)
# Note: Functions with CC 11-20 are allowed through without blocking
# Users can check these with: claude-quality complexity <file>
return issues
def _check_modernization_issues(
results: AnalysisResults,
config: QualityConfig,
) -> list[str]:
"""Check for code modernization issues."""
issues: list[str] = []
if "modernization" not in results:
return issues
try:
modernization_data = results["modernization"]
files = modernization_data.get("files", {})
if not isinstance(files, dict):
return issues
except (AttributeError, TypeError):
return issues
total_issues: int = 0
issue_types: set[str] = set()
files_values = cast("dict[str, object]", files).values()
for file_issues_raw in files_values:
if not isinstance(file_issues_raw, list):
continue
# Cast to proper type after runtime check
file_issues_list = cast("list[object]", file_issues_raw)
total_issues += len(file_issues_list)
for issue_raw in file_issues_list:
if not isinstance(issue_raw, dict):
continue
# Cast to proper type after runtime check
issue_dict = cast("dict[str, object]", issue_raw)
issue_type_raw = issue_dict.get("issue_type", "unknown")
if isinstance(issue_type_raw, str):
issue_types.add(issue_type_raw)
# Only flag if there are non-type-hint issues or many type hint issues
non_type_issues: int = len(
[t for t in issue_types if "type" not in t and "typing" not in t],
)
type_issues: int = total_issues - non_type_issues
if non_type_issues > 0:
non_type_list: list[str] = [
t for t in issue_types if "type" not in t and "typing" not in t
]
issues.append(
f"Modernization needed: {non_type_issues} non-type issues "
f"({', '.join(non_type_list)})",
)
elif config.require_type_hints and type_issues > 10:
issues.append(
f"Many missing type hints: {type_issues} functions/parameters "
"lacking annotations",
)
return issues
def _check_type_checking_issues(
results: AnalysisResults,
source_code: str = "",
) -> list[str]:
"""Check for type checking issues from Sourcery, BasedPyright, and Pyrefly."""
issues: list[str] = []
if "type_checking" not in results:
return issues
with suppress(AttributeError, TypeError):
type_checking_data = results["type_checking"]
type_issues = type_checking_data.get("issues", [])
# Group by tool and format with enhanced messages
for issue_str in type_issues[:3]: # Limit to first 3 for brevity
issue_text = str(issue_str)
# Extract tool name (usually starts with "Tool:")
tool_name = "Type Checker"
if ":" in issue_text:
potential_tool = issue_text.split(":")[0].strip()
if potential_tool in ("Sourcery", "BasedPyright", "Pyrefly"):
tool_name = potential_tool
issue_text = ":".join(issue_text.split(":")[1:]).strip()
enriched_message = EnhancedMessageFormatter.format_type_error_message(
tool_name=tool_name,
error_output=issue_text,
source_code=source_code,
)
issues.append(enriched_message)
return issues
def check_code_issues(
results: AnalysisResults,
config: QualityConfig,
source_code: str = "",
file_path: str = "",
) -> tuple[bool, list[str]]:
"""Check analysis results for issues that should block the operation."""
issues: list[str] = []
issues.extend(_check_internal_duplicates(results, source_code, file_path))
issues.extend(_check_complexity_issues(results, config))
issues.extend(_check_modernization_issues(results, config))
issues.extend(_check_type_checking_issues(results, source_code))
return len(issues) > 0, issues
def store_pre_state(file_path: str, content: str) -> None:
"""Store file state before modification for later comparison."""
import tempfile
cache_dir = Path(tempfile.gettempdir()) / ".quality_state"
cache_dir.mkdir(exist_ok=True, mode=0o700)
state: dict[str, str | int] = {
"file_path": file_path,
"timestamp": datetime.now(UTC).isoformat(),
"content_hash": hashlib.sha256(content.encode()).hexdigest(),
"lines": len(content.split("\n")),
"functions": content.count("def "),
"classes": content.count("class "),
}
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
cache_file: Path = cache_dir / f"{cache_key}_pre.json"
cache_file.write_text(json.dumps(state, indent=2))
def check_state_changes(file_path: str) -> list[str]:
"""Check for quality changes between pre and post states."""
import tempfile
issues: list[str] = []
cache_dir: Path = Path(tempfile.gettempdir()) / ".quality_state"
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
pre_file: Path = cache_dir / f"{cache_key}_pre.json"
if not pre_file.exists():
return issues
try:
pre_state = json.loads(pre_file.read_text())
try:
current_content = Path(file_path).read_text()
except OSError:
return issues # Can't compare if can't read file
current_lines = len(current_content.split("\n"))
current_functions = current_content.count("def ")
# Check for significant changes
if current_functions < pre_state.get("functions", 0):
issues.append(
f"⚠️ Reduced functions: {pre_state['functions']}{current_functions}",
)
if current_lines > pre_state.get("lines", 0) * 1.5:
issues.append(
"⚠️ File size increased significantly: "
f"{pre_state['lines']}{current_lines} lines",
)
except Exception: # noqa: BLE001
logging.debug("Could not analyze state changes for %s", file_path)
return issues
def check_cross_file_duplicates(file_path: str, config: QualityConfig) -> list[str]:
"""Check for duplicates across project files."""
issues: list[str] = []
# Get project root
project_root: Path = Path(file_path).parent
while (
project_root.parent != project_root
and not (project_root / ".git").exists()
and not (project_root / "pyproject.toml").exists()
):
project_root = project_root.parent
try:
claude_quality_cmd = get_claude_quality_command()
# Prepare virtual environment for subprocess
venv_bin = _get_project_venv_bin()
env = os.environ.copy()
env["VIRTUAL_ENV"] = str(venv_bin.parent)
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
env.pop("PYTHONHOME", None)
result = subprocess.run( # noqa: S603
[
*claude_quality_cmd,
"duplicates",
str(project_root),
"--threshold",
str(config.duplicate_threshold),
"--format",
"json",
],
check=False,
capture_output=True,
text=True,
timeout=60,
env=env,
)
if result.returncode == 0:
data = json.loads(result.stdout)
duplicates = data.get("duplicates", [])
if any(file_path in str(d) for d in duplicates):
issues.append("⚠️ Cross-file duplication detected")
except Exception: # noqa: BLE001
logging.debug("Could not check cross-file duplicates for %s", file_path)
return issues
def verify_naming_conventions(file_path: str) -> list[str]:
"""Verify PEP8 naming conventions."""
issues: list[str] = []
try:
content = Path(file_path).read_text()
except OSError:
return issues # Can't check naming if can't read file
# Check function names (should be snake_case)
if bad_funcs := re.findall(
r"def\s+([A-Z][a-zA-Z0-9_]*|[a-z]+[A-Z][a-zA-Z0-9_]*)\s*\(",
content,
):
issues.append(f"⚠️ Non-PEP8 function names: {', '.join(bad_funcs[:3])}")
# Check class names (should be PascalCase)
if bad_classes := re.findall(r"class\s+([a-z][a-z0-9_]*)\s*[\(:]", content):
issues.append(f"⚠️ Non-PEP8 class names: {', '.join(bad_classes[:3])}")
return issues
def _detect_any_usage(content: str) -> list[str]:
"""Detect forbidden typing.Any usage and suggest specific types."""
# Use type inference helper to find Any usage with context
any_usages = TypeInferenceHelper.find_any_usage_with_context(content)
# Fallback to line-by-line check for syntax errors
if not any_usages:
lines_with_any: set[int] = set()
for index, line in enumerate(content.splitlines(), start=1):
code_portion = line.split("#", 1)[0]
if re.search(r"\bAny\b", code_portion):
lines_with_any.add(index)
if lines_with_any:
sorted_lines = sorted(lines_with_any)
display_lines = ", ".join(str(num) for num in sorted_lines[:5])
if len(sorted_lines) > 5:
display_lines += ", …"
return [
f"⚠️ Forbidden typing.Any usage at line(s) {display_lines}; "
"replace with specific types",
]
if not any_usages:
return []
issues: list[str] = []
# Group by context type
by_context: dict[str, list[dict[str, str | int]]] = {}
for usage in any_usages:
context = str(usage.get("context", "unknown"))
if context not in by_context:
by_context[context] = []
by_context[context].append(usage)
# Format enriched messages for each context type
for context, usages_list in by_context.items():
lines = [str(u.get("line", "?")) for u in usages_list[:5]]
line_summary = ", ".join(lines)
if len(usages_list) > 5:
line_summary += ", ..."
# Get suggestions
suggestions: list[str] = []
for usage in usages_list[:3]: # Show first 3 suggestions
element = str(usage.get("element", ""))
suggested = str(usage.get("suggested", ""))
if suggested and suggested not in {"Any", "Infer from usage"}:
suggestions.append(f"{element}: {suggested}")
parts: list[str] = [
f"⚠️ Forbidden typing.Any Usage ({context})",
f"📍 Lines: {line_summary}",
]
if suggestions:
parts.append("💡 Suggested Types:")
parts.extend(suggestions)
else:
parts.append("💡 Tip: Replace `Any` with specific types based on usage")
parts.extend(
[
"",
"🔗 Common Replacements:",
" • dict[str, Any] → dict[str, int] (or appropriate value type)",
" • list[Any] → list[str] (or appropriate element type)",
(
" • Callable[..., Any] → Callable[[int, str], bool] "
"(with specific signature)"
),
]
)
issues.append("\n".join(parts))
return issues
def _detect_type_ignore_usage(content: str) -> list[str]:
"""Detect forbidden # type: ignore usage in proposed content."""
pattern = re.compile(r"#\s*type:\s*ignore(?:\b|\[)", re.IGNORECASE)
lines_with_type_ignore: set[int] = set()
try:
# Dedent the content to handle code fragments with leading indentation
dedented_content = textwrap.dedent(content)
for token_type, token_string, start, _, _ in tokenize.generate_tokens(
StringIO(dedented_content).readline,
):
if token_type == tokenize.COMMENT and pattern.search(token_string):
lines_with_type_ignore.add(start[0])
except (tokenize.TokenError, IndentationError):
for index, line in enumerate(content.splitlines(), start=1):
if pattern.search(line):
lines_with_type_ignore.add(index)
if not lines_with_type_ignore:
return []
sorted_lines = sorted(lines_with_type_ignore)
display_lines = ", ".join(str(num) for num in sorted_lines[:5])
if len(sorted_lines) > 5:
display_lines += ", …"
return [
"⚠️ Forbidden # type: ignore usage at line(s) "
f"{display_lines}; remove the suppression and fix typing issues instead",
]
def _detect_old_typing_patterns(content: str) -> list[str]:
"""Detect old typing patterns that should use modern syntax."""
# Old typing imports that should be replaced
old_patterns = {
r"\bfrom typing import.*\bUnion\b": (
"Use | syntax instead of Union (e.g., str | int)"
),
r"\bfrom typing import.*\bOptional\b": (
"Use | None syntax instead of Optional (e.g., str | None)"
),
r"\bfrom typing import.*\bList\b": "Use list[T] instead of List[T]",
r"\bfrom typing import.*\bDict\b": "Use dict[K, V] instead of Dict[K, V]",
r"\bfrom typing import.*\bSet\b": "Use set[T] instead of Set[T]",
r"\bfrom typing import.*\bTuple\b": (
"Use tuple[T, ...] instead of Tuple[T, ...]"
),
r"\bUnion\s*\[": "Use | syntax instead of Union (e.g., str | int)",
r"\bOptional\s*\[": "Use | None syntax instead of Optional (e.g., str | None)",
r"\bList\s*\[": "Use list[T] instead of List[T]",
r"\bDict\s*\[": "Use dict[K, V] instead of Dict[K, V]",
r"\bSet\s*\[": "Use set[T] instead of Set[T]",
r"\bTuple\s*\[": "Use tuple[T, ...] instead of Tuple[T, ...]",
}
lines = content.splitlines()
found_issues = []
for pattern, message in old_patterns.items():
lines_with_pattern = []
for i, line in enumerate(lines, 1):
# Skip comments
code_part = line.split("#")[0]
if re.search(pattern, code_part):
lines_with_pattern.append(i)
if lines_with_pattern:
display_lines: str = ", ".join(str(num) for num in lines_with_pattern[:5])
if len(lines_with_pattern) > 5:
display_lines += ", …"
issue_text: str = f"⚠️ Old typing pattern at line(s) {display_lines}: {message}"
found_issues.append(issue_text)
return found_issues
def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
"""Detect files and functions/classes with suspicious adjective/adverb suffixes."""
issues: list[str] = []
# Check file name against other files in the same directory
file_path_obj = Path(file_path)
if file_path_obj.parent.exists():
file_stem = file_path_obj.stem
file_suffix = file_path_obj.suffix
# Check if current file has suspicious suffix
for suffix in SUSPICIOUS_SUFFIXES:
for separator in ("_", "-"):
suffix_token = f"{separator}{suffix}"
if not file_stem.endswith(suffix_token):
continue
base_name = file_stem[: -len(suffix_token)]
potential_original = file_path_obj.parent / f"{base_name}{file_suffix}"
if potential_original.exists() and potential_original != file_path_obj:
message = FILE_SUFFIX_DUPLICATE_MSG.format(
current=file_path_obj.name,
original=potential_original.name,
)
issues.append(message)
break
else:
continue
break
# Check if any existing files are suffixed versions of current file
for existing_file in file_path_obj.parent.glob(f"{file_stem}_*{file_suffix}"):
if existing_file != file_path_obj:
existing_stem = existing_file.stem
if existing_stem.startswith(f"{file_stem}_"):
potential_suffix = existing_stem[len(file_stem) + 1 :]
if potential_suffix in SUSPICIOUS_SUFFIXES:
message = EXISTING_FILE_DUPLICATE_MSG.format(
current=file_path_obj.name,
existing=existing_file.name,
)
issues.append(message)
break
# Same check for dash-separated suffixes
for existing_file in file_path_obj.parent.glob(f"{file_stem}-*{file_suffix}"):
if existing_file != file_path_obj:
existing_stem = existing_file.stem
if existing_stem.startswith(f"{file_stem}-"):
potential_suffix = existing_stem[len(file_stem) + 1 :]
if potential_suffix in SUSPICIOUS_SUFFIXES:
message = EXISTING_FILE_DUPLICATE_MSG.format(
current=file_path_obj.name,
existing=existing_file.name,
)
issues.append(message)
break
# Check function and class names in content
try:
# Dedent the content to handle code fragments with leading indentation
tree = ast.parse(textwrap.dedent(content))
class SuffixVisitor(ast.NodeVisitor):
def __init__(self):
self.function_names: set[str] = set()
self.class_names: set[str] = set()
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.function_names.add(node.name)
self.generic_visit(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self.function_names.add(node.name)
self.generic_visit(node)
def visit_ClassDef(self, node: ast.ClassDef) -> None:
self.class_names.add(node.name)
self.generic_visit(node)
visitor = SuffixVisitor()
visitor.visit(tree)
# Check for suspicious function name patterns
for func_name in visitor.function_names:
for suffix in SUSPICIOUS_SUFFIXES:
suffix_token = f"_{suffix}"
if not func_name.endswith(suffix_token):
continue
base_name = func_name[: -len(suffix_token)]
if base_name in visitor.function_names:
message = NAME_SUFFIX_DUPLICATE_MSG.format(
kind="Function",
name=func_name,
base=base_name,
)
issues.append(message)
break
# Check for suspicious class name patterns
for class_name in visitor.class_names:
for suffix in SUSPICIOUS_SUFFIXES:
pascal_suffix = suffix.capitalize()
snake_suffix = f"_{suffix}"
potential_matches = (
(pascal_suffix, class_name[: -len(pascal_suffix)]),
(snake_suffix, class_name[: -len(snake_suffix)]),
)
for token, base_name in potential_matches:
if (
token
and class_name.endswith(token)
and base_name in visitor.class_names
):
message = NAME_SUFFIX_DUPLICATE_MSG.format(
kind="Class",
name=class_name,
base=base_name,
)
issues.append(message)
break
else:
continue
break
except SyntaxError:
# If we can't parse the AST, skip function/class checks
pass
return issues
def _perform_quality_check(
file_path: str,
content: str,
config: QualityConfig,
enable_type_checks: bool = True,
) -> tuple[bool, list[str]]:
"""Perform quality analysis and return issues."""
# Store state if tracking enabled
if config.state_tracking_enabled:
store_pre_state(file_path, content)
# Run quality analysis
results = analyze_code_quality(
content,
file_path,
config,
enable_type_checks=enable_type_checks,
)
return check_code_issues(results, config, content, file_path)
def _handle_quality_issues(
file_path: str,
issues: list[str],
config: QualityConfig,
*,
forced_permission: str | None = None,
) -> JsonObject:
"""Handle quality issues based on enforcement mode."""
# Prepare denial message with formatted issues
formatted_issues: list[str] = []
for issue in issues:
# Add indentation to multi-line issues for better readability
if "\n" in issue:
lines = issue.split("\n")
formatted_issues.append(f"{lines[0]}")
for line in lines[1:]:
formatted_issues.append(f" {line}")
else:
formatted_issues.append(f"{issue}")
message = (
f"Code quality check failed for {Path(file_path).name}:\n"
+ "\n".join(formatted_issues)
+ "\n\nFix these issues before writing the code."
)
# Make decision based on enforcement mode
if forced_permission:
return _create_hook_response("PreToolUse", forced_permission, message)
if config.enforcement_mode == "strict":
return _create_hook_response("PreToolUse", "deny", message)
if config.enforcement_mode == "warn":
return _create_hook_response("PreToolUse", "ask", message)
# permissive
warning_message = f"⚠️ Quality Warning:\n{message}"
return _create_hook_response(
"PreToolUse",
"allow",
warning_message,
warning_message,
)
def _exit_with_reason(reason: str, exit_code: int = 2) -> None:
"""Write reason to stderr and exit with specified code."""
sys.stderr.write(reason)
sys.exit(exit_code)
def _create_hook_response(
event_name: str,
permission: str = "",
reason: str = "",
system_message: str = "",
additional_context: str = "",
*,
decision: str | None = None,
) -> JsonObject:
"""Create standardized hook response."""
hook_output: dict[str, object] = {
"hookEventName": event_name,
}
if permission:
hook_output["permissionDecision"] = permission
if reason:
hook_output["permissionDecisionReason"] = reason
if additional_context:
hook_output["additionalContext"] = additional_context
response: JsonObject = {
"hookSpecificOutput": hook_output,
}
if permission:
response["permissionDecision"] = permission
if decision:
response["decision"] = decision
if reason:
response["reason"] = reason
if system_message:
response["systemMessage"] = system_message
return response
def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject:
"""Handle PreToolUse hook - analyze content before write/edit."""
tool_name = str(hook_data.get("tool_name", ""))
tool_input_raw = hook_data.get("tool_input", {})
if not isinstance(tool_input_raw, dict):
return _create_hook_response("PreToolUse", "allow")
tool_input = cast("dict[str, object]", tool_input_raw)
# Only analyze for write/edit tools
if tool_name not in ["Write", "Edit", "MultiEdit"]:
return _create_hook_response("PreToolUse", "allow")
# Extract content based on tool type
file_path = str(tool_input.get("file_path", ""))
content = ""
if tool_name == "Write":
raw_content = tool_input.get("content", "")
content = "" if raw_content is None else str(raw_content)
elif tool_name == "Edit":
new_string = tool_input.get("new_string", "")
content = "" if new_string is None else str(new_string)
elif tool_name == "MultiEdit":
edits = tool_input.get("edits", [])
if isinstance(edits, list):
edits_list = cast("list[object]", edits)
parts: list[str] = []
for edit in edits_list:
if not isinstance(edit, dict):
continue
edit_dict = cast("dict[str, object]", edit)
new_str = edit_dict.get("new_string")
parts.append("") if new_str is None else parts.append(str(new_str))
content = "\n".join(parts)
# Only analyze Python files
if not file_path or not file_path.endswith(".py") or not content:
return _create_hook_response("PreToolUse", "allow")
# Check if this is a test file and test quality checks are enabled
is_test = is_test_file(file_path)
run_test_checks = config.test_quality_enabled and is_test
enable_type_checks = tool_name == "Write"
# Always run core checks (Any, type: ignore, typing, duplicates) before skipping
any_usage_issues = _detect_any_usage(content)
type_ignore_issues = _detect_type_ignore_usage(content)
old_typing_issues = _detect_old_typing_patterns(content)
suffix_duplication_issues = _detect_suffix_duplication(file_path, content)
precheck_issues = (
any_usage_issues
+ type_ignore_issues
+ old_typing_issues
+ suffix_duplication_issues
)
# Run test quality checks if enabled and file is a test file
if run_test_checks:
test_quality_issues = run_test_quality_checks(content, file_path, config)
precheck_issues.extend(test_quality_issues)
# Skip detailed analysis for configured patterns unless test checks should run
# Note: Core quality checks (Any, type: ignore, duplicates) always run above
should_skip_detailed = should_skip_file(file_path, config) and not run_test_checks
try:
# Run detailed quality checks only if not skipping
if should_skip_detailed:
issues = []
else:
_has_issues, issues = _perform_quality_check(
file_path,
content,
config,
enable_type_checks=enable_type_checks,
)
all_issues = precheck_issues + issues
if not all_issues:
return _create_hook_response("PreToolUse", "allow")
if precheck_issues:
return _handle_quality_issues(
file_path,
all_issues,
config,
forced_permission="deny",
)
return _handle_quality_issues(file_path, all_issues, config)
except Exception as e: # noqa: BLE001
if precheck_issues:
return _handle_quality_issues(
file_path,
precheck_issues,
config,
forced_permission="deny",
)
return _create_hook_response(
"PreToolUse",
"allow",
f"Warning: Code quality check failed with error: {e}",
f"Warning: Code quality check failed with error: {e}",
)
def posttooluse_hook(
hook_data: JsonObject,
config: QualityConfig,
) -> JsonObject:
"""Handle PostToolUse hook - verify quality after write/edit."""
tool_name: str = str(hook_data.get("tool_name", ""))
tool_output = hook_data.get("tool_response", hook_data.get("tool_output", {}))
# Only process write/edit tools
if tool_name not in ["Write", "Edit", "MultiEdit"]:
return _create_hook_response("PostToolUse")
# Extract file path from output
file_path: str = ""
if isinstance(tool_output, dict):
tool_output_dict = cast("dict[str, object]", tool_output)
file_path = str(
tool_output_dict.get("file_path", "") or tool_output_dict.get("path", ""),
)
elif isinstance(tool_output, str) and (
match := re.search(r"([/\w\-_.]+\.py)", tool_output)
):
file_path = match[1]
if not file_path or not file_path.endswith(".py"):
return _create_hook_response("PostToolUse")
if not Path(file_path).exists():
return _create_hook_response("PostToolUse")
issues: list[str] = []
# Read entire file content for full analysis
try:
with open(file_path, encoding="utf-8") as f:
file_content = f.read()
except (OSError, UnicodeDecodeError):
# If we can't read the file, skip full content analysis
file_content = ""
# Run full file quality checks on the entire content
if file_content:
any_usage_issues = _detect_any_usage(file_content)
type_ignore_issues = _detect_type_ignore_usage(file_content)
issues.extend(any_usage_issues)
issues.extend(type_ignore_issues)
# Check state changes if tracking enabled
if config.state_tracking_enabled:
delta_issues = check_state_changes(file_path)
issues.extend(delta_issues)
# Run cross-file duplicate detection if enabled
if config.cross_file_check_enabled:
cross_file_issues = check_cross_file_duplicates(file_path, config)
issues.extend(cross_file_issues)
# Verify naming conventions if enabled
if config.verify_naming:
naming_issues = verify_naming_conventions(file_path)
issues.extend(naming_issues)
# Format response
if issues:
message = (
f"📝 Post-write quality notes for {Path(file_path).name}:\n"
+ "\n".join(issues)
)
return _create_hook_response(
"PostToolUse",
"",
message,
message,
message,
decision="block",
)
if config.show_success:
message = f"{Path(file_path).name} passed post-write verification"
return _create_hook_response(
"PostToolUse",
"",
"",
message,
"",
decision="approve",
)
return _create_hook_response("PostToolUse")
def is_test_file(file_path: str) -> bool:
"""Check if file path is in a test directory."""
path_parts = Path(file_path).parts
return any(part in ("test", "tests", "testing") for part in path_parts)
def run_test_quality_checks(
content: str,
file_path: str,
config: QualityConfig,
) -> list[str]:
"""Run Sourcery's test rules and return guidance-enhanced issues."""
issues: list[str] = []
# Only run test quality checks for test files
if not is_test_file(file_path):
return issues
suffix = Path(file_path).suffix or ".py"
# Create temp file in project directory to inherit config files
tmp_dir = _get_project_tmp_dir(file_path)
with NamedTemporaryFile(
mode="w",
suffix=suffix,
delete=False,
dir=str(tmp_dir),
prefix="test_validation_",
) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
# Run Sourcery with specific test-related rules
venv_bin = _get_project_venv_bin()
sourcery_path = venv_bin / "sourcery"
if not sourcery_path.exists():
# Try to find sourcery in PATH
import shutil
sourcery_path = shutil.which("sourcery") or str(venv_bin / "sourcery")
if not sourcery_path or not Path(sourcery_path).exists():
logging.debug("Sourcery not found at %s", sourcery_path)
return issues
# Specific rules for test quality - use correct Sourcery format
test_rules = [
"no-conditionals-in-tests",
"no-loop-in-tests",
"raise-specific-error",
"dont-import-test-modules",
]
# Build command with --enable for each rule
cmd = [str(sourcery_path), "review", tmp_path]
for rule in test_rules:
cmd.extend(["--enable", rule])
cmd.append("--check") # Return exit code 1 if issues found
# Activate virtual environment for the subprocess
env = os.environ.copy()
env["VIRTUAL_ENV"] = str(venv_bin.parent)
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
# Remove any PYTHONHOME that might interfere
env.pop("PYTHONHOME", None)
logging.debug("Running Sourcery command: %s", " ".join(cmd))
result = subprocess.run( # noqa: S603
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
env=env,
)
logging.debug("Sourcery exit code: %s", result.returncode)
logging.debug("Sourcery stdout: %s", result.stdout)
logging.debug("Sourcery stderr: %s", result.stderr)
# Sourcery with --check returns:
# - Exit code 0: No issues found
# - Exit code 1: Issues found
# - Exit code 2: Error occurred
if result.returncode == 1:
# Issues were found - parse the output
output = result.stdout + result.stderr
# Try to extract rule names from the output. Sourcery usually includes
# rule identifiers in brackets or descriptive text.
for rule in test_rules:
if rule in output or rule.replace("-", " ") in output.lower():
base_guidance = generate_test_quality_guidance(
rule,
content,
file_path,
config,
)
external_context = get_external_context(
rule,
content,
file_path,
config,
)
if external_context:
base_guidance += f"\n\n{external_context}"
issues.append(base_guidance)
break # Only add one guidance message
else:
# If no specific rule found, provide general guidance
base_guidance = generate_test_quality_guidance(
"unknown",
content,
file_path,
config,
)
external_context = get_external_context(
"unknown",
content,
file_path,
config,
)
if external_context:
base_guidance += f"\n\n{external_context}"
issues.append(base_guidance)
elif result.returncode == 2:
# Error occurred
logging.debug("Sourcery error: %s", result.stderr)
# Don't block on Sourcery errors - just log them
# Exit code 0 means no issues - do nothing
except (subprocess.TimeoutExpired, OSError, json.JSONDecodeError) as e:
# If Sourcery fails, don't block the operation
logging.debug("Test quality check failed for %s: %s", file_path, e)
finally:
Path(tmp_path).unlink(missing_ok=True)
return issues
def main() -> None:
"""Main hook entry point."""
try:
# Load configuration
config = QualityConfig.from_env()
# Read hook input from stdin
try:
hook_data: JsonObject = json.load(sys.stdin)
except json.JSONDecodeError:
fallback_response: JsonObject = {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow",
},
}
sys.stdout.write(json.dumps(fallback_response))
return
# Detect hook type: tool_response=PostToolUse, tool_input=PreToolUse
response: JsonObject
if "tool_response" in hook_data or "tool_output" in hook_data:
# PostToolUse hook
response = posttooluse_hook(hook_data, config)
else:
# PreToolUse hook
response = pretooluse_hook(hook_data, config)
print(json.dumps(response)) # noqa: T201
# Handle exit codes based on hook output
hook_output_raw = response.get("hookSpecificOutput", {})
if not isinstance(hook_output_raw, dict):
return
hook_output = cast("dict[str, object]", hook_output_raw)
permission_decision = hook_output.get("permissionDecision")
if permission_decision == "deny":
# Exit code 2: Blocking error - stderr fed back to Claude
reason = str(
hook_output.get("permissionDecisionReason", "Permission denied"),
)
_exit_with_reason(reason)
elif permission_decision == "ask":
# Also use exit code 2 for ask decisions to ensure Claude sees the message
reason = str(
hook_output.get("permissionDecisionReason", "Permission request"),
)
_exit_with_reason(reason)
# Exit code 0: Success (default)
except Exception as e: # noqa: BLE001
# Unexpected error - use exit code 1 (non-blocking error)
sys.stderr.write(f"Hook error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()