- Added a new QualityConfig class to manage test quality check configurations. - Implemented test quality checks for specific rules in test files, including prevention of conditionals, loops, and generic exceptions. - Integrated external context providers (Context7 and Firecrawl) for additional guidance on test quality violations. - Enhanced error messaging to provide detailed, actionable guidance for detected issues. - Updated README_HOOKS.md to document new test quality features and configuration options. - Added unit tests to verify the functionality of test quality checks and their integration with the pretooluse_hook.
1475 lines
52 KiB
Python
1475 lines
52 KiB
Python
"""Unified quality hook for Claude Code supporting both PreToolUse and PostToolUse.
|
|
|
|
Prevents writing duplicate, complex, or non-modernized code and verifies quality
|
|
after writes.
|
|
"""
|
|
|
|
import ast
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import tokenize
|
|
from collections.abc import Callable
|
|
from contextlib import suppress
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from io import StringIO
|
|
from pathlib import Path
|
|
from tempfile import NamedTemporaryFile
|
|
from typing import TypedDict, cast
|
|
|
|
# Import internal duplicate detector
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
from internal_duplicate_detector import detect_internal_duplicates
|
|
|
|
|
|
class QualityConfig:
|
|
"""Configuration for quality checks."""
|
|
|
|
def __init__(self):
|
|
self._config = config = QualityConfig.from_env()
|
|
config.skip_patterns = ["test_", "_test.py", "/tests/", "/fixtures/"]
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._config, name, self._config[name])
|
|
|
|
@classmethod
|
|
def from_env(cls) -> "QualityConfig":
|
|
"""Load config from environment variables."""
|
|
return cls()
|
|
|
|
|
|
def get_external_context(rule_id: str, content: str, file_path: str, config: QualityConfig) -> str:
|
|
"""Get additional context from external APIs for enhanced guidance."""
|
|
context_parts = []
|
|
|
|
# Context7 integration for additional context analysis
|
|
if config.context7_enabled and config.context7_api_key:
|
|
try:
|
|
# Note: This would integrate with actual Context7 API
|
|
# For now, providing placeholder for the integration
|
|
context7_context = _get_context7_analysis(rule_id, content, config.context7_api_key)
|
|
if context7_context:
|
|
context_parts.append(f"📊 Context7 Analysis: {context7_context}")
|
|
except Exception as e:
|
|
logging.debug("Context7 API call failed: %s", e)
|
|
|
|
# Firecrawl integration for web scraping additional examples
|
|
if config.firecrawl_enabled and config.firecrawl_api_key:
|
|
try:
|
|
# Note: This would integrate with actual Firecrawl API
|
|
# For now, providing placeholder for the integration
|
|
firecrawl_examples = _get_firecrawl_examples(rule_id, config.firecrawl_api_key)
|
|
if firecrawl_examples:
|
|
context_parts.append(f"🔗 Additional Examples: {firecrawl_examples}")
|
|
except Exception as e:
|
|
logging.debug("Firecrawl API call failed: %s", e)
|
|
|
|
return "\n\n".join(context_parts) if context_parts else ""
|
|
|
|
|
|
def _get_context7_analysis(rule_id: str, content: str, api_key: str) -> str:
|
|
"""Placeholder for Context7 API integration."""
|
|
# This would make actual API calls to Context7
|
|
# For demonstration, returning contextual analysis based on rule type
|
|
context_map = {
|
|
"no-conditionals-in-tests": "Consider using data-driven tests with pytest.mark.parametrize for better test isolation.",
|
|
"no-loop-in-tests": "Each iteration should be a separate test case for better failure isolation and debugging.",
|
|
"raise-specific-error": "Specific exceptions improve error handling and make tests more maintainable.",
|
|
"dont-import-test-modules": "Keep test utilities separate from production code to maintain clean architecture."
|
|
}
|
|
return context_map.get(rule_id, "Additional context analysis would be provided here.")
|
|
|
|
|
|
def _get_firecrawl_examples(rule_id: str, api_key: str) -> str:
|
|
"""Placeholder for Firecrawl API integration."""
|
|
# This would scrape web for additional examples and best practices
|
|
# For demonstration, returning relevant examples based on rule type
|
|
examples_map = {
|
|
"no-conditionals-in-tests": "See pytest documentation on parameterized tests for multiple scenarios.",
|
|
"no-loop-in-tests": "Check testing best practices guides for data-driven testing approaches.",
|
|
"raise-specific-error": "Review Python exception hierarchy and custom exception patterns.",
|
|
"dont-import-test-modules": "Explore clean architecture patterns for test organization."
|
|
}
|
|
return examples_map.get(rule_id, "Additional examples and patterns would be sourced from web resources.")
|
|
|
|
def generate_test_quality_guidance(rule_id: str, content: str, file_path: str, config: QualityConfig) -> str:
|
|
"""Generate specific, actionable guidance for test quality rule violations."""
|
|
|
|
# Extract function name and context for better guidance
|
|
function_name = "test_function"
|
|
if "def " in content:
|
|
# Try to extract the test function name
|
|
import re
|
|
match = re.search(r'def\s+(\w+)\s*\(', content)
|
|
if match:
|
|
function_name = match.group(1)
|
|
|
|
guidance_map = {
|
|
"no-conditionals-in-tests": {
|
|
"title": "Conditional Logic in Test Function",
|
|
"problem": f"Test function '{function_name}' contains conditional statements (if/elif/else).",
|
|
"why": "Tests should be simple assertions that verify specific behavior. Conditionals make tests harder to understand and maintain.",
|
|
"solutions": [
|
|
"✅ Replace conditionals with parameterized test cases",
|
|
"✅ Use pytest.mark.parametrize for multiple scenarios",
|
|
"✅ Extract conditional logic into helper functions",
|
|
"✅ Use assertion libraries like assertpy for complex conditions"
|
|
],
|
|
"examples": [
|
|
"# ❌ Instead of this:",
|
|
"def test_user_access():",
|
|
" user = create_user()",
|
|
" if user.is_admin:",
|
|
" assert user.can_access_admin()",
|
|
" else:",
|
|
" assert not user.can_access_admin()",
|
|
"",
|
|
"# ✅ Do this:",
|
|
"@pytest.mark.parametrize('is_admin,can_access', [",
|
|
" (True, True),",
|
|
" (False, False)",
|
|
"])",
|
|
"def test_user_access(is_admin, can_access):",
|
|
" user = create_user(admin=is_admin)",
|
|
" assert user.can_access_admin() == can_access"
|
|
]
|
|
},
|
|
"no-loop-in-tests": {
|
|
"title": "Loop Found in Test Function",
|
|
"problem": f"Test function '{function_name}' contains loops (for/while).",
|
|
"why": "Loops in tests often indicate testing multiple scenarios that should be separate test cases.",
|
|
"solutions": [
|
|
"✅ Break loop into individual test cases",
|
|
"✅ Use parameterized tests for multiple data scenarios",
|
|
"✅ Extract loop logic into data providers or fixtures",
|
|
"✅ Use pytest fixtures for setup that requires iteration"
|
|
],
|
|
"examples": [
|
|
"# ❌ Instead of this:",
|
|
"def test_process_items():",
|
|
" for item in test_items:",
|
|
" result = process_item(item)",
|
|
" assert result.success",
|
|
"",
|
|
"# ✅ Do this:",
|
|
"@pytest.mark.parametrize('item', test_items)",
|
|
"def test_process_item(item):",
|
|
" result = process_item(item)",
|
|
" assert result.success"
|
|
]
|
|
},
|
|
"raise-specific-error": {
|
|
"title": "Generic Exception in Test",
|
|
"problem": f"Test function '{function_name}' raises generic Exception or BaseException.",
|
|
"why": "Generic exceptions make it harder to identify specific test failures and handle different error conditions appropriately.",
|
|
"solutions": [
|
|
"✅ Use specific exception types (ValueError, TypeError, etc.)",
|
|
"✅ Create custom exception classes for domain-specific errors",
|
|
"✅ Use pytest.raises() with specific exception types",
|
|
"✅ Test for expected exceptions, not generic ones"
|
|
],
|
|
"examples": [
|
|
"# ❌ Instead of this:",
|
|
"def test_invalid_input():",
|
|
" with pytest.raises(Exception):",
|
|
" process_invalid_data()",
|
|
"",
|
|
"# ✅ Do this:",
|
|
"def test_invalid_input():",
|
|
" with pytest.raises(ValueError, match='Invalid input'):",
|
|
" process_invalid_data()",
|
|
"",
|
|
"# Or create custom exceptions:",
|
|
"def test_business_logic():",
|
|
" with pytest.raises(InsufficientFundsError):",
|
|
" account.withdraw(1000)"
|
|
]
|
|
},
|
|
"dont-import-test-modules": {
|
|
"title": "Test Module Import in Non-Test Code",
|
|
"problem": f"File '{Path(file_path).name}' imports test modules outside of test directories.",
|
|
"why": "Test modules should only be imported by other test files. Production code should not depend on test utilities.",
|
|
"solutions": [
|
|
"✅ Move shared test utilities to a separate utils module",
|
|
"✅ Create production versions of test helper functions",
|
|
"✅ Use dependency injection instead of direct test imports",
|
|
"✅ Extract common logic into a shared production module"
|
|
],
|
|
"examples": [
|
|
"# ❌ Instead of this (in production code):",
|
|
"from tests.test_helpers import create_test_data",
|
|
"",
|
|
"# ✅ Do this:",
|
|
"# Create src/utils/test_data_factory.py",
|
|
"from src.utils.test_data_factory import create_test_data",
|
|
"",
|
|
"# Or use fixtures in tests:",
|
|
"@pytest.fixture",
|
|
"def test_data():",
|
|
" return create_production_data()"
|
|
]
|
|
}
|
|
}
|
|
|
|
guidance = guidance_map.get(rule_id, {
|
|
"title": f"Test Quality Issue: {rule_id}",
|
|
"problem": f"Rule '{rule_id}' violation detected in test code.",
|
|
"why": "This rule helps maintain test quality and best practices.",
|
|
"solutions": [
|
|
"✅ Review the specific rule requirements",
|
|
"✅ Refactor code to follow test best practices",
|
|
"✅ Consult testing framework documentation"
|
|
],
|
|
"examples": [
|
|
"# Review the rule documentation for specific guidance",
|
|
"# Consider the test's purpose and refactor accordingly"
|
|
]
|
|
})
|
|
|
|
# Format the guidance message
|
|
message = f"""
|
|
🚫 {guidance['title']}
|
|
|
|
📋 Problem: {guidance['problem']}
|
|
|
|
❓ Why this matters: {guidance['why']}
|
|
|
|
🛠️ How to fix it:
|
|
{chr(10).join(f" • {solution}" for solution in guidance['solutions'])}
|
|
|
|
💡 Example:
|
|
{chr(10).join(f" {line}" for line in guidance['examples'])}
|
|
|
|
🔍 File: {Path(file_path).name}
|
|
📍 Function: {function_name}
|
|
"""
|
|
|
|
return message.strip()
|
|
|
|
|
|
class ToolConfig(TypedDict):
|
|
"""Configuration for a type checking tool."""
|
|
|
|
args: list[str]
|
|
error_check: Callable[[subprocess.CompletedProcess[str]], bool]
|
|
error_message: str | Callable[[subprocess.CompletedProcess[str]], str]
|
|
|
|
|
|
class DuplicateLocation(TypedDict):
|
|
"""Location information for a duplicate code block."""
|
|
|
|
name: str
|
|
lines: str
|
|
|
|
|
|
class Duplicate(TypedDict):
|
|
"""Duplicate code detection result."""
|
|
|
|
similarity: float
|
|
description: str
|
|
locations: list[DuplicateLocation]
|
|
|
|
|
|
class DuplicateResults(TypedDict):
|
|
"""Results from duplicate detection analysis."""
|
|
|
|
duplicates: list[Duplicate]
|
|
|
|
|
|
class ComplexitySummary(TypedDict):
|
|
"""Summary of complexity analysis."""
|
|
|
|
average_cyclomatic_complexity: float
|
|
|
|
|
|
class ComplexityDistribution(TypedDict):
|
|
"""Distribution of complexity levels."""
|
|
|
|
High: int
|
|
Very_High: int
|
|
Extreme: int
|
|
|
|
|
|
class ComplexityResults(TypedDict):
|
|
"""Results from complexity analysis."""
|
|
|
|
summary: ComplexitySummary
|
|
distribution: ComplexityDistribution
|
|
|
|
|
|
class TypeCheckingResults(TypedDict):
|
|
"""Results from type checking analysis."""
|
|
|
|
issues: list[str]
|
|
|
|
|
|
class AnalysisResults(TypedDict, total=False):
|
|
"""Complete analysis results from quality checks."""
|
|
|
|
internal_duplicates: DuplicateResults
|
|
complexity: ComplexityResults
|
|
type_checking: TypeCheckingResults
|
|
modernization: dict[str, object] # JSON structure varies
|
|
|
|
|
|
# Type aliases for JSON-like structures
|
|
JsonObject = dict[str, object]
|
|
JsonValue = str | int | float | bool | None | list[object] | JsonObject
|
|
|
|
|
|
@dataclass
|
|
class QualityConfig:
|
|
"""Configuration for quality checks."""
|
|
|
|
# Core settings
|
|
duplicate_threshold: float = 0.7
|
|
duplicate_enabled: bool = True
|
|
complexity_threshold: int = 10
|
|
complexity_enabled: bool = True
|
|
modernization_enabled: bool = True
|
|
require_type_hints: bool = True
|
|
enforcement_mode: str = "strict" # strict/warn/permissive
|
|
|
|
# Type checking tools
|
|
sourcery_enabled: bool = True
|
|
basedpyright_enabled: bool = True
|
|
pyrefly_enabled: bool = True
|
|
type_check_exit_code: int = 2
|
|
|
|
# PostToolUse features
|
|
state_tracking_enabled: bool = False
|
|
cross_file_check_enabled: bool = False
|
|
verify_naming: bool = True
|
|
show_success: bool = False
|
|
|
|
# Test quality checks
|
|
test_quality_enabled: bool = True
|
|
|
|
# External context providers
|
|
context7_enabled: bool = False
|
|
context7_api_key: str = ""
|
|
firecrawl_enabled: bool = False
|
|
firecrawl_api_key: str = ""
|
|
|
|
# File patterns
|
|
skip_patterns: list[str] | None = None
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.skip_patterns is None:
|
|
self.skip_patterns = ["test_", "_test.py", "/tests/", "/fixtures/"]
|
|
|
|
@classmethod
|
|
def from_env(cls) -> "QualityConfig":
|
|
"""Load config from environment variables."""
|
|
return cls(
|
|
duplicate_threshold=float(os.getenv("QUALITY_DUP_THRESHOLD", "0.7")),
|
|
duplicate_enabled=os.getenv("QUALITY_DUP_ENABLED", "true").lower()
|
|
== "true",
|
|
complexity_threshold=int(os.getenv("QUALITY_COMPLEXITY_THRESHOLD", "10")),
|
|
complexity_enabled=os.getenv("QUALITY_COMPLEXITY_ENABLED", "true").lower()
|
|
== "true",
|
|
modernization_enabled=os.getenv("QUALITY_MODERN_ENABLED", "true").lower()
|
|
== "true",
|
|
require_type_hints=os.getenv("QUALITY_REQUIRE_TYPES", "true").lower()
|
|
== "true",
|
|
enforcement_mode=os.getenv("QUALITY_ENFORCEMENT", "strict"),
|
|
state_tracking_enabled=os.getenv("QUALITY_STATE_TRACKING", "false").lower()
|
|
== "true",
|
|
cross_file_check_enabled=os.getenv(
|
|
"QUALITY_CROSS_FILE_CHECK",
|
|
"false",
|
|
).lower()
|
|
== "true",
|
|
verify_naming=os.getenv("QUALITY_VERIFY_NAMING", "true").lower() == "true",
|
|
show_success=os.getenv("QUALITY_SHOW_SUCCESS", "false").lower() == "true",
|
|
sourcery_enabled=os.getenv("QUALITY_SOURCERY_ENABLED", "true").lower()
|
|
== "true",
|
|
basedpyright_enabled=os.getenv(
|
|
"QUALITY_BASEDPYRIGHT_ENABLED",
|
|
"true",
|
|
).lower()
|
|
== "true",
|
|
pyrefly_enabled=os.getenv("QUALITY_PYREFLY_ENABLED", "true").lower()
|
|
== "true",
|
|
type_check_exit_code=int(os.getenv("QUALITY_TYPE_CHECK_EXIT_CODE", "2")),
|
|
test_quality_enabled=os.getenv("QUALITY_TEST_QUALITY_ENABLED", "true").lower()
|
|
== "true",
|
|
context7_enabled=os.getenv("QUALITY_CONTEXT7_ENABLED", "false").lower()
|
|
== "true",
|
|
context7_api_key=os.getenv("QUALITY_CONTEXT7_API_KEY", ""),
|
|
firecrawl_enabled=os.getenv("QUALITY_FIRECRAWL_ENABLED", "false").lower()
|
|
== "true",
|
|
firecrawl_api_key=os.getenv("QUALITY_FIRECRAWL_API_KEY", ""),
|
|
)
|
|
|
|
|
|
def should_skip_file(file_path: str, config: QualityConfig) -> bool:
|
|
"""Check if file should be skipped based on patterns."""
|
|
if config.skip_patterns is None:
|
|
return False
|
|
return any(pattern in file_path for pattern in config.skip_patterns)
|
|
|
|
|
|
def get_claude_quality_command() -> list[str]:
|
|
"""Return a path-resilient command for invoking claude-quality."""
|
|
repo_root = Path(__file__).resolve().parent.parent
|
|
venv_python = repo_root / ".venv/bin/python"
|
|
if venv_python.exists():
|
|
return [str(venv_python), "-m", "quality.cli.main"]
|
|
|
|
venv_cli = repo_root / ".venv/bin/claude-quality"
|
|
if venv_cli.exists():
|
|
return [str(venv_cli)]
|
|
|
|
return ["claude-quality"]
|
|
|
|
|
|
def _ensure_tool_installed(tool_name: str) -> bool:
|
|
"""Ensure a type checking tool is installed in the virtual environment."""
|
|
venv_bin = Path(__file__).parent.parent / ".venv/bin"
|
|
tool_path = venv_bin / tool_name
|
|
|
|
if tool_path.exists():
|
|
return True
|
|
|
|
# Try to install using uv if available
|
|
try:
|
|
result = subprocess.run( # noqa: S603
|
|
[str(venv_bin / "uv"), "pip", "install", tool_name],
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=60,
|
|
)
|
|
except (subprocess.TimeoutExpired, OSError):
|
|
return False
|
|
else:
|
|
return result.returncode == 0
|
|
|
|
|
|
def _run_type_checker(
|
|
tool_name: str,
|
|
file_path: str,
|
|
_config: QualityConfig,
|
|
) -> tuple[bool, str]:
|
|
"""Run a type checking tool and return success status and output."""
|
|
venv_bin = Path(__file__).parent.parent / ".venv/bin"
|
|
tool_path = venv_bin / tool_name
|
|
|
|
if not tool_path.exists() and not _ensure_tool_installed(tool_name):
|
|
return True, f"Warning: {tool_name} not available"
|
|
|
|
# Tool configuration mapping
|
|
tool_configs: dict[str, ToolConfig] = {
|
|
"basedpyright": ToolConfig(
|
|
args=["--outputjson", file_path],
|
|
error_check=lambda result: result.returncode == 1,
|
|
error_message="Type errors found",
|
|
),
|
|
"pyrefly": ToolConfig(
|
|
args=["check", file_path],
|
|
error_check=lambda result: result.returncode == 1,
|
|
error_message=lambda result: str(result.stdout).strip(),
|
|
),
|
|
"sourcery": ToolConfig(
|
|
args=["review", file_path],
|
|
error_check=lambda result: (
|
|
"issues detected" in str(result.stdout)
|
|
and "0 issues detected" not in str(result.stdout)
|
|
),
|
|
error_message=lambda result: str(result.stdout).strip(),
|
|
),
|
|
}
|
|
|
|
tool_config = tool_configs.get(tool_name)
|
|
if not tool_config:
|
|
return True, f"Warning: Unknown tool {tool_name}"
|
|
|
|
try:
|
|
cmd = [str(tool_path)] + tool_config["args"]
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
)
|
|
|
|
# Check for tool-specific errors
|
|
error_check = tool_config["error_check"]
|
|
if error_check(result):
|
|
error_msg = tool_config["error_message"]
|
|
message = str(error_msg(result) if callable(error_msg) else error_msg)
|
|
return False, message
|
|
|
|
# Return success or warning
|
|
if result.returncode == 0:
|
|
return True, ""
|
|
|
|
except subprocess.TimeoutExpired:
|
|
return True, f"Warning: {tool_name} timeout"
|
|
except OSError:
|
|
return True, f"Warning: {tool_name} execution error"
|
|
else:
|
|
return True, f"Warning: {tool_name} error (exit {result.returncode})"
|
|
|
|
|
|
def _initialize_analysis() -> tuple[AnalysisResults, list[str]]:
|
|
"""Initialize analysis with empty results and claude-quality command."""
|
|
results: AnalysisResults = {}
|
|
claude_quality_cmd: list[str] = get_claude_quality_command()
|
|
return results, claude_quality_cmd
|
|
|
|
|
|
def run_type_checks(file_path: str, config: QualityConfig) -> list[str]:
|
|
"""Run all enabled type checking tools and return any issues."""
|
|
issues: list[str] = []
|
|
|
|
# Run Sourcery
|
|
if config.sourcery_enabled:
|
|
success, output = _run_type_checker("sourcery", file_path, config)
|
|
if not success and output:
|
|
issues.append(f"Sourcery: {output.strip()}")
|
|
|
|
# Run BasedPyright
|
|
if config.basedpyright_enabled:
|
|
success, output = _run_type_checker("basedpyright", file_path, config)
|
|
if not success and output:
|
|
issues.append(f"BasedPyright: {output.strip()}")
|
|
|
|
# Run Pyrefly
|
|
if config.pyrefly_enabled:
|
|
success, output = _run_type_checker("pyrefly", file_path, config)
|
|
if not success and output:
|
|
issues.append(f"Pyrefly: {output.strip()}")
|
|
|
|
return issues
|
|
|
|
|
|
def _run_quality_analyses(
|
|
content: str,
|
|
tmp_path: str,
|
|
config: QualityConfig,
|
|
enable_type_checks: bool,
|
|
) -> AnalysisResults:
|
|
"""Run all quality analysis checks and return results."""
|
|
results, claude_quality_cmd = _initialize_analysis()
|
|
|
|
# First check for internal duplicates within the file
|
|
if config.duplicate_enabled:
|
|
internal_duplicates_raw = detect_internal_duplicates(
|
|
content,
|
|
threshold=config.duplicate_threshold,
|
|
min_lines=4,
|
|
)
|
|
# Cast after runtime validation - function returns compatible structure
|
|
internal_duplicates = cast("DuplicateResults", internal_duplicates_raw)
|
|
if internal_duplicates.get("duplicates"):
|
|
results["internal_duplicates"] = internal_duplicates
|
|
|
|
# Run complexity analysis
|
|
if config.complexity_enabled:
|
|
cmd = [
|
|
*claude_quality_cmd,
|
|
"complexity",
|
|
tmp_path,
|
|
"--threshold",
|
|
str(config.complexity_threshold),
|
|
"--format",
|
|
"json",
|
|
]
|
|
with suppress(subprocess.TimeoutExpired):
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
)
|
|
if result.returncode == 0:
|
|
with suppress(json.JSONDecodeError):
|
|
results["complexity"] = json.loads(result.stdout)
|
|
|
|
# Run type checking if any tool is enabled
|
|
if enable_type_checks and any(
|
|
[
|
|
config.sourcery_enabled,
|
|
config.basedpyright_enabled,
|
|
config.pyrefly_enabled,
|
|
],
|
|
):
|
|
try:
|
|
if type_issues := run_type_checks(tmp_path, config):
|
|
results["type_checking"] = {"issues": type_issues}
|
|
except Exception as e: # noqa: BLE001
|
|
logging.debug("Type checking failed: %s", e)
|
|
|
|
# Run modernization analysis
|
|
if config.modernization_enabled:
|
|
cmd = [
|
|
*claude_quality_cmd,
|
|
"modernization",
|
|
tmp_path,
|
|
"--include-type-hints" if config.require_type_hints else "",
|
|
"--format",
|
|
"json",
|
|
]
|
|
cmd = [c for c in cmd if c] # Remove empty strings
|
|
with suppress(subprocess.TimeoutExpired):
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
)
|
|
if result.returncode == 0:
|
|
with suppress(json.JSONDecodeError):
|
|
results["modernization"] = json.loads(result.stdout)
|
|
|
|
return results
|
|
|
|
|
|
def analyze_code_quality(
|
|
content: str,
|
|
file_path: str,
|
|
config: QualityConfig,
|
|
*,
|
|
enable_type_checks: bool = True,
|
|
) -> AnalysisResults:
|
|
"""Analyze code content using claude-quality toolkit."""
|
|
suffix = Path(file_path).suffix or ".py"
|
|
with NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp:
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
return _run_quality_analyses(content, tmp_path, config, enable_type_checks)
|
|
|
|
finally:
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|
|
|
|
|
def _check_internal_duplicates(results: AnalysisResults) -> list[str]:
|
|
"""Check for internal duplicate code within the same file."""
|
|
issues: list[str] = []
|
|
if "internal_duplicates" not in results:
|
|
return issues
|
|
|
|
duplicates: list[Duplicate] = results["internal_duplicates"].get(
|
|
"duplicates",
|
|
[],
|
|
)
|
|
for dup in duplicates[:3]: # Show first 3
|
|
locations = ", ".join(
|
|
f"{loc['name']} ({loc['lines']})" for loc in dup.get("locations", [])
|
|
)
|
|
issues.append(
|
|
f"Internal duplication ({dup.get('similarity', 0):.0%} similar): "
|
|
f"{dup.get('description')} - {locations}",
|
|
)
|
|
return issues
|
|
|
|
|
|
def _check_complexity_issues(
|
|
results: AnalysisResults,
|
|
config: QualityConfig,
|
|
) -> list[str]:
|
|
"""Check for code complexity issues."""
|
|
issues: list[str] = []
|
|
if "complexity" not in results:
|
|
return issues
|
|
|
|
complexity_data = results["complexity"]
|
|
summary = complexity_data.get("summary", {})
|
|
avg_cc = summary.get("average_cyclomatic_complexity", 0.0)
|
|
if avg_cc > config.complexity_threshold:
|
|
issues.append(
|
|
f"High average complexity: CC={avg_cc:.1f} "
|
|
f"(threshold: {config.complexity_threshold})",
|
|
)
|
|
|
|
distribution = complexity_data.get("distribution", {})
|
|
high_count: int = (
|
|
distribution.get("High", 0)
|
|
+ distribution.get("Very High", 0)
|
|
+ distribution.get("Extreme", 0)
|
|
)
|
|
if high_count > 0:
|
|
issues.append(f"Found {high_count} function(s) with high complexity")
|
|
return issues
|
|
|
|
|
|
def _check_modernization_issues(
|
|
results: AnalysisResults,
|
|
config: QualityConfig,
|
|
) -> list[str]:
|
|
"""Check for code modernization issues."""
|
|
issues: list[str] = []
|
|
if "modernization" not in results:
|
|
return issues
|
|
|
|
try:
|
|
modernization_data = results["modernization"]
|
|
files = modernization_data.get("files", {})
|
|
if not isinstance(files, dict):
|
|
return issues
|
|
except (AttributeError, TypeError):
|
|
return issues
|
|
total_issues: int = 0
|
|
issue_types: set[str] = set()
|
|
|
|
files_values = cast("dict[str, object]", files).values()
|
|
for file_issues_raw in files_values:
|
|
if not isinstance(file_issues_raw, list):
|
|
continue
|
|
|
|
# Cast to proper type after runtime check
|
|
file_issues_list = cast("list[object]", file_issues_raw)
|
|
total_issues += len(file_issues_list)
|
|
|
|
for issue_raw in file_issues_list:
|
|
if not isinstance(issue_raw, dict):
|
|
continue
|
|
|
|
# Cast to proper type after runtime check
|
|
issue_dict = cast("dict[str, object]", issue_raw)
|
|
issue_type_raw = issue_dict.get("issue_type", "unknown")
|
|
if isinstance(issue_type_raw, str):
|
|
issue_types.add(issue_type_raw)
|
|
|
|
# Only flag if there are non-type-hint issues or many type hint issues
|
|
non_type_issues: int = len(
|
|
[t for t in issue_types if "type" not in t and "typing" not in t],
|
|
)
|
|
type_issues: int = total_issues - non_type_issues
|
|
|
|
if non_type_issues > 0:
|
|
non_type_list: list[str] = [
|
|
t for t in issue_types if "type" not in t and "typing" not in t
|
|
]
|
|
issues.append(
|
|
f"Modernization needed: {non_type_issues} non-type issues "
|
|
f"({', '.join(non_type_list)})",
|
|
)
|
|
elif config.require_type_hints and type_issues > 10:
|
|
issues.append(
|
|
f"Many missing type hints: {type_issues} functions/parameters "
|
|
"lacking annotations",
|
|
)
|
|
return issues
|
|
|
|
|
|
def _check_type_checking_issues(results: AnalysisResults) -> list[str]:
|
|
"""Check for type checking issues from Sourcery, BasedPyright, and Pyrefly."""
|
|
issues: list[str] = []
|
|
if "type_checking" not in results:
|
|
return issues
|
|
|
|
with suppress(AttributeError, TypeError):
|
|
type_checking_data = results["type_checking"]
|
|
type_issues = type_checking_data.get("issues", [])
|
|
issues.extend(str(issue_raw) for issue_raw in type_issues[:5])
|
|
|
|
return issues
|
|
|
|
|
|
def check_code_issues(
|
|
results: AnalysisResults,
|
|
config: QualityConfig,
|
|
) -> tuple[bool, list[str]]:
|
|
"""Check analysis results for issues that should block the operation."""
|
|
issues: list[str] = []
|
|
|
|
issues.extend(_check_internal_duplicates(results))
|
|
issues.extend(_check_complexity_issues(results, config))
|
|
issues.extend(_check_modernization_issues(results, config))
|
|
issues.extend(_check_type_checking_issues(results))
|
|
|
|
return len(issues) > 0, issues
|
|
|
|
|
|
def store_pre_state(file_path: str, content: str) -> None:
|
|
"""Store file state before modification for later comparison."""
|
|
import tempfile
|
|
|
|
cache_dir = Path(tempfile.gettempdir()) / ".quality_state"
|
|
cache_dir.mkdir(exist_ok=True, mode=0o700)
|
|
|
|
state: dict[str, str | int] = {
|
|
"file_path": file_path,
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
"content_hash": hashlib.sha256(content.encode()).hexdigest(),
|
|
"lines": len(content.split("\n")),
|
|
"functions": content.count("def "),
|
|
"classes": content.count("class "),
|
|
}
|
|
|
|
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
|
|
cache_file: Path = cache_dir / f"{cache_key}_pre.json"
|
|
cache_file.write_text(json.dumps(state, indent=2))
|
|
|
|
|
|
def check_state_changes(file_path: str) -> list[str]:
|
|
"""Check for quality changes between pre and post states."""
|
|
import tempfile
|
|
|
|
issues: list[str] = []
|
|
cache_dir: Path = Path(tempfile.gettempdir()) / ".quality_state"
|
|
cache_key: str = hashlib.sha256(file_path.encode()).hexdigest()[:8]
|
|
|
|
pre_file: Path = cache_dir / f"{cache_key}_pre.json"
|
|
if not pre_file.exists():
|
|
return issues
|
|
|
|
try:
|
|
pre_state = json.loads(pre_file.read_text())
|
|
try:
|
|
current_content = Path(file_path).read_text()
|
|
except OSError:
|
|
return issues # Can't compare if can't read file
|
|
|
|
current_lines = len(current_content.split("\n"))
|
|
current_functions = current_content.count("def ")
|
|
|
|
# Check for significant changes
|
|
if current_functions < pre_state.get("functions", 0):
|
|
issues.append(
|
|
f"⚠️ Reduced functions: {pre_state['functions']} → {current_functions}",
|
|
)
|
|
|
|
if current_lines > pre_state.get("lines", 0) * 1.5:
|
|
issues.append(
|
|
"⚠️ File size increased significantly: "
|
|
f"{pre_state['lines']} → {current_lines} lines",
|
|
)
|
|
except Exception: # noqa: BLE001
|
|
logging.debug("Could not analyze state changes for %s", file_path)
|
|
|
|
return issues
|
|
|
|
|
|
def check_cross_file_duplicates(file_path: str, config: QualityConfig) -> list[str]:
|
|
"""Check for duplicates across project files."""
|
|
issues: list[str] = []
|
|
|
|
# Get project root
|
|
project_root: Path = Path(file_path).parent
|
|
while (
|
|
project_root.parent != project_root
|
|
and not (project_root / ".git").exists()
|
|
and not (project_root / "pyproject.toml").exists()
|
|
):
|
|
project_root = project_root.parent
|
|
|
|
try:
|
|
claude_quality_cmd = get_claude_quality_command()
|
|
result = subprocess.run( # noqa: S603
|
|
[
|
|
*claude_quality_cmd,
|
|
"duplicates",
|
|
str(project_root),
|
|
"--threshold",
|
|
str(config.duplicate_threshold),
|
|
"--format",
|
|
"json",
|
|
],
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=60,
|
|
)
|
|
if result.returncode == 0:
|
|
data = json.loads(result.stdout)
|
|
duplicates = data.get("duplicates", [])
|
|
if any(file_path in str(d) for d in duplicates):
|
|
issues.append("⚠️ Cross-file duplication detected")
|
|
except Exception: # noqa: BLE001
|
|
logging.debug("Could not check cross-file duplicates for %s", file_path)
|
|
|
|
return issues
|
|
|
|
|
|
def verify_naming_conventions(file_path: str) -> list[str]:
|
|
"""Verify PEP8 naming conventions."""
|
|
issues: list[str] = []
|
|
try:
|
|
content = Path(file_path).read_text()
|
|
except OSError:
|
|
return issues # Can't check naming if can't read file
|
|
|
|
# Check function names (should be snake_case)
|
|
if bad_funcs := re.findall(
|
|
r"def\s+([A-Z][a-zA-Z0-9_]*|[a-z]+[A-Z][a-zA-Z0-9_]*)\s*\(",
|
|
content,
|
|
):
|
|
issues.append(f"⚠️ Non-PEP8 function names: {', '.join(bad_funcs[:3])}")
|
|
|
|
# Check class names (should be PascalCase)
|
|
if bad_classes := re.findall(r"class\s+([a-z][a-z0-9_]*)\s*[\(:]", content):
|
|
issues.append(f"⚠️ Non-PEP8 class names: {', '.join(bad_classes[:3])}")
|
|
|
|
return issues
|
|
|
|
|
|
def _detect_any_usage(content: str) -> list[str]:
|
|
"""Detect forbidden typing.Any usage in proposed content."""
|
|
|
|
class _AnyUsageVisitor(ast.NodeVisitor):
|
|
"""Collect line numbers where typing.Any is referenced."""
|
|
|
|
def __init__(self) -> None:
|
|
self.lines: set[int] = set()
|
|
|
|
def visit_Name(self, node: ast.Name) -> None:
|
|
if node.id == "Any":
|
|
self.lines.add(node.lineno)
|
|
self.generic_visit(node)
|
|
|
|
def visit_Attribute(self, node: ast.Attribute) -> None:
|
|
if node.attr == "Any":
|
|
self.lines.add(node.lineno)
|
|
self.generic_visit(node)
|
|
|
|
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
|
for alias in node.names:
|
|
if alias.name == "Any" or alias.asname == "Any":
|
|
self.lines.add(node.lineno)
|
|
self.generic_visit(node)
|
|
|
|
def visit_Import(self, node: ast.Import) -> None:
|
|
for alias in node.names:
|
|
if alias.name == "Any" or alias.asname == "Any":
|
|
self.lines.add(node.lineno)
|
|
self.generic_visit(node)
|
|
|
|
lines_with_any: set[int] = set()
|
|
try:
|
|
tree = ast.parse(content)
|
|
except SyntaxError:
|
|
for index, line in enumerate(content.splitlines(), start=1):
|
|
code_portion = line.split("#", 1)[0]
|
|
if re.search(r"\bAny\b", code_portion):
|
|
lines_with_any.add(index)
|
|
else:
|
|
visitor = _AnyUsageVisitor()
|
|
visitor.visit(tree)
|
|
lines_with_any = visitor.lines
|
|
|
|
if not lines_with_any:
|
|
return []
|
|
|
|
sorted_lines = sorted(lines_with_any)
|
|
display_lines = ", ".join(str(num) for num in sorted_lines[:5])
|
|
if len(sorted_lines) > 5:
|
|
display_lines += ", …"
|
|
|
|
return [
|
|
"⚠️ Forbidden typing.Any usage at line(s) "
|
|
f"{display_lines}; replace with specific types"
|
|
]
|
|
|
|
|
|
def _detect_type_ignore_usage(content: str) -> list[str]:
|
|
"""Detect forbidden # type: ignore usage in proposed content."""
|
|
|
|
pattern = re.compile(r"#\s*type:\s*ignore(?:\b|\[)", re.IGNORECASE)
|
|
lines_with_type_ignore: set[int] = set()
|
|
|
|
try:
|
|
for token_type, token_string, start, _, _ in tokenize.generate_tokens(
|
|
StringIO(content).readline,
|
|
):
|
|
if token_type == tokenize.COMMENT and pattern.search(token_string):
|
|
lines_with_type_ignore.add(start[0])
|
|
except tokenize.TokenError:
|
|
for index, line in enumerate(content.splitlines(), start=1):
|
|
if pattern.search(line):
|
|
lines_with_type_ignore.add(index)
|
|
|
|
if not lines_with_type_ignore:
|
|
return []
|
|
|
|
sorted_lines = sorted(lines_with_type_ignore)
|
|
display_lines = ", ".join(str(num) for num in sorted_lines[:5])
|
|
if len(sorted_lines) > 5:
|
|
display_lines += ", …"
|
|
|
|
return [
|
|
"⚠️ Forbidden # type: ignore usage at line(s) "
|
|
f"{display_lines}; remove the suppression and fix typing issues instead"
|
|
]
|
|
|
|
|
|
def _perform_quality_check(
|
|
file_path: str,
|
|
content: str,
|
|
config: QualityConfig,
|
|
enable_type_checks: bool = True,
|
|
) -> tuple[bool, list[str]]:
|
|
"""Perform quality analysis and return issues."""
|
|
# Store state if tracking enabled
|
|
if config.state_tracking_enabled:
|
|
store_pre_state(file_path, content)
|
|
|
|
# Run quality analysis
|
|
results = analyze_code_quality(
|
|
content,
|
|
file_path,
|
|
config,
|
|
enable_type_checks=enable_type_checks,
|
|
)
|
|
return check_code_issues(results, config)
|
|
|
|
|
|
def _handle_quality_issues(
|
|
file_path: str,
|
|
issues: list[str],
|
|
config: QualityConfig,
|
|
*,
|
|
forced_permission: str | None = None,
|
|
) -> JsonObject:
|
|
"""Handle quality issues based on enforcement mode."""
|
|
# Prepare denial message
|
|
message = (
|
|
f"Code quality check failed for {Path(file_path).name}:\n"
|
|
+ "\n".join(f" • {issue}" for issue in issues)
|
|
+ "\n\nFix these issues before writing the code."
|
|
)
|
|
|
|
# Make decision based on enforcement mode
|
|
if forced_permission:
|
|
return _create_hook_response("PreToolUse", forced_permission, message)
|
|
|
|
if config.enforcement_mode == "strict":
|
|
return _create_hook_response("PreToolUse", "deny", message)
|
|
if config.enforcement_mode == "warn":
|
|
return _create_hook_response("PreToolUse", "ask", message)
|
|
# permissive
|
|
warning_message = f"⚠️ Quality Warning:\n{message}"
|
|
return _create_hook_response(
|
|
"PreToolUse",
|
|
"allow",
|
|
warning_message,
|
|
warning_message,
|
|
)
|
|
|
|
|
|
def _exit_with_reason(reason: str, exit_code: int = 2) -> None:
|
|
"""Write reason to stderr and exit with specified code."""
|
|
sys.stderr.write(reason)
|
|
sys.exit(exit_code)
|
|
|
|
|
|
def _create_hook_response(
|
|
event_name: str,
|
|
permission: str = "",
|
|
reason: str = "",
|
|
system_message: str = "",
|
|
additional_context: str = "",
|
|
*,
|
|
decision: str | None = None,
|
|
) -> JsonObject:
|
|
"""Create standardized hook response."""
|
|
hook_output: dict[str, object] = {
|
|
"hookEventName": event_name,
|
|
}
|
|
|
|
if permission:
|
|
hook_output["permissionDecision"] = permission
|
|
if reason:
|
|
hook_output["permissionDecisionReason"] = reason
|
|
|
|
if additional_context:
|
|
hook_output["additionalContext"] = additional_context
|
|
|
|
response: JsonObject = {
|
|
"hookSpecificOutput": hook_output,
|
|
}
|
|
|
|
if permission:
|
|
response["permissionDecision"] = permission
|
|
|
|
if decision:
|
|
response["decision"] = decision
|
|
|
|
if reason:
|
|
response["reason"] = reason
|
|
|
|
if system_message:
|
|
response["systemMessage"] = system_message
|
|
|
|
return response
|
|
|
|
|
|
def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject:
|
|
"""Handle PreToolUse hook - analyze content before write/edit."""
|
|
tool_name = str(hook_data.get("tool_name", ""))
|
|
tool_input_raw = hook_data.get("tool_input", {})
|
|
if not isinstance(tool_input_raw, dict):
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
tool_input = cast("dict[str, object]", tool_input_raw)
|
|
|
|
# Only analyze for write/edit tools
|
|
if tool_name not in ["Write", "Edit", "MultiEdit"]:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
# Extract content based on tool type
|
|
file_path = str(tool_input.get("file_path", ""))
|
|
content = ""
|
|
|
|
if tool_name == "Write":
|
|
raw_content = tool_input.get("content", "")
|
|
content = "" if raw_content is None else str(raw_content)
|
|
elif tool_name == "Edit":
|
|
new_string = tool_input.get("new_string", "")
|
|
content = "" if new_string is None else str(new_string)
|
|
elif tool_name == "MultiEdit":
|
|
edits = tool_input.get("edits", [])
|
|
if isinstance(edits, list):
|
|
edits_list = cast("list[object]", edits)
|
|
parts: list[str] = []
|
|
for edit in edits_list:
|
|
if not isinstance(edit, dict):
|
|
continue
|
|
edit_dict = cast("dict[str, object]", edit)
|
|
new_str = edit_dict.get("new_string")
|
|
parts.append("") if new_str is None else parts.append(str(new_str))
|
|
content = "\n".join(parts)
|
|
|
|
# Only analyze Python files
|
|
if not file_path or not file_path.endswith(".py") or not content:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
# Check if this is a test file and test quality checks are enabled
|
|
is_test = is_test_file(file_path)
|
|
run_test_checks = config.test_quality_enabled and is_test
|
|
|
|
# Skip analysis for configured patterns, but not if it's a test file with test checks enabled
|
|
if should_skip_file(file_path, config) and not run_test_checks:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
enable_type_checks = tool_name == "Write"
|
|
any_usage_issues = _detect_any_usage(content)
|
|
type_ignore_issues = _detect_type_ignore_usage(content)
|
|
precheck_issues = any_usage_issues + type_ignore_issues
|
|
|
|
# Run test quality checks if enabled and file is a test file
|
|
if run_test_checks:
|
|
test_quality_issues = run_test_quality_checks(content, file_path, config)
|
|
precheck_issues.extend(test_quality_issues)
|
|
|
|
try:
|
|
_has_issues, issues = _perform_quality_check(
|
|
file_path,
|
|
content,
|
|
config,
|
|
enable_type_checks=enable_type_checks,
|
|
)
|
|
|
|
all_issues = precheck_issues + issues
|
|
|
|
if not all_issues:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
if precheck_issues:
|
|
return _handle_quality_issues(
|
|
file_path,
|
|
all_issues,
|
|
config,
|
|
forced_permission="deny",
|
|
)
|
|
|
|
return _handle_quality_issues(file_path, all_issues, config)
|
|
except Exception as e: # noqa: BLE001
|
|
if precheck_issues:
|
|
return _handle_quality_issues(
|
|
file_path,
|
|
precheck_issues,
|
|
config,
|
|
forced_permission="deny",
|
|
)
|
|
return _create_hook_response(
|
|
"PreToolUse",
|
|
"allow",
|
|
f"Warning: Code quality check failed with error: {e}",
|
|
f"Warning: Code quality check failed with error: {e}",
|
|
)
|
|
|
|
|
|
def posttooluse_hook(
|
|
hook_data: JsonObject,
|
|
config: QualityConfig,
|
|
) -> JsonObject:
|
|
"""Handle PostToolUse hook - verify quality after write/edit."""
|
|
tool_name: str = str(hook_data.get("tool_name", ""))
|
|
tool_output = hook_data.get("tool_response", hook_data.get("tool_output", {}))
|
|
|
|
# Only process write/edit tools
|
|
if tool_name not in ["Write", "Edit", "MultiEdit"]:
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
# Extract file path from output
|
|
file_path: str = ""
|
|
if isinstance(tool_output, dict):
|
|
tool_output_dict = cast("dict[str, object]", tool_output)
|
|
file_path = str(
|
|
tool_output_dict.get("file_path", "") or tool_output_dict.get("path", ""),
|
|
)
|
|
elif isinstance(tool_output, str) and (
|
|
match := re.search(r"([/\w\-_.]+\.py)", tool_output)
|
|
):
|
|
file_path = match[1]
|
|
|
|
if not file_path or not file_path.endswith(".py"):
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
if not Path(file_path).exists():
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
issues: list[str] = []
|
|
|
|
# Check state changes if tracking enabled
|
|
if config.state_tracking_enabled:
|
|
delta_issues = check_state_changes(file_path)
|
|
issues.extend(delta_issues)
|
|
|
|
# Run cross-file duplicate detection if enabled
|
|
if config.cross_file_check_enabled:
|
|
cross_file_issues = check_cross_file_duplicates(file_path, config)
|
|
issues.extend(cross_file_issues)
|
|
|
|
# Verify naming conventions if enabled
|
|
if config.verify_naming:
|
|
naming_issues = verify_naming_conventions(file_path)
|
|
issues.extend(naming_issues)
|
|
|
|
# Format response
|
|
if issues:
|
|
message = (
|
|
f"📝 Post-write quality notes for {Path(file_path).name}:\n"
|
|
+ "\n".join(issues)
|
|
)
|
|
return _create_hook_response(
|
|
"PostToolUse",
|
|
"",
|
|
message,
|
|
message,
|
|
message,
|
|
decision="block",
|
|
)
|
|
if config.show_success:
|
|
message = f"✅ {Path(file_path).name} passed post-write verification"
|
|
return _create_hook_response(
|
|
"PostToolUse",
|
|
"",
|
|
"",
|
|
message,
|
|
"",
|
|
decision="approve",
|
|
)
|
|
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
|
|
def is_test_file(file_path: str) -> bool:
|
|
"""Check if file path is in a test directory."""
|
|
path_parts = Path(file_path).parts
|
|
return any(part in ("test", "tests", "testing") for part in path_parts)
|
|
|
|
|
|
def run_test_quality_checks(content: str, file_path: str, config: QualityConfig) -> list[str]:
|
|
"""Run Sourcery with specific test-related rules and return issues with enhanced guidance."""
|
|
issues: list[str] = []
|
|
|
|
# Only run test quality checks for test files
|
|
if not is_test_file(file_path):
|
|
return issues
|
|
|
|
suffix = Path(file_path).suffix or ".py"
|
|
with NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp:
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
# Run Sourcery with specific test-related rules
|
|
venv_bin = Path(__file__).parent.parent / ".venv/bin"
|
|
sourcery_path = venv_bin / "sourcery"
|
|
|
|
if not sourcery_path.exists():
|
|
# Try to find sourcery in PATH
|
|
import shutil
|
|
sourcery_path = shutil.which("sourcery") or str(venv_bin / "sourcery")
|
|
|
|
if not sourcery_path or not Path(sourcery_path).exists():
|
|
logging.debug("Sourcery not found at %s", sourcery_path)
|
|
return issues
|
|
|
|
# Specific rules for test quality - use correct Sourcery format
|
|
test_rules = [
|
|
"no-conditionals-in-tests",
|
|
"no-loop-in-tests",
|
|
"raise-specific-error",
|
|
"dont-import-test-modules",
|
|
]
|
|
|
|
cmd = [
|
|
sourcery_path,
|
|
"review",
|
|
tmp_path,
|
|
"--rules",
|
|
",".join(test_rules),
|
|
"--format",
|
|
"json",
|
|
]
|
|
|
|
logging.debug("Running Sourcery command: %s", " ".join(cmd))
|
|
result = subprocess.run( # noqa: S603
|
|
cmd,
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=30,
|
|
)
|
|
|
|
logging.debug("Sourcery exit code: %s", result.returncode)
|
|
logging.debug("Sourcery stdout: %s", result.stdout)
|
|
logging.debug("Sourcery stderr: %s", result.stderr)
|
|
|
|
if result.returncode == 0:
|
|
try:
|
|
sourcery_output = json.loads(result.stdout)
|
|
# Extract issues from Sourcery output - handle different JSON formats
|
|
if "files" in sourcery_output:
|
|
for file_issues in sourcery_output["files"].values():
|
|
if isinstance(file_issues, list):
|
|
for issue in file_issues:
|
|
if isinstance(issue, dict):
|
|
rule_id = issue.get("rule", "unknown")
|
|
# Generate enhanced guidance for each violation
|
|
base_guidance = generate_test_quality_guidance(rule_id, content, file_path)
|
|
|
|
# Add external context if available
|
|
external_context = get_external_context(rule_id, content, file_path, config)
|
|
if external_context:
|
|
base_guidance += f"\n\n{external_context}"
|
|
|
|
issues.append(base_guidance)
|
|
elif "violations" in sourcery_output:
|
|
# Alternative format
|
|
for violation in sourcery_output["violations"]:
|
|
if isinstance(violation, dict):
|
|
rule_id = violation.get("rule", "unknown")
|
|
base_guidance = generate_test_quality_guidance(rule_id, content, file_path)
|
|
|
|
# Add external context if available
|
|
external_context = get_external_context(rule_id, content, file_path, config)
|
|
if external_context:
|
|
base_guidance += f"\n\n{external_context}"
|
|
|
|
issues.append(base_guidance)
|
|
elif isinstance(sourcery_output, list):
|
|
# Direct list of issues
|
|
for issue in sourcery_output:
|
|
if isinstance(issue, dict):
|
|
rule_id = issue.get("rule", "unknown")
|
|
base_guidance = generate_test_quality_guidance(rule_id, content, file_path)
|
|
|
|
# Add external context if available
|
|
external_context = get_external_context(rule_id, content, file_path, config)
|
|
if external_context:
|
|
base_guidance += f"\n\n{external_context}"
|
|
|
|
issues.append(base_guidance)
|
|
except json.JSONDecodeError as e:
|
|
logging.debug("Failed to parse Sourcery JSON output: %s", e)
|
|
# If JSON parsing fails, provide general guidance with external context
|
|
base_guidance = generate_test_quality_guidance("unknown", content, file_path)
|
|
external_context = get_external_context("unknown", content, file_path, config)
|
|
if external_context:
|
|
base_guidance += f"\n\n{external_context}"
|
|
issues.append(base_guidance)
|
|
elif result.returncode != 0 and (result.stdout.strip() or result.stderr.strip()):
|
|
# Sourcery found issues or errors - provide general guidance
|
|
error_output = (result.stdout + " " + result.stderr).strip()
|
|
base_guidance = generate_test_quality_guidance("sourcery-error", content, file_path)
|
|
external_context = get_external_context("sourcery-error", content, file_path, config)
|
|
if external_context:
|
|
base_guidance += f"\n\n{external_context}"
|
|
issues.append(base_guidance)
|
|
|
|
except (subprocess.TimeoutExpired, OSError, json.JSONDecodeError) as e:
|
|
# If Sourcery fails, don't block the operation
|
|
logging.debug("Test quality check failed for %s: %s", file_path, e)
|
|
finally:
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|
|
|
return issues
|
|
|
|
|
|
def main() -> None:
|
|
"""Main hook entry point."""
|
|
try:
|
|
# Load configuration
|
|
config = QualityConfig.from_env()
|
|
|
|
# Read hook input from stdin
|
|
try:
|
|
hook_data: JsonObject = json.load(sys.stdin)
|
|
except json.JSONDecodeError:
|
|
fallback_response: JsonObject = {
|
|
"hookSpecificOutput": {
|
|
"hookEventName": "PreToolUse",
|
|
"permissionDecision": "allow",
|
|
},
|
|
}
|
|
sys.stdout.write(json.dumps(fallback_response))
|
|
return
|
|
|
|
# Detect hook type: tool_response=PostToolUse, tool_input=PreToolUse
|
|
response: JsonObject
|
|
if "tool_response" in hook_data or "tool_output" in hook_data:
|
|
# PostToolUse hook
|
|
response = posttooluse_hook(hook_data, config)
|
|
else:
|
|
# PreToolUse hook
|
|
response = pretooluse_hook(hook_data, config)
|
|
|
|
print(json.dumps(response)) # noqa: T201
|
|
|
|
# Handle exit codes based on hook output
|
|
hook_output_raw = response.get("hookSpecificOutput", {})
|
|
if not isinstance(hook_output_raw, dict):
|
|
return
|
|
|
|
hook_output = cast("dict[str, object]", hook_output_raw)
|
|
permission_decision = hook_output.get("permissionDecision")
|
|
|
|
if permission_decision == "deny":
|
|
# Exit code 2: Blocking error - stderr fed back to Claude
|
|
reason = str(
|
|
hook_output.get("permissionDecisionReason", "Permission denied"),
|
|
)
|
|
_exit_with_reason(reason)
|
|
elif permission_decision == "ask":
|
|
# Also use exit code 2 for ask decisions to ensure Claude sees the message
|
|
reason = str(
|
|
hook_output.get("permissionDecisionReason", "Permission request"),
|
|
)
|
|
_exit_with_reason(reason)
|
|
# Exit code 0: Success (default)
|
|
|
|
except Exception as e: # noqa: BLE001
|
|
# Unexpected error - use exit code 1 (non-blocking error)
|
|
sys.stderr.write(f"Hook error: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|