- Replace threading locks with fcntl file-based locks for proper inter-process synchronization - Hooks run as separate processes, so threading locks don't work across invocations - Implement non-blocking lock acquisition to prevent hook deadlocks - Use fcntl.flock on a shared lock file in /tmp/.claude_hooks/subprocess.lock - Simplify lock usage with context manager pattern in both hooks - Ensure graceful fallback if lock can't be acquired (e.g., due to concurrent hooks) This properly fixes the API Error 400 concurrency issues by serializing subprocess operations across all hook invocations, not just within a single process.
2126 lines
70 KiB
Python
2126 lines
70 KiB
Python
"""Unified quality hook for Claude Code supporting both PreToolUse and PostToolUse.
|
|
|
|
Prevents writing duplicate, complex, or non-modernized code and verifies quality
|
|
after writes.
|
|
"""
|
|
|
|
import ast
|
|
import fcntl
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import textwrap
|
|
import tokenize
|
|
from collections.abc import Callable
|
|
from contextlib import contextmanager, suppress
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from importlib import import_module
|
|
from io import StringIO
|
|
from pathlib import Path
|
|
from tempfile import NamedTemporaryFile, gettempdir
|
|
from typing import TYPE_CHECKING, TypedDict, cast
|
|
|
|
# Import message enrichment helpers
|
|
try:
|
|
from .message_enrichment import EnhancedMessageFormatter
|
|
from .type_inference import TypeInferenceHelper
|
|
except ImportError:
|
|
# Fallback for direct execution
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
from message_enrichment import EnhancedMessageFormatter
|
|
from type_inference import TypeInferenceHelper
|
|
|
|
# Import internal duplicate detector; fall back to local path when executed directly
|
|
if TYPE_CHECKING:
|
|
from .internal_duplicate_detector import (
|
|
Duplicate,
|
|
DuplicateResults,
|
|
detect_internal_duplicates,
|
|
)
|
|
else:
|
|
try:
|
|
from .internal_duplicate_detector import (
|
|
Duplicate,
|
|
DuplicateResults,
|
|
detect_internal_duplicates,
|
|
)
|
|
except ImportError:
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
module = import_module("internal_duplicate_detector")
|
|
Duplicate = module.Duplicate
|
|
DuplicateResults = module.DuplicateResults
|
|
detect_internal_duplicates = module.detect_internal_duplicates
|
|
|
|
|
|
# File-based lock for inter-process synchronization
|
|
def _get_lock_file() -> Path:
|
|
"""Get path to lock file for subprocess serialization."""
|
|
lock_dir = Path(gettempdir()) / ".claude_hooks"
|
|
lock_dir.mkdir(exist_ok=True, mode=0o700)
|
|
return lock_dir / "subprocess.lock"
|
|
|
|
|
|
@contextmanager
|
|
def _subprocess_lock(timeout: float = 10.0):
|
|
"""Context manager for file-based subprocess locking.
|
|
|
|
Args:
|
|
timeout: Timeout in seconds for acquiring lock (not used, non-blocking).
|
|
|
|
Yields:
|
|
True if lock was acquired, False if timeout occurred.
|
|
"""
|
|
lock_file = _get_lock_file()
|
|
|
|
# Open or create lock file
|
|
with open(lock_file, "a") as f:
|
|
try:
|
|
# Try to acquire exclusive lock (non-blocking)
|
|
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
|
yield True
|
|
except (IOError, OSError):
|
|
# Lock is held by another process, skip to avoid blocking
|
|
yield False
|
|
finally:
|
|
try:
|
|
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
|
except (IOError, OSError):
|
|
pass
|
|
|
|
|
|
SUSPICIOUS_SUFFIXES: tuple[str, ...] = (
|
|
"enhanced",
|
|
"improved",
|
|
"better",
|
|
"new",
|
|
"updated",
|
|
"modified",
|
|
"refactored",
|
|
"optimized",
|
|
"fixed",
|
|
"clean",
|
|
"simple",
|
|
"advanced",
|
|
"basic",
|
|
"complete",
|
|
"final",
|
|
"latest",
|
|
"current",
|
|
"temp",
|
|
"temporary",
|
|
"backup",
|
|
"old",
|
|
"legacy",
|
|
"unified",
|
|
"merged",
|
|
"combined",
|
|
"integrated",
|
|
"consolidated",
|
|
"extended",
|
|
"enriched",
|
|
"augmented",
|
|
"upgraded",
|
|
"revised",
|
|
"polished",
|
|
"streamlined",
|
|
"simplified",
|
|
"modernized",
|
|
"normalized",
|
|
"sanitized",
|
|
"validated",
|
|
"verified",
|
|
"corrected",
|
|
"patched",
|
|
"stable",
|
|
"experimental",
|
|
"alpha",
|
|
"beta",
|
|
"draft",
|
|
"preliminary",
|
|
"prototype",
|
|
"working",
|
|
"test",
|
|
"debug",
|
|
"custom",
|
|
"special",
|
|
"generic",
|
|
"specific",
|
|
"general",
|
|
"detailed",
|
|
"minimal",
|
|
"full",
|
|
"partial",
|
|
"quick",
|
|
"fast",
|
|
"slow",
|
|
"smart",
|
|
"intelligent",
|
|
"auto",
|
|
"manual",
|
|
"secure",
|
|
"safe",
|
|
"robust",
|
|
"flexible",
|
|
"dynamic",
|
|
"static",
|
|
"reactive",
|
|
"async",
|
|
"sync",
|
|
"parallel",
|
|
"serial",
|
|
"distributed",
|
|
"centralized",
|
|
"decentralized",
|
|
)
|
|
|
|
FILE_SUFFIX_DUPLICATE_MSG = (
|
|
"⚠️ File '{current}' appears to be a suffixed duplicate of '{original}'. "
|
|
"Consider refactoring instead of creating variations with adjective "
|
|
"suffixes."
|
|
)
|
|
|
|
EXISTING_FILE_DUPLICATE_MSG = (
|
|
"⚠️ Creating '{current}' when '{existing}' already exists suggests "
|
|
"duplication. Consider consolidating or using a more descriptive name."
|
|
)
|
|
|
|
NAME_SUFFIX_DUPLICATE_MSG = (
|
|
"⚠️ {kind} '{name}' appears to be a suffixed duplicate of '{base}'. "
|
|
"Consider refactoring instead of creating variations."
|
|
)
|
|
|
|
|
|
def get_external_context(
|
|
rule_id: str,
|
|
content: str,
|
|
_file_path: str,
|
|
config: "QualityConfig",
|
|
) -> str:
|
|
"""Fetch additional guidance from optional integrations."""
|
|
context_parts: list[str] = []
|
|
|
|
if config.context7_enabled and config.context7_api_key:
|
|
try:
|
|
context7_context = _get_context7_analysis(
|
|
rule_id,
|
|
content,
|
|
config.context7_api_key,
|
|
)
|
|
except (OSError, RuntimeError, ValueError, TimeoutError) as exc:
|
|
logging.debug("Context7 API call failed: %s", exc)
|
|
else:
|
|
if context7_context:
|
|
context_parts.append(f"📊 Context7 Analysis: {context7_context}")
|
|
|
|
if config.firecrawl_enabled and config.firecrawl_api_key:
|
|
try:
|
|
firecrawl_examples = _get_firecrawl_examples(
|
|
rule_id,
|
|
config.firecrawl_api_key,
|
|
)
|
|
except (OSError, RuntimeError, ValueError, TimeoutError) as exc:
|
|
logging.debug("Firecrawl API call failed: %s", exc)
|
|
else:
|
|
if firecrawl_examples:
|
|
context_parts.append(f"🔗 Additional Examples: {firecrawl_examples}")
|
|
|
|
return "\n\n".join(context_parts)
|
|
|
|
|
|
def _get_context7_analysis(
|
|
rule_id: str,
|
|
_content: str,
|
|
_api_key: str,
|
|
) -> str:
|
|
"""Placeholder for Context7 API integration."""
|
|
context_map = {
|
|
"no-conditionals-in-tests": (
|
|
"Use pytest.mark.parametrize instead of branching inside tests."
|
|
),
|
|
"no-loop-in-tests": (
|
|
"Split loops into individual tests or parameterize the data for clarity."
|
|
),
|
|
"raise-specific-error": (
|
|
"Raise precise exception types (ValueError, custom errors, etc.) "
|
|
"so failures highlight the right behaviour."
|
|
),
|
|
"dont-import-test-modules": (
|
|
"Production code should not depend on test helpers; "
|
|
"extract shared logic into utility modules."
|
|
),
|
|
}
|
|
|
|
return context_map.get(
|
|
rule_id,
|
|
"Context7 guidance will appear once the integration is available.",
|
|
)
|
|
|
|
|
|
def _get_firecrawl_examples(rule_id: str, _api_key: str) -> str:
|
|
"""Placeholder for Firecrawl API integration."""
|
|
examples_map = {
|
|
"no-conditionals-in-tests": (
|
|
"Pytest parameterization tutorial: "
|
|
"docs.pytest.org/en/latest/how-to/parametrize.html"
|
|
),
|
|
"no-loop-in-tests": (
|
|
"Testing best practices: keep scenarios in separate "
|
|
"tests for precise failures."
|
|
),
|
|
"raise-specific-error": (
|
|
"Python docs on exceptions: docs.python.org/3/tutorial/errors.html"
|
|
),
|
|
"dont-import-test-modules": (
|
|
"Clean architecture tip: production modules should not import from tests."
|
|
),
|
|
}
|
|
|
|
return examples_map.get(
|
|
rule_id,
|
|
"Firecrawl examples will appear once the integration is available.",
|
|
)
|
|
|
|
|
|
def generate_test_quality_guidance(
|
|
rule_id: str,
|
|
content: str,
|
|
_file_path: str,
|
|
_config: "QualityConfig",
|
|
) -> str:
|
|
"""Return enriched guidance for test quality rule violations."""
|
|
function_name = "test_function"
|
|
match = re.search(r"def\s+(\w+)\s*\(", content)
|
|
if match:
|
|
function_name = match.group(1)
|
|
|
|
# Extract a small snippet of the violating code
|
|
code_snippet = ""
|
|
if match:
|
|
# Try to get a few lines around the function definition
|
|
lines = content.splitlines()
|
|
for i, line in enumerate(lines):
|
|
if f"def {function_name}" in line:
|
|
snippet_start = max(0, i)
|
|
snippet_end = min(len(lines), i + 10) # Show first 10 lines
|
|
code_snippet = "\n".join(lines[snippet_start:snippet_end])
|
|
break
|
|
|
|
# Use enhanced formatter for rich test quality messages
|
|
return EnhancedMessageFormatter.format_test_quality_message(
|
|
rule_id=rule_id,
|
|
function_name=function_name,
|
|
code_snippet=code_snippet,
|
|
include_examples=True,
|
|
)
|
|
|
|
|
|
class ToolConfig(TypedDict):
|
|
"""Configuration for a type checking tool."""
|
|
|
|
args: list[str]
|
|
error_check: Callable[[subprocess.CompletedProcess[str]], bool]
|
|
error_message: str | Callable[[subprocess.CompletedProcess[str]], str]
|
|
|
|
|
|
class ComplexitySummary(TypedDict):
|
|
"""Summary of complexity analysis."""
|
|
|
|
average_cyclomatic_complexity: float
|
|
|
|
|
|
class ComplexityDistribution(TypedDict):
|
|
"""Distribution of complexity levels."""
|
|
|
|
High: int
|
|
Very_High: int
|
|
Extreme: int
|
|
|
|
|
|
class ComplexityResults(TypedDict):
|
|
"""Results from complexity analysis."""
|
|
|
|
summary: ComplexitySummary
|
|
distribution: ComplexityDistribution
|
|
|
|
|
|
class TypeCheckingResults(TypedDict):
|
|
"""Results from type checking analysis."""
|
|
|
|
issues: list[str]
|
|
|
|
|
|
class AnalysisResults(TypedDict, total=False):
|
|
"""Complete analysis results from quality checks."""
|
|
|
|
internal_duplicates: DuplicateResults
|
|
complexity: ComplexityResults
|
|
type_checking: TypeCheckingResults
|
|
modernization: dict[str, object] # JSON structure varies
|
|
|
|
|
|
# Type aliases for JSON-like structures
|
|
JsonObject = dict[str, object]
|
|
JsonValue = str | int | float | bool | None | list[object] | JsonObject
|
|
|
|
|
|
@dataclass
|
|
class QualityConfig:
|
|
"""Configuration for quality checks."""
|
|
|
|
# Core settings
|
|
duplicate_threshold: float = 0.7
|
|
duplicate_enabled: bool = True
|
|
complexity_threshold: int = 10
|
|
complexity_enabled: bool = True
|
|
modernization_enabled: bool = True
|
|
require_type_hints: bool = True
|
|
enforcement_mode: str = "strict" # strict/warn/permissive
|
|
|
|
# Type checking tools
|
|
sourcery_enabled: bool = True
|
|
basedpyright_enabled: bool = True
|
|
pyrefly_enabled: bool = True
|
|
type_check_exit_code: int = 2
|
|
|
|
# PostToolUse features
|
|
state_tracking_enabled: bool = False
|
|
cross_file_check_enabled: bool = False
|
|
verify_naming: bool = True
|
|
show_success: bool = False
|
|
|
|
# Test quality checks
|
|
test_quality_enabled: bool = True
|
|
|
|
# External context providers
|
|
context7_enabled: bool = False
|
|
context7_api_key: str = ""
|
|
firecrawl_enabled: bool = False
|
|
firecrawl_api_key: str = ""
|
|
|
|
# File patterns
|
|
skip_patterns: list[str] | None = None
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.skip_patterns is None:
|
|
self.skip_patterns = ["test_", "_test.py", "/tests/", "/fixtures/"]
|
|
|
|
@classmethod
|
|
def from_env(cls) -> "QualityConfig":
|
|
"""Load config from environment variables."""
|
|
return cls(
|
|
duplicate_threshold=float(os.getenv("QUALITY_DUP_THRESHOLD", "0.7")),
|
|
duplicate_enabled=os.getenv("QUALITY_DUP_ENABLED", "true").lower()
|
|
== "true",
|
|
complexity_threshold=int(os.getenv("QUALITY_COMPLEXITY_THRESHOLD", "10")),
|
|
complexity_enabled=os.getenv("QUALITY_COMPLEXITY_ENABLED", "true").lower()
|
|
== "true",
|
|
modernization_enabled=os.getenv("QUALITY_MODERN_ENABLED", "true").lower()
|
|
== "true",
|
|
require_type_hints=os.getenv("QUALITY_REQUIRE_TYPES", "true").lower()
|
|
== "true",
|
|
enforcement_mode=os.getenv("QUALITY_ENFORCEMENT", "strict"),
|
|
state_tracking_enabled=os.getenv("QUALITY_STATE_TRACKING", "false").lower()
|
|
== "true",
|
|
cross_file_check_enabled=os.getenv(
|
|
"QUALITY_CROSS_FILE_CHECK",
|
|
"false",
|
|
).lower()
|
|
== "true",
|
|
verify_naming=os.getenv("QUALITY_VERIFY_NAMING", "true").lower() == "true",
|
|
show_success=os.getenv("QUALITY_SHOW_SUCCESS", "false").lower() == "true",
|
|
sourcery_enabled=os.getenv("QUALITY_SOURCERY_ENABLED", "true").lower()
|
|
== "true",
|
|
basedpyright_enabled=os.getenv(
|
|
"QUALITY_BASEDPYRIGHT_ENABLED",
|
|
"true",
|
|
).lower()
|
|
== "true",
|
|
pyrefly_enabled=os.getenv("QUALITY_PYREFLY_ENABLED", "true").lower()
|
|
== "true",
|
|
type_check_exit_code=int(os.getenv("QUALITY_TYPE_CHECK_EXIT_CODE", "2")),
|
|
test_quality_enabled=(
|
|
os.getenv("QUALITY_TEST_QUALITY_ENABLED", "true").lower() == "true"
|
|
),
|
|
context7_enabled=(
|
|
os.getenv("QUALITY_CONTEXT7_ENABLED", "false").lower() == "true"
|
|
),
|
|
context7_api_key=os.getenv("QUALITY_CONTEXT7_API_KEY", ""),
|
|
firecrawl_enabled=(
|
|
os.getenv("QUALITY_FIRECRAWL_ENABLED", "false").lower() == "true"
|
|
),
|
|
firecrawl_api_key=os.getenv("QUALITY_FIRECRAWL_API_KEY", ""),
|
|
)
|
|
|
|
|
|
def should_skip_file(file_path: str, config: QualityConfig) -> bool:
|
|
"""Check if file should be skipped based on patterns."""
|
|
if config.skip_patterns is None:
|
|
return False
|
|
return any(pattern in file_path for pattern in config.skip_patterns)
|
|
|
|
|
|
def _module_candidate(path: Path) -> tuple[Path, list[str]]:
|
|
"""Build a module invocation candidate for a python executable."""
|
|
return path, [str(path), "-m", "quality.cli.main"]
|
|
|
|
|
|
def _cli_candidate(path: Path) -> tuple[Path, list[str]]:
|
|
"""Build a direct CLI invocation candidate."""
|
|
return path, [str(path)]
|
|
|
|
|
|
def get_claude_quality_command(repo_root: Path | None = None) -> list[str]:
|
|
"""Return a path-resilient command for invoking claude-quality."""
|
|
repo_root = repo_root or Path(__file__).resolve().parent.parent
|
|
platform_name = sys.platform
|
|
is_windows = platform_name.startswith("win")
|
|
|
|
scripts_dir = repo_root / ".venv" / ("Scripts" if is_windows else "bin")
|
|
python_names = (
|
|
["python.exe", "python3.exe"] if is_windows else ["python", "python3"]
|
|
)
|
|
cli_names = (
|
|
["claude-quality.exe", "claude-quality"] if is_windows else ["claude-quality"]
|
|
)
|
|
|
|
candidates: list[tuple[Path, list[str]]] = []
|
|
for name in python_names:
|
|
candidates.append(_module_candidate(scripts_dir / name))
|
|
for name in cli_names:
|
|
candidates.append(_cli_candidate(scripts_dir / name))
|
|
|
|
for candidate_path, command in candidates:
|
|
if candidate_path.exists():
|
|
return command
|
|
|
|
interpreter_fallbacks = ["python"] if is_windows else ["python3", "python"]
|
|
for interpreter in interpreter_fallbacks:
|
|
if shutil.which(interpreter):
|
|
return [interpreter, "-m", "quality.cli.main"]
|
|
|
|
if shutil.which("claude-quality"):
|
|
return ["claude-quality"]
|
|
|
|
message = (
|
|
"'claude-quality' was not found on PATH. Please ensure it is installed and "
|
|
"available."
|
|
)
|
|
raise RuntimeError(message)
|
|
|
|
|
|
def _get_project_venv_bin(file_path: str | None = None) -> Path:
|
|
"""Get the virtual environment bin directory for the current project.
|
|
|
|
Args:
|
|
file_path: Optional file path to determine project root from.
|
|
If not provided, uses current working directory.
|
|
"""
|
|
# Start from the file's directory if provided, otherwise from cwd
|
|
temp_root = Path(gettempdir()).resolve()
|
|
if file_path:
|
|
resolved_path = Path(file_path).resolve()
|
|
try:
|
|
if resolved_path.is_relative_to(temp_root):
|
|
start_path = Path.cwd()
|
|
else:
|
|
start_path = resolved_path.parent
|
|
except ValueError:
|
|
start_path = resolved_path.parent
|
|
else:
|
|
start_path = Path.cwd()
|
|
|
|
current = start_path
|
|
|
|
while current != current.parent:
|
|
venv_candidate = current / ".venv"
|
|
if venv_candidate.exists() and venv_candidate.is_dir():
|
|
bin_dir = venv_candidate / "bin"
|
|
if bin_dir.exists():
|
|
return bin_dir
|
|
current = current.parent
|
|
|
|
# Fallback to claude-scripts venv if no project venv found
|
|
return Path(__file__).parent.parent / ".venv" / "bin"
|
|
|
|
|
|
def _format_basedpyright_errors(json_output: str) -> str:
|
|
"""Format basedpyright JSON output into readable error messages."""
|
|
try:
|
|
data = json.loads(json_output)
|
|
diagnostics = data.get("generalDiagnostics", [])
|
|
|
|
if not diagnostics:
|
|
return "Type errors found (no details available)"
|
|
|
|
# Group by severity and format
|
|
errors: list[str] = []
|
|
for diag in diagnostics[:10]: # Limit to first 10 errors
|
|
severity = diag.get("severity", "error").upper()
|
|
message = diag.get("message", "Unknown error")
|
|
rule = diag.get("rule", "")
|
|
range_info = diag.get("range", {})
|
|
start = range_info.get("start", {})
|
|
line = start.get("line", 0) + 1 # Convert 0-indexed to 1-indexed
|
|
|
|
rule_text = f" [{rule}]" if rule else ""
|
|
errors.append(f" [{severity}] Line {line}: {message}{rule_text}")
|
|
|
|
count = len(diagnostics)
|
|
summary = f"Found {count} type error{'s' if count != 1 else ''}"
|
|
if count > 10:
|
|
summary += " (showing first 10)"
|
|
|
|
return f"{summary}:\n" + "\n".join(errors)
|
|
except (json.JSONDecodeError, KeyError, TypeError):
|
|
return "Type errors found (failed to parse details)"
|
|
|
|
|
|
def _format_pyrefly_errors(output: str) -> str:
|
|
"""Format pyrefly output into readable error messages."""
|
|
if not output or not output.strip():
|
|
return "Type errors found (no details available)"
|
|
|
|
# Pyrefly already has pretty good formatting, but let's clean it up
|
|
lines = output.strip().split("\n")
|
|
|
|
# Count ERROR lines to provide summary
|
|
error_count = sum(1 for line in lines if line.strip().startswith("ERROR"))
|
|
|
|
if error_count == 0:
|
|
return output.strip()
|
|
|
|
summary = f"Found {error_count} type error{'s' if error_count != 1 else ''}"
|
|
return f"{summary}:\n{output.strip()}"
|
|
|
|
|
|
def _format_sourcery_errors(output: str) -> str:
|
|
"""Format sourcery output into readable error messages."""
|
|
if not output or not output.strip():
|
|
return "Code quality issues found (no details available)"
|
|
|
|
# Extract issue count if present
|
|
lines = output.strip().split("\n")
|
|
|
|
# Sourcery typically outputs: "✖ X issues detected"
|
|
issue_count = 0
|
|
for line in lines:
|
|
if "issue" in line.lower() and "detected" in line.lower():
|
|
# Try to extract the number
|
|
import re
|
|
|
|
match = re.search(r"(\d+)\s+issue", line)
|
|
if match:
|
|
issue_count = int(match.group(1))
|
|
break
|
|
|
|
# Format the output, removing redundant summary lines
|
|
formatted_lines: list[str] = []
|
|
for line in lines:
|
|
# Skip the summary line as we'll add our own
|
|
if "issue" in line.lower() and "detected" in line.lower():
|
|
continue
|
|
# Skip empty lines at start/end
|
|
if line.strip():
|
|
formatted_lines.append(line)
|
|
|
|
if issue_count > 0:
|
|
plural = "issue" if issue_count == 1 else "issues"
|
|
summary = f"Found {issue_count} code quality {plural}"
|
|
return f"{summary}:\n" + "\n".join(formatted_lines)
|
|
|
|
return output.strip()
|
|
|
|
|
|
def _ensure_tool_installed(tool_name: str) -> bool:
|
|
"""Ensure a type checking tool is installed in the virtual environment."""
|
|
venv_bin = _get_project_venv_bin()
|
|
tool_path = venv_bin / tool_name
|
|
|
|
if tool_path.exists():
|
|
return True
|
|
|
|
# Try to install using uv if available
|
|
try:
|
|
result = subprocess.run( # noqa: S603
|
|
[str(venv_bin / "uv"), "pip", "install", tool_name],
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=60,
|
|
)
|
|
except (subprocess.TimeoutExpired, OSError):
|
|
return False
|
|
else:
|
|
return result.returncode == 0
|
|
|
|
|
|
def _run_type_checker(
|
|
tool_name: str,
|
|
file_path: str,
|
|
_config: QualityConfig,
|
|
*,
|
|
original_file_path: str | None = None,
|
|
) -> tuple[bool, str]:
|
|
"""Run a type checking tool and return success status and output."""
|
|
venv_bin = _get_project_venv_bin(original_file_path or file_path)
|
|
tool_path = venv_bin / tool_name
|
|
|
|
if not tool_path.exists() and not _ensure_tool_installed(tool_name):
|
|
return True, f"Warning: {tool_name} not available"
|
|
|
|
# Tool configuration mapping
|
|
tool_configs: dict[str, ToolConfig] = {
|
|
"basedpyright": ToolConfig(
|
|
args=["--outputjson", file_path],
|
|
error_check=lambda result: result.returncode == 1,
|
|
error_message=lambda result: _format_basedpyright_errors(result.stdout),
|
|
),
|
|
"pyrefly": ToolConfig(
|
|
args=["check", file_path],
|
|
error_check=lambda result: result.returncode == 1,
|
|
error_message=lambda result: _format_pyrefly_errors(result.stdout),
|
|
),
|
|
"sourcery": ToolConfig(
|
|
args=["review", file_path],
|
|
error_check=lambda result: (
|
|
"issues detected" in str(result.stdout)
|
|
and "0 issues detected" not in str(result.stdout)
|
|
),
|
|
error_message=lambda result: _format_sourcery_errors(result.stdout),
|
|
),
|
|
}
|
|
|
|
tool_config = tool_configs.get(tool_name)
|
|
if not tool_config:
|
|
return True, f"Warning: Unknown tool {tool_name}"
|
|
|
|
# Acquire file-based lock to prevent subprocess concurrency issues
|
|
with _subprocess_lock(timeout=10.0) as acquired:
|
|
if not acquired:
|
|
return True, f"Warning: {tool_name} lock timeout"
|
|
|
|
try:
|
|
cmd = [str(tool_path)] + tool_config["args"]
|
|
|
|
# Activate virtual environment for the subprocess
|
|
env = os.environ.copy()
|
|
env["VIRTUAL_ENV"] = str(venv_bin.parent)
|
|
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
|
|
# Remove any PYTHONHOME that might interfere
|
|
env.pop("PYTHONHOME", None)
|
|
|
|
# Add PYTHONPATH=src if src directory exists in project root
|
|
# This allows type checkers to resolve imports from src/
|
|
project_root = venv_bin.parent.parent # Go from .venv/bin to project root
|
|
src_dir = project_root / "src"
|
|
if src_dir.exists() and src_dir.is_dir():
|
|
existing_pythonpath = env.get("PYTHONPATH", "")
|
|
if existing_pythonpath:
|
|
env["PYTHONPATH"] = f"{src_dir}:{existing_pythonpath}"
|
|
else:
|
|
env["PYTHONPATH"] = str(src_dir)
|
|
|
|
# Run type checker from project root to pick up project configuration files
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
env=env,
|
|
cwd=str(project_root),
|
|
)
|
|
|
|
# Check for tool-specific errors
|
|
error_check = tool_config["error_check"]
|
|
if error_check(result):
|
|
error_msg = tool_config["error_message"]
|
|
message = str(error_msg(result) if callable(error_msg) else error_msg)
|
|
return False, message
|
|
|
|
# Return success or warning
|
|
if result.returncode == 0:
|
|
return True, ""
|
|
return True, f"Warning: {tool_name} error (exit {result.returncode})"
|
|
|
|
except subprocess.TimeoutExpired:
|
|
return True, f"Warning: {tool_name} timeout"
|
|
except OSError:
|
|
return True, f"Warning: {tool_name} execution error"
|
|
|
|
|
|
def _initialize_analysis() -> tuple[AnalysisResults, list[str]]:
|
|
"""Initialize analysis with empty results and claude-quality command."""
|
|
results: AnalysisResults = {}
|
|
claude_quality_cmd: list[str] = get_claude_quality_command()
|
|
return results, claude_quality_cmd
|
|
|
|
|
|
def run_type_checks(
|
|
file_path: str,
|
|
config: QualityConfig,
|
|
*,
|
|
original_file_path: str | None = None,
|
|
) -> list[str]:
|
|
"""Run all enabled type checking tools and return any issues."""
|
|
issues: list[str] = []
|
|
|
|
# Run Sourcery
|
|
if config.sourcery_enabled:
|
|
success, output = _run_type_checker(
|
|
"sourcery",
|
|
file_path,
|
|
config,
|
|
original_file_path=original_file_path,
|
|
)
|
|
if not success and output:
|
|
issues.append(f"Sourcery: {output.strip()}")
|
|
|
|
# Run BasedPyright
|
|
if config.basedpyright_enabled:
|
|
success, output = _run_type_checker(
|
|
"basedpyright",
|
|
file_path,
|
|
config,
|
|
original_file_path=original_file_path,
|
|
)
|
|
if not success and output:
|
|
issues.append(f"BasedPyright: {output.strip()}")
|
|
|
|
# Run Pyrefly
|
|
if config.pyrefly_enabled:
|
|
success, output = _run_type_checker(
|
|
"pyrefly",
|
|
file_path,
|
|
config,
|
|
original_file_path=original_file_path,
|
|
)
|
|
if not success and output:
|
|
issues.append(f"Pyrefly: {output.strip()}")
|
|
|
|
return issues
|
|
|
|
|
|
def _run_quality_analyses(
|
|
content: str,
|
|
tmp_path: str,
|
|
config: QualityConfig,
|
|
enable_type_checks: bool,
|
|
*,
|
|
original_file_path: str | None = None,
|
|
) -> AnalysisResults:
|
|
"""Run all quality analysis checks and return results."""
|
|
results, claude_quality_cmd = _initialize_analysis()
|
|
|
|
# First check for internal duplicates within the file
|
|
if config.duplicate_enabled:
|
|
internal_duplicates = detect_internal_duplicates(
|
|
content,
|
|
threshold=config.duplicate_threshold,
|
|
min_lines=4,
|
|
)
|
|
if internal_duplicates.get("duplicates"):
|
|
results["internal_duplicates"] = internal_duplicates
|
|
|
|
# Run complexity analysis
|
|
if config.complexity_enabled:
|
|
cmd = [
|
|
*claude_quality_cmd,
|
|
"complexity",
|
|
tmp_path,
|
|
"--threshold",
|
|
str(config.complexity_threshold),
|
|
"--format",
|
|
"json",
|
|
]
|
|
|
|
# Prepare virtual environment for subprocess
|
|
venv_bin = _get_project_venv_bin(original_file_path)
|
|
env = os.environ.copy()
|
|
env["VIRTUAL_ENV"] = str(venv_bin.parent)
|
|
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
|
|
env.pop("PYTHONHOME", None)
|
|
|
|
# Acquire lock to prevent subprocess concurrency issues
|
|
with _subprocess_lock(timeout=10.0) as acquired:
|
|
if acquired:
|
|
with suppress(subprocess.TimeoutExpired):
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
env=env,
|
|
)
|
|
if result.returncode == 0:
|
|
with suppress(json.JSONDecodeError):
|
|
results["complexity"] = json.loads(result.stdout)
|
|
|
|
# Run type checking if any tool is enabled
|
|
if enable_type_checks and any(
|
|
[
|
|
config.sourcery_enabled,
|
|
config.basedpyright_enabled,
|
|
config.pyrefly_enabled,
|
|
],
|
|
):
|
|
try:
|
|
if type_issues := run_type_checks(
|
|
tmp_path,
|
|
config,
|
|
original_file_path=original_file_path,
|
|
):
|
|
results["type_checking"] = {"issues": type_issues}
|
|
except Exception as e: # noqa: BLE001
|
|
logging.debug("Type checking failed: %s", e)
|
|
|
|
# Run modernization analysis
|
|
if config.modernization_enabled:
|
|
cmd = [
|
|
*claude_quality_cmd,
|
|
"modernization",
|
|
tmp_path,
|
|
"--include-type-hints" if config.require_type_hints else "",
|
|
"--format",
|
|
"json",
|
|
]
|
|
cmd = [c for c in cmd if c] # Remove empty strings
|
|
|
|
# Prepare virtual environment for subprocess
|
|
venv_bin = _get_project_venv_bin(original_file_path)
|
|
env = os.environ.copy()
|
|
env["VIRTUAL_ENV"] = str(venv_bin.parent)
|
|
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
|
|
env.pop("PYTHONHOME", None)
|
|
|
|
# Acquire lock to prevent subprocess concurrency issues
|
|
with _subprocess_lock(timeout=10.0) as acquired:
|
|
if acquired:
|
|
with suppress(subprocess.TimeoutExpired):
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
env=env,
|
|
)
|
|
if result.returncode == 0:
|
|
with suppress(json.JSONDecodeError):
|
|
results["modernization"] = json.loads(result.stdout)
|
|
|
|
return results
|
|
|
|
|
|
def _find_project_root(file_path: str) -> Path:
|
|
"""Find project root by looking for common markers."""
|
|
file_path_obj = Path(file_path).resolve()
|
|
current = file_path_obj.parent
|
|
|
|
# Look for common project markers
|
|
while current != current.parent:
|
|
if any(
|
|
(current / marker).exists()
|
|
for marker in [
|
|
".git",
|
|
"pyrightconfig.json",
|
|
"pyproject.toml",
|
|
".venv",
|
|
"setup.py",
|
|
]
|
|
):
|
|
return current
|
|
current = current.parent
|
|
|
|
# Fallback to parent directory
|
|
return file_path_obj.parent
|
|
|
|
|
|
def _get_project_tmp_dir(file_path: str) -> Path:
|
|
"""Get or create .tmp directory in project root."""
|
|
project_root = _find_project_root(file_path)
|
|
tmp_dir = project_root / ".tmp"
|
|
tmp_dir.mkdir(exist_ok=True)
|
|
|
|
# Ensure .tmp is gitignored
|
|
gitignore = project_root / ".gitignore"
|
|
if gitignore.exists():
|
|
content = gitignore.read_text()
|
|
if ".tmp/" not in content and ".tmp" not in content:
|
|
# Add .tmp/ to .gitignore
|
|
with gitignore.open("a") as f:
|
|
f.write("\n# Temporary files created by code quality hooks\n.tmp/\n")
|
|
|
|
return tmp_dir
|
|
|
|
|
|
def analyze_code_quality(
|
|
content: str,
|
|
file_path: str,
|
|
config: QualityConfig,
|
|
*,
|
|
enable_type_checks: bool = True,
|
|
) -> AnalysisResults:
|
|
"""Analyze code content using claude-quality toolkit."""
|
|
suffix = Path(file_path).suffix or ".py"
|
|
|
|
# Create temp file in project directory, not /tmp, so it inherits config files
|
|
# like pyrightconfig.json, pyproject.toml, etc.
|
|
tmp_dir = _get_project_tmp_dir(file_path)
|
|
|
|
# Create temp file in project's .tmp directory
|
|
with NamedTemporaryFile(
|
|
mode="w",
|
|
suffix=suffix,
|
|
delete=False,
|
|
dir=str(tmp_dir),
|
|
prefix="hook_validation_",
|
|
) as tmp:
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
return _run_quality_analyses(
|
|
content,
|
|
tmp_path,
|
|
config,
|
|
enable_type_checks,
|
|
original_file_path=file_path,
|
|
)
|
|
|
|
finally:
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|
|
|
|
|
def _check_internal_duplicates(
|
|
results: AnalysisResults,
|
|
source_code: str = "",
|
|
file_path: str = "",
|
|
) -> list[str]:
|
|
"""Check for internal duplicate code within the same file."""
|
|
issues: list[str] = []
|
|
if "internal_duplicates" not in results:
|
|
return issues
|
|
|
|
duplicates: list[Duplicate] = results["internal_duplicates"].get(
|
|
"duplicates",
|
|
[],
|
|
)
|
|
for dup in duplicates[:3]: # Show first 3
|
|
# Use enhanced formatter for rich duplicate messages
|
|
duplicate_type = dup.get("type", "unknown")
|
|
similarity = dup.get("similarity", 0.0)
|
|
locations_raw = dup.get("locations", [])
|
|
# Cast to list of dicts for the formatter
|
|
locations_dicts: list[dict[str, str]] = [
|
|
{
|
|
"name": str(loc.get("name", "unknown")),
|
|
"type": str(loc.get("type", "code")),
|
|
"lines": str(loc.get("lines", "?")),
|
|
}
|
|
for loc in locations_raw
|
|
]
|
|
|
|
enriched_message = EnhancedMessageFormatter.format_duplicate_message(
|
|
duplicate_type=str(duplicate_type),
|
|
similarity=float(similarity),
|
|
locations=locations_dicts,
|
|
source_code=source_code,
|
|
include_refactoring=True,
|
|
file_path=file_path,
|
|
)
|
|
issues.append(enriched_message)
|
|
return issues
|
|
|
|
|
|
def _check_complexity_issues(
|
|
results: AnalysisResults,
|
|
config: QualityConfig,
|
|
) -> list[str]:
|
|
"""Check for code complexity issues."""
|
|
issues: list[str] = []
|
|
if "complexity" not in results:
|
|
return issues
|
|
|
|
complexity_data = results["complexity"]
|
|
summary = complexity_data.get("summary", {})
|
|
avg_cc = summary.get("average_cyclomatic_complexity", 0.0)
|
|
distribution = complexity_data.get("distribution", {})
|
|
# Only block on Very High (21-50) and Extreme (50+), not High (11-20)
|
|
# Radon's "High" category (CC 11-20) is acceptable for moderately complex functions
|
|
critical_count: int = (
|
|
distribution.get("Very High", 0)
|
|
+ distribution.get("Extreme", 0)
|
|
)
|
|
high_count: int = distribution.get("High", 0)
|
|
|
|
# Block only if average is too high OR if any functions are critically complex (CC > 20)
|
|
# This allows individual functions with CC 11-20 (Radon's "High" category)
|
|
if avg_cc > config.complexity_threshold or critical_count > 0:
|
|
# Use enhanced formatter for rich complexity messages
|
|
enriched_message = EnhancedMessageFormatter.format_complexity_message(
|
|
avg_complexity=avg_cc,
|
|
threshold=config.complexity_threshold,
|
|
high_count=critical_count + high_count, # Show total for context
|
|
)
|
|
issues.append(enriched_message)
|
|
# Note: Functions with CC 11-20 are allowed through without blocking
|
|
# Users can check these with: claude-quality complexity <file>
|
|
return issues
|
|
|
|
|
|
def _check_modernization_issues(
|
|
results: AnalysisResults,
|
|
config: QualityConfig,
|
|
) -> list[str]:
|
|
"""Check for code modernization issues."""
|
|
issues: list[str] = []
|
|
if "modernization" not in results:
|
|
return issues
|
|
|
|
try:
|
|
modernization_data = results["modernization"]
|
|
files = modernization_data.get("files", {})
|
|
if not isinstance(files, dict):
|
|
return issues
|
|
except (AttributeError, TypeError):
|
|
return issues
|
|
total_issues: int = 0
|
|
issue_types: set[str] = set()
|
|
|
|
files_values = cast("dict[str, object]", files).values()
|
|
for file_issues_raw in files_values:
|
|
if not isinstance(file_issues_raw, list):
|
|
continue
|
|
|
|
# Cast to proper type after runtime check
|
|
file_issues_list = cast("list[object]", file_issues_raw)
|
|
total_issues += len(file_issues_list)
|
|
|
|
for issue_raw in file_issues_list:
|
|
if not isinstance(issue_raw, dict):
|
|
continue
|
|
|
|
# Cast to proper type after runtime check
|
|
issue_dict = cast("dict[str, object]", issue_raw)
|
|
issue_type_raw = issue_dict.get("issue_type", "unknown")
|
|
if isinstance(issue_type_raw, str):
|
|
issue_types.add(issue_type_raw)
|
|
|
|
# Only flag if there are non-type-hint issues or many type hint issues
|
|
non_type_issues: int = len(
|
|
[t for t in issue_types if "type" not in t and "typing" not in t],
|
|
)
|
|
type_issues: int = total_issues - non_type_issues
|
|
|
|
if non_type_issues > 0:
|
|
non_type_list: list[str] = [
|
|
t for t in issue_types if "type" not in t and "typing" not in t
|
|
]
|
|
issues.append(
|
|
f"Modernization needed: {non_type_issues} non-type issues "
|
|
f"({', '.join(non_type_list)})",
|
|
)
|
|
elif config.require_type_hints and type_issues > 10:
|
|
issues.append(
|
|
f"Many missing type hints: {type_issues} functions/parameters "
|
|
"lacking annotations",
|
|
)
|
|
return issues
|
|
|
|
|
|
def _check_type_checking_issues(
|
|
results: AnalysisResults,
|
|
source_code: str = "",
|
|
) -> list[str]:
|
|
"""Check for type checking issues from Sourcery, BasedPyright, and Pyrefly."""
|
|
issues: list[str] = []
|
|
if "type_checking" not in results:
|
|
return issues
|
|
|
|
with suppress(AttributeError, TypeError):
|
|
type_checking_data = results["type_checking"]
|
|
type_issues = type_checking_data.get("issues", [])
|
|
# Group by tool and format with enhanced messages
|
|
for issue_str in type_issues[:3]: # Limit to first 3 for brevity
|
|
issue_text = str(issue_str)
|
|
# Extract tool name (usually starts with "Tool:")
|
|
tool_name = "Type Checker"
|
|
if ":" in issue_text:
|
|
potential_tool = issue_text.split(":")[0].strip()
|
|
if potential_tool in ("Sourcery", "BasedPyright", "Pyrefly"):
|
|
tool_name = potential_tool
|
|
issue_text = ":".join(issue_text.split(":")[1:]).strip()
|
|
|
|
enriched_message = EnhancedMessageFormatter.format_type_error_message(
|
|
tool_name=tool_name,
|
|
error_output=issue_text,
|
|
source_code=source_code,
|
|
)
|
|
issues.append(enriched_message)
|
|
|
|
return issues
|
|
|
|
|
|
def check_code_issues(
|
|
results: AnalysisResults,
|
|
config: QualityConfig,
|
|
source_code: str = "",
|
|
file_path: str = "",
|
|
) -> tuple[bool, list[str]]:
|
|
"""Check analysis results for issues that should block the operation."""
|
|
issues: list[str] = []
|
|
|
|
issues.extend(_check_internal_duplicates(results, source_code, file_path))
|
|
issues.extend(_check_complexity_issues(results, config))
|
|
issues.extend(_check_modernization_issues(results, config))
|
|
issues.extend(_check_type_checking_issues(results, source_code))
|
|
|
|
return len(issues) > 0, issues
|
|
|
|
|
|
def store_pre_state(file_path: str, content: str) -> None:
|
|
"""Store file state before modification for later comparison."""
|
|
import tempfile
|
|
|
|
cache_dir = Path(tempfile.gettempdir()) / ".quality_state"
|
|
cache_dir.mkdir(exist_ok=True, mode=0o700)
|
|
|
|
state: dict[str, str | int] = {
|
|
"file_path": file_path,
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
"content_hash": hashlib.sha256(content.encode()).hexdigest(),
|
|
"lines": len(content.split("\n")),
|
|
"functions": content.count("def "),
|
|
"classes": content.count("class "),
|
|
}
|
|
|
|
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
|
|
cache_file: Path = cache_dir / f"{cache_key}_pre.json"
|
|
cache_file.write_text(json.dumps(state, indent=2))
|
|
|
|
|
|
def check_state_changes(file_path: str) -> list[str]:
|
|
"""Check for quality changes between pre and post states."""
|
|
import tempfile
|
|
|
|
issues: list[str] = []
|
|
cache_dir: Path = Path(tempfile.gettempdir()) / ".quality_state"
|
|
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
|
|
|
|
pre_file: Path = cache_dir / f"{cache_key}_pre.json"
|
|
if not pre_file.exists():
|
|
return issues
|
|
|
|
try:
|
|
pre_state = json.loads(pre_file.read_text())
|
|
try:
|
|
current_content = Path(file_path).read_text()
|
|
except OSError:
|
|
return issues # Can't compare if can't read file
|
|
|
|
current_lines = len(current_content.split("\n"))
|
|
current_functions = current_content.count("def ")
|
|
|
|
# Check for significant changes
|
|
if current_functions < pre_state.get("functions", 0):
|
|
issues.append(
|
|
f"⚠️ Reduced functions: {pre_state['functions']} → {current_functions}",
|
|
)
|
|
|
|
if current_lines > pre_state.get("lines", 0) * 1.5:
|
|
issues.append(
|
|
"⚠️ File size increased significantly: "
|
|
f"{pre_state['lines']} → {current_lines} lines",
|
|
)
|
|
except Exception: # noqa: BLE001
|
|
logging.debug("Could not analyze state changes for %s", file_path)
|
|
|
|
return issues
|
|
|
|
|
|
def check_cross_file_duplicates(file_path: str, config: QualityConfig) -> list[str]:
|
|
"""Check for duplicates across project files."""
|
|
issues: list[str] = []
|
|
|
|
# Get project root
|
|
project_root: Path = Path(file_path).parent
|
|
while (
|
|
project_root.parent != project_root
|
|
and not (project_root / ".git").exists()
|
|
and not (project_root / "pyproject.toml").exists()
|
|
):
|
|
project_root = project_root.parent
|
|
|
|
try:
|
|
claude_quality_cmd = get_claude_quality_command()
|
|
|
|
# Prepare virtual environment for subprocess
|
|
venv_bin = _get_project_venv_bin()
|
|
env = os.environ.copy()
|
|
env["VIRTUAL_ENV"] = str(venv_bin.parent)
|
|
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
|
|
env.pop("PYTHONHOME", None)
|
|
|
|
result = subprocess.run( # noqa: S603
|
|
[
|
|
*claude_quality_cmd,
|
|
"duplicates",
|
|
str(project_root),
|
|
"--threshold",
|
|
str(config.duplicate_threshold),
|
|
"--format",
|
|
"json",
|
|
],
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=60,
|
|
env=env,
|
|
)
|
|
if result.returncode == 0:
|
|
data = json.loads(result.stdout)
|
|
duplicates = data.get("duplicates", [])
|
|
if any(file_path in str(d) for d in duplicates):
|
|
issues.append("⚠️ Cross-file duplication detected")
|
|
except Exception: # noqa: BLE001
|
|
logging.debug("Could not check cross-file duplicates for %s", file_path)
|
|
|
|
return issues
|
|
|
|
|
|
def verify_naming_conventions(file_path: str) -> list[str]:
|
|
"""Verify PEP8 naming conventions."""
|
|
issues: list[str] = []
|
|
try:
|
|
content = Path(file_path).read_text()
|
|
except OSError:
|
|
return issues # Can't check naming if can't read file
|
|
|
|
# Check function names (should be snake_case)
|
|
if bad_funcs := re.findall(
|
|
r"def\s+([A-Z][a-zA-Z0-9_]*|[a-z]+[A-Z][a-zA-Z0-9_]*)\s*\(",
|
|
content,
|
|
):
|
|
issues.append(f"⚠️ Non-PEP8 function names: {', '.join(bad_funcs[:3])}")
|
|
|
|
# Check class names (should be PascalCase)
|
|
if bad_classes := re.findall(r"class\s+([a-z][a-z0-9_]*)\s*[\(:]", content):
|
|
issues.append(f"⚠️ Non-PEP8 class names: {', '.join(bad_classes[:3])}")
|
|
|
|
return issues
|
|
|
|
|
|
def _detect_any_usage(content: str) -> list[str]:
|
|
"""Detect forbidden typing.Any usage and suggest specific types."""
|
|
# Use type inference helper to find Any usage with context
|
|
any_usages = TypeInferenceHelper.find_any_usage_with_context(content)
|
|
|
|
# Fallback to line-by-line check for syntax errors
|
|
if not any_usages:
|
|
lines_with_any: set[int] = set()
|
|
for index, line in enumerate(content.splitlines(), start=1):
|
|
code_portion = line.split("#", 1)[0]
|
|
if re.search(r"\bAny\b", code_portion):
|
|
lines_with_any.add(index)
|
|
|
|
if lines_with_any:
|
|
sorted_lines = sorted(lines_with_any)
|
|
display_lines = ", ".join(str(num) for num in sorted_lines[:5])
|
|
if len(sorted_lines) > 5:
|
|
display_lines += ", …"
|
|
return [
|
|
f"⚠️ Forbidden typing.Any usage at line(s) {display_lines}; "
|
|
"replace with specific types",
|
|
]
|
|
|
|
if not any_usages:
|
|
return []
|
|
|
|
issues: list[str] = []
|
|
|
|
# Group by context type
|
|
by_context: dict[str, list[dict[str, str | int]]] = {}
|
|
for usage in any_usages:
|
|
context = str(usage.get("context", "unknown"))
|
|
if context not in by_context:
|
|
by_context[context] = []
|
|
by_context[context].append(usage)
|
|
|
|
# Format enriched messages for each context type
|
|
for context, usages_list in by_context.items():
|
|
lines = [str(u.get("line", "?")) for u in usages_list[:5]]
|
|
line_summary = ", ".join(lines)
|
|
if len(usages_list) > 5:
|
|
line_summary += ", ..."
|
|
|
|
# Get suggestions
|
|
suggestions: list[str] = []
|
|
for usage in usages_list[:3]: # Show first 3 suggestions
|
|
element = str(usage.get("element", ""))
|
|
suggested = str(usage.get("suggested", ""))
|
|
if suggested and suggested not in {"Any", "Infer from usage"}:
|
|
suggestions.append(f" • {element}: {suggested}")
|
|
|
|
parts: list[str] = [
|
|
f"⚠️ Forbidden typing.Any Usage ({context})",
|
|
f"📍 Lines: {line_summary}",
|
|
]
|
|
|
|
if suggestions:
|
|
parts.append("💡 Suggested Types:")
|
|
parts.extend(suggestions)
|
|
else:
|
|
parts.append("💡 Tip: Replace `Any` with specific types based on usage")
|
|
|
|
parts.extend(
|
|
[
|
|
"",
|
|
"🔗 Common Replacements:",
|
|
" • dict[str, Any] → dict[str, int] (or appropriate value type)",
|
|
" • list[Any] → list[str] (or appropriate element type)",
|
|
(
|
|
" • Callable[..., Any] → Callable[[int, str], bool] "
|
|
"(with specific signature)"
|
|
),
|
|
]
|
|
)
|
|
|
|
issues.append("\n".join(parts))
|
|
|
|
return issues
|
|
|
|
|
|
def _detect_type_ignore_usage(content: str) -> list[str]:
|
|
"""Detect forbidden # type: ignore usage in proposed content."""
|
|
pattern = re.compile(r"#\s*type:\s*ignore(?:\b|\[)", re.IGNORECASE)
|
|
lines_with_type_ignore: set[int] = set()
|
|
|
|
try:
|
|
# Dedent the content to handle code fragments with leading indentation
|
|
dedented_content = textwrap.dedent(content)
|
|
for token_type, token_string, start, _, _ in tokenize.generate_tokens(
|
|
StringIO(dedented_content).readline,
|
|
):
|
|
if token_type == tokenize.COMMENT and pattern.search(token_string):
|
|
lines_with_type_ignore.add(start[0])
|
|
except (tokenize.TokenError, IndentationError):
|
|
for index, line in enumerate(content.splitlines(), start=1):
|
|
if pattern.search(line):
|
|
lines_with_type_ignore.add(index)
|
|
|
|
if not lines_with_type_ignore:
|
|
return []
|
|
|
|
sorted_lines = sorted(lines_with_type_ignore)
|
|
display_lines = ", ".join(str(num) for num in sorted_lines[:5])
|
|
if len(sorted_lines) > 5:
|
|
display_lines += ", …"
|
|
|
|
return [
|
|
"⚠️ Forbidden # type: ignore usage at line(s) "
|
|
f"{display_lines}; remove the suppression and fix typing issues instead",
|
|
]
|
|
|
|
|
|
def _detect_old_typing_patterns(content: str) -> list[str]:
|
|
"""Detect old typing patterns that should use modern syntax."""
|
|
# Old typing imports that should be replaced
|
|
old_patterns = {
|
|
r"\bfrom typing import.*\bUnion\b": (
|
|
"Use | syntax instead of Union (e.g., str | int)"
|
|
),
|
|
r"\bfrom typing import.*\bOptional\b": (
|
|
"Use | None syntax instead of Optional (e.g., str | None)"
|
|
),
|
|
r"\bfrom typing import.*\bList\b": "Use list[T] instead of List[T]",
|
|
r"\bfrom typing import.*\bDict\b": "Use dict[K, V] instead of Dict[K, V]",
|
|
r"\bfrom typing import.*\bSet\b": "Use set[T] instead of Set[T]",
|
|
r"\bfrom typing import.*\bTuple\b": (
|
|
"Use tuple[T, ...] instead of Tuple[T, ...]"
|
|
),
|
|
r"\bUnion\s*\[": "Use | syntax instead of Union (e.g., str | int)",
|
|
r"\bOptional\s*\[": "Use | None syntax instead of Optional (e.g., str | None)",
|
|
r"\bList\s*\[": "Use list[T] instead of List[T]",
|
|
r"\bDict\s*\[": "Use dict[K, V] instead of Dict[K, V]",
|
|
r"\bSet\s*\[": "Use set[T] instead of Set[T]",
|
|
r"\bTuple\s*\[": "Use tuple[T, ...] instead of Tuple[T, ...]",
|
|
}
|
|
|
|
lines = content.splitlines()
|
|
found_issues = []
|
|
|
|
for pattern, message in old_patterns.items():
|
|
lines_with_pattern = []
|
|
for i, line in enumerate(lines, 1):
|
|
# Skip comments
|
|
code_part = line.split("#")[0]
|
|
if re.search(pattern, code_part):
|
|
lines_with_pattern.append(i)
|
|
|
|
if lines_with_pattern:
|
|
display_lines: str = ", ".join(str(num) for num in lines_with_pattern[:5])
|
|
if len(lines_with_pattern) > 5:
|
|
display_lines += ", …"
|
|
issue_text: str = f"⚠️ Old typing pattern at line(s) {display_lines}: {message}"
|
|
found_issues.append(issue_text)
|
|
|
|
return found_issues
|
|
|
|
|
|
def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
|
|
"""Detect files and functions/classes with suspicious adjective/adverb suffixes."""
|
|
issues: list[str] = []
|
|
|
|
# Check file name against other files in the same directory
|
|
file_path_obj = Path(file_path)
|
|
if file_path_obj.parent.exists():
|
|
file_stem = file_path_obj.stem
|
|
file_suffix = file_path_obj.suffix
|
|
|
|
# Check if current file has suspicious suffix
|
|
for suffix in SUSPICIOUS_SUFFIXES:
|
|
for separator in ("_", "-"):
|
|
suffix_token = f"{separator}{suffix}"
|
|
if not file_stem.endswith(suffix_token):
|
|
continue
|
|
|
|
base_name = file_stem[: -len(suffix_token)]
|
|
potential_original = file_path_obj.parent / f"{base_name}{file_suffix}"
|
|
if potential_original.exists() and potential_original != file_path_obj:
|
|
message = FILE_SUFFIX_DUPLICATE_MSG.format(
|
|
current=file_path_obj.name,
|
|
original=potential_original.name,
|
|
)
|
|
issues.append(message)
|
|
break
|
|
else:
|
|
continue
|
|
break
|
|
|
|
# Check if any existing files are suffixed versions of current file
|
|
for existing_file in file_path_obj.parent.glob(f"{file_stem}_*{file_suffix}"):
|
|
if existing_file != file_path_obj:
|
|
existing_stem = existing_file.stem
|
|
if existing_stem.startswith(f"{file_stem}_"):
|
|
potential_suffix = existing_stem[len(file_stem) + 1 :]
|
|
if potential_suffix in SUSPICIOUS_SUFFIXES:
|
|
message = EXISTING_FILE_DUPLICATE_MSG.format(
|
|
current=file_path_obj.name,
|
|
existing=existing_file.name,
|
|
)
|
|
issues.append(message)
|
|
break
|
|
|
|
# Same check for dash-separated suffixes
|
|
for existing_file in file_path_obj.parent.glob(f"{file_stem}-*{file_suffix}"):
|
|
if existing_file != file_path_obj:
|
|
existing_stem = existing_file.stem
|
|
if existing_stem.startswith(f"{file_stem}-"):
|
|
potential_suffix = existing_stem[len(file_stem) + 1 :]
|
|
if potential_suffix in SUSPICIOUS_SUFFIXES:
|
|
message = EXISTING_FILE_DUPLICATE_MSG.format(
|
|
current=file_path_obj.name,
|
|
existing=existing_file.name,
|
|
)
|
|
issues.append(message)
|
|
break
|
|
|
|
# Check function and class names in content
|
|
try:
|
|
# Dedent the content to handle code fragments with leading indentation
|
|
tree = ast.parse(textwrap.dedent(content))
|
|
|
|
class SuffixVisitor(ast.NodeVisitor):
|
|
def __init__(self):
|
|
self.function_names: set[str] = set()
|
|
self.class_names: set[str] = set()
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
|
self.function_names.add(node.name)
|
|
self.generic_visit(node)
|
|
|
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
|
self.function_names.add(node.name)
|
|
self.generic_visit(node)
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
|
self.class_names.add(node.name)
|
|
self.generic_visit(node)
|
|
|
|
visitor = SuffixVisitor()
|
|
visitor.visit(tree)
|
|
|
|
# Check for suspicious function name patterns
|
|
for func_name in visitor.function_names:
|
|
for suffix in SUSPICIOUS_SUFFIXES:
|
|
suffix_token = f"_{suffix}"
|
|
if not func_name.endswith(suffix_token):
|
|
continue
|
|
|
|
base_name = func_name[: -len(suffix_token)]
|
|
if base_name in visitor.function_names:
|
|
message = NAME_SUFFIX_DUPLICATE_MSG.format(
|
|
kind="Function",
|
|
name=func_name,
|
|
base=base_name,
|
|
)
|
|
issues.append(message)
|
|
break
|
|
|
|
# Check for suspicious class name patterns
|
|
for class_name in visitor.class_names:
|
|
for suffix in SUSPICIOUS_SUFFIXES:
|
|
pascal_suffix = suffix.capitalize()
|
|
snake_suffix = f"_{suffix}"
|
|
|
|
potential_matches = (
|
|
(pascal_suffix, class_name[: -len(pascal_suffix)]),
|
|
(snake_suffix, class_name[: -len(snake_suffix)]),
|
|
)
|
|
|
|
for token, base_name in potential_matches:
|
|
if (
|
|
token
|
|
and class_name.endswith(token)
|
|
and base_name in visitor.class_names
|
|
):
|
|
message = NAME_SUFFIX_DUPLICATE_MSG.format(
|
|
kind="Class",
|
|
name=class_name,
|
|
base=base_name,
|
|
)
|
|
issues.append(message)
|
|
break
|
|
else:
|
|
continue
|
|
break
|
|
|
|
except SyntaxError:
|
|
# If we can't parse the AST, skip function/class checks
|
|
pass
|
|
|
|
return issues
|
|
|
|
|
|
def _perform_quality_check(
|
|
file_path: str,
|
|
content: str,
|
|
config: QualityConfig,
|
|
enable_type_checks: bool = True,
|
|
) -> tuple[bool, list[str]]:
|
|
"""Perform quality analysis and return issues."""
|
|
# Store state if tracking enabled
|
|
if config.state_tracking_enabled:
|
|
store_pre_state(file_path, content)
|
|
|
|
# Run quality analysis
|
|
results = analyze_code_quality(
|
|
content,
|
|
file_path,
|
|
config,
|
|
enable_type_checks=enable_type_checks,
|
|
)
|
|
return check_code_issues(results, config, content, file_path)
|
|
|
|
|
|
def _handle_quality_issues(
|
|
file_path: str,
|
|
issues: list[str],
|
|
config: QualityConfig,
|
|
*,
|
|
forced_permission: str | None = None,
|
|
) -> JsonObject:
|
|
"""Handle quality issues based on enforcement mode."""
|
|
# Prepare denial message with formatted issues
|
|
formatted_issues: list[str] = []
|
|
for issue in issues:
|
|
# Add indentation to multi-line issues for better readability
|
|
if "\n" in issue:
|
|
lines = issue.split("\n")
|
|
formatted_issues.append(f"• {lines[0]}")
|
|
for line in lines[1:]:
|
|
formatted_issues.append(f" {line}")
|
|
else:
|
|
formatted_issues.append(f"• {issue}")
|
|
|
|
message = (
|
|
f"Code quality check failed for {Path(file_path).name}:\n"
|
|
+ "\n".join(formatted_issues)
|
|
+ "\n\nFix these issues before writing the code."
|
|
)
|
|
|
|
# Make decision based on enforcement mode
|
|
if forced_permission:
|
|
return _create_hook_response("PreToolUse", forced_permission, message)
|
|
|
|
if config.enforcement_mode == "strict":
|
|
return _create_hook_response("PreToolUse", "deny", message)
|
|
if config.enforcement_mode == "warn":
|
|
return _create_hook_response("PreToolUse", "ask", message)
|
|
# permissive
|
|
warning_message = f"⚠️ Quality Warning:\n{message}"
|
|
return _create_hook_response(
|
|
"PreToolUse",
|
|
"allow",
|
|
warning_message,
|
|
warning_message,
|
|
)
|
|
|
|
|
|
def _exit_with_reason(reason: str, exit_code: int = 2) -> None:
|
|
"""Write reason to stderr and exit with specified code."""
|
|
sys.stderr.write(reason)
|
|
sys.exit(exit_code)
|
|
|
|
|
|
def _create_hook_response(
|
|
event_name: str,
|
|
permission: str = "",
|
|
reason: str = "",
|
|
system_message: str = "",
|
|
additional_context: str = "",
|
|
*,
|
|
decision: str | None = None,
|
|
) -> JsonObject:
|
|
"""Create standardized hook response."""
|
|
hook_output: dict[str, object] = {
|
|
"hookEventName": event_name,
|
|
}
|
|
|
|
if permission:
|
|
hook_output["permissionDecision"] = permission
|
|
if reason:
|
|
hook_output["permissionDecisionReason"] = reason
|
|
|
|
if additional_context:
|
|
hook_output["additionalContext"] = additional_context
|
|
|
|
response: JsonObject = {
|
|
"hookSpecificOutput": hook_output,
|
|
}
|
|
|
|
if permission:
|
|
response["permissionDecision"] = permission
|
|
|
|
if decision:
|
|
response["decision"] = decision
|
|
|
|
if reason:
|
|
response["reason"] = reason
|
|
|
|
if system_message:
|
|
response["systemMessage"] = system_message
|
|
|
|
return response
|
|
|
|
|
|
def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject:
|
|
"""Handle PreToolUse hook - analyze content before write/edit."""
|
|
tool_name = str(hook_data.get("tool_name", ""))
|
|
tool_input_raw = hook_data.get("tool_input", {})
|
|
if not isinstance(tool_input_raw, dict):
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
tool_input = cast("dict[str, object]", tool_input_raw)
|
|
|
|
# Only analyze for write/edit tools
|
|
if tool_name not in ["Write", "Edit", "MultiEdit"]:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
# Extract content based on tool type
|
|
file_path = str(tool_input.get("file_path", ""))
|
|
content = ""
|
|
|
|
if tool_name == "Write":
|
|
raw_content = tool_input.get("content", "")
|
|
content = "" if raw_content is None else str(raw_content)
|
|
elif tool_name == "Edit":
|
|
new_string = tool_input.get("new_string", "")
|
|
content = "" if new_string is None else str(new_string)
|
|
elif tool_name == "MultiEdit":
|
|
edits = tool_input.get("edits", [])
|
|
if isinstance(edits, list):
|
|
edits_list = cast("list[object]", edits)
|
|
parts: list[str] = []
|
|
for edit in edits_list:
|
|
if not isinstance(edit, dict):
|
|
continue
|
|
edit_dict = cast("dict[str, object]", edit)
|
|
new_str = edit_dict.get("new_string")
|
|
parts.append("") if new_str is None else parts.append(str(new_str))
|
|
content = "\n".join(parts)
|
|
|
|
# Only analyze Python files
|
|
if not file_path or not file_path.endswith(".py") or not content:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
# Check if this is a test file and test quality checks are enabled
|
|
is_test = is_test_file(file_path)
|
|
run_test_checks = config.test_quality_enabled and is_test
|
|
|
|
enable_type_checks = tool_name == "Write"
|
|
|
|
# Always run core checks (Any, type: ignore, typing, duplicates) before skipping
|
|
any_usage_issues = _detect_any_usage(content)
|
|
type_ignore_issues = _detect_type_ignore_usage(content)
|
|
old_typing_issues = _detect_old_typing_patterns(content)
|
|
suffix_duplication_issues = _detect_suffix_duplication(file_path, content)
|
|
precheck_issues = (
|
|
any_usage_issues
|
|
+ type_ignore_issues
|
|
+ old_typing_issues
|
|
+ suffix_duplication_issues
|
|
)
|
|
|
|
# Run test quality checks if enabled and file is a test file
|
|
if run_test_checks:
|
|
test_quality_issues = run_test_quality_checks(content, file_path, config)
|
|
precheck_issues.extend(test_quality_issues)
|
|
|
|
# Skip detailed analysis for configured patterns unless test checks should run
|
|
# Note: Core quality checks (Any, type: ignore, duplicates) always run above
|
|
should_skip_detailed = should_skip_file(file_path, config) and not run_test_checks
|
|
|
|
try:
|
|
# Run detailed quality checks only if not skipping
|
|
if should_skip_detailed:
|
|
issues = []
|
|
else:
|
|
_has_issues, issues = _perform_quality_check(
|
|
file_path,
|
|
content,
|
|
config,
|
|
enable_type_checks=enable_type_checks,
|
|
)
|
|
|
|
all_issues = precheck_issues + issues
|
|
|
|
if not all_issues:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
if precheck_issues:
|
|
return _handle_quality_issues(
|
|
file_path,
|
|
all_issues,
|
|
config,
|
|
forced_permission="deny",
|
|
)
|
|
|
|
return _handle_quality_issues(file_path, all_issues, config)
|
|
except Exception as e: # noqa: BLE001
|
|
if precheck_issues:
|
|
return _handle_quality_issues(
|
|
file_path,
|
|
precheck_issues,
|
|
config,
|
|
forced_permission="deny",
|
|
)
|
|
return _create_hook_response(
|
|
"PreToolUse",
|
|
"allow",
|
|
f"Warning: Code quality check failed with error: {e}",
|
|
f"Warning: Code quality check failed with error: {e}",
|
|
)
|
|
|
|
|
|
def posttooluse_hook(
|
|
hook_data: JsonObject,
|
|
config: QualityConfig,
|
|
) -> JsonObject:
|
|
"""Handle PostToolUse hook - verify quality after write/edit."""
|
|
tool_name: str = str(hook_data.get("tool_name", ""))
|
|
tool_output = hook_data.get("tool_response", hook_data.get("tool_output", {}))
|
|
|
|
# Only process write/edit tools
|
|
if tool_name not in ["Write", "Edit", "MultiEdit"]:
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
# Extract file path from output
|
|
file_path: str = ""
|
|
if isinstance(tool_output, dict):
|
|
tool_output_dict = cast("dict[str, object]", tool_output)
|
|
file_path = str(
|
|
tool_output_dict.get("file_path", "") or tool_output_dict.get("path", ""),
|
|
)
|
|
elif isinstance(tool_output, str) and (
|
|
match := re.search(r"([/\w\-_.]+\.py)", tool_output)
|
|
):
|
|
file_path = match[1]
|
|
|
|
if not file_path or not file_path.endswith(".py"):
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
if not Path(file_path).exists():
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
issues: list[str] = []
|
|
|
|
# Read entire file content for full analysis
|
|
try:
|
|
with open(file_path, encoding="utf-8") as f:
|
|
file_content = f.read()
|
|
except (OSError, UnicodeDecodeError):
|
|
# If we can't read the file, skip full content analysis
|
|
file_content = ""
|
|
|
|
# Run full file quality checks on the entire content
|
|
if file_content:
|
|
any_usage_issues = _detect_any_usage(file_content)
|
|
type_ignore_issues = _detect_type_ignore_usage(file_content)
|
|
issues.extend(any_usage_issues)
|
|
issues.extend(type_ignore_issues)
|
|
|
|
# Check state changes if tracking enabled
|
|
if config.state_tracking_enabled:
|
|
delta_issues = check_state_changes(file_path)
|
|
issues.extend(delta_issues)
|
|
|
|
# Run cross-file duplicate detection if enabled
|
|
if config.cross_file_check_enabled:
|
|
cross_file_issues = check_cross_file_duplicates(file_path, config)
|
|
issues.extend(cross_file_issues)
|
|
|
|
# Verify naming conventions if enabled
|
|
if config.verify_naming:
|
|
naming_issues = verify_naming_conventions(file_path)
|
|
issues.extend(naming_issues)
|
|
|
|
# Format response
|
|
if issues:
|
|
message = (
|
|
f"📝 Post-write quality notes for {Path(file_path).name}:\n"
|
|
+ "\n".join(issues)
|
|
)
|
|
return _create_hook_response(
|
|
"PostToolUse",
|
|
"",
|
|
message,
|
|
message,
|
|
message,
|
|
decision="block",
|
|
)
|
|
if config.show_success:
|
|
message = f"✅ {Path(file_path).name} passed post-write verification"
|
|
return _create_hook_response(
|
|
"PostToolUse",
|
|
"",
|
|
"",
|
|
message,
|
|
"",
|
|
decision="approve",
|
|
)
|
|
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
|
|
def is_test_file(file_path: str) -> bool:
|
|
"""Check if file path is in a test directory."""
|
|
path_parts = Path(file_path).parts
|
|
return any(part in ("test", "tests", "testing") for part in path_parts)
|
|
|
|
|
|
def run_test_quality_checks(
|
|
content: str,
|
|
file_path: str,
|
|
config: QualityConfig,
|
|
) -> list[str]:
|
|
"""Run Sourcery's test rules and return guidance-enhanced issues."""
|
|
issues: list[str] = []
|
|
|
|
# Only run test quality checks for test files
|
|
if not is_test_file(file_path):
|
|
return issues
|
|
|
|
suffix = Path(file_path).suffix or ".py"
|
|
|
|
# Create temp file in project directory to inherit config files
|
|
tmp_dir = _get_project_tmp_dir(file_path)
|
|
|
|
with NamedTemporaryFile(
|
|
mode="w",
|
|
suffix=suffix,
|
|
delete=False,
|
|
dir=str(tmp_dir),
|
|
prefix="test_validation_",
|
|
) as tmp:
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
# Run Sourcery with specific test-related rules
|
|
venv_bin = _get_project_venv_bin()
|
|
sourcery_path = venv_bin / "sourcery"
|
|
|
|
if not sourcery_path.exists():
|
|
# Try to find sourcery in PATH
|
|
import shutil
|
|
|
|
sourcery_path = shutil.which("sourcery") or str(venv_bin / "sourcery")
|
|
|
|
if not sourcery_path or not Path(sourcery_path).exists():
|
|
logging.debug("Sourcery not found at %s", sourcery_path)
|
|
return issues
|
|
|
|
# Specific rules for test quality - use correct Sourcery format
|
|
test_rules = [
|
|
"no-conditionals-in-tests",
|
|
"no-loop-in-tests",
|
|
"raise-specific-error",
|
|
"dont-import-test-modules",
|
|
]
|
|
|
|
# Build command with --enable for each rule
|
|
cmd = [str(sourcery_path), "review", tmp_path]
|
|
for rule in test_rules:
|
|
cmd.extend(["--enable", rule])
|
|
cmd.append("--check") # Return exit code 1 if issues found
|
|
|
|
# Activate virtual environment for the subprocess
|
|
env = os.environ.copy()
|
|
env["VIRTUAL_ENV"] = str(venv_bin.parent)
|
|
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
|
|
# Remove any PYTHONHOME that might interfere
|
|
env.pop("PYTHONHOME", None)
|
|
|
|
logging.debug("Running Sourcery command: %s", " ".join(cmd))
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
env=env,
|
|
)
|
|
|
|
logging.debug("Sourcery exit code: %s", result.returncode)
|
|
logging.debug("Sourcery stdout: %s", result.stdout)
|
|
logging.debug("Sourcery stderr: %s", result.stderr)
|
|
|
|
# Sourcery with --check returns:
|
|
# - Exit code 0: No issues found
|
|
# - Exit code 1: Issues found
|
|
# - Exit code 2: Error occurred
|
|
if result.returncode == 1:
|
|
# Issues were found - parse the output
|
|
output = result.stdout + result.stderr
|
|
|
|
# Try to extract rule names from the output. Sourcery usually includes
|
|
# rule identifiers in brackets or descriptive text.
|
|
for rule in test_rules:
|
|
if rule in output or rule.replace("-", " ") in output.lower():
|
|
base_guidance = generate_test_quality_guidance(
|
|
rule,
|
|
content,
|
|
file_path,
|
|
config,
|
|
)
|
|
external_context = get_external_context(
|
|
rule,
|
|
content,
|
|
file_path,
|
|
config,
|
|
)
|
|
if external_context:
|
|
base_guidance += f"\n\n{external_context}"
|
|
issues.append(base_guidance)
|
|
break # Only add one guidance message
|
|
else:
|
|
# If no specific rule found, provide general guidance
|
|
base_guidance = generate_test_quality_guidance(
|
|
"unknown",
|
|
content,
|
|
file_path,
|
|
config,
|
|
)
|
|
external_context = get_external_context(
|
|
"unknown",
|
|
content,
|
|
file_path,
|
|
config,
|
|
)
|
|
if external_context:
|
|
base_guidance += f"\n\n{external_context}"
|
|
issues.append(base_guidance)
|
|
elif result.returncode == 2:
|
|
# Error occurred
|
|
logging.debug("Sourcery error: %s", result.stderr)
|
|
# Don't block on Sourcery errors - just log them
|
|
# Exit code 0 means no issues - do nothing
|
|
|
|
except (subprocess.TimeoutExpired, OSError, json.JSONDecodeError) as e:
|
|
# If Sourcery fails, don't block the operation
|
|
logging.debug("Test quality check failed for %s: %s", file_path, e)
|
|
finally:
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|
|
|
return issues
|
|
|
|
|
|
def main() -> None:
|
|
"""Main hook entry point."""
|
|
try:
|
|
# Load configuration
|
|
config = QualityConfig.from_env()
|
|
|
|
# Read hook input from stdin
|
|
try:
|
|
hook_data: JsonObject = json.load(sys.stdin)
|
|
except json.JSONDecodeError:
|
|
fallback_response: JsonObject = {
|
|
"hookSpecificOutput": {
|
|
"hookEventName": "PreToolUse",
|
|
"permissionDecision": "allow",
|
|
},
|
|
}
|
|
sys.stdout.write(json.dumps(fallback_response))
|
|
return
|
|
|
|
# Detect hook type: tool_response=PostToolUse, tool_input=PreToolUse
|
|
response: JsonObject
|
|
if "tool_response" in hook_data or "tool_output" in hook_data:
|
|
# PostToolUse hook
|
|
response = posttooluse_hook(hook_data, config)
|
|
else:
|
|
# PreToolUse hook
|
|
response = pretooluse_hook(hook_data, config)
|
|
|
|
print(json.dumps(response)) # noqa: T201
|
|
|
|
# Handle exit codes based on hook output
|
|
hook_output_raw = response.get("hookSpecificOutput", {})
|
|
if not isinstance(hook_output_raw, dict):
|
|
return
|
|
|
|
hook_output = cast("dict[str, object]", hook_output_raw)
|
|
permission_decision = hook_output.get("permissionDecision")
|
|
|
|
if permission_decision == "deny":
|
|
# Exit code 2: Blocking error - stderr fed back to Claude
|
|
reason = str(
|
|
hook_output.get("permissionDecisionReason", "Permission denied"),
|
|
)
|
|
_exit_with_reason(reason)
|
|
elif permission_decision == "ask":
|
|
# Also use exit code 2 for ask decisions to ensure Claude sees the message
|
|
reason = str(
|
|
hook_output.get("permissionDecisionReason", "Permission request"),
|
|
)
|
|
_exit_with_reason(reason)
|
|
# Exit code 0: Success (default)
|
|
|
|
except Exception as e: # noqa: BLE001
|
|
# Unexpected error - use exit code 1 (non-blocking error)
|
|
sys.stderr.write(f"Hook error: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|