Files
claude-scripts/hooks/code_quality_guard.py
Travis Vasceannie 9cf5baafb4 feat: implement test quality checks with enhanced guidance and external context integration
- Added a new QualityConfig class to manage test quality check configurations.
- Implemented test quality checks for specific rules in test files, including prevention of conditionals, loops, and generic exceptions.
- Integrated external context providers (Context7 and Firecrawl) for additional guidance on test quality violations.
- Enhanced error messaging to provide detailed, actionable guidance for detected issues.
- Updated README_HOOKS.md to document new test quality features and configuration options.
- Added unit tests to verify the functionality of test quality checks and their integration with the pretooluse_hook.
2025-09-29 20:59:32 +00:00

1475 lines
52 KiB
Python

"""Unified quality hook for Claude Code supporting both PreToolUse and PostToolUse.
Prevents writing duplicate, complex, or non-modernized code and verifies quality
after writes.
"""
import ast
import hashlib
import json
import logging
import os
import re
import subprocess
import sys
import tokenize
from collections.abc import Callable
from contextlib import suppress
from dataclasses import dataclass
from datetime import UTC, datetime
from io import StringIO
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import TypedDict, cast
# Import internal duplicate detector
sys.path.insert(0, str(Path(__file__).parent))
from internal_duplicate_detector import detect_internal_duplicates
class QualityConfig:
"""Configuration for quality checks."""
def __init__(self):
self._config = config = QualityConfig.from_env()
config.skip_patterns = ["test_", "_test.py", "/tests/", "/fixtures/"]
def __getattr__(self, name):
return getattr(self._config, name, self._config[name])
@classmethod
def from_env(cls) -> "QualityConfig":
"""Load config from environment variables."""
return cls()
def get_external_context(rule_id: str, content: str, file_path: str, config: QualityConfig) -> str:
"""Get additional context from external APIs for enhanced guidance."""
context_parts = []
# Context7 integration for additional context analysis
if config.context7_enabled and config.context7_api_key:
try:
# Note: This would integrate with actual Context7 API
# For now, providing placeholder for the integration
context7_context = _get_context7_analysis(rule_id, content, config.context7_api_key)
if context7_context:
context_parts.append(f"📊 Context7 Analysis: {context7_context}")
except Exception as e:
logging.debug("Context7 API call failed: %s", e)
# Firecrawl integration for web scraping additional examples
if config.firecrawl_enabled and config.firecrawl_api_key:
try:
# Note: This would integrate with actual Firecrawl API
# For now, providing placeholder for the integration
firecrawl_examples = _get_firecrawl_examples(rule_id, config.firecrawl_api_key)
if firecrawl_examples:
context_parts.append(f"🔗 Additional Examples: {firecrawl_examples}")
except Exception as e:
logging.debug("Firecrawl API call failed: %s", e)
return "\n\n".join(context_parts) if context_parts else ""
def _get_context7_analysis(rule_id: str, content: str, api_key: str) -> str:
"""Placeholder for Context7 API integration."""
# This would make actual API calls to Context7
# For demonstration, returning contextual analysis based on rule type
context_map = {
"no-conditionals-in-tests": "Consider using data-driven tests with pytest.mark.parametrize for better test isolation.",
"no-loop-in-tests": "Each iteration should be a separate test case for better failure isolation and debugging.",
"raise-specific-error": "Specific exceptions improve error handling and make tests more maintainable.",
"dont-import-test-modules": "Keep test utilities separate from production code to maintain clean architecture."
}
return context_map.get(rule_id, "Additional context analysis would be provided here.")
def _get_firecrawl_examples(rule_id: str, api_key: str) -> str:
"""Placeholder for Firecrawl API integration."""
# This would scrape web for additional examples and best practices
# For demonstration, returning relevant examples based on rule type
examples_map = {
"no-conditionals-in-tests": "See pytest documentation on parameterized tests for multiple scenarios.",
"no-loop-in-tests": "Check testing best practices guides for data-driven testing approaches.",
"raise-specific-error": "Review Python exception hierarchy and custom exception patterns.",
"dont-import-test-modules": "Explore clean architecture patterns for test organization."
}
return examples_map.get(rule_id, "Additional examples and patterns would be sourced from web resources.")
def generate_test_quality_guidance(rule_id: str, content: str, file_path: str, config: QualityConfig) -> str:
"""Generate specific, actionable guidance for test quality rule violations."""
# Extract function name and context for better guidance
function_name = "test_function"
if "def " in content:
# Try to extract the test function name
import re
match = re.search(r'def\s+(\w+)\s*\(', content)
if match:
function_name = match.group(1)
guidance_map = {
"no-conditionals-in-tests": {
"title": "Conditional Logic in Test Function",
"problem": f"Test function '{function_name}' contains conditional statements (if/elif/else).",
"why": "Tests should be simple assertions that verify specific behavior. Conditionals make tests harder to understand and maintain.",
"solutions": [
"✅ Replace conditionals with parameterized test cases",
"✅ Use pytest.mark.parametrize for multiple scenarios",
"✅ Extract conditional logic into helper functions",
"✅ Use assertion libraries like assertpy for complex conditions"
],
"examples": [
"# ❌ Instead of this:",
"def test_user_access():",
" user = create_user()",
" if user.is_admin:",
" assert user.can_access_admin()",
" else:",
" assert not user.can_access_admin()",
"",
"# ✅ Do this:",
"@pytest.mark.parametrize('is_admin,can_access', [",
" (True, True),",
" (False, False)",
"])",
"def test_user_access(is_admin, can_access):",
" user = create_user(admin=is_admin)",
" assert user.can_access_admin() == can_access"
]
},
"no-loop-in-tests": {
"title": "Loop Found in Test Function",
"problem": f"Test function '{function_name}' contains loops (for/while).",
"why": "Loops in tests often indicate testing multiple scenarios that should be separate test cases.",
"solutions": [
"✅ Break loop into individual test cases",
"✅ Use parameterized tests for multiple data scenarios",
"✅ Extract loop logic into data providers or fixtures",
"✅ Use pytest fixtures for setup that requires iteration"
],
"examples": [
"# ❌ Instead of this:",
"def test_process_items():",
" for item in test_items:",
" result = process_item(item)",
" assert result.success",
"",
"# ✅ Do this:",
"@pytest.mark.parametrize('item', test_items)",
"def test_process_item(item):",
" result = process_item(item)",
" assert result.success"
]
},
"raise-specific-error": {
"title": "Generic Exception in Test",
"problem": f"Test function '{function_name}' raises generic Exception or BaseException.",
"why": "Generic exceptions make it harder to identify specific test failures and handle different error conditions appropriately.",
"solutions": [
"✅ Use specific exception types (ValueError, TypeError, etc.)",
"✅ Create custom exception classes for domain-specific errors",
"✅ Use pytest.raises() with specific exception types",
"✅ Test for expected exceptions, not generic ones"
],
"examples": [
"# ❌ Instead of this:",
"def test_invalid_input():",
" with pytest.raises(Exception):",
" process_invalid_data()",
"",
"# ✅ Do this:",
"def test_invalid_input():",
" with pytest.raises(ValueError, match='Invalid input'):",
" process_invalid_data()",
"",
"# Or create custom exceptions:",
"def test_business_logic():",
" with pytest.raises(InsufficientFundsError):",
" account.withdraw(1000)"
]
},
"dont-import-test-modules": {
"title": "Test Module Import in Non-Test Code",
"problem": f"File '{Path(file_path).name}' imports test modules outside of test directories.",
"why": "Test modules should only be imported by other test files. Production code should not depend on test utilities.",
"solutions": [
"✅ Move shared test utilities to a separate utils module",
"✅ Create production versions of test helper functions",
"✅ Use dependency injection instead of direct test imports",
"✅ Extract common logic into a shared production module"
],
"examples": [
"# ❌ Instead of this (in production code):",
"from tests.test_helpers import create_test_data",
"",
"# ✅ Do this:",
"# Create src/utils/test_data_factory.py",
"from src.utils.test_data_factory import create_test_data",
"",
"# Or use fixtures in tests:",
"@pytest.fixture",
"def test_data():",
" return create_production_data()"
]
}
}
guidance = guidance_map.get(rule_id, {
"title": f"Test Quality Issue: {rule_id}",
"problem": f"Rule '{rule_id}' violation detected in test code.",
"why": "This rule helps maintain test quality and best practices.",
"solutions": [
"✅ Review the specific rule requirements",
"✅ Refactor code to follow test best practices",
"✅ Consult testing framework documentation"
],
"examples": [
"# Review the rule documentation for specific guidance",
"# Consider the test's purpose and refactor accordingly"
]
})
# Format the guidance message
message = f"""
🚫 {guidance['title']}
📋 Problem: {guidance['problem']}
❓ Why this matters: {guidance['why']}
🛠️ How to fix it:
{chr(10).join(f"{solution}" for solution in guidance['solutions'])}
💡 Example:
{chr(10).join(f" {line}" for line in guidance['examples'])}
🔍 File: {Path(file_path).name}
📍 Function: {function_name}
"""
return message.strip()
class ToolConfig(TypedDict):
"""Configuration for a type checking tool."""
args: list[str]
error_check: Callable[[subprocess.CompletedProcess[str]], bool]
error_message: str | Callable[[subprocess.CompletedProcess[str]], str]
class DuplicateLocation(TypedDict):
"""Location information for a duplicate code block."""
name: str
lines: str
class Duplicate(TypedDict):
"""Duplicate code detection result."""
similarity: float
description: str
locations: list[DuplicateLocation]
class DuplicateResults(TypedDict):
"""Results from duplicate detection analysis."""
duplicates: list[Duplicate]
class ComplexitySummary(TypedDict):
"""Summary of complexity analysis."""
average_cyclomatic_complexity: float
class ComplexityDistribution(TypedDict):
"""Distribution of complexity levels."""
High: int
Very_High: int
Extreme: int
class ComplexityResults(TypedDict):
"""Results from complexity analysis."""
summary: ComplexitySummary
distribution: ComplexityDistribution
class TypeCheckingResults(TypedDict):
"""Results from type checking analysis."""
issues: list[str]
class AnalysisResults(TypedDict, total=False):
"""Complete analysis results from quality checks."""
internal_duplicates: DuplicateResults
complexity: ComplexityResults
type_checking: TypeCheckingResults
modernization: dict[str, object] # JSON structure varies
# Type aliases for JSON-like structures
JsonObject = dict[str, object]
JsonValue = str | int | float | bool | None | list[object] | JsonObject
@dataclass
class QualityConfig:
"""Configuration for quality checks."""
# Core settings
duplicate_threshold: float = 0.7
duplicate_enabled: bool = True
complexity_threshold: int = 10
complexity_enabled: bool = True
modernization_enabled: bool = True
require_type_hints: bool = True
enforcement_mode: str = "strict" # strict/warn/permissive
# Type checking tools
sourcery_enabled: bool = True
basedpyright_enabled: bool = True
pyrefly_enabled: bool = True
type_check_exit_code: int = 2
# PostToolUse features
state_tracking_enabled: bool = False
cross_file_check_enabled: bool = False
verify_naming: bool = True
show_success: bool = False
# Test quality checks
test_quality_enabled: bool = True
# External context providers
context7_enabled: bool = False
context7_api_key: str = ""
firecrawl_enabled: bool = False
firecrawl_api_key: str = ""
# File patterns
skip_patterns: list[str] | None = None
def __post_init__(self) -> None:
if self.skip_patterns is None:
self.skip_patterns = ["test_", "_test.py", "/tests/", "/fixtures/"]
@classmethod
def from_env(cls) -> "QualityConfig":
"""Load config from environment variables."""
return cls(
duplicate_threshold=float(os.getenv("QUALITY_DUP_THRESHOLD", "0.7")),
duplicate_enabled=os.getenv("QUALITY_DUP_ENABLED", "true").lower()
== "true",
complexity_threshold=int(os.getenv("QUALITY_COMPLEXITY_THRESHOLD", "10")),
complexity_enabled=os.getenv("QUALITY_COMPLEXITY_ENABLED", "true").lower()
== "true",
modernization_enabled=os.getenv("QUALITY_MODERN_ENABLED", "true").lower()
== "true",
require_type_hints=os.getenv("QUALITY_REQUIRE_TYPES", "true").lower()
== "true",
enforcement_mode=os.getenv("QUALITY_ENFORCEMENT", "strict"),
state_tracking_enabled=os.getenv("QUALITY_STATE_TRACKING", "false").lower()
== "true",
cross_file_check_enabled=os.getenv(
"QUALITY_CROSS_FILE_CHECK",
"false",
).lower()
== "true",
verify_naming=os.getenv("QUALITY_VERIFY_NAMING", "true").lower() == "true",
show_success=os.getenv("QUALITY_SHOW_SUCCESS", "false").lower() == "true",
sourcery_enabled=os.getenv("QUALITY_SOURCERY_ENABLED", "true").lower()
== "true",
basedpyright_enabled=os.getenv(
"QUALITY_BASEDPYRIGHT_ENABLED",
"true",
).lower()
== "true",
pyrefly_enabled=os.getenv("QUALITY_PYREFLY_ENABLED", "true").lower()
== "true",
type_check_exit_code=int(os.getenv("QUALITY_TYPE_CHECK_EXIT_CODE", "2")),
test_quality_enabled=os.getenv("QUALITY_TEST_QUALITY_ENABLED", "true").lower()
== "true",
context7_enabled=os.getenv("QUALITY_CONTEXT7_ENABLED", "false").lower()
== "true",
context7_api_key=os.getenv("QUALITY_CONTEXT7_API_KEY", ""),
firecrawl_enabled=os.getenv("QUALITY_FIRECRAWL_ENABLED", "false").lower()
== "true",
firecrawl_api_key=os.getenv("QUALITY_FIRECRAWL_API_KEY", ""),
)
def should_skip_file(file_path: str, config: QualityConfig) -> bool:
"""Check if file should be skipped based on patterns."""
if config.skip_patterns is None:
return False
return any(pattern in file_path for pattern in config.skip_patterns)
def get_claude_quality_command() -> list[str]:
"""Return a path-resilient command for invoking claude-quality."""
repo_root = Path(__file__).resolve().parent.parent
venv_python = repo_root / ".venv/bin/python"
if venv_python.exists():
return [str(venv_python), "-m", "quality.cli.main"]
venv_cli = repo_root / ".venv/bin/claude-quality"
if venv_cli.exists():
return [str(venv_cli)]
return ["claude-quality"]
def _ensure_tool_installed(tool_name: str) -> bool:
"""Ensure a type checking tool is installed in the virtual environment."""
venv_bin = Path(__file__).parent.parent / ".venv/bin"
tool_path = venv_bin / tool_name
if tool_path.exists():
return True
# Try to install using uv if available
try:
result = subprocess.run( # noqa: S603
[str(venv_bin / "uv"), "pip", "install", tool_name],
check=False,
capture_output=True,
text=True,
timeout=60,
)
except (subprocess.TimeoutExpired, OSError):
return False
else:
return result.returncode == 0
def _run_type_checker(
tool_name: str,
file_path: str,
_config: QualityConfig,
) -> tuple[bool, str]:
"""Run a type checking tool and return success status and output."""
venv_bin = Path(__file__).parent.parent / ".venv/bin"
tool_path = venv_bin / tool_name
if not tool_path.exists() and not _ensure_tool_installed(tool_name):
return True, f"Warning: {tool_name} not available"
# Tool configuration mapping
tool_configs: dict[str, ToolConfig] = {
"basedpyright": ToolConfig(
args=["--outputjson", file_path],
error_check=lambda result: result.returncode == 1,
error_message="Type errors found",
),
"pyrefly": ToolConfig(
args=["check", file_path],
error_check=lambda result: result.returncode == 1,
error_message=lambda result: str(result.stdout).strip(),
),
"sourcery": ToolConfig(
args=["review", file_path],
error_check=lambda result: (
"issues detected" in str(result.stdout)
and "0 issues detected" not in str(result.stdout)
),
error_message=lambda result: str(result.stdout).strip(),
),
}
tool_config = tool_configs.get(tool_name)
if not tool_config:
return True, f"Warning: Unknown tool {tool_name}"
try:
cmd = [str(tool_path)] + tool_config["args"]
result = subprocess.run( # noqa: S603
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
)
# Check for tool-specific errors
error_check = tool_config["error_check"]
if error_check(result):
error_msg = tool_config["error_message"]
message = str(error_msg(result) if callable(error_msg) else error_msg)
return False, message
# Return success or warning
if result.returncode == 0:
return True, ""
except subprocess.TimeoutExpired:
return True, f"Warning: {tool_name} timeout"
except OSError:
return True, f"Warning: {tool_name} execution error"
else:
return True, f"Warning: {tool_name} error (exit {result.returncode})"
def _initialize_analysis() -> tuple[AnalysisResults, list[str]]:
"""Initialize analysis with empty results and claude-quality command."""
results: AnalysisResults = {}
claude_quality_cmd: list[str] = get_claude_quality_command()
return results, claude_quality_cmd
def run_type_checks(file_path: str, config: QualityConfig) -> list[str]:
"""Run all enabled type checking tools and return any issues."""
issues: list[str] = []
# Run Sourcery
if config.sourcery_enabled:
success, output = _run_type_checker("sourcery", file_path, config)
if not success and output:
issues.append(f"Sourcery: {output.strip()}")
# Run BasedPyright
if config.basedpyright_enabled:
success, output = _run_type_checker("basedpyright", file_path, config)
if not success and output:
issues.append(f"BasedPyright: {output.strip()}")
# Run Pyrefly
if config.pyrefly_enabled:
success, output = _run_type_checker("pyrefly", file_path, config)
if not success and output:
issues.append(f"Pyrefly: {output.strip()}")
return issues
def _run_quality_analyses(
content: str,
tmp_path: str,
config: QualityConfig,
enable_type_checks: bool,
) -> AnalysisResults:
"""Run all quality analysis checks and return results."""
results, claude_quality_cmd = _initialize_analysis()
# First check for internal duplicates within the file
if config.duplicate_enabled:
internal_duplicates_raw = detect_internal_duplicates(
content,
threshold=config.duplicate_threshold,
min_lines=4,
)
# Cast after runtime validation - function returns compatible structure
internal_duplicates = cast("DuplicateResults", internal_duplicates_raw)
if internal_duplicates.get("duplicates"):
results["internal_duplicates"] = internal_duplicates
# Run complexity analysis
if config.complexity_enabled:
cmd = [
*claude_quality_cmd,
"complexity",
tmp_path,
"--threshold",
str(config.complexity_threshold),
"--format",
"json",
]
with suppress(subprocess.TimeoutExpired):
result = subprocess.run( # noqa: S603
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
)
if result.returncode == 0:
with suppress(json.JSONDecodeError):
results["complexity"] = json.loads(result.stdout)
# Run type checking if any tool is enabled
if enable_type_checks and any(
[
config.sourcery_enabled,
config.basedpyright_enabled,
config.pyrefly_enabled,
],
):
try:
if type_issues := run_type_checks(tmp_path, config):
results["type_checking"] = {"issues": type_issues}
except Exception as e: # noqa: BLE001
logging.debug("Type checking failed: %s", e)
# Run modernization analysis
if config.modernization_enabled:
cmd = [
*claude_quality_cmd,
"modernization",
tmp_path,
"--include-type-hints" if config.require_type_hints else "",
"--format",
"json",
]
cmd = [c for c in cmd if c] # Remove empty strings
with suppress(subprocess.TimeoutExpired):
result = subprocess.run( # noqa: S603
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
)
if result.returncode == 0:
with suppress(json.JSONDecodeError):
results["modernization"] = json.loads(result.stdout)
return results
def analyze_code_quality(
content: str,
file_path: str,
config: QualityConfig,
*,
enable_type_checks: bool = True,
) -> AnalysisResults:
"""Analyze code content using claude-quality toolkit."""
suffix = Path(file_path).suffix or ".py"
with NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
return _run_quality_analyses(content, tmp_path, config, enable_type_checks)
finally:
Path(tmp_path).unlink(missing_ok=True)
def _check_internal_duplicates(results: AnalysisResults) -> list[str]:
"""Check for internal duplicate code within the same file."""
issues: list[str] = []
if "internal_duplicates" not in results:
return issues
duplicates: list[Duplicate] = results["internal_duplicates"].get(
"duplicates",
[],
)
for dup in duplicates[:3]: # Show first 3
locations = ", ".join(
f"{loc['name']} ({loc['lines']})" for loc in dup.get("locations", [])
)
issues.append(
f"Internal duplication ({dup.get('similarity', 0):.0%} similar): "
f"{dup.get('description')} - {locations}",
)
return issues
def _check_complexity_issues(
results: AnalysisResults,
config: QualityConfig,
) -> list[str]:
"""Check for code complexity issues."""
issues: list[str] = []
if "complexity" not in results:
return issues
complexity_data = results["complexity"]
summary = complexity_data.get("summary", {})
avg_cc = summary.get("average_cyclomatic_complexity", 0.0)
if avg_cc > config.complexity_threshold:
issues.append(
f"High average complexity: CC={avg_cc:.1f} "
f"(threshold: {config.complexity_threshold})",
)
distribution = complexity_data.get("distribution", {})
high_count: int = (
distribution.get("High", 0)
+ distribution.get("Very High", 0)
+ distribution.get("Extreme", 0)
)
if high_count > 0:
issues.append(f"Found {high_count} function(s) with high complexity")
return issues
def _check_modernization_issues(
results: AnalysisResults,
config: QualityConfig,
) -> list[str]:
"""Check for code modernization issues."""
issues: list[str] = []
if "modernization" not in results:
return issues
try:
modernization_data = results["modernization"]
files = modernization_data.get("files", {})
if not isinstance(files, dict):
return issues
except (AttributeError, TypeError):
return issues
total_issues: int = 0
issue_types: set[str] = set()
files_values = cast("dict[str, object]", files).values()
for file_issues_raw in files_values:
if not isinstance(file_issues_raw, list):
continue
# Cast to proper type after runtime check
file_issues_list = cast("list[object]", file_issues_raw)
total_issues += len(file_issues_list)
for issue_raw in file_issues_list:
if not isinstance(issue_raw, dict):
continue
# Cast to proper type after runtime check
issue_dict = cast("dict[str, object]", issue_raw)
issue_type_raw = issue_dict.get("issue_type", "unknown")
if isinstance(issue_type_raw, str):
issue_types.add(issue_type_raw)
# Only flag if there are non-type-hint issues or many type hint issues
non_type_issues: int = len(
[t for t in issue_types if "type" not in t and "typing" not in t],
)
type_issues: int = total_issues - non_type_issues
if non_type_issues > 0:
non_type_list: list[str] = [
t for t in issue_types if "type" not in t and "typing" not in t
]
issues.append(
f"Modernization needed: {non_type_issues} non-type issues "
f"({', '.join(non_type_list)})",
)
elif config.require_type_hints and type_issues > 10:
issues.append(
f"Many missing type hints: {type_issues} functions/parameters "
"lacking annotations",
)
return issues
def _check_type_checking_issues(results: AnalysisResults) -> list[str]:
"""Check for type checking issues from Sourcery, BasedPyright, and Pyrefly."""
issues: list[str] = []
if "type_checking" not in results:
return issues
with suppress(AttributeError, TypeError):
type_checking_data = results["type_checking"]
type_issues = type_checking_data.get("issues", [])
issues.extend(str(issue_raw) for issue_raw in type_issues[:5])
return issues
def check_code_issues(
results: AnalysisResults,
config: QualityConfig,
) -> tuple[bool, list[str]]:
"""Check analysis results for issues that should block the operation."""
issues: list[str] = []
issues.extend(_check_internal_duplicates(results))
issues.extend(_check_complexity_issues(results, config))
issues.extend(_check_modernization_issues(results, config))
issues.extend(_check_type_checking_issues(results))
return len(issues) > 0, issues
def store_pre_state(file_path: str, content: str) -> None:
"""Store file state before modification for later comparison."""
import tempfile
cache_dir = Path(tempfile.gettempdir()) / ".quality_state"
cache_dir.mkdir(exist_ok=True, mode=0o700)
state: dict[str, str | int] = {
"file_path": file_path,
"timestamp": datetime.now(UTC).isoformat(),
"content_hash": hashlib.sha256(content.encode()).hexdigest(),
"lines": len(content.split("\n")),
"functions": content.count("def "),
"classes": content.count("class "),
}
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
cache_file: Path = cache_dir / f"{cache_key}_pre.json"
cache_file.write_text(json.dumps(state, indent=2))
def check_state_changes(file_path: str) -> list[str]:
"""Check for quality changes between pre and post states."""
import tempfile
issues: list[str] = []
cache_dir: Path = Path(tempfile.gettempdir()) / ".quality_state"
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
pre_file: Path = cache_dir / f"{cache_key}_pre.json"
if not pre_file.exists():
return issues
try:
pre_state = json.loads(pre_file.read_text())
try:
current_content = Path(file_path).read_text()
except OSError:
return issues # Can't compare if can't read file
current_lines = len(current_content.split("\n"))
current_functions = current_content.count("def ")
# Check for significant changes
if current_functions < pre_state.get("functions", 0):
issues.append(
f"⚠️ Reduced functions: {pre_state['functions']}{current_functions}",
)
if current_lines > pre_state.get("lines", 0) * 1.5:
issues.append(
"⚠️ File size increased significantly: "
f"{pre_state['lines']}{current_lines} lines",
)
except Exception: # noqa: BLE001
logging.debug("Could not analyze state changes for %s", file_path)
return issues
def check_cross_file_duplicates(file_path: str, config: QualityConfig) -> list[str]:
"""Check for duplicates across project files."""
issues: list[str] = []
# Get project root
project_root: Path = Path(file_path).parent
while (
project_root.parent != project_root
and not (project_root / ".git").exists()
and not (project_root / "pyproject.toml").exists()
):
project_root = project_root.parent
try:
claude_quality_cmd = get_claude_quality_command()
result = subprocess.run( # noqa: S603
[
*claude_quality_cmd,
"duplicates",
str(project_root),
"--threshold",
str(config.duplicate_threshold),
"--format",
"json",
],
check=False,
capture_output=True,
text=True,
timeout=60,
)
if result.returncode == 0:
data = json.loads(result.stdout)
duplicates = data.get("duplicates", [])
if any(file_path in str(d) for d in duplicates):
issues.append("⚠️ Cross-file duplication detected")
except Exception: # noqa: BLE001
logging.debug("Could not check cross-file duplicates for %s", file_path)
return issues
def verify_naming_conventions(file_path: str) -> list[str]:
"""Verify PEP8 naming conventions."""
issues: list[str] = []
try:
content = Path(file_path).read_text()
except OSError:
return issues # Can't check naming if can't read file
# Check function names (should be snake_case)
if bad_funcs := re.findall(
r"def\s+([A-Z][a-zA-Z0-9_]*|[a-z]+[A-Z][a-zA-Z0-9_]*)\s*\(",
content,
):
issues.append(f"⚠️ Non-PEP8 function names: {', '.join(bad_funcs[:3])}")
# Check class names (should be PascalCase)
if bad_classes := re.findall(r"class\s+([a-z][a-z0-9_]*)\s*[\(:]", content):
issues.append(f"⚠️ Non-PEP8 class names: {', '.join(bad_classes[:3])}")
return issues
def _detect_any_usage(content: str) -> list[str]:
"""Detect forbidden typing.Any usage in proposed content."""
class _AnyUsageVisitor(ast.NodeVisitor):
"""Collect line numbers where typing.Any is referenced."""
def __init__(self) -> None:
self.lines: set[int] = set()
def visit_Name(self, node: ast.Name) -> None:
if node.id == "Any":
self.lines.add(node.lineno)
self.generic_visit(node)
def visit_Attribute(self, node: ast.Attribute) -> None:
if node.attr == "Any":
self.lines.add(node.lineno)
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
for alias in node.names:
if alias.name == "Any" or alias.asname == "Any":
self.lines.add(node.lineno)
self.generic_visit(node)
def visit_Import(self, node: ast.Import) -> None:
for alias in node.names:
if alias.name == "Any" or alias.asname == "Any":
self.lines.add(node.lineno)
self.generic_visit(node)
lines_with_any: set[int] = set()
try:
tree = ast.parse(content)
except SyntaxError:
for index, line in enumerate(content.splitlines(), start=1):
code_portion = line.split("#", 1)[0]
if re.search(r"\bAny\b", code_portion):
lines_with_any.add(index)
else:
visitor = _AnyUsageVisitor()
visitor.visit(tree)
lines_with_any = visitor.lines
if not lines_with_any:
return []
sorted_lines = sorted(lines_with_any)
display_lines = ", ".join(str(num) for num in sorted_lines[:5])
if len(sorted_lines) > 5:
display_lines += ", …"
return [
"⚠️ Forbidden typing.Any usage at line(s) "
f"{display_lines}; replace with specific types"
]
def _detect_type_ignore_usage(content: str) -> list[str]:
"""Detect forbidden # type: ignore usage in proposed content."""
pattern = re.compile(r"#\s*type:\s*ignore(?:\b|\[)", re.IGNORECASE)
lines_with_type_ignore: set[int] = set()
try:
for token_type, token_string, start, _, _ in tokenize.generate_tokens(
StringIO(content).readline,
):
if token_type == tokenize.COMMENT and pattern.search(token_string):
lines_with_type_ignore.add(start[0])
except tokenize.TokenError:
for index, line in enumerate(content.splitlines(), start=1):
if pattern.search(line):
lines_with_type_ignore.add(index)
if not lines_with_type_ignore:
return []
sorted_lines = sorted(lines_with_type_ignore)
display_lines = ", ".join(str(num) for num in sorted_lines[:5])
if len(sorted_lines) > 5:
display_lines += ", …"
return [
"⚠️ Forbidden # type: ignore usage at line(s) "
f"{display_lines}; remove the suppression and fix typing issues instead"
]
def _perform_quality_check(
file_path: str,
content: str,
config: QualityConfig,
enable_type_checks: bool = True,
) -> tuple[bool, list[str]]:
"""Perform quality analysis and return issues."""
# Store state if tracking enabled
if config.state_tracking_enabled:
store_pre_state(file_path, content)
# Run quality analysis
results = analyze_code_quality(
content,
file_path,
config,
enable_type_checks=enable_type_checks,
)
return check_code_issues(results, config)
def _handle_quality_issues(
file_path: str,
issues: list[str],
config: QualityConfig,
*,
forced_permission: str | None = None,
) -> JsonObject:
"""Handle quality issues based on enforcement mode."""
# Prepare denial message
message = (
f"Code quality check failed for {Path(file_path).name}:\n"
+ "\n".join(f"{issue}" for issue in issues)
+ "\n\nFix these issues before writing the code."
)
# Make decision based on enforcement mode
if forced_permission:
return _create_hook_response("PreToolUse", forced_permission, message)
if config.enforcement_mode == "strict":
return _create_hook_response("PreToolUse", "deny", message)
if config.enforcement_mode == "warn":
return _create_hook_response("PreToolUse", "ask", message)
# permissive
warning_message = f"⚠️ Quality Warning:\n{message}"
return _create_hook_response(
"PreToolUse",
"allow",
warning_message,
warning_message,
)
def _exit_with_reason(reason: str, exit_code: int = 2) -> None:
"""Write reason to stderr and exit with specified code."""
sys.stderr.write(reason)
sys.exit(exit_code)
def _create_hook_response(
event_name: str,
permission: str = "",
reason: str = "",
system_message: str = "",
additional_context: str = "",
*,
decision: str | None = None,
) -> JsonObject:
"""Create standardized hook response."""
hook_output: dict[str, object] = {
"hookEventName": event_name,
}
if permission:
hook_output["permissionDecision"] = permission
if reason:
hook_output["permissionDecisionReason"] = reason
if additional_context:
hook_output["additionalContext"] = additional_context
response: JsonObject = {
"hookSpecificOutput": hook_output,
}
if permission:
response["permissionDecision"] = permission
if decision:
response["decision"] = decision
if reason:
response["reason"] = reason
if system_message:
response["systemMessage"] = system_message
return response
def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject:
"""Handle PreToolUse hook - analyze content before write/edit."""
tool_name = str(hook_data.get("tool_name", ""))
tool_input_raw = hook_data.get("tool_input", {})
if not isinstance(tool_input_raw, dict):
return _create_hook_response("PreToolUse", "allow")
tool_input = cast("dict[str, object]", tool_input_raw)
# Only analyze for write/edit tools
if tool_name not in ["Write", "Edit", "MultiEdit"]:
return _create_hook_response("PreToolUse", "allow")
# Extract content based on tool type
file_path = str(tool_input.get("file_path", ""))
content = ""
if tool_name == "Write":
raw_content = tool_input.get("content", "")
content = "" if raw_content is None else str(raw_content)
elif tool_name == "Edit":
new_string = tool_input.get("new_string", "")
content = "" if new_string is None else str(new_string)
elif tool_name == "MultiEdit":
edits = tool_input.get("edits", [])
if isinstance(edits, list):
edits_list = cast("list[object]", edits)
parts: list[str] = []
for edit in edits_list:
if not isinstance(edit, dict):
continue
edit_dict = cast("dict[str, object]", edit)
new_str = edit_dict.get("new_string")
parts.append("") if new_str is None else parts.append(str(new_str))
content = "\n".join(parts)
# Only analyze Python files
if not file_path or not file_path.endswith(".py") or not content:
return _create_hook_response("PreToolUse", "allow")
# Check if this is a test file and test quality checks are enabled
is_test = is_test_file(file_path)
run_test_checks = config.test_quality_enabled and is_test
# Skip analysis for configured patterns, but not if it's a test file with test checks enabled
if should_skip_file(file_path, config) and not run_test_checks:
return _create_hook_response("PreToolUse", "allow")
enable_type_checks = tool_name == "Write"
any_usage_issues = _detect_any_usage(content)
type_ignore_issues = _detect_type_ignore_usage(content)
precheck_issues = any_usage_issues + type_ignore_issues
# Run test quality checks if enabled and file is a test file
if run_test_checks:
test_quality_issues = run_test_quality_checks(content, file_path, config)
precheck_issues.extend(test_quality_issues)
try:
_has_issues, issues = _perform_quality_check(
file_path,
content,
config,
enable_type_checks=enable_type_checks,
)
all_issues = precheck_issues + issues
if not all_issues:
return _create_hook_response("PreToolUse", "allow")
if precheck_issues:
return _handle_quality_issues(
file_path,
all_issues,
config,
forced_permission="deny",
)
return _handle_quality_issues(file_path, all_issues, config)
except Exception as e: # noqa: BLE001
if precheck_issues:
return _handle_quality_issues(
file_path,
precheck_issues,
config,
forced_permission="deny",
)
return _create_hook_response(
"PreToolUse",
"allow",
f"Warning: Code quality check failed with error: {e}",
f"Warning: Code quality check failed with error: {e}",
)
def posttooluse_hook(
hook_data: JsonObject,
config: QualityConfig,
) -> JsonObject:
"""Handle PostToolUse hook - verify quality after write/edit."""
tool_name: str = str(hook_data.get("tool_name", ""))
tool_output = hook_data.get("tool_response", hook_data.get("tool_output", {}))
# Only process write/edit tools
if tool_name not in ["Write", "Edit", "MultiEdit"]:
return _create_hook_response("PostToolUse")
# Extract file path from output
file_path: str = ""
if isinstance(tool_output, dict):
tool_output_dict = cast("dict[str, object]", tool_output)
file_path = str(
tool_output_dict.get("file_path", "") or tool_output_dict.get("path", ""),
)
elif isinstance(tool_output, str) and (
match := re.search(r"([/\w\-_.]+\.py)", tool_output)
):
file_path = match[1]
if not file_path or not file_path.endswith(".py"):
return _create_hook_response("PostToolUse")
if not Path(file_path).exists():
return _create_hook_response("PostToolUse")
issues: list[str] = []
# Check state changes if tracking enabled
if config.state_tracking_enabled:
delta_issues = check_state_changes(file_path)
issues.extend(delta_issues)
# Run cross-file duplicate detection if enabled
if config.cross_file_check_enabled:
cross_file_issues = check_cross_file_duplicates(file_path, config)
issues.extend(cross_file_issues)
# Verify naming conventions if enabled
if config.verify_naming:
naming_issues = verify_naming_conventions(file_path)
issues.extend(naming_issues)
# Format response
if issues:
message = (
f"📝 Post-write quality notes for {Path(file_path).name}:\n"
+ "\n".join(issues)
)
return _create_hook_response(
"PostToolUse",
"",
message,
message,
message,
decision="block",
)
if config.show_success:
message = f"{Path(file_path).name} passed post-write verification"
return _create_hook_response(
"PostToolUse",
"",
"",
message,
"",
decision="approve",
)
return _create_hook_response("PostToolUse")
def is_test_file(file_path: str) -> bool:
"""Check if file path is in a test directory."""
path_parts = Path(file_path).parts
return any(part in ("test", "tests", "testing") for part in path_parts)
def run_test_quality_checks(content: str, file_path: str, config: QualityConfig) -> list[str]:
"""Run Sourcery with specific test-related rules and return issues with enhanced guidance."""
issues: list[str] = []
# Only run test quality checks for test files
if not is_test_file(file_path):
return issues
suffix = Path(file_path).suffix or ".py"
with NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
# Run Sourcery with specific test-related rules
venv_bin = Path(__file__).parent.parent / ".venv/bin"
sourcery_path = venv_bin / "sourcery"
if not sourcery_path.exists():
# Try to find sourcery in PATH
import shutil
sourcery_path = shutil.which("sourcery") or str(venv_bin / "sourcery")
if not sourcery_path or not Path(sourcery_path).exists():
logging.debug("Sourcery not found at %s", sourcery_path)
return issues
# Specific rules for test quality - use correct Sourcery format
test_rules = [
"no-conditionals-in-tests",
"no-loop-in-tests",
"raise-specific-error",
"dont-import-test-modules",
]
cmd = [
sourcery_path,
"review",
tmp_path,
"--rules",
",".join(test_rules),
"--format",
"json",
]
logging.debug("Running Sourcery command: %s", " ".join(cmd))
result = subprocess.run( # noqa: S603
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
)
logging.debug("Sourcery exit code: %s", result.returncode)
logging.debug("Sourcery stdout: %s", result.stdout)
logging.debug("Sourcery stderr: %s", result.stderr)
if result.returncode == 0:
try:
sourcery_output = json.loads(result.stdout)
# Extract issues from Sourcery output - handle different JSON formats
if "files" in sourcery_output:
for file_issues in sourcery_output["files"].values():
if isinstance(file_issues, list):
for issue in file_issues:
if isinstance(issue, dict):
rule_id = issue.get("rule", "unknown")
# Generate enhanced guidance for each violation
base_guidance = generate_test_quality_guidance(rule_id, content, file_path)
# Add external context if available
external_context = get_external_context(rule_id, content, file_path, config)
if external_context:
base_guidance += f"\n\n{external_context}"
issues.append(base_guidance)
elif "violations" in sourcery_output:
# Alternative format
for violation in sourcery_output["violations"]:
if isinstance(violation, dict):
rule_id = violation.get("rule", "unknown")
base_guidance = generate_test_quality_guidance(rule_id, content, file_path)
# Add external context if available
external_context = get_external_context(rule_id, content, file_path, config)
if external_context:
base_guidance += f"\n\n{external_context}"
issues.append(base_guidance)
elif isinstance(sourcery_output, list):
# Direct list of issues
for issue in sourcery_output:
if isinstance(issue, dict):
rule_id = issue.get("rule", "unknown")
base_guidance = generate_test_quality_guidance(rule_id, content, file_path)
# Add external context if available
external_context = get_external_context(rule_id, content, file_path, config)
if external_context:
base_guidance += f"\n\n{external_context}"
issues.append(base_guidance)
except json.JSONDecodeError as e:
logging.debug("Failed to parse Sourcery JSON output: %s", e)
# If JSON parsing fails, provide general guidance with external context
base_guidance = generate_test_quality_guidance("unknown", content, file_path)
external_context = get_external_context("unknown", content, file_path, config)
if external_context:
base_guidance += f"\n\n{external_context}"
issues.append(base_guidance)
elif result.returncode != 0 and (result.stdout.strip() or result.stderr.strip()):
# Sourcery found issues or errors - provide general guidance
error_output = (result.stdout + " " + result.stderr).strip()
base_guidance = generate_test_quality_guidance("sourcery-error", content, file_path)
external_context = get_external_context("sourcery-error", content, file_path, config)
if external_context:
base_guidance += f"\n\n{external_context}"
issues.append(base_guidance)
except (subprocess.TimeoutExpired, OSError, json.JSONDecodeError) as e:
# If Sourcery fails, don't block the operation
logging.debug("Test quality check failed for %s: %s", file_path, e)
finally:
Path(tmp_path).unlink(missing_ok=True)
return issues
def main() -> None:
"""Main hook entry point."""
try:
# Load configuration
config = QualityConfig.from_env()
# Read hook input from stdin
try:
hook_data: JsonObject = json.load(sys.stdin)
except json.JSONDecodeError:
fallback_response: JsonObject = {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow",
},
}
sys.stdout.write(json.dumps(fallback_response))
return
# Detect hook type: tool_response=PostToolUse, tool_input=PreToolUse
response: JsonObject
if "tool_response" in hook_data or "tool_output" in hook_data:
# PostToolUse hook
response = posttooluse_hook(hook_data, config)
else:
# PreToolUse hook
response = pretooluse_hook(hook_data, config)
print(json.dumps(response)) # noqa: T201
# Handle exit codes based on hook output
hook_output_raw = response.get("hookSpecificOutput", {})
if not isinstance(hook_output_raw, dict):
return
hook_output = cast("dict[str, object]", hook_output_raw)
permission_decision = hook_output.get("permissionDecision")
if permission_decision == "deny":
# Exit code 2: Blocking error - stderr fed back to Claude
reason = str(
hook_output.get("permissionDecisionReason", "Permission denied"),
)
_exit_with_reason(reason)
elif permission_decision == "ask":
# Also use exit code 2 for ask decisions to ensure Claude sees the message
reason = str(
hook_output.get("permissionDecisionReason", "Permission request"),
)
_exit_with_reason(reason)
# Exit code 0: Success (default)
except Exception as e: # noqa: BLE001
# Unexpected error - use exit code 1 (non-blocking error)
sys.stderr.write(f"Hook error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()