- Introduced file-based locking in `bash_command_guard.py` to prevent concurrent execution issues. - Added configuration for lock timeout and polling interval in `bash_guard_constants.py`. - Implemented a new `hook_chain.py` to unify hook execution, allowing for sequential processing of guards. - Updated `claude-code-settings.json` to support the new hook chaining mechanism. - Refactored subprocess lock handling to improve reliability and prevent deadlocks. This update improves the robustness of the hook system by ensuring that bash commands are executed in a controlled manner, reducing the risk of concurrency-related errors.
2111 lines
70 KiB
Python
2111 lines
70 KiB
Python
"""Unified quality hook for Claude Code supporting both PreToolUse and PostToolUse.
|
|
|
|
Prevents writing duplicate, complex, or non-modernized code and verifies quality
|
|
after writes.
|
|
"""
|
|
|
|
import ast
|
|
import fcntl
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import textwrap
|
|
import threading
|
|
import tokenize
|
|
from collections.abc import Callable
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from contextlib import contextmanager, suppress
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from importlib import import_module
|
|
from io import StringIO
|
|
from pathlib import Path
|
|
from tempfile import NamedTemporaryFile, gettempdir
|
|
from typing import TYPE_CHECKING, TypedDict, cast
|
|
|
|
# Import message enrichment helpers
|
|
try:
|
|
from .message_enrichment import EnhancedMessageFormatter
|
|
from .type_inference import TypeInferenceHelper
|
|
except ImportError:
|
|
# Fallback for direct execution
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
from message_enrichment import EnhancedMessageFormatter
|
|
from type_inference import TypeInferenceHelper
|
|
|
|
# Import internal duplicate detector; fall back to local path when executed directly
|
|
if TYPE_CHECKING:
|
|
from .internal_duplicate_detector import (
|
|
Duplicate,
|
|
DuplicateResults,
|
|
detect_internal_duplicates,
|
|
)
|
|
else:
|
|
try:
|
|
from .internal_duplicate_detector import (
|
|
Duplicate,
|
|
DuplicateResults,
|
|
detect_internal_duplicates,
|
|
)
|
|
except ImportError:
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
module = import_module("internal_duplicate_detector")
|
|
Duplicate = module.Duplicate
|
|
DuplicateResults = module.DuplicateResults
|
|
detect_internal_duplicates = module.detect_internal_duplicates
|
|
|
|
|
|
# File-based lock for inter-process synchronization
|
|
def _get_lock_file() -> Path:
|
|
"""Get path to lock file for subprocess serialization."""
|
|
lock_dir = Path(gettempdir()) / ".claude_hooks"
|
|
lock_dir.mkdir(exist_ok=True, mode=0o700)
|
|
return lock_dir / "subprocess.lock"
|
|
|
|
|
|
@contextmanager
|
|
def _subprocess_lock(timeout: float = 10.0):
|
|
"""Context manager for file-based subprocess locking.
|
|
|
|
Args:
|
|
timeout: Timeout in seconds for acquiring lock (not used, non-blocking).
|
|
|
|
Yields:
|
|
True if lock was acquired, False if timeout occurred.
|
|
"""
|
|
lock_file = _get_lock_file()
|
|
|
|
# Open or create lock file
|
|
with open(lock_file, "a") as f:
|
|
try:
|
|
# Try to acquire exclusive lock (non-blocking)
|
|
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
|
yield True
|
|
except (IOError, OSError):
|
|
# Lock is held by another process, skip to avoid blocking
|
|
yield False
|
|
finally:
|
|
with suppress(IOError, OSError):
|
|
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
|
|
|
|
|
SUSPICIOUS_SUFFIXES: tuple[str, ...] = (
|
|
"enhanced",
|
|
"improved",
|
|
"better",
|
|
"new",
|
|
"updated",
|
|
"modified",
|
|
"refactored",
|
|
"optimized",
|
|
"fixed",
|
|
"clean",
|
|
"simple",
|
|
"advanced",
|
|
"basic",
|
|
"complete",
|
|
"final",
|
|
"latest",
|
|
"current",
|
|
"temp",
|
|
"temporary",
|
|
"backup",
|
|
"old",
|
|
"legacy",
|
|
"unified",
|
|
"merged",
|
|
"combined",
|
|
"integrated",
|
|
"consolidated",
|
|
"extended",
|
|
"enriched",
|
|
"augmented",
|
|
"upgraded",
|
|
"revised",
|
|
"polished",
|
|
"streamlined",
|
|
"simplified",
|
|
"modernized",
|
|
"normalized",
|
|
"sanitized",
|
|
"validated",
|
|
"verified",
|
|
"corrected",
|
|
"patched",
|
|
"stable",
|
|
"experimental",
|
|
"alpha",
|
|
"beta",
|
|
"draft",
|
|
"preliminary",
|
|
"prototype",
|
|
"working",
|
|
"test",
|
|
"debug",
|
|
"custom",
|
|
"special",
|
|
"generic",
|
|
"specific",
|
|
"general",
|
|
"detailed",
|
|
"minimal",
|
|
"full",
|
|
"partial",
|
|
"quick",
|
|
"fast",
|
|
"slow",
|
|
"smart",
|
|
"intelligent",
|
|
"auto",
|
|
"manual",
|
|
"secure",
|
|
"safe",
|
|
"robust",
|
|
"flexible",
|
|
"dynamic",
|
|
"static",
|
|
"reactive",
|
|
"async",
|
|
"sync",
|
|
"parallel",
|
|
"serial",
|
|
"distributed",
|
|
"centralized",
|
|
"decentralized",
|
|
)
|
|
|
|
FILE_SUFFIX_DUPLICATE_MSG = (
|
|
"⚠️ File '{current}' appears to be a suffixed duplicate of '{original}'. "
|
|
"Consider refactoring instead of creating variations with adjective "
|
|
"suffixes."
|
|
)
|
|
|
|
EXISTING_FILE_DUPLICATE_MSG = (
|
|
"⚠️ Creating '{current}' when '{existing}' already exists suggests "
|
|
"duplication. Consider consolidating or using a more descriptive name."
|
|
)
|
|
|
|
NAME_SUFFIX_DUPLICATE_MSG = (
|
|
"⚠️ {kind} '{name}' appears to be a suffixed duplicate of '{base}'. "
|
|
"Consider refactoring instead of creating variations."
|
|
)
|
|
|
|
|
|
def get_external_context(
|
|
rule_id: str,
|
|
content: str,
|
|
_file_path: str,
|
|
config: "QualityConfig",
|
|
) -> str:
|
|
"""Fetch additional guidance from optional integrations."""
|
|
context_parts: list[str] = []
|
|
|
|
if config.context7_enabled and config.context7_api_key:
|
|
try:
|
|
context7_context = _get_context7_analysis(
|
|
rule_id,
|
|
content,
|
|
config.context7_api_key,
|
|
)
|
|
except (OSError, RuntimeError, ValueError, TimeoutError) as exc:
|
|
logging.debug("Context7 API call failed: %s", exc)
|
|
else:
|
|
if context7_context:
|
|
context_parts.append(f"📊 Context7 Analysis: {context7_context}")
|
|
|
|
if config.firecrawl_enabled and config.firecrawl_api_key:
|
|
try:
|
|
firecrawl_examples = _get_firecrawl_examples(
|
|
rule_id,
|
|
config.firecrawl_api_key,
|
|
)
|
|
except (OSError, RuntimeError, ValueError, TimeoutError) as exc:
|
|
logging.debug("Firecrawl API call failed: %s", exc)
|
|
else:
|
|
if firecrawl_examples:
|
|
context_parts.append(f"🔗 Additional Examples: {firecrawl_examples}")
|
|
|
|
return "\n\n".join(context_parts)
|
|
|
|
|
|
def _get_context7_analysis(
|
|
rule_id: str,
|
|
_content: str,
|
|
_api_key: str,
|
|
) -> str:
|
|
"""Placeholder for Context7 API integration."""
|
|
context_map = {
|
|
"no-conditionals-in-tests": (
|
|
"Use pytest.mark.parametrize instead of branching inside tests."
|
|
),
|
|
"no-loop-in-tests": (
|
|
"Split loops into individual tests or parameterize the data for clarity."
|
|
),
|
|
"raise-specific-error": (
|
|
"Raise precise exception types (ValueError, custom errors, etc.) "
|
|
"so failures highlight the right behaviour."
|
|
),
|
|
"dont-import-test-modules": (
|
|
"Production code should not depend on test helpers; "
|
|
"extract shared logic into utility modules."
|
|
),
|
|
}
|
|
|
|
return context_map.get(
|
|
rule_id,
|
|
"Context7 guidance will appear once the integration is available.",
|
|
)
|
|
|
|
|
|
def _get_firecrawl_examples(rule_id: str, _api_key: str) -> str:
|
|
"""Placeholder for Firecrawl API integration."""
|
|
examples_map = {
|
|
"no-conditionals-in-tests": (
|
|
"Pytest parameterization tutorial: "
|
|
"docs.pytest.org/en/latest/how-to/parametrize.html"
|
|
),
|
|
"no-loop-in-tests": (
|
|
"Testing best practices: keep scenarios in separate "
|
|
"tests for precise failures."
|
|
),
|
|
"raise-specific-error": (
|
|
"Python docs on exceptions: docs.python.org/3/tutorial/errors.html"
|
|
),
|
|
"dont-import-test-modules": (
|
|
"Clean architecture tip: production modules should not import from tests."
|
|
),
|
|
}
|
|
|
|
return examples_map.get(
|
|
rule_id,
|
|
"Firecrawl examples will appear once the integration is available.",
|
|
)
|
|
|
|
|
|
def generate_test_quality_guidance(
|
|
rule_id: str,
|
|
content: str,
|
|
_file_path: str,
|
|
_config: "QualityConfig",
|
|
) -> str:
|
|
"""Return enriched guidance for test quality rule violations."""
|
|
match = re.search(r"def\s+(\w+)\s*\(", content)
|
|
function_name = match[1] if match else "test_function"
|
|
# Extract a small snippet of the violating code
|
|
code_snippet = ""
|
|
if match:
|
|
# Try to get a few lines around the function definition
|
|
lines = content.splitlines()
|
|
for i, line in enumerate(lines):
|
|
if f"def {function_name}" in line:
|
|
snippet_start = max(0, i)
|
|
snippet_end = min(len(lines), i + 10) # Show first 10 lines
|
|
code_snippet = "\n".join(lines[snippet_start:snippet_end])
|
|
break
|
|
|
|
# Use enhanced formatter for rich test quality messages
|
|
return EnhancedMessageFormatter.format_test_quality_message(
|
|
rule_id=rule_id,
|
|
function_name=function_name,
|
|
code_snippet=code_snippet,
|
|
include_examples=True,
|
|
)
|
|
|
|
|
|
class ToolConfig(TypedDict):
|
|
"""Configuration for a type checking tool."""
|
|
|
|
args: list[str]
|
|
error_check: Callable[[subprocess.CompletedProcess[str]], bool]
|
|
error_message: str | Callable[[subprocess.CompletedProcess[str]], str]
|
|
|
|
|
|
class ComplexitySummary(TypedDict):
|
|
"""Summary of complexity analysis."""
|
|
|
|
average_cyclomatic_complexity: float
|
|
|
|
|
|
class ComplexityDistribution(TypedDict):
|
|
"""Distribution of complexity levels."""
|
|
|
|
High: int
|
|
Very_High: int
|
|
Extreme: int
|
|
|
|
|
|
class ComplexityResults(TypedDict):
|
|
"""Results from complexity analysis."""
|
|
|
|
summary: ComplexitySummary
|
|
distribution: ComplexityDistribution
|
|
|
|
|
|
class TypeCheckingResults(TypedDict):
|
|
"""Results from type checking analysis."""
|
|
|
|
issues: list[str]
|
|
|
|
|
|
class AnalysisResults(TypedDict, total=False):
|
|
"""Complete analysis results from quality checks."""
|
|
|
|
internal_duplicates: DuplicateResults
|
|
complexity: ComplexityResults
|
|
type_checking: TypeCheckingResults
|
|
modernization: dict[str, object] # JSON structure varies
|
|
|
|
|
|
# Type aliases for JSON-like structures
|
|
JsonObject = dict[str, object]
|
|
JsonValue = str | int | float | bool | None | list[object] | JsonObject
|
|
|
|
|
|
@dataclass
|
|
class QualityConfig:
|
|
"""Configuration for quality checks."""
|
|
|
|
# Core settings
|
|
duplicate_threshold: float = 0.7
|
|
duplicate_enabled: bool = True
|
|
complexity_threshold: int = 10
|
|
complexity_enabled: bool = True
|
|
modernization_enabled: bool = True
|
|
require_type_hints: bool = True
|
|
enforcement_mode: str = "strict" # strict/warn/permissive
|
|
|
|
# Type checking tools
|
|
sourcery_enabled: bool = True
|
|
basedpyright_enabled: bool = True
|
|
pyrefly_enabled: bool = True
|
|
type_check_exit_code: int = 2
|
|
|
|
# PostToolUse features
|
|
state_tracking_enabled: bool = False
|
|
cross_file_check_enabled: bool = False
|
|
verify_naming: bool = True
|
|
show_success: bool = False
|
|
|
|
# Test quality checks
|
|
test_quality_enabled: bool = True
|
|
|
|
# External context providers
|
|
context7_enabled: bool = False
|
|
context7_api_key: str = ""
|
|
firecrawl_enabled: bool = False
|
|
firecrawl_api_key: str = ""
|
|
|
|
# File patterns
|
|
skip_patterns: list[str] | None = None
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.skip_patterns is None:
|
|
self.skip_patterns = ["test_", "_test.py", "/tests/", "/fixtures/"]
|
|
|
|
@classmethod
|
|
def from_env(cls) -> "QualityConfig":
|
|
"""Load config from environment variables."""
|
|
return cls(
|
|
duplicate_threshold=float(os.getenv("QUALITY_DUP_THRESHOLD", "0.7")),
|
|
duplicate_enabled=os.getenv("QUALITY_DUP_ENABLED", "true").lower()
|
|
== "true",
|
|
complexity_threshold=int(os.getenv("QUALITY_COMPLEXITY_THRESHOLD", "10")),
|
|
complexity_enabled=os.getenv("QUALITY_COMPLEXITY_ENABLED", "true").lower()
|
|
== "true",
|
|
modernization_enabled=os.getenv("QUALITY_MODERN_ENABLED", "true").lower()
|
|
== "true",
|
|
require_type_hints=os.getenv("QUALITY_REQUIRE_TYPES", "true").lower()
|
|
== "true",
|
|
enforcement_mode=os.getenv("QUALITY_ENFORCEMENT", "strict"),
|
|
state_tracking_enabled=os.getenv("QUALITY_STATE_TRACKING", "false").lower()
|
|
== "true",
|
|
cross_file_check_enabled=os.getenv(
|
|
"QUALITY_CROSS_FILE_CHECK",
|
|
"false",
|
|
).lower()
|
|
== "true",
|
|
verify_naming=os.getenv("QUALITY_VERIFY_NAMING", "true").lower() == "true",
|
|
show_success=os.getenv("QUALITY_SHOW_SUCCESS", "false").lower() == "true",
|
|
sourcery_enabled=os.getenv("QUALITY_SOURCERY_ENABLED", "true").lower()
|
|
== "true",
|
|
basedpyright_enabled=os.getenv(
|
|
"QUALITY_BASEDPYRIGHT_ENABLED",
|
|
"true",
|
|
).lower()
|
|
== "true",
|
|
pyrefly_enabled=os.getenv("QUALITY_PYREFLY_ENABLED", "true").lower()
|
|
== "true",
|
|
type_check_exit_code=int(os.getenv("QUALITY_TYPE_CHECK_EXIT_CODE", "2")),
|
|
test_quality_enabled=(
|
|
os.getenv("QUALITY_TEST_QUALITY_ENABLED", "true").lower() == "true"
|
|
),
|
|
context7_enabled=(
|
|
os.getenv("QUALITY_CONTEXT7_ENABLED", "false").lower() == "true"
|
|
),
|
|
context7_api_key=os.getenv("QUALITY_CONTEXT7_API_KEY", ""),
|
|
firecrawl_enabled=(
|
|
os.getenv("QUALITY_FIRECRAWL_ENABLED", "false").lower() == "true"
|
|
),
|
|
firecrawl_api_key=os.getenv("QUALITY_FIRECRAWL_API_KEY", ""),
|
|
)
|
|
|
|
|
|
def should_skip_file(file_path: str, config: QualityConfig) -> bool:
|
|
"""Check if file should be skipped based on patterns."""
|
|
if config.skip_patterns is None:
|
|
return False
|
|
return any(pattern in file_path for pattern in config.skip_patterns)
|
|
|
|
|
|
def _module_candidate(path: Path) -> tuple[Path, list[str]]:
|
|
"""Build a module invocation candidate for a python executable."""
|
|
return path, [str(path), "-m", "quality.cli.main"]
|
|
|
|
|
|
def _cli_candidate(path: Path) -> tuple[Path, list[str]]:
|
|
"""Build a direct CLI invocation candidate."""
|
|
return path, [str(path)]
|
|
|
|
|
|
def get_claude_quality_command(repo_root: Path | None = None) -> list[str]:
|
|
"""Return a path-resilient command for invoking claude-quality."""
|
|
repo_root = repo_root or Path(__file__).resolve().parent.parent
|
|
platform_name = sys.platform
|
|
is_windows = platform_name.startswith("win")
|
|
|
|
scripts_dir = repo_root / ".venv" / ("Scripts" if is_windows else "bin")
|
|
python_names = (
|
|
["python.exe", "python3.exe"] if is_windows else ["python", "python3"]
|
|
)
|
|
cli_names = (
|
|
["claude-quality.exe", "claude-quality"] if is_windows else ["claude-quality"]
|
|
)
|
|
|
|
candidates: list[tuple[Path, list[str]]] = [
|
|
_module_candidate(scripts_dir / name) for name in python_names
|
|
]
|
|
candidates.extend(_cli_candidate(scripts_dir / name) for name in cli_names)
|
|
for candidate_path, command in candidates:
|
|
if candidate_path.exists():
|
|
return command
|
|
|
|
interpreter_fallbacks = ["python"] if is_windows else ["python3", "python"]
|
|
for interpreter in interpreter_fallbacks:
|
|
if shutil.which(interpreter):
|
|
return [interpreter, "-m", "quality.cli.main"]
|
|
|
|
if shutil.which("claude-quality"):
|
|
return ["claude-quality"]
|
|
|
|
message = (
|
|
"'claude-quality' was not found on PATH. Please ensure it is installed and "
|
|
"available."
|
|
)
|
|
raise RuntimeError(message)
|
|
|
|
|
|
def _get_project_venv_bin(file_path: str | None = None) -> Path:
|
|
"""Get the virtual environment bin directory for the current project.
|
|
|
|
Args:
|
|
file_path: Optional file path to determine project root from.
|
|
If not provided, uses current working directory.
|
|
"""
|
|
# Start from the file's directory if provided, otherwise from cwd
|
|
temp_root = Path(gettempdir()).resolve()
|
|
if file_path:
|
|
resolved_path = Path(file_path).resolve()
|
|
try:
|
|
if resolved_path.is_relative_to(temp_root):
|
|
start_path = Path.cwd()
|
|
else:
|
|
start_path = resolved_path.parent
|
|
except ValueError:
|
|
start_path = resolved_path.parent
|
|
else:
|
|
start_path = Path.cwd()
|
|
|
|
current = start_path
|
|
|
|
while current != current.parent:
|
|
venv_candidate = current / ".venv"
|
|
if venv_candidate.exists() and venv_candidate.is_dir():
|
|
bin_dir = venv_candidate / "bin"
|
|
if bin_dir.exists():
|
|
return bin_dir
|
|
current = current.parent
|
|
|
|
# Fallback to claude-scripts venv if no project venv found
|
|
return Path(__file__).parent.parent / ".venv" / "bin"
|
|
|
|
|
|
def _format_basedpyright_errors(json_output: str) -> str:
|
|
"""Format basedpyright JSON output into readable error messages."""
|
|
try:
|
|
data = json.loads(json_output)
|
|
diagnostics = data.get("generalDiagnostics", [])
|
|
|
|
if not diagnostics:
|
|
return "Type errors found (no details available)"
|
|
|
|
# Group by severity and format
|
|
errors: list[str] = []
|
|
for diag in diagnostics[:10]: # Limit to first 10 errors
|
|
severity = diag.get("severity", "error").upper()
|
|
message = diag.get("message", "Unknown error")
|
|
rule = diag.get("rule", "")
|
|
range_info = diag.get("range", {})
|
|
start = range_info.get("start", {})
|
|
line = start.get("line", 0) + 1 # Convert 0-indexed to 1-indexed
|
|
|
|
rule_text = f" [{rule}]" if rule else ""
|
|
errors.append(f" [{severity}] Line {line}: {message}{rule_text}")
|
|
|
|
count = len(diagnostics)
|
|
summary = f"Found {count} type error{'s' if count != 1 else ''}"
|
|
if count > 10:
|
|
summary += " (showing first 10)"
|
|
|
|
return f"{summary}:\n" + "\n".join(errors)
|
|
except (json.JSONDecodeError, KeyError, TypeError):
|
|
return "Type errors found (failed to parse details)"
|
|
|
|
|
|
def _format_pyrefly_errors(output: str) -> str:
|
|
"""Format pyrefly output into readable error messages."""
|
|
if not output or not output.strip():
|
|
return "Type errors found (no details available)"
|
|
|
|
# Pyrefly already has pretty good formatting, but let's clean it up
|
|
lines = output.strip().split("\n")
|
|
|
|
# Count ERROR lines to provide summary
|
|
error_count = sum(bool(line.strip().startswith("ERROR"))
|
|
for line in lines)
|
|
|
|
if error_count == 0:
|
|
return output.strip()
|
|
|
|
summary = f"Found {error_count} type error{'s' if error_count != 1 else ''}"
|
|
return f"{summary}:\n{output.strip()}"
|
|
|
|
|
|
def _format_sourcery_errors(output: str) -> str:
|
|
"""Format sourcery output into readable error messages."""
|
|
if not output or not output.strip():
|
|
return "Code quality issues found (no details available)"
|
|
|
|
# Extract issue count if present
|
|
lines = output.strip().split("\n")
|
|
|
|
# Sourcery typically outputs: "✖ X issues detected"
|
|
issue_count = 0
|
|
for line in lines:
|
|
if "issue" in line.lower() and "detected" in line.lower():
|
|
# Try to extract the number
|
|
import re
|
|
|
|
if match := re.search(r"(\d+)\s+issue", line):
|
|
issue_count = int(match[1])
|
|
break
|
|
|
|
# Format the output, removing redundant summary lines
|
|
formatted_lines: list[str] = []
|
|
for line in lines:
|
|
# Skip the summary line as we'll add our own
|
|
if "issue" in line.lower() and "detected" in line.lower():
|
|
continue
|
|
# Skip empty lines at start/end
|
|
if line.strip():
|
|
formatted_lines.append(line)
|
|
|
|
if issue_count > 0:
|
|
plural = "issue" if issue_count == 1 else "issues"
|
|
summary = f"Found {issue_count} code quality {plural}"
|
|
return f"{summary}:\n" + "\n".join(formatted_lines)
|
|
|
|
return output.strip()
|
|
|
|
|
|
def _ensure_tool_installed(tool_name: str) -> bool:
|
|
"""Ensure a type checking tool is installed in the virtual environment."""
|
|
venv_bin = _get_project_venv_bin()
|
|
tool_path = venv_bin / tool_name
|
|
|
|
if tool_path.exists():
|
|
return True
|
|
|
|
# Try to install using uv if available
|
|
try:
|
|
result = subprocess.run( # noqa: S603
|
|
[str(venv_bin / "uv"), "pip", "install", tool_name],
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=60,
|
|
)
|
|
except (subprocess.TimeoutExpired, OSError):
|
|
return False
|
|
else:
|
|
return result.returncode == 0
|
|
|
|
|
|
def _run_type_checker(
|
|
tool_name: str,
|
|
file_path: str,
|
|
_config: QualityConfig,
|
|
*,
|
|
original_file_path: str | None = None,
|
|
) -> tuple[bool, str]:
|
|
"""Run a type checking tool and return success status and output."""
|
|
venv_bin = _get_project_venv_bin(original_file_path or file_path)
|
|
tool_path = venv_bin / tool_name
|
|
|
|
if not tool_path.exists() and not _ensure_tool_installed(tool_name):
|
|
return True, f"Warning: {tool_name} not available"
|
|
|
|
# Tool configuration mapping
|
|
tool_configs: dict[str, ToolConfig] = {
|
|
"basedpyright": ToolConfig(
|
|
args=["--outputjson", file_path],
|
|
error_check=lambda result: result.returncode == 1,
|
|
error_message=lambda result: _format_basedpyright_errors(result.stdout),
|
|
),
|
|
"pyrefly": ToolConfig(
|
|
args=["check", file_path],
|
|
error_check=lambda result: result.returncode == 1,
|
|
error_message=lambda result: _format_pyrefly_errors(result.stdout),
|
|
),
|
|
"sourcery": ToolConfig(
|
|
args=["review", file_path],
|
|
error_check=lambda result: (
|
|
"issues detected" in str(result.stdout)
|
|
and "0 issues detected" not in str(result.stdout)
|
|
),
|
|
error_message=lambda result: _format_sourcery_errors(result.stdout),
|
|
),
|
|
}
|
|
|
|
tool_config = tool_configs.get(tool_name)
|
|
if not tool_config:
|
|
return True, f"Warning: Unknown tool {tool_name}"
|
|
|
|
# Acquire file-based lock to prevent subprocess concurrency issues
|
|
with _subprocess_lock(timeout=10.0) as acquired:
|
|
if not acquired:
|
|
return True, f"Warning: {tool_name} lock timeout"
|
|
|
|
try:
|
|
cmd = [str(tool_path)] + tool_config["args"]
|
|
|
|
# Activate virtual environment for the subprocess
|
|
env = os.environ.copy()
|
|
env["VIRTUAL_ENV"] = str(venv_bin.parent)
|
|
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
|
|
# Remove any PYTHONHOME that might interfere
|
|
env.pop("PYTHONHOME", None)
|
|
|
|
# Add PYTHONPATH=src if src directory exists in project root
|
|
# This allows type checkers to resolve imports from src/
|
|
project_root = venv_bin.parent.parent # Go from .venv/bin to project root
|
|
src_dir = project_root / "src"
|
|
if src_dir.exists() and src_dir.is_dir():
|
|
if existing_pythonpath := env.get("PYTHONPATH", ""):
|
|
env["PYTHONPATH"] = f"{src_dir}:{existing_pythonpath}"
|
|
else:
|
|
env["PYTHONPATH"] = str(src_dir)
|
|
|
|
# Run type checker from project root to pick up project configuration files
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
env=env,
|
|
cwd=str(project_root),
|
|
)
|
|
|
|
# Check for tool-specific errors
|
|
error_check = tool_config["error_check"]
|
|
if error_check(result):
|
|
error_msg = tool_config["error_message"]
|
|
message = str(error_msg(result) if callable(error_msg) else error_msg)
|
|
return False, message
|
|
|
|
# Return success or warning
|
|
if result.returncode == 0:
|
|
return True, ""
|
|
return True, f"Warning: {tool_name} error (exit {result.returncode})"
|
|
|
|
except subprocess.TimeoutExpired:
|
|
return True, f"Warning: {tool_name} timeout"
|
|
except OSError:
|
|
return True, f"Warning: {tool_name} execution error"
|
|
|
|
|
|
def _initialize_analysis() -> tuple[AnalysisResults, list[str]]:
|
|
"""Initialize analysis with empty results and claude-quality command."""
|
|
results: AnalysisResults = {}
|
|
claude_quality_cmd: list[str] = get_claude_quality_command()
|
|
return results, claude_quality_cmd
|
|
|
|
|
|
def run_type_checks(
|
|
file_path: str,
|
|
config: QualityConfig,
|
|
*,
|
|
original_file_path: str | None = None,
|
|
) -> list[str]:
|
|
"""Run all enabled type checking tools in parallel and return any issues."""
|
|
issues: list[str] = []
|
|
tools = [
|
|
("sourcery", config.sourcery_enabled),
|
|
("basedpyright", config.basedpyright_enabled),
|
|
("pyrefly", config.pyrefly_enabled),
|
|
]
|
|
|
|
enabled = [tool for tool, enabled in tools if enabled]
|
|
if not enabled:
|
|
return issues
|
|
|
|
# Run type checkers in parallel
|
|
with ThreadPoolExecutor(max_workers=min(3, len(enabled))) as executor:
|
|
futures = {
|
|
executor.submit(
|
|
_run_type_checker,
|
|
tool,
|
|
file_path,
|
|
config,
|
|
original_file_path=original_file_path,
|
|
): tool
|
|
for tool in enabled
|
|
}
|
|
|
|
for future in as_completed(futures, timeout=45):
|
|
try:
|
|
success, output = future.result()
|
|
tool = futures[future]
|
|
if not success and output:
|
|
issues.append(f"{tool.capitalize()}: {output.strip()}")
|
|
except Exception as e: # noqa: BLE001
|
|
logging.debug("Type checker failed: %s", e)
|
|
|
|
return issues
|
|
|
|
|
|
def _run_quality_analyses(
|
|
content: str,
|
|
tmp_path: str,
|
|
config: QualityConfig,
|
|
enable_type_checks: bool,
|
|
*,
|
|
original_file_path: str | None = None,
|
|
) -> AnalysisResults:
|
|
"""Run all quality analysis checks and return results."""
|
|
results, claude_quality_cmd = _initialize_analysis()
|
|
|
|
# First check for internal duplicates within the file
|
|
if config.duplicate_enabled:
|
|
internal_duplicates = detect_internal_duplicates(
|
|
content,
|
|
threshold=config.duplicate_threshold,
|
|
min_lines=4,
|
|
)
|
|
if internal_duplicates.get("duplicates"):
|
|
results["internal_duplicates"] = internal_duplicates
|
|
|
|
# Run complexity analysis
|
|
if config.complexity_enabled:
|
|
cmd = [
|
|
*claude_quality_cmd,
|
|
"complexity",
|
|
tmp_path,
|
|
"--threshold",
|
|
str(config.complexity_threshold),
|
|
"--format",
|
|
"json",
|
|
]
|
|
|
|
# Prepare virtual environment for subprocess
|
|
venv_bin = _get_project_venv_bin(original_file_path)
|
|
env = os.environ.copy()
|
|
env["VIRTUAL_ENV"] = str(venv_bin.parent)
|
|
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
|
|
env.pop("PYTHONHOME", None)
|
|
|
|
# Acquire lock to prevent subprocess concurrency issues
|
|
with _subprocess_lock(timeout=10.0) as acquired:
|
|
if acquired:
|
|
with suppress(subprocess.TimeoutExpired):
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
env=env,
|
|
)
|
|
if result.returncode == 0:
|
|
with suppress(json.JSONDecodeError):
|
|
results["complexity"] = json.loads(result.stdout)
|
|
|
|
# Run type checking if any tool is enabled
|
|
if enable_type_checks and any(
|
|
[
|
|
config.sourcery_enabled,
|
|
config.basedpyright_enabled,
|
|
config.pyrefly_enabled,
|
|
],
|
|
):
|
|
try:
|
|
if type_issues := run_type_checks(
|
|
tmp_path,
|
|
config,
|
|
original_file_path=original_file_path,
|
|
):
|
|
results["type_checking"] = {"issues": type_issues}
|
|
except Exception as e: # noqa: BLE001
|
|
logging.debug("Type checking failed: %s", e)
|
|
|
|
# Run modernization analysis
|
|
if config.modernization_enabled:
|
|
cmd = [
|
|
*claude_quality_cmd,
|
|
"modernization",
|
|
tmp_path,
|
|
"--include-type-hints" if config.require_type_hints else "",
|
|
"--format",
|
|
"json",
|
|
]
|
|
cmd = [c for c in cmd if c] # Remove empty strings
|
|
|
|
# Prepare virtual environment for subprocess
|
|
venv_bin = _get_project_venv_bin(original_file_path)
|
|
env = os.environ.copy()
|
|
env["VIRTUAL_ENV"] = str(venv_bin.parent)
|
|
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
|
|
env.pop("PYTHONHOME", None)
|
|
|
|
# Acquire lock to prevent subprocess concurrency issues
|
|
with _subprocess_lock(timeout=10.0) as acquired:
|
|
if acquired:
|
|
with suppress(subprocess.TimeoutExpired):
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
env=env,
|
|
)
|
|
if result.returncode == 0:
|
|
with suppress(json.JSONDecodeError):
|
|
results["modernization"] = json.loads(result.stdout)
|
|
|
|
return results
|
|
|
|
|
|
def _find_project_root(file_path: str) -> Path:
|
|
"""Find project root by looking for common markers."""
|
|
file_path_obj = Path(file_path).resolve()
|
|
current = file_path_obj.parent
|
|
|
|
# Look for common project markers
|
|
while current != current.parent:
|
|
if any(
|
|
(current / marker).exists()
|
|
for marker in [
|
|
".git",
|
|
"pyrightconfig.json",
|
|
"pyproject.toml",
|
|
".venv",
|
|
"setup.py",
|
|
]
|
|
):
|
|
return current
|
|
current = current.parent
|
|
|
|
# Fallback to parent directory
|
|
return file_path_obj.parent
|
|
|
|
|
|
def _get_project_tmp_dir(file_path: str) -> Path:
|
|
"""Get or create .tmp directory in project root."""
|
|
project_root = _find_project_root(file_path)
|
|
tmp_dir = project_root / ".tmp"
|
|
tmp_dir.mkdir(exist_ok=True)
|
|
|
|
# Ensure .tmp is gitignored
|
|
gitignore = project_root / ".gitignore"
|
|
if gitignore.exists():
|
|
content = gitignore.read_text()
|
|
if ".tmp/" not in content and ".tmp" not in content:
|
|
# Add .tmp/ to .gitignore
|
|
with gitignore.open("a") as f:
|
|
f.write("\n# Temporary files created by code quality hooks\n.tmp/\n")
|
|
|
|
return tmp_dir
|
|
|
|
|
|
def analyze_code_quality(
|
|
content: str,
|
|
file_path: str,
|
|
config: QualityConfig,
|
|
*,
|
|
enable_type_checks: bool = True,
|
|
) -> AnalysisResults:
|
|
"""Analyze code content using claude-quality toolkit."""
|
|
suffix = Path(file_path).suffix or ".py"
|
|
|
|
# Create temp file in project directory, not /tmp, so it inherits config files
|
|
# like pyrightconfig.json, pyproject.toml, etc.
|
|
tmp_dir = _get_project_tmp_dir(file_path)
|
|
|
|
# Create temp file in project's .tmp directory
|
|
with NamedTemporaryFile(
|
|
mode="w",
|
|
suffix=suffix,
|
|
delete=False,
|
|
dir=str(tmp_dir),
|
|
prefix="hook_validation_",
|
|
) as tmp:
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
return _run_quality_analyses(
|
|
content,
|
|
tmp_path,
|
|
config,
|
|
enable_type_checks,
|
|
original_file_path=file_path,
|
|
)
|
|
|
|
finally:
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|
|
|
|
|
def _check_internal_duplicates(
|
|
results: AnalysisResults,
|
|
source_code: str = "",
|
|
file_path: str = "",
|
|
) -> list[str]:
|
|
"""Check for internal duplicate code within the same file."""
|
|
issues: list[str] = []
|
|
if "internal_duplicates" not in results:
|
|
return issues
|
|
|
|
duplicates: list[Duplicate] = results["internal_duplicates"].get(
|
|
"duplicates",
|
|
[],
|
|
)
|
|
for dup in duplicates[:3]: # Show first 3
|
|
# Use enhanced formatter for rich duplicate messages
|
|
duplicate_type = dup.get("type", "unknown")
|
|
similarity = dup.get("similarity", 0.0)
|
|
locations_raw = dup.get("locations", [])
|
|
# Cast to list of dicts for the formatter
|
|
locations_dicts: list[dict[str, str]] = [
|
|
{
|
|
"name": str(loc.get("name", "unknown")),
|
|
"type": str(loc.get("type", "code")),
|
|
"lines": str(loc.get("lines", "?")),
|
|
}
|
|
for loc in locations_raw
|
|
]
|
|
|
|
enriched_message = EnhancedMessageFormatter.format_duplicate_message(
|
|
duplicate_type=str(duplicate_type),
|
|
similarity=float(similarity),
|
|
locations=locations_dicts,
|
|
source_code=source_code,
|
|
include_refactoring=True,
|
|
file_path=file_path,
|
|
)
|
|
issues.append(enriched_message)
|
|
return issues
|
|
|
|
|
|
def _check_complexity_issues(
|
|
results: AnalysisResults,
|
|
config: QualityConfig,
|
|
) -> list[str]:
|
|
"""Check for code complexity issues."""
|
|
issues: list[str] = []
|
|
if "complexity" not in results:
|
|
return issues
|
|
|
|
complexity_data = results["complexity"]
|
|
summary = complexity_data.get("summary", {})
|
|
avg_cc = summary.get("average_cyclomatic_complexity", 0.0)
|
|
distribution = complexity_data.get("distribution", {})
|
|
# Only block on Very High (21-50) and Extreme (50+), not High (11-20)
|
|
# Radon's "High" category (CC 11-20) is acceptable for moderately complex functions
|
|
critical_count: int = (
|
|
distribution.get("Very High", 0)
|
|
+ distribution.get("Extreme", 0)
|
|
)
|
|
high_count: int = distribution.get("High", 0)
|
|
|
|
# Block only if average is too high OR if any functions are critically complex (CC > 20)
|
|
# This allows individual functions with CC 11-20 (Radon's "High" category)
|
|
if avg_cc > config.complexity_threshold or critical_count > 0:
|
|
# Use enhanced formatter for rich complexity messages
|
|
enriched_message = EnhancedMessageFormatter.format_complexity_message(
|
|
avg_complexity=avg_cc,
|
|
threshold=config.complexity_threshold,
|
|
high_count=critical_count + high_count, # Show total for context
|
|
)
|
|
issues.append(enriched_message)
|
|
# Note: Functions with CC 11-20 are allowed through without blocking
|
|
# Users can check these with: claude-quality complexity <file>
|
|
return issues
|
|
|
|
|
|
def _check_modernization_issues(
|
|
results: AnalysisResults,
|
|
config: QualityConfig,
|
|
) -> list[str]:
|
|
"""Check for code modernization issues."""
|
|
issues: list[str] = []
|
|
if "modernization" not in results:
|
|
return issues
|
|
|
|
try:
|
|
modernization_data = results["modernization"]
|
|
files = modernization_data.get("files", {})
|
|
if not isinstance(files, dict):
|
|
return issues
|
|
except (AttributeError, TypeError):
|
|
return issues
|
|
total_issues: int = 0
|
|
issue_types: set[str] = set()
|
|
|
|
files_values = cast("dict[str, object]", files).values()
|
|
for file_issues_raw in files_values:
|
|
if not isinstance(file_issues_raw, list):
|
|
continue
|
|
|
|
# Cast to proper type after runtime check
|
|
file_issues_list = cast("list[object]", file_issues_raw)
|
|
total_issues += len(file_issues_list)
|
|
|
|
for issue_raw in file_issues_list:
|
|
if not isinstance(issue_raw, dict):
|
|
continue
|
|
|
|
# Cast to proper type after runtime check
|
|
issue_dict = cast("dict[str, object]", issue_raw)
|
|
issue_type_raw = issue_dict.get("issue_type", "unknown")
|
|
if isinstance(issue_type_raw, str):
|
|
issue_types.add(issue_type_raw)
|
|
|
|
# Only flag if there are non-type-hint issues or many type hint issues
|
|
non_type_issues: int = len(
|
|
[t for t in issue_types if "type" not in t and "typing" not in t],
|
|
)
|
|
type_issues: int = total_issues - non_type_issues
|
|
|
|
if non_type_issues > 0:
|
|
non_type_list: list[str] = [
|
|
t for t in issue_types if "type" not in t and "typing" not in t
|
|
]
|
|
issues.append(
|
|
f"Modernization needed: {non_type_issues} non-type issues "
|
|
f"({', '.join(non_type_list)})",
|
|
)
|
|
elif config.require_type_hints and type_issues > 10:
|
|
issues.append(
|
|
f"Many missing type hints: {type_issues} functions/parameters "
|
|
"lacking annotations",
|
|
)
|
|
return issues
|
|
|
|
|
|
def _check_type_checking_issues(
|
|
results: AnalysisResults,
|
|
source_code: str = "",
|
|
) -> list[str]:
|
|
"""Check for type checking issues from Sourcery, BasedPyright, and Pyrefly."""
|
|
issues: list[str] = []
|
|
if "type_checking" not in results:
|
|
return issues
|
|
|
|
with suppress(AttributeError, TypeError):
|
|
type_checking_data = results["type_checking"]
|
|
type_issues = type_checking_data.get("issues", [])
|
|
# Group by tool and format with enhanced messages
|
|
for issue_str in type_issues[:3]: # Limit to first 3 for brevity
|
|
issue_text = str(issue_str)
|
|
# Extract tool name (usually starts with "Tool:")
|
|
tool_name = "Type Checker"
|
|
if ":" in issue_text:
|
|
potential_tool = issue_text.split(":")[0].strip()
|
|
if potential_tool in ("Sourcery", "BasedPyright", "Pyrefly"):
|
|
tool_name = potential_tool
|
|
issue_text = ":".join(issue_text.split(":")[1:]).strip()
|
|
|
|
enriched_message = EnhancedMessageFormatter.format_type_error_message(
|
|
tool_name=tool_name,
|
|
error_output=issue_text,
|
|
source_code=source_code,
|
|
)
|
|
issues.append(enriched_message)
|
|
|
|
return issues
|
|
|
|
|
|
def check_code_issues(
|
|
results: AnalysisResults,
|
|
config: QualityConfig,
|
|
source_code: str = "",
|
|
file_path: str = "",
|
|
) -> tuple[bool, list[str]]:
|
|
"""Check analysis results for issues that should block the operation."""
|
|
issues: list[str] = []
|
|
|
|
issues.extend(_check_internal_duplicates(results, source_code, file_path))
|
|
issues.extend(_check_complexity_issues(results, config))
|
|
issues.extend(_check_modernization_issues(results, config))
|
|
issues.extend(_check_type_checking_issues(results, source_code))
|
|
|
|
return len(issues) > 0, issues
|
|
|
|
|
|
def store_pre_state(file_path: str, content: str) -> None:
|
|
"""Store file state before modification for later comparison."""
|
|
import tempfile
|
|
|
|
cache_dir = Path(tempfile.gettempdir()) / ".quality_state"
|
|
cache_dir.mkdir(exist_ok=True, mode=0o700)
|
|
|
|
state: dict[str, str | int] = {
|
|
"file_path": file_path,
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
"content_hash": hashlib.sha256(content.encode()).hexdigest(),
|
|
"lines": len(content.split("\n")),
|
|
"functions": content.count("def "),
|
|
"classes": content.count("class "),
|
|
}
|
|
|
|
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
|
|
cache_file: Path = cache_dir / f"{cache_key}_pre.json"
|
|
cache_file.write_text(json.dumps(state, indent=2))
|
|
|
|
|
|
def check_state_changes(file_path: str) -> list[str]:
|
|
"""Check for quality changes between pre and post states."""
|
|
import tempfile
|
|
|
|
issues: list[str] = []
|
|
cache_dir: Path = Path(tempfile.gettempdir()) / ".quality_state"
|
|
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
|
|
|
|
pre_file: Path = cache_dir / f"{cache_key}_pre.json"
|
|
if not pre_file.exists():
|
|
return issues
|
|
|
|
try:
|
|
pre_state = json.loads(pre_file.read_text())
|
|
try:
|
|
current_content = Path(file_path).read_text()
|
|
except OSError:
|
|
return issues # Can't compare if can't read file
|
|
|
|
current_lines = len(current_content.split("\n"))
|
|
current_functions = current_content.count("def ")
|
|
|
|
# Check for significant changes
|
|
if current_functions < pre_state.get("functions", 0):
|
|
issues.append(
|
|
f"⚠️ Reduced functions: {pre_state['functions']} → {current_functions}",
|
|
)
|
|
|
|
if current_lines > pre_state.get("lines", 0) * 1.5:
|
|
issues.append(
|
|
"⚠️ File size increased significantly: "
|
|
f"{pre_state['lines']} → {current_lines} lines",
|
|
)
|
|
except Exception: # noqa: BLE001
|
|
logging.debug("Could not analyze state changes for %s", file_path)
|
|
|
|
return issues
|
|
|
|
|
|
def check_cross_file_duplicates(file_path: str, config: QualityConfig) -> list[str]:
|
|
"""Check for duplicates across project files."""
|
|
issues: list[str] = []
|
|
|
|
# Get project root
|
|
project_root: Path = Path(file_path).parent
|
|
while (
|
|
project_root.parent != project_root
|
|
and not (project_root / ".git").exists()
|
|
and not (project_root / "pyproject.toml").exists()
|
|
):
|
|
project_root = project_root.parent
|
|
|
|
try:
|
|
claude_quality_cmd = get_claude_quality_command()
|
|
|
|
# Prepare virtual environment for subprocess
|
|
venv_bin = _get_project_venv_bin()
|
|
env = os.environ.copy()
|
|
env["VIRTUAL_ENV"] = str(venv_bin.parent)
|
|
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
|
|
env.pop("PYTHONHOME", None)
|
|
|
|
result = subprocess.run( # noqa: S603
|
|
[
|
|
*claude_quality_cmd,
|
|
"duplicates",
|
|
str(project_root),
|
|
"--threshold",
|
|
str(config.duplicate_threshold),
|
|
"--format",
|
|
"json",
|
|
],
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=60,
|
|
env=env,
|
|
)
|
|
if result.returncode == 0:
|
|
data = json.loads(result.stdout)
|
|
duplicates = data.get("duplicates", [])
|
|
if any(file_path in str(d) for d in duplicates):
|
|
issues.append("⚠️ Cross-file duplication detected")
|
|
except Exception: # noqa: BLE001
|
|
logging.debug("Could not check cross-file duplicates for %s", file_path)
|
|
|
|
return issues
|
|
|
|
|
|
def verify_naming_conventions(file_path: str) -> list[str]:
|
|
"""Verify PEP8 naming conventions."""
|
|
issues: list[str] = []
|
|
try:
|
|
content = Path(file_path).read_text()
|
|
except OSError:
|
|
return issues # Can't check naming if can't read file
|
|
|
|
# Check function names (should be snake_case)
|
|
if bad_funcs := re.findall(
|
|
r"def\s+([A-Z][a-zA-Z0-9_]*|[a-z]+[A-Z][a-zA-Z0-9_]*)\s*\(",
|
|
content,
|
|
):
|
|
issues.append(f"⚠️ Non-PEP8 function names: {', '.join(bad_funcs[:3])}")
|
|
|
|
# Check class names (should be PascalCase)
|
|
if bad_classes := re.findall(r"class\s+([a-z][a-z0-9_]*)\s*[\(:]", content):
|
|
issues.append(f"⚠️ Non-PEP8 class names: {', '.join(bad_classes[:3])}")
|
|
|
|
return issues
|
|
|
|
|
|
def _detect_any_usage(content: str) -> list[str]:
|
|
"""Detect forbidden typing.Any usage and suggest specific types."""
|
|
# Use type inference helper to find Any usage with context
|
|
any_usages = TypeInferenceHelper.find_any_usage_with_context(content)
|
|
|
|
# Fallback to line-by-line check for syntax errors
|
|
if not any_usages:
|
|
lines_with_any: set[int] = set()
|
|
for index, line in enumerate(content.splitlines(), start=1):
|
|
code_portion = line.split("#", 1)[0]
|
|
if re.search(r"\bAny\b", code_portion):
|
|
lines_with_any.add(index)
|
|
|
|
if lines_with_any:
|
|
sorted_lines = sorted(lines_with_any)
|
|
display_lines = ", ".join(str(num) for num in sorted_lines[:5])
|
|
if len(sorted_lines) > 5:
|
|
display_lines += ", …"
|
|
return [
|
|
f"⚠️ Forbidden typing.Any usage at line(s) {display_lines}; "
|
|
"replace with specific types",
|
|
]
|
|
|
|
if not any_usages:
|
|
return []
|
|
|
|
issues: list[str] = []
|
|
|
|
# Group by context type
|
|
by_context: dict[str, list[dict[str, str | int]]] = {}
|
|
for usage in any_usages:
|
|
context = str(usage.get("context", "unknown"))
|
|
if context not in by_context:
|
|
by_context[context] = []
|
|
by_context[context].append(usage)
|
|
|
|
# Format enriched messages for each context type
|
|
for context, usages_list in by_context.items():
|
|
lines = [str(u.get("line", "?")) for u in usages_list[:5]]
|
|
line_summary = ", ".join(lines)
|
|
if len(usages_list) > 5:
|
|
line_summary += ", ..."
|
|
|
|
# Get suggestions
|
|
suggestions: list[str] = []
|
|
for usage in usages_list[:3]: # Show first 3 suggestions
|
|
element = str(usage.get("element", ""))
|
|
suggested = str(usage.get("suggested", ""))
|
|
if suggested and suggested not in {"Any", "Infer from usage"}:
|
|
suggestions.append(f" • {element}: {suggested}")
|
|
|
|
parts: list[str] = [
|
|
f"⚠️ Forbidden typing.Any Usage ({context})",
|
|
f"📍 Lines: {line_summary}",
|
|
]
|
|
|
|
if suggestions:
|
|
parts.append("💡 Suggested Types:")
|
|
parts.extend(suggestions)
|
|
else:
|
|
parts.append("💡 Tip: Replace `Any` with specific types based on usage")
|
|
|
|
parts.extend(
|
|
[
|
|
"",
|
|
"🔗 Common Replacements:",
|
|
" • dict[str, Any] → dict[str, int] (or appropriate value type)",
|
|
" • list[Any] → list[str] (or appropriate element type)",
|
|
(
|
|
" • Callable[..., Any] → Callable[[int, str], bool] "
|
|
"(with specific signature)"
|
|
),
|
|
]
|
|
)
|
|
|
|
issues.append("\n".join(parts))
|
|
|
|
return issues
|
|
|
|
|
|
def _detect_type_ignore_usage(content: str) -> list[str]:
|
|
"""Detect forbidden # type: ignore usage in proposed content."""
|
|
pattern = re.compile(r"#\s*type:\s*ignore(?:\b|\[)", re.IGNORECASE)
|
|
lines_with_type_ignore: set[int] = set()
|
|
|
|
try:
|
|
# Dedent the content to handle code fragments with leading indentation
|
|
dedented_content = textwrap.dedent(content)
|
|
for token_type, token_string, start, _, _ in tokenize.generate_tokens(
|
|
StringIO(dedented_content).readline,
|
|
):
|
|
if token_type == tokenize.COMMENT and pattern.search(token_string):
|
|
lines_with_type_ignore.add(start[0])
|
|
except (tokenize.TokenError, IndentationError):
|
|
for index, line in enumerate(content.splitlines(), start=1):
|
|
if pattern.search(line):
|
|
lines_with_type_ignore.add(index)
|
|
|
|
if not lines_with_type_ignore:
|
|
return []
|
|
|
|
sorted_lines = sorted(lines_with_type_ignore)
|
|
display_lines = ", ".join(str(num) for num in sorted_lines[:5])
|
|
if len(sorted_lines) > 5:
|
|
display_lines += ", …"
|
|
|
|
return [
|
|
"⚠️ Forbidden # type: ignore usage at line(s) "
|
|
f"{display_lines}; remove the suppression and fix typing issues instead",
|
|
]
|
|
|
|
|
|
def _detect_old_typing_patterns(content: str) -> list[str]:
|
|
"""Detect old typing patterns that should use modern syntax."""
|
|
# Old typing imports that should be replaced
|
|
old_patterns = {
|
|
r"\bfrom typing import.*\bUnion\b": (
|
|
"Use | syntax instead of Union (e.g., str | int)"
|
|
),
|
|
r"\bfrom typing import.*\bOptional\b": (
|
|
"Use | None syntax instead of Optional (e.g., str | None)"
|
|
),
|
|
r"\bfrom typing import.*\bList\b": "Use list[T] instead of List[T]",
|
|
r"\bfrom typing import.*\bDict\b": "Use dict[K, V] instead of Dict[K, V]",
|
|
r"\bfrom typing import.*\bSet\b": "Use set[T] instead of Set[T]",
|
|
r"\bfrom typing import.*\bTuple\b": (
|
|
"Use tuple[T, ...] instead of Tuple[T, ...]"
|
|
),
|
|
r"\bUnion\s*\[": "Use | syntax instead of Union (e.g., str | int)",
|
|
r"\bOptional\s*\[": "Use | None syntax instead of Optional (e.g., str | None)",
|
|
r"\bList\s*\[": "Use list[T] instead of List[T]",
|
|
r"\bDict\s*\[": "Use dict[K, V] instead of Dict[K, V]",
|
|
r"\bSet\s*\[": "Use set[T] instead of Set[T]",
|
|
r"\bTuple\s*\[": "Use tuple[T, ...] instead of Tuple[T, ...]",
|
|
}
|
|
|
|
lines = content.splitlines()
|
|
found_issues = []
|
|
|
|
for pattern, message in old_patterns.items():
|
|
lines_with_pattern = []
|
|
for i, line in enumerate(lines, 1):
|
|
# Skip comments
|
|
code_part = line.split("#")[0]
|
|
if re.search(pattern, code_part):
|
|
lines_with_pattern.append(i)
|
|
|
|
if lines_with_pattern:
|
|
display_lines: str = ", ".join(str(num) for num in lines_with_pattern[:5])
|
|
if len(lines_with_pattern) > 5:
|
|
display_lines += ", …"
|
|
issue_text: str = f"⚠️ Old typing pattern at line(s) {display_lines}: {message}"
|
|
found_issues.append(issue_text)
|
|
|
|
return found_issues
|
|
|
|
|
|
def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
|
|
"""Detect files and functions/classes with suspicious adjective/adverb suffixes."""
|
|
issues: list[str] = []
|
|
|
|
# Check file name against other files in the same directory
|
|
file_path_obj = Path(file_path)
|
|
if file_path_obj.parent.exists():
|
|
file_stem = file_path_obj.stem
|
|
file_suffix = file_path_obj.suffix
|
|
|
|
# Check if current file has suspicious suffix
|
|
for suffix in SUSPICIOUS_SUFFIXES:
|
|
for separator in ("_", "-"):
|
|
suffix_token = f"{separator}{suffix}"
|
|
if not file_stem.endswith(suffix_token):
|
|
continue
|
|
|
|
base_name = file_stem[: -len(suffix_token)]
|
|
potential_original = file_path_obj.parent / f"{base_name}{file_suffix}"
|
|
if potential_original.exists() and potential_original != file_path_obj:
|
|
message = FILE_SUFFIX_DUPLICATE_MSG.format(
|
|
current=file_path_obj.name,
|
|
original=potential_original.name,
|
|
)
|
|
issues.append(message)
|
|
break
|
|
else:
|
|
continue
|
|
break
|
|
|
|
# Check if any existing files are suffixed versions of current file
|
|
for existing_file in file_path_obj.parent.glob(f"{file_stem}_*{file_suffix}"):
|
|
if existing_file != file_path_obj:
|
|
existing_stem = existing_file.stem
|
|
if existing_stem.startswith(f"{file_stem}_"):
|
|
potential_suffix = existing_stem[len(file_stem) + 1 :]
|
|
if potential_suffix in SUSPICIOUS_SUFFIXES:
|
|
message = EXISTING_FILE_DUPLICATE_MSG.format(
|
|
current=file_path_obj.name,
|
|
existing=existing_file.name,
|
|
)
|
|
issues.append(message)
|
|
break
|
|
|
|
# Same check for dash-separated suffixes
|
|
for existing_file in file_path_obj.parent.glob(f"{file_stem}-*{file_suffix}"):
|
|
if existing_file != file_path_obj:
|
|
existing_stem = existing_file.stem
|
|
if existing_stem.startswith(f"{file_stem}-"):
|
|
potential_suffix = existing_stem[len(file_stem) + 1 :]
|
|
if potential_suffix in SUSPICIOUS_SUFFIXES:
|
|
message = EXISTING_FILE_DUPLICATE_MSG.format(
|
|
current=file_path_obj.name,
|
|
existing=existing_file.name,
|
|
)
|
|
issues.append(message)
|
|
break
|
|
|
|
# Check function and class names in content
|
|
with suppress(SyntaxError):
|
|
# Dedent the content to handle code fragments with leading indentation
|
|
tree = ast.parse(textwrap.dedent(content))
|
|
|
|
class SuffixVisitor(ast.NodeVisitor):
|
|
def __init__(self):
|
|
self.function_names: set[str] = set()
|
|
self.class_names: set[str] = set()
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
|
self.function_names.add(node.name)
|
|
self.generic_visit(node)
|
|
|
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
|
self.function_names.add(node.name)
|
|
self.generic_visit(node)
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
|
self.class_names.add(node.name)
|
|
self.generic_visit(node)
|
|
|
|
visitor = SuffixVisitor()
|
|
visitor.visit(tree)
|
|
|
|
# Check for suspicious function name patterns
|
|
for func_name in visitor.function_names:
|
|
for suffix in SUSPICIOUS_SUFFIXES:
|
|
suffix_token = f"_{suffix}"
|
|
if not func_name.endswith(suffix_token):
|
|
continue
|
|
|
|
base_name = func_name[: -len(suffix_token)]
|
|
if base_name in visitor.function_names:
|
|
message = NAME_SUFFIX_DUPLICATE_MSG.format(
|
|
kind="Function",
|
|
name=func_name,
|
|
base=base_name,
|
|
)
|
|
issues.append(message)
|
|
break
|
|
|
|
# Check for suspicious class name patterns
|
|
for class_name in visitor.class_names:
|
|
for suffix in SUSPICIOUS_SUFFIXES:
|
|
pascal_suffix = suffix.capitalize()
|
|
snake_suffix = f"_{suffix}"
|
|
|
|
potential_matches = (
|
|
(pascal_suffix, class_name[: -len(pascal_suffix)]),
|
|
(snake_suffix, class_name[: -len(snake_suffix)]),
|
|
)
|
|
|
|
for token, base_name in potential_matches:
|
|
if (
|
|
token
|
|
and class_name.endswith(token)
|
|
and base_name in visitor.class_names
|
|
):
|
|
message = NAME_SUFFIX_DUPLICATE_MSG.format(
|
|
kind="Class",
|
|
name=class_name,
|
|
base=base_name,
|
|
)
|
|
issues.append(message)
|
|
break
|
|
else:
|
|
continue
|
|
break
|
|
|
|
return issues
|
|
|
|
|
|
def _perform_quality_check(
|
|
file_path: str,
|
|
content: str,
|
|
config: QualityConfig,
|
|
enable_type_checks: bool = True,
|
|
) -> tuple[bool, list[str]]:
|
|
"""Perform quality analysis and return issues."""
|
|
# Store state if tracking enabled
|
|
if config.state_tracking_enabled:
|
|
store_pre_state(file_path, content)
|
|
|
|
# Run quality analysis
|
|
results = analyze_code_quality(
|
|
content,
|
|
file_path,
|
|
config,
|
|
enable_type_checks=enable_type_checks,
|
|
)
|
|
return check_code_issues(results, config, content, file_path)
|
|
|
|
|
|
def _handle_quality_issues(
|
|
file_path: str,
|
|
issues: list[str],
|
|
config: QualityConfig,
|
|
*,
|
|
forced_permission: str | None = None,
|
|
) -> JsonObject:
|
|
"""Handle quality issues based on enforcement mode."""
|
|
# Prepare denial message with formatted issues
|
|
formatted_issues: list[str] = []
|
|
for issue in issues:
|
|
# Add indentation to multi-line issues for better readability
|
|
if "\n" in issue:
|
|
lines = issue.split("\n")
|
|
formatted_issues.append(f"• {lines[0]}")
|
|
formatted_issues.extend(f" {line}" for line in lines[1:])
|
|
else:
|
|
formatted_issues.append(f"• {issue}")
|
|
|
|
message = (
|
|
f"Code quality check failed for {Path(file_path).name}:\n"
|
|
+ "\n".join(formatted_issues)
|
|
+ "\n\nFix these issues before writing the code."
|
|
)
|
|
|
|
# Make decision based on enforcement mode
|
|
if forced_permission:
|
|
return _create_hook_response("PreToolUse", forced_permission, message)
|
|
|
|
if config.enforcement_mode == "strict":
|
|
return _create_hook_response("PreToolUse", "deny", message)
|
|
if config.enforcement_mode == "warn":
|
|
return _create_hook_response("PreToolUse", "ask", message)
|
|
# permissive
|
|
warning_message = f"⚠️ Quality Warning:\n{message}"
|
|
return _create_hook_response(
|
|
"PreToolUse",
|
|
"allow",
|
|
warning_message,
|
|
warning_message,
|
|
)
|
|
|
|
|
|
def _exit_with_reason(reason: str, exit_code: int = 2) -> None:
|
|
"""Write reason to stderr and exit with specified code."""
|
|
sys.stderr.write(reason)
|
|
sys.exit(exit_code)
|
|
|
|
|
|
def _create_hook_response(
|
|
event_name: str,
|
|
permission: str = "",
|
|
reason: str = "",
|
|
system_message: str = "",
|
|
additional_context: str = "",
|
|
*,
|
|
decision: str | None = None,
|
|
) -> JsonObject:
|
|
"""Create standardized hook response."""
|
|
hook_output: dict[str, object] = {
|
|
"hookEventName": event_name,
|
|
}
|
|
|
|
if permission:
|
|
hook_output["permissionDecision"] = permission
|
|
if reason:
|
|
hook_output["permissionDecisionReason"] = reason
|
|
|
|
if additional_context:
|
|
hook_output["additionalContext"] = additional_context
|
|
|
|
response: JsonObject = {
|
|
"hookSpecificOutput": hook_output,
|
|
}
|
|
|
|
if permission:
|
|
response["permissionDecision"] = permission
|
|
|
|
if decision:
|
|
response["decision"] = decision
|
|
|
|
if reason:
|
|
response["reason"] = reason
|
|
|
|
if system_message:
|
|
response["systemMessage"] = system_message
|
|
|
|
return response
|
|
|
|
|
|
def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject:
|
|
"""Handle PreToolUse hook - analyze content before write/edit."""
|
|
tool_name = str(hook_data.get("tool_name", ""))
|
|
tool_input_raw = hook_data.get("tool_input", {})
|
|
if not isinstance(tool_input_raw, dict):
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
tool_input = cast("dict[str, object]", tool_input_raw)
|
|
|
|
# Only analyze for write/edit tools
|
|
if tool_name not in ["Write", "Edit", "MultiEdit"]:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
# Extract content based on tool type
|
|
file_path = str(tool_input.get("file_path", ""))
|
|
content = ""
|
|
|
|
if tool_name == "Write":
|
|
raw_content = tool_input.get("content", "")
|
|
content = "" if raw_content is None else str(raw_content)
|
|
elif tool_name == "Edit":
|
|
new_string = tool_input.get("new_string", "")
|
|
content = "" if new_string is None else str(new_string)
|
|
elif tool_name == "MultiEdit":
|
|
edits = tool_input.get("edits", [])
|
|
if isinstance(edits, list):
|
|
edits_list = cast("list[object]", edits)
|
|
parts: list[str] = []
|
|
for edit in edits_list:
|
|
if not isinstance(edit, dict):
|
|
continue
|
|
edit_dict = cast("dict[str, object]", edit)
|
|
new_str = edit_dict.get("new_string")
|
|
parts.append("") if new_str is None else parts.append(str(new_str))
|
|
content = "\n".join(parts)
|
|
|
|
# Only analyze Python files
|
|
if not file_path or not file_path.endswith(".py") or not content:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
# Check if this is a test file and test quality checks are enabled
|
|
is_test = is_test_file(file_path)
|
|
run_test_checks = config.test_quality_enabled and is_test
|
|
|
|
enable_type_checks = tool_name == "Write"
|
|
|
|
# Always run core checks (Any, type: ignore, typing, duplicates) before skipping
|
|
any_usage_issues = _detect_any_usage(content)
|
|
type_ignore_issues = _detect_type_ignore_usage(content)
|
|
old_typing_issues = _detect_old_typing_patterns(content)
|
|
suffix_duplication_issues = _detect_suffix_duplication(file_path, content)
|
|
precheck_issues = (
|
|
any_usage_issues
|
|
+ type_ignore_issues
|
|
+ old_typing_issues
|
|
+ suffix_duplication_issues
|
|
)
|
|
|
|
# Run test quality checks if enabled and file is a test file
|
|
if run_test_checks:
|
|
test_quality_issues = run_test_quality_checks(content, file_path, config)
|
|
precheck_issues.extend(test_quality_issues)
|
|
|
|
# Skip detailed analysis for configured patterns unless test checks should run
|
|
# Note: Core quality checks (Any, type: ignore, duplicates) always run above
|
|
should_skip_detailed = should_skip_file(file_path, config) and not run_test_checks
|
|
|
|
try:
|
|
# Run detailed quality checks only if not skipping
|
|
if should_skip_detailed:
|
|
issues = []
|
|
else:
|
|
_has_issues, issues = _perform_quality_check(
|
|
file_path,
|
|
content,
|
|
config,
|
|
enable_type_checks=enable_type_checks,
|
|
)
|
|
|
|
all_issues = precheck_issues + issues
|
|
|
|
if not all_issues:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
if precheck_issues:
|
|
return _handle_quality_issues(
|
|
file_path,
|
|
all_issues,
|
|
config,
|
|
forced_permission="deny",
|
|
)
|
|
|
|
return _handle_quality_issues(file_path, all_issues, config)
|
|
except Exception as e: # noqa: BLE001
|
|
if precheck_issues:
|
|
return _handle_quality_issues(
|
|
file_path,
|
|
precheck_issues,
|
|
config,
|
|
forced_permission="deny",
|
|
)
|
|
return _create_hook_response(
|
|
"PreToolUse",
|
|
"allow",
|
|
f"Warning: Code quality check failed with error: {e}",
|
|
f"Warning: Code quality check failed with error: {e}",
|
|
)
|
|
|
|
|
|
def posttooluse_hook(
|
|
hook_data: JsonObject,
|
|
config: QualityConfig,
|
|
) -> JsonObject:
|
|
"""Handle PostToolUse hook - verify quality after write/edit."""
|
|
tool_name: str = str(hook_data.get("tool_name", ""))
|
|
tool_output = hook_data.get("tool_response", hook_data.get("tool_output", {}))
|
|
|
|
# Only process write/edit tools
|
|
if tool_name not in ["Write", "Edit", "MultiEdit"]:
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
# Extract file path from output
|
|
file_path: str = ""
|
|
if isinstance(tool_output, dict):
|
|
tool_output_dict = cast("dict[str, object]", tool_output)
|
|
file_path = str(
|
|
tool_output_dict.get("file_path", "") or tool_output_dict.get("path", ""),
|
|
)
|
|
elif isinstance(tool_output, str) and (
|
|
match := re.search(r"([/\w\-_.]+\.py)", tool_output)
|
|
):
|
|
file_path = match[1]
|
|
|
|
if not file_path or not file_path.endswith(".py"):
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
if not Path(file_path).exists():
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
issues: list[str] = []
|
|
|
|
# Read entire file content for full analysis
|
|
try:
|
|
with open(file_path, encoding="utf-8") as f:
|
|
file_content = f.read()
|
|
except (OSError, UnicodeDecodeError):
|
|
# If we can't read the file, skip full content analysis
|
|
file_content = ""
|
|
|
|
# Run full file quality checks on the entire content
|
|
if file_content:
|
|
any_usage_issues = _detect_any_usage(file_content)
|
|
type_ignore_issues = _detect_type_ignore_usage(file_content)
|
|
issues.extend(any_usage_issues)
|
|
issues.extend(type_ignore_issues)
|
|
|
|
# Check state changes if tracking enabled
|
|
if config.state_tracking_enabled:
|
|
delta_issues = check_state_changes(file_path)
|
|
issues.extend(delta_issues)
|
|
|
|
# Run cross-file duplicate detection if enabled
|
|
if config.cross_file_check_enabled:
|
|
cross_file_issues = check_cross_file_duplicates(file_path, config)
|
|
issues.extend(cross_file_issues)
|
|
|
|
# Verify naming conventions if enabled
|
|
if config.verify_naming:
|
|
naming_issues = verify_naming_conventions(file_path)
|
|
issues.extend(naming_issues)
|
|
|
|
# Format response
|
|
if issues:
|
|
message = (
|
|
f"📝 Post-write quality notes for {Path(file_path).name}:\n"
|
|
+ "\n".join(issues)
|
|
)
|
|
return _create_hook_response(
|
|
"PostToolUse",
|
|
"",
|
|
message,
|
|
message,
|
|
message,
|
|
decision="block",
|
|
)
|
|
if config.show_success:
|
|
message = f"✅ {Path(file_path).name} passed post-write verification"
|
|
return _create_hook_response(
|
|
"PostToolUse",
|
|
"",
|
|
"",
|
|
message,
|
|
"",
|
|
decision="approve",
|
|
)
|
|
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
|
|
def is_test_file(file_path: str) -> bool:
|
|
"""Check if file path is in a test directory."""
|
|
path_parts = Path(file_path).parts
|
|
return any(part in ("test", "tests", "testing") for part in path_parts)
|
|
|
|
|
|
def run_test_quality_checks(
|
|
content: str,
|
|
file_path: str,
|
|
config: QualityConfig,
|
|
) -> list[str]:
|
|
"""Run Sourcery's test rules and return guidance-enhanced issues."""
|
|
issues: list[str] = []
|
|
|
|
# Only run test quality checks for test files
|
|
if not is_test_file(file_path):
|
|
return issues
|
|
|
|
suffix = Path(file_path).suffix or ".py"
|
|
|
|
# Create temp file in project directory to inherit config files
|
|
tmp_dir = _get_project_tmp_dir(file_path)
|
|
|
|
with NamedTemporaryFile(
|
|
mode="w",
|
|
suffix=suffix,
|
|
delete=False,
|
|
dir=str(tmp_dir),
|
|
prefix="test_validation_",
|
|
) as tmp:
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
# Run Sourcery with specific test-related rules
|
|
venv_bin = _get_project_venv_bin()
|
|
sourcery_path = venv_bin / "sourcery"
|
|
|
|
if not sourcery_path.exists():
|
|
# Try to find sourcery in PATH
|
|
import shutil
|
|
|
|
sourcery_path = shutil.which("sourcery") or str(venv_bin / "sourcery")
|
|
|
|
if not sourcery_path or not Path(sourcery_path).exists():
|
|
logging.debug("Sourcery not found at %s", sourcery_path)
|
|
return issues
|
|
|
|
# Specific rules for test quality - use correct Sourcery format
|
|
test_rules = [
|
|
"no-conditionals-in-tests",
|
|
"no-loop-in-tests",
|
|
"raise-specific-error",
|
|
"dont-import-test-modules",
|
|
]
|
|
|
|
# Build command with --enable for each rule
|
|
cmd = [str(sourcery_path), "review", tmp_path]
|
|
for rule in test_rules:
|
|
cmd.extend(["--enable", rule])
|
|
cmd.append("--check") # Return exit code 1 if issues found
|
|
|
|
# Activate virtual environment for the subprocess
|
|
env = os.environ.copy()
|
|
env["VIRTUAL_ENV"] = str(venv_bin.parent)
|
|
env["PATH"] = f"{venv_bin}:{env.get('PATH', '')}"
|
|
# Remove any PYTHONHOME that might interfere
|
|
env.pop("PYTHONHOME", None)
|
|
|
|
logging.debug("Running Sourcery command: %s", " ".join(cmd))
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
env=env,
|
|
)
|
|
|
|
logging.debug("Sourcery exit code: %s", result.returncode)
|
|
logging.debug("Sourcery stdout: %s", result.stdout)
|
|
logging.debug("Sourcery stderr: %s", result.stderr)
|
|
|
|
# Sourcery with --check returns:
|
|
# - Exit code 0: No issues found
|
|
# - Exit code 1: Issues found
|
|
# - Exit code 2: Error occurred
|
|
if result.returncode == 1:
|
|
# Issues were found - parse the output
|
|
output = result.stdout + result.stderr
|
|
|
|
# Try to extract rule names from the output. Sourcery usually includes
|
|
# rule identifiers in brackets or descriptive text.
|
|
for rule in test_rules:
|
|
if rule in output or rule.replace("-", " ") in output.lower():
|
|
base_guidance = generate_test_quality_guidance(
|
|
rule,
|
|
content,
|
|
file_path,
|
|
config,
|
|
)
|
|
if external_context := get_external_context(
|
|
rule,
|
|
content,
|
|
file_path,
|
|
config,
|
|
):
|
|
base_guidance += f"\n\n{external_context}"
|
|
issues.append(base_guidance)
|
|
break # Only add one guidance message
|
|
else:
|
|
# If no specific rule found, provide general guidance
|
|
base_guidance = generate_test_quality_guidance(
|
|
"unknown",
|
|
content,
|
|
file_path,
|
|
config,
|
|
)
|
|
if external_context := get_external_context(
|
|
"unknown",
|
|
content,
|
|
file_path,
|
|
config,
|
|
):
|
|
base_guidance += f"\n\n{external_context}"
|
|
issues.append(base_guidance)
|
|
elif result.returncode == 2:
|
|
# Error occurred
|
|
logging.debug("Sourcery error: %s", result.stderr)
|
|
# Don't block on Sourcery errors - just log them
|
|
# Exit code 0 means no issues - do nothing
|
|
|
|
except (subprocess.TimeoutExpired, OSError, json.JSONDecodeError) as e:
|
|
# If Sourcery fails, don't block the operation
|
|
logging.debug("Test quality check failed for %s: %s", file_path, e)
|
|
finally:
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|
|
|
return issues
|
|
|
|
|
|
def main() -> None:
|
|
"""Main hook entry point."""
|
|
try:
|
|
# Load configuration
|
|
config = QualityConfig.from_env()
|
|
|
|
# Read hook input from stdin
|
|
try:
|
|
hook_data: JsonObject = json.load(sys.stdin)
|
|
except json.JSONDecodeError:
|
|
fallback_response: JsonObject = {
|
|
"hookSpecificOutput": {
|
|
"hookEventName": "PreToolUse",
|
|
"permissionDecision": "allow",
|
|
},
|
|
}
|
|
sys.stdout.write(json.dumps(fallback_response))
|
|
return
|
|
|
|
# Detect hook type: tool_response=PostToolUse, tool_input=PreToolUse
|
|
response: JsonObject
|
|
if "tool_response" in hook_data or "tool_output" in hook_data:
|
|
# PostToolUse hook
|
|
response = posttooluse_hook(hook_data, config)
|
|
else:
|
|
# PreToolUse hook
|
|
response = pretooluse_hook(hook_data, config)
|
|
|
|
print(json.dumps(response)) # noqa: T201
|
|
|
|
# Handle exit codes based on hook output
|
|
hook_output_raw = response.get("hookSpecificOutput", {})
|
|
if not isinstance(hook_output_raw, dict):
|
|
return
|
|
|
|
hook_output = cast("dict[str, object]", hook_output_raw)
|
|
permission_decision = hook_output.get("permissionDecision")
|
|
|
|
if permission_decision == "deny":
|
|
# Exit code 2: Blocking error - stderr fed back to Claude
|
|
reason = str(
|
|
hook_output.get("permissionDecisionReason", "Permission denied"),
|
|
)
|
|
_exit_with_reason(reason)
|
|
elif permission_decision == "ask":
|
|
# Also use exit code 2 for ask decisions to ensure Claude sees the message
|
|
reason = str(
|
|
hook_output.get("permissionDecisionReason", "Permission request"),
|
|
)
|
|
_exit_with_reason(reason)
|
|
# Exit code 0: Success (default)
|
|
|
|
except Exception as e: # noqa: BLE001
|
|
# Unexpected error - use exit code 1 (non-blocking error)
|
|
sys.stderr.write(f"Hook error: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|