feat: enhance test coverage and improve code quality checks

- Updated test files to improve coverage for Any usage, type: ignore, and old typing patterns, ensuring that these patterns are properly blocked.
- Refactored test structure for better clarity and maintainability, including the introduction of fixtures and improved assertions.
- Enhanced error handling and messaging in the hook system to provide clearer feedback on violations.
- Improved integration tests to validate hook behavior across various scenarios, ensuring robustness and reliability.
This commit is contained in:
2025-10-08 09:10:32 +00:00
parent 3e2e2dfbc1
commit f3832bdf3d
18 changed files with 2429 additions and 15669 deletions

450
hooks/bash_command_guard.py Normal file
View File

@@ -0,0 +1,450 @@
"""Shell command guard for Claude Code PreToolUse/PostToolUse hooks.
Prevents circumvention of type safety rules via shell commands that could inject
'Any' types or type ignore comments into Python files.
"""
import json
import re
import subprocess
import sys
from pathlib import Path
from shutil import which
from typing import TypedDict
from .bash_guard_constants import (
DANGEROUS_SHELL_PATTERNS,
FORBIDDEN_PATTERNS,
PYTHON_FILE_PATTERNS,
)
class JsonObject(TypedDict, total=False):
"""Type for JSON-like objects."""
hookEventName: str
permissionDecision: str
permissionDecisionReason: str
decision: str
reason: str
systemMessage: str
hookSpecificOutput: dict[str, object]
def _contains_forbidden_pattern(text: str) -> tuple[bool, str | None]:
"""Check if text contains any forbidden patterns.
Args:
text: The text to check for forbidden patterns.
Returns:
Tuple of (has_violation, matched_pattern_description)
"""
for pattern in FORBIDDEN_PATTERNS:
if re.search(pattern, text, re.IGNORECASE):
if "Any" in pattern:
return True, "typing.Any usage"
if "type.*ignore" in pattern:
return True, "type suppression comment"
return False, None
def _is_dangerous_shell_command(command: str) -> tuple[bool, str | None]:
"""Check if shell command uses dangerous patterns.
Args:
command: The shell command to analyze.
Returns:
Tuple of (is_dangerous, reason)
"""
# Check if command targets Python files
targets_python = any(
re.search(pattern, command) for pattern in PYTHON_FILE_PATTERNS
)
if not targets_python:
return False, None
# Check for dangerous shell patterns
for pattern in DANGEROUS_SHELL_PATTERNS:
if re.search(pattern, command):
tool_match = re.search(
r"\b(sed|awk|perl|ed|echo|printf|cat|tee|find|xargs|python|vim|nano|emacs)\b",
pattern,
)
tool_name = tool_match.group(1) if tool_match else "shell utility"
return True, f"Use of {tool_name} to modify Python files"
return False, None
def _command_contains_forbidden_injection(command: str) -> tuple[bool, str | None]:
"""Check if command attempts to inject forbidden patterns.
Args:
command: The shell command to analyze.
Returns:
Tuple of (has_injection, violation_description)
"""
# Check if the command itself contains forbidden patterns
has_violation, violation_type = _contains_forbidden_pattern(command)
if has_violation:
return True, violation_type
# Check for encoded or escaped patterns
# Handle common escape sequences
decoded_cmd = command.replace("\\n", "\n").replace("\\t", "\t")
decoded_cmd = re.sub(r"\\\s", " ", decoded_cmd)
has_violation, violation_type = _contains_forbidden_pattern(decoded_cmd)
if has_violation:
return True, f"{violation_type} (escaped)"
return False, None
def _analyze_bash_command(command: str) -> tuple[bool, list[str]]:
"""Analyze bash command for safety violations.
Args:
command: The bash command to analyze.
Returns:
Tuple of (should_block, list_of_violations)
"""
violations: list[str] = []
# Check for forbidden pattern injection
has_injection, injection_type = _command_contains_forbidden_injection(command)
if has_injection:
violations.append(f"⛔ Shell command attempts to inject {injection_type}")
# Check for dangerous shell patterns on Python files
is_dangerous, danger_reason = _is_dangerous_shell_command(command)
if is_dangerous:
violations.append(
f"⛔ {danger_reason} is forbidden - use Edit/Write tools instead",
)
return len(violations) > 0, violations
def _create_hook_response(
event_name: str,
permission: str = "",
reason: str = "",
system_message: str = "",
*,
decision: str | None = None,
) -> JsonObject:
"""Create standardized hook response.
Args:
event_name: Name of the hook event (PreToolUse, PostToolUse, Stop).
permission: Permission decision (allow, deny, ask).
reason: Reason for the decision.
system_message: System message to display.
decision: Decision for PostToolUse/Stop hooks (approve, block).
Returns:
JSON response object for the hook.
"""
hook_output: dict[str, object] = {
"hookEventName": event_name,
}
if permission:
hook_output["permissionDecision"] = permission
if reason:
hook_output["permissionDecisionReason"] = reason
response: JsonObject = {
"hookSpecificOutput": hook_output,
}
if permission:
response["permissionDecision"] = permission
if decision:
response["decision"] = decision
if reason:
response["reason"] = reason
if system_message:
response["systemMessage"] = system_message
return response
def pretooluse_bash_hook(hook_data: dict[str, object]) -> JsonObject:
"""Handle PreToolUse hook for Bash commands.
Args:
hook_data: Hook input data containing tool_name and tool_input.
Returns:
Hook response with permission decision.
"""
tool_name = str(hook_data.get("tool_name", ""))
# Only analyze Bash commands
if tool_name != "Bash":
return _create_hook_response("PreToolUse", "allow")
tool_input_raw = hook_data.get("tool_input", {})
if not isinstance(tool_input_raw, dict):
return _create_hook_response("PreToolUse", "allow")
tool_input: dict[str, object] = dict(tool_input_raw)
command = str(tool_input.get("command", ""))
if not command:
return _create_hook_response("PreToolUse", "allow")
# Analyze command for violations
should_block, violations = _analyze_bash_command(command)
if not should_block:
return _create_hook_response("PreToolUse", "allow")
# Build denial message
violation_text = "\n".join(f" {v}" for v in violations)
message = (
f"🚫 Shell Command Blocked\n\n"
f"Violations:\n{violation_text}\n\n"
f"Command: {command[:200]}{'...' if len(command) > 200 else ''}\n\n"
f"Use Edit/Write tools to modify Python files with proper type safety."
)
return _create_hook_response(
"PreToolUse",
"deny",
message,
message,
)
def posttooluse_bash_hook(hook_data: dict[str, object]) -> JsonObject:
"""Handle PostToolUse hook for Bash commands.
Args:
hook_data: Hook output data containing tool_response.
Returns:
Hook response with decision.
"""
tool_name = str(hook_data.get("tool_name", ""))
# Only analyze Bash commands
if tool_name != "Bash":
return _create_hook_response("PostToolUse")
# Extract command from hook data
tool_input_raw = hook_data.get("tool_input", {})
if not isinstance(tool_input_raw, dict):
return _create_hook_response("PostToolUse")
tool_input: dict[str, object] = dict(tool_input_raw)
command = str(tool_input.get("command", ""))
# Check if command modified any Python files
# Look for file paths in the command
python_files: list[str] = []
for match in re.finditer(r"([^\s]+\.pyi?)\b", command):
file_path = match.group(1)
if Path(file_path).exists():
python_files.append(file_path)
if not python_files:
return _create_hook_response("PostToolUse")
# Scan modified files for violations
violations: list[str] = []
for file_path in python_files:
try:
with open(file_path, encoding="utf-8") as file_handle:
content = file_handle.read()
has_violation, violation_type = _contains_forbidden_pattern(content)
if has_violation:
violations.append(
f"⛔ File '{Path(file_path).name}' contains {violation_type}",
)
except (OSError, UnicodeDecodeError):
# If we can't read the file, skip it
continue
if violations:
violation_text = "\n".join(f" {v}" for v in violations)
message = (
f"🚫 Post-Execution Violation Detected\n\n"
f"Violations:\n{violation_text}\n\n"
f"Shell command introduced forbidden patterns. "
f"Please revert changes and use proper typing."
)
return _create_hook_response(
"PostToolUse",
"",
message,
message,
decision="block",
)
return _create_hook_response("PostToolUse")
def stop_hook(_hook_data: dict[str, object]) -> JsonObject:
"""Handle Stop hook - final validation before completion.
Args:
_hook_data: Stop hook data (unused).
Returns:
Hook response with decision.
"""
# Get list of changed files from git
try:
git_path = which("git")
if git_path is None:
return _create_hook_response("Stop", decision="approve")
# Safe: invokes git with fixed arguments, no user input interpolation.
result = subprocess.run( # noqa: S603
[git_path, "diff", "--name-only", "--cached"],
capture_output=True,
text=True,
check=False,
timeout=10,
)
if result.returncode != 0:
# No git repo or no staged changes
return _create_hook_response("Stop", decision="approve")
changed_files = [
file_name.strip()
for file_name in result.stdout.split("\n")
if file_name.strip() and file_name.strip().endswith((".py", ".pyi"))
]
if not changed_files:
return _create_hook_response("Stop", decision="approve")
# Scan all changed Python files
violations: list[str] = []
for file_path in changed_files:
if not Path(file_path).exists():
continue
try:
with open(file_path, encoding="utf-8") as file_handle:
content = file_handle.read()
has_violation, violation_type = _contains_forbidden_pattern(content)
if has_violation:
violations.append(f"⛔ {file_path}: {violation_type}")
except (OSError, UnicodeDecodeError):
continue
if violations:
violation_text = "\n".join(f" {v}" for v in violations)
message = (
f"🚫 Final Validation Failed\n\n"
f"Violations:\n{violation_text}\n\n"
f"Please remove forbidden patterns before completing."
)
return _create_hook_response(
"Stop",
"",
message,
message,
decision="block",
)
return _create_hook_response("Stop", decision="approve")
except (OSError, subprocess.SubprocessError, TimeoutError) as exc:
# If validation fails, allow but warn
return _create_hook_response(
"Stop",
"",
f"Warning: Final validation error: {exc}",
f"Warning: Final validation error: {exc}",
decision="approve",
)
def main() -> None:
"""Main hook entry point."""
try:
# Read hook input from stdin
try:
hook_data: dict[str, object] = json.load(sys.stdin)
except json.JSONDecodeError:
fallback_response: JsonObject = {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow",
},
}
sys.stdout.write(json.dumps(fallback_response))
return
# Detect hook type
response: JsonObject
if "tool_response" in hook_data or "tool_output" in hook_data:
# PostToolUse hook
response = posttooluse_bash_hook(hook_data)
elif hook_data.get("hookEventName") == "Stop":
# Stop hook
response = stop_hook(hook_data)
else:
# PreToolUse hook
response = pretooluse_bash_hook(hook_data)
sys.stdout.write(json.dumps(response))
# Handle exit codes based on hook output
hook_output_raw = response.get("hookSpecificOutput", {})
if not hook_output_raw or not isinstance(hook_output_raw, dict):
return
hook_output: dict[str, object] = hook_output_raw
permission_decision = hook_output.get("permissionDecision")
if permission_decision == "deny":
# Exit code 2: Blocking error
reason = str(
hook_output.get("permissionDecisionReason", "Permission denied"),
)
sys.stderr.write(reason)
sys.exit(2)
elif permission_decision == "ask":
# Exit code 2 for ask decisions
reason = str(
hook_output.get("permissionDecisionReason", "Permission request"),
)
sys.stderr.write(reason)
sys.exit(2)
# Check for Stop hook block decision
if response.get("decision") == "block":
reason = str(response.get("reason", "Validation failed"))
sys.stderr.write(reason)
sys.exit(2)
except (OSError, ValueError, subprocess.SubprocessError, TimeoutError) as exc:
# Unexpected error - use exit code 1 (non-blocking error)
sys.stderr.write(f"Hook error: {exc}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,72 @@
"""Shared constants for bash command guard functionality.
This module contains patterns and constants used to detect forbidden code patterns
and dangerous shell commands that could compromise type safety.
"""
# Forbidden patterns that should never appear in Python code
FORBIDDEN_PATTERNS = [
r"\bfrom\s+typing\s+import\s+.*\bAny\b", # from typing import Any
r"\bimport\s+typing\.Any\b", # import typing.Any
r"\btyping\.Any\b", # typing.Any reference
r"\b:\s*Any\b", # Type annotation with Any
r"->\s*Any\b", # Return type Any
r"#\s*type:\s*ignore", # type suppression comment
]
# Shell command patterns that can modify files
DANGEROUS_SHELL_PATTERNS = [
# Direct file writes
r"\becho\s+.*>",
r"\bprintf\s+.*>",
r"\bcat\s+>",
r"\btee\s+",
# Stream editors and text processors
r"\bsed\s+",
r"\bawk\s+",
r"\bperl\s+",
r"\bed\s+",
# Mass operations
r"\bfind\s+.*-exec",
r"\bxargs\s+",
r"\bgrep\s+.*\|\s*xargs",
# Python execution with file operations
r"\bpython\s+-c\s+.*open\(",
r"\bpython\s+-c\s+.*write\(",
r"\bpython3\s+-c\s+.*open\(",
r"\bpython3\s+-c\s+.*write\(",
# Editor batch modes
r"\bvim\s+-c",
r"\bnano\s+--tempfile",
r"\bemacs\s+--batch",
]
# Python file patterns to protect
PYTHON_FILE_PATTERNS = [
r"\.py\b",
r"\.pyi\b",
]
# Pattern descriptions for error messages
FORBIDDEN_PATTERN_DESCRIPTIONS = {
"Any": "typing.Any usage",
"type.*ignore": "type suppression comment",
}
# Tool names extracted from dangerous patterns
DANGEROUS_TOOLS = [
"sed",
"awk",
"perl",
"ed",
"echo",
"printf",
"cat",
"tee",
"find",
"xargs",
"python",
"vim",
"nano",
"emacs",
]

View File

@@ -13,6 +13,8 @@ import re
import shutil
import subprocess
import sys
from importlib import import_module
import textwrap
import tokenize
from collections.abc import Callable
from contextlib import suppress
@@ -20,237 +22,261 @@ from dataclasses import dataclass
from datetime import UTC, datetime
from io import StringIO
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import TypedDict, cast
from tempfile import NamedTemporaryFile, gettempdir
from typing import TYPE_CHECKING, TypedDict, cast
# Import internal duplicate detector
sys.path.insert(0, str(Path(__file__).parent))
from internal_duplicate_detector import detect_internal_duplicates
# Import internal duplicate detector; fall back to local path when executed directly
if TYPE_CHECKING:
from .internal_duplicate_detector import (
Duplicate,
DuplicateResults,
detect_internal_duplicates,
)
else:
try:
from .internal_duplicate_detector import (
Duplicate,
DuplicateResults,
detect_internal_duplicates,
)
except ImportError:
sys.path.insert(0, str(Path(__file__).parent))
module = import_module("internal_duplicate_detector")
Duplicate = module.Duplicate
DuplicateResults = module.DuplicateResults
detect_internal_duplicates = module.detect_internal_duplicates
SUSPICIOUS_SUFFIXES: tuple[str, ...] = (
"enhanced",
"improved",
"better",
"new",
"updated",
"modified",
"refactored",
"optimized",
"fixed",
"clean",
"simple",
"advanced",
"basic",
"complete",
"final",
"latest",
"current",
"temp",
"temporary",
"backup",
"old",
"legacy",
"unified",
"merged",
"combined",
"integrated",
"consolidated",
"extended",
"enriched",
"augmented",
"upgraded",
"revised",
"polished",
"streamlined",
"simplified",
"modernized",
"normalized",
"sanitized",
"validated",
"verified",
"corrected",
"patched",
"stable",
"experimental",
"alpha",
"beta",
"draft",
"preliminary",
"prototype",
"working",
"test",
"debug",
"custom",
"special",
"generic",
"specific",
"general",
"detailed",
"minimal",
"full",
"partial",
"quick",
"fast",
"slow",
"smart",
"intelligent",
"auto",
"manual",
"secure",
"safe",
"robust",
"flexible",
"dynamic",
"static",
"reactive",
"async",
"sync",
"parallel",
"serial",
"distributed",
"centralized",
"decentralized",
)
FILE_SUFFIX_DUPLICATE_MSG = (
"⚠️ File '{current}' appears to be a suffixed duplicate of '{original}'. "
"Consider refactoring instead of creating variations with adjective "
"suffixes."
)
EXISTING_FILE_DUPLICATE_MSG = (
"⚠️ Creating '{current}' when '{existing}' already exists suggests "
"duplication. Consider consolidating or using a more descriptive name."
)
NAME_SUFFIX_DUPLICATE_MSG = (
"⚠️ {kind} '{name}' appears to be a suffixed duplicate of '{base}'. "
"Consider refactoring instead of creating variations."
)
class QualityConfig:
"""Configuration for quality checks."""
def get_external_context(
rule_id: str,
content: str,
_file_path: str,
config: "QualityConfig",
) -> str:
"""Fetch additional guidance from optional integrations."""
context_parts: list[str] = []
def __init__(self):
self._config = config = QualityConfig.from_env()
config.skip_patterns = ["test_", "_test.py", "/tests/", "/fixtures/"]
def __getattr__(self, name):
return getattr(self._config, name, self._config[name])
@classmethod
def from_env(cls) -> "QualityConfig":
"""Load config from environment variables."""
return cls()
def get_external_context(rule_id: str, content: str, file_path: str, config: QualityConfig) -> str:
"""Get additional context from external APIs for enhanced guidance."""
context_parts = []
# Context7 integration for additional context analysis
if config.context7_enabled and config.context7_api_key:
try:
# Note: This would integrate with actual Context7 API
# For now, providing placeholder for the integration
context7_context = _get_context7_analysis(rule_id, content, config.context7_api_key)
context7_context = _get_context7_analysis(
rule_id,
content,
config.context7_api_key,
)
except (OSError, RuntimeError, ValueError, TimeoutError) as exc:
logging.debug("Context7 API call failed: %s", exc)
else:
if context7_context:
context_parts.append(f"📊 Context7 Analysis: {context7_context}")
except Exception as e:
logging.debug("Context7 API call failed: %s", e)
# Firecrawl integration for web scraping additional examples
if config.firecrawl_enabled and config.firecrawl_api_key:
try:
# Note: This would integrate with actual Firecrawl API
# For now, providing placeholder for the integration
firecrawl_examples = _get_firecrawl_examples(rule_id, config.firecrawl_api_key)
firecrawl_examples = _get_firecrawl_examples(
rule_id,
config.firecrawl_api_key,
)
except (OSError, RuntimeError, ValueError, TimeoutError) as exc:
logging.debug("Firecrawl API call failed: %s", exc)
else:
if firecrawl_examples:
context_parts.append(f"🔗 Additional Examples: {firecrawl_examples}")
except Exception as e:
logging.debug("Firecrawl API call failed: %s", e)
return "\n\n".join(context_parts) if context_parts else ""
return "\n\n".join(context_parts)
def _get_context7_analysis(rule_id: str, content: str, api_key: str) -> str:
def _get_context7_analysis(
rule_id: str,
_content: str,
_api_key: str,
) -> str:
"""Placeholder for Context7 API integration."""
# This would make actual API calls to Context7
# For demonstration, returning contextual analysis based on rule type
context_map = {
"no-conditionals-in-tests": "Consider using data-driven tests with pytest.mark.parametrize for better test isolation.",
"no-loop-in-tests": "Each iteration should be a separate test case for better failure isolation and debugging.",
"raise-specific-error": "Specific exceptions improve error handling and make tests more maintainable.",
"dont-import-test-modules": "Keep test utilities separate from production code to maintain clean architecture."
"no-conditionals-in-tests": (
"Use pytest.mark.parametrize instead of branching inside tests."
),
"no-loop-in-tests": (
"Split loops into individual tests or parameterize the data for clarity."
),
"raise-specific-error": (
"Raise precise exception types (ValueError, custom errors, etc.) "
"so failures highlight the right behaviour."
),
"dont-import-test-modules": (
"Production code should not depend on test helpers; "
"extract shared logic into utility modules."
),
}
return context_map.get(rule_id, "Additional context analysis would be provided here.")
return context_map.get(
rule_id,
"Context7 guidance will appear once the integration is available.",
)
def _get_firecrawl_examples(rule_id: str, api_key: str) -> str:
def _get_firecrawl_examples(rule_id: str, _api_key: str) -> str:
"""Placeholder for Firecrawl API integration."""
# This would scrape web for additional examples and best practices
# For demonstration, returning relevant examples based on rule type
examples_map = {
"no-conditionals-in-tests": "See pytest documentation on parameterized tests for multiple scenarios.",
"no-loop-in-tests": "Check testing best practices guides for data-driven testing approaches.",
"raise-specific-error": "Review Python exception hierarchy and custom exception patterns.",
"dont-import-test-modules": "Explore clean architecture patterns for test organization."
"no-conditionals-in-tests": (
"Pytest parameterization tutorial: "
"docs.pytest.org/en/latest/how-to/parametrize.html"
),
"no-loop-in-tests": (
"Testing best practices: keep scenarios in separate "
"tests for precise failures."
),
"raise-specific-error": (
"Python docs on exceptions: docs.python.org/3/tutorial/errors.html"
),
"dont-import-test-modules": (
"Clean architecture tip: production modules should not import from tests."
),
}
return examples_map.get(rule_id, "Additional examples and patterns would be sourced from web resources.")
def generate_test_quality_guidance(rule_id: str, content: str, file_path: str, config: QualityConfig) -> str:
"""Generate specific, actionable guidance for test quality rule violations."""
return examples_map.get(
rule_id,
"Firecrawl examples will appear once the integration is available.",
)
# Extract function name and context for better guidance
def generate_test_quality_guidance(
rule_id: str,
content: str,
file_path: str,
_config: "QualityConfig",
) -> str:
"""Return concise guidance for test quality rule violations."""
function_name = "test_function"
if "def " in content:
# Try to extract the test function name
import re
match = re.search(r'def\s+(\w+)\s*\(', content)
if match:
function_name = match.group(1)
match = re.search(r"def\s+(\w+)\s*\(", content)
if match:
function_name = match.group(1)
file_name = Path(file_path).name
guidance_map = {
"no-conditionals-in-tests": {
"title": "Conditional Logic in Test Function",
"problem": f"Test function '{function_name}' contains conditional statements (if/elif/else).",
"why": "Tests should be simple assertions that verify specific behavior. Conditionals make tests harder to understand and maintain.",
"solutions": [
"✅ Replace conditionals with parameterized test cases",
"✅ Use pytest.mark.parametrize for multiple scenarios",
"✅ Extract conditional logic into helper functions",
"✅ Use assertion libraries like assertpy for complex conditions"
],
"examples": [
"# ❌ Instead of this:",
"def test_user_access():",
" user = create_user()",
" if user.is_admin:",
" assert user.can_access_admin()",
" else:",
" assert not user.can_access_admin()",
"",
"# ✅ Do this:",
"@pytest.mark.parametrize('is_admin,can_access', [",
" (True, True),",
" (False, False)",
"])",
"def test_user_access(is_admin, can_access):",
" user = create_user(admin=is_admin)",
" assert user.can_access_admin() == can_access"
]
},
"no-loop-in-tests": {
"title": "Loop Found in Test Function",
"problem": f"Test function '{function_name}' contains loops (for/while).",
"why": "Loops in tests often indicate testing multiple scenarios that should be separate test cases.",
"solutions": [
"✅ Break loop into individual test cases",
"✅ Use parameterized tests for multiple data scenarios",
"✅ Extract loop logic into data providers or fixtures",
"✅ Use pytest fixtures for setup that requires iteration"
],
"examples": [
"# ❌ Instead of this:",
"def test_process_items():",
" for item in test_items:",
" result = process_item(item)",
" assert result.success",
"",
"# ✅ Do this:",
"@pytest.mark.parametrize('item', test_items)",
"def test_process_item(item):",
" result = process_item(item)",
" assert result.success"
]
},
"raise-specific-error": {
"title": "Generic Exception in Test",
"problem": f"Test function '{function_name}' raises generic Exception or BaseException.",
"why": "Generic exceptions make it harder to identify specific test failures and handle different error conditions appropriately.",
"solutions": [
"✅ Use specific exception types (ValueError, TypeError, etc.)",
"✅ Create custom exception classes for domain-specific errors",
"✅ Use pytest.raises() with specific exception types",
"✅ Test for expected exceptions, not generic ones"
],
"examples": [
"# ❌ Instead of this:",
"def test_invalid_input():",
" with pytest.raises(Exception):",
" process_invalid_data()",
"",
"# ✅ Do this:",
"def test_invalid_input():",
" with pytest.raises(ValueError, match='Invalid input'):",
" process_invalid_data()",
"",
"# Or create custom exceptions:",
"def test_business_logic():",
" with pytest.raises(InsufficientFundsError):",
" account.withdraw(1000)"
]
},
"dont-import-test-modules": {
"title": "Test Module Import in Non-Test Code",
"problem": f"File '{Path(file_path).name}' imports test modules outside of test directories.",
"why": "Test modules should only be imported by other test files. Production code should not depend on test utilities.",
"solutions": [
"✅ Move shared test utilities to a separate utils module",
"✅ Create production versions of test helper functions",
"✅ Use dependency injection instead of direct test imports",
"✅ Extract common logic into a shared production module"
],
"examples": [
"# ❌ Instead of this (in production code):",
"from tests.test_helpers import create_test_data",
"",
"# ✅ Do this:",
"# Create src/utils/test_data_factory.py",
"from src.utils.test_data_factory import create_test_data",
"",
"# Or use fixtures in tests:",
"@pytest.fixture",
"def test_data():",
" return create_production_data()"
]
}
"no-conditionals-in-tests": (
f"Test {function_name} contains conditional logic. "
"Parameterize or split tests so each scenario is explicit."
),
"no-loop-in-tests": (
f"Test {function_name} iterates over data. Break the loop into "
"separate tests or use pytest parameterization."
),
"raise-specific-error": (
f"Test {function_name} asserts generic exceptions. "
"Assert specific types to document the expected behaviour."
),
"dont-import-test-modules": (
f"File {file_name} imports from tests. Move shared helpers into a "
"production module or provide them via fixtures."
),
}
guidance = guidance_map.get(rule_id, {
"title": f"Test Quality Issue: {rule_id}",
"problem": f"Rule '{rule_id}' violation detected in test code.",
"why": "This rule helps maintain test quality and best practices.",
"solutions": [
"✅ Review the specific rule requirements",
"✅ Refactor code to follow test best practices",
"✅ Consult testing framework documentation"
],
"examples": [
"# Review the rule documentation for specific guidance",
"# Consider the test's purpose and refactor accordingly"
]
})
# Format the guidance message
message = f"""
🚫 {guidance['title']}
📋 Problem: {guidance['problem']}
❓ Why this matters: {guidance['why']}
🛠️ How to fix it:
{chr(10).join(f"{solution}" for solution in guidance['solutions'])}
💡 Example:
{chr(10).join(f" {line}" for line in guidance['examples'])}
🔍 File: {Path(file_path).name}
📍 Function: {function_name}
"""
return message.strip()
return guidance_map.get(
rule_id,
"Keep tests behaviour-focused: avoid conditionals, loops, generic exceptions, "
"and production dependencies on test helpers.",
)
class ToolConfig(TypedDict):
@@ -261,27 +287,6 @@ class ToolConfig(TypedDict):
error_message: str | Callable[[subprocess.CompletedProcess[str]], str]
class DuplicateLocation(TypedDict):
"""Location information for a duplicate code block."""
name: str
lines: str
class Duplicate(TypedDict):
"""Duplicate code detection result."""
similarity: float
description: str
locations: list[DuplicateLocation]
class DuplicateResults(TypedDict):
"""Results from duplicate detection analysis."""
duplicates: list[Duplicate]
class ComplexitySummary(TypedDict):
"""Summary of complexity analysis."""
@@ -398,13 +403,17 @@ class QualityConfig:
pyrefly_enabled=os.getenv("QUALITY_PYREFLY_ENABLED", "true").lower()
== "true",
type_check_exit_code=int(os.getenv("QUALITY_TYPE_CHECK_EXIT_CODE", "2")),
test_quality_enabled=os.getenv("QUALITY_TEST_QUALITY_ENABLED", "true").lower()
== "true",
context7_enabled=os.getenv("QUALITY_CONTEXT7_ENABLED", "false").lower()
== "true",
test_quality_enabled=(
os.getenv("QUALITY_TEST_QUALITY_ENABLED", "true").lower()
== "true"
),
context7_enabled=(
os.getenv("QUALITY_CONTEXT7_ENABLED", "false").lower() == "true"
),
context7_api_key=os.getenv("QUALITY_CONTEXT7_API_KEY", ""),
firecrawl_enabled=os.getenv("QUALITY_FIRECRAWL_ENABLED", "false").lower()
== "true",
firecrawl_enabled=(
os.getenv("QUALITY_FIRECRAWL_ENABLED", "false").lower() == "true"
),
firecrawl_api_key=os.getenv("QUALITY_FIRECRAWL_API_KEY", ""),
)
@@ -418,26 +427,29 @@ def should_skip_file(file_path: str, config: QualityConfig) -> bool:
def _module_candidate(path: Path) -> tuple[Path, list[str]]:
"""Build a module invocation candidate for a python executable."""
return path, [str(path), "-m", "quality.cli.main"]
def _cli_candidate(path: Path) -> tuple[Path, list[str]]:
"""Build a direct CLI invocation candidate."""
return path, [str(path)]
def get_claude_quality_command(repo_root: Path | None = None) -> list[str]:
"""Return a path-resilient command for invoking claude-quality."""
repo_root = repo_root or Path(__file__).resolve().parent.parent
platform_name = sys.platform
is_windows = platform_name.startswith("win")
scripts_dir = repo_root / ".venv" / ("Scripts" if is_windows else "bin")
python_names = ["python.exe", "python3.exe"] if is_windows else ["python", "python3"]
cli_names = ["claude-quality.exe", "claude-quality"] if is_windows else ["claude-quality"]
python_names = (
["python.exe", "python3.exe"] if is_windows else ["python", "python3"]
)
cli_names = (
["claude-quality.exe", "claude-quality"]
if is_windows
else ["claude-quality"]
)
candidates: list[tuple[Path, list[str]]] = []
for name in python_names:
@@ -457,9 +469,11 @@ def get_claude_quality_command(repo_root: Path | None = None) -> list[str]:
if shutil.which("claude-quality"):
return ["claude-quality"]
raise RuntimeError(
"'claude-quality' was not found on PATH. Please ensure it is installed and available."
message = (
"'claude-quality' was not found on PATH. Please ensure it is installed and "
"available."
)
raise RuntimeError(message)
def _get_project_venv_bin(file_path: str | None = None) -> Path:
@@ -470,8 +484,16 @@ def _get_project_venv_bin(file_path: str | None = None) -> Path:
If not provided, uses current working directory.
"""
# Start from the file's directory if provided, otherwise from cwd
if file_path and not file_path.startswith("/tmp"):
start_path = Path(file_path).resolve().parent
temp_root = Path(gettempdir()).resolve()
if file_path:
resolved_path = Path(file_path).resolve()
try:
if resolved_path.is_relative_to(temp_root):
start_path = Path.cwd()
else:
start_path = resolved_path.parent
except ValueError:
start_path = resolved_path.parent
else:
start_path = Path.cwd()
@@ -501,7 +523,7 @@ def _format_basedpyright_errors(json_output: str) -> str:
# Group by severity and format
errors = []
for diag in diagnostics[:10]: # Limit to first 10 errors
severity = diag.get("severity", "error")
severity = diag.get("severity", "error").upper()
message = diag.get("message", "Unknown error")
rule = diag.get("rule", "")
range_info = diag.get("range", {})
@@ -509,7 +531,7 @@ def _format_basedpyright_errors(json_output: str) -> str:
line = start.get("line", 0) + 1 # Convert 0-indexed to 1-indexed
rule_text = f" [{rule}]" if rule else ""
errors.append(f" Line {line}: {message}{rule_text}")
errors.append(f" [{severity}] Line {line}: {message}{rule_text}")
count = len(diagnostics)
summary = f"Found {count} type error{'s' if count != 1 else ''}"
@@ -569,7 +591,8 @@ def _format_sourcery_errors(output: str) -> str:
formatted_lines.append(line)
if issue_count > 0:
summary = f"Found {issue_count} code quality issue{'s' if issue_count != 1 else ''}"
plural = "issue" if issue_count == 1 else "issues"
summary = f"Found {issue_count} code quality {plural}"
return f"{summary}:\n" + "\n".join(formatted_lines)
return output.strip()
@@ -659,7 +682,7 @@ def _run_type_checker(
else:
env["PYTHONPATH"] = str(src_dir)
# Run type checker from project root so it finds pyrightconfig.json and other configs
# Run type checker from project root to pick up project configuration files
result = subprocess.run( # noqa: S603
cmd,
check=False,
@@ -754,13 +777,11 @@ def _run_quality_analyses(
# First check for internal duplicates within the file
if config.duplicate_enabled:
internal_duplicates_raw = detect_internal_duplicates(
internal_duplicates = detect_internal_duplicates(
content,
threshold=config.duplicate_threshold,
min_lines=4,
)
# Cast after runtime validation - function returns compatible structure
internal_duplicates = cast("DuplicateResults", internal_duplicates_raw)
if internal_duplicates.get("duplicates"):
results["internal_duplicates"] = internal_duplicates
@@ -857,7 +878,7 @@ def _find_project_root(file_path: str) -> Path:
# Look for common project markers
while current != current.parent:
if any((current / marker).exists() for marker in [
".git", "pyrightconfig.json", "pyproject.toml", ".venv", "setup.py"
".git", "pyrightconfig.json", "pyproject.toml", ".venv", "setup.py",
]):
return current
current = current.parent
@@ -1226,7 +1247,8 @@ def _detect_any_usage(content: str) -> list[str]:
lines_with_any: set[int] = set()
try:
tree = ast.parse(content)
# Dedent the content to handle code fragments with leading indentation
tree = ast.parse(textwrap.dedent(content))
except SyntaxError:
for index, line in enumerate(content.splitlines(), start=1):
code_portion = line.split("#", 1)[0]
@@ -1247,23 +1269,24 @@ def _detect_any_usage(content: str) -> list[str]:
return [
"⚠️ Forbidden typing.Any usage at line(s) "
f"{display_lines}; replace with specific types"
f"{display_lines}; replace with specific types",
]
def _detect_type_ignore_usage(content: str) -> list[str]:
"""Detect forbidden # type: ignore usage in proposed content."""
pattern = re.compile(r"#\s*type:\s*ignore(?:\b|\[)", re.IGNORECASE)
lines_with_type_ignore: set[int] = set()
try:
# Dedent the content to handle code fragments with leading indentation
dedented_content = textwrap.dedent(content)
for token_type, token_string, start, _, _ in tokenize.generate_tokens(
StringIO(content).readline,
StringIO(dedented_content).readline,
):
if token_type == tokenize.COMMENT and pattern.search(token_string):
lines_with_type_ignore.add(start[0])
except tokenize.TokenError:
except (tokenize.TokenError, IndentationError):
for index, line in enumerate(content.splitlines(), start=1):
if pattern.search(line):
lines_with_type_ignore.add(index)
@@ -1278,28 +1301,32 @@ def _detect_type_ignore_usage(content: str) -> list[str]:
return [
"⚠️ Forbidden # type: ignore usage at line(s) "
f"{display_lines}; remove the suppression and fix typing issues instead"
f"{display_lines}; remove the suppression and fix typing issues instead",
]
def _detect_old_typing_patterns(content: str) -> list[str]:
"""Detect old typing patterns that should use modern syntax."""
issues: list[str] = []
# Old typing imports that should be replaced
old_patterns = {
r'\bfrom typing import.*\bUnion\b': 'Use | syntax instead of Union (e.g., str | int)',
r'\bfrom typing import.*\bOptional\b': 'Use | None syntax instead of Optional (e.g., str | None)',
r'\bfrom typing import.*\bList\b': 'Use list[T] instead of List[T]',
r'\bfrom typing import.*\bDict\b': 'Use dict[K, V] instead of Dict[K, V]',
r'\bfrom typing import.*\bSet\b': 'Use set[T] instead of Set[T]',
r'\bfrom typing import.*\bTuple\b': 'Use tuple[T, ...] instead of Tuple[T, ...]',
r'\bUnion\s*\[': 'Use | syntax instead of Union (e.g., str | int)',
r'\bOptional\s*\[': 'Use | None syntax instead of Optional (e.g., str | None)',
r'\bList\s*\[': 'Use list[T] instead of List[T]',
r'\bDict\s*\[': 'Use dict[K, V] instead of Dict[K, V]',
r'\bSet\s*\[': 'Use set[T] instead of Set[T]',
r'\bTuple\s*\[': 'Use tuple[T, ...] instead of Tuple[T, ...]',
r"\bfrom typing import.*\bUnion\b": (
"Use | syntax instead of Union (e.g., str | int)"
),
r"\bfrom typing import.*\bOptional\b": (
"Use | None syntax instead of Optional (e.g., str | None)"
),
r"\bfrom typing import.*\bList\b": "Use list[T] instead of List[T]",
r"\bfrom typing import.*\bDict\b": "Use dict[K, V] instead of Dict[K, V]",
r"\bfrom typing import.*\bSet\b": "Use set[T] instead of Set[T]",
r"\bfrom typing import.*\bTuple\b": (
"Use tuple[T, ...] instead of Tuple[T, ...]"
),
r"\bUnion\s*\[": "Use | syntax instead of Union (e.g., str | int)",
r"\bOptional\s*\[": "Use | None syntax instead of Optional (e.g., str | None)",
r"\bList\s*\[": "Use list[T] instead of List[T]",
r"\bDict\s*\[": "Use dict[K, V] instead of Dict[K, V]",
r"\bSet\s*\[": "Use set[T] instead of Set[T]",
r"\bTuple\s*\[": "Use tuple[T, ...] instead of Tuple[T, ...]",
}
lines = content.splitlines()
@@ -1309,7 +1336,7 @@ def _detect_old_typing_patterns(content: str) -> list[str]:
lines_with_pattern = []
for i, line in enumerate(lines, 1):
# Skip comments
code_part = line.split('#')[0]
code_part = line.split("#")[0]
if re.search(pattern, code_part):
lines_with_pattern.append(i)
@@ -1317,7 +1344,10 @@ def _detect_old_typing_patterns(content: str) -> list[str]:
display_lines = ", ".join(str(num) for num in lines_with_pattern[:5])
if len(lines_with_pattern) > 5:
display_lines += ", …"
found_issues.append(f"⚠️ Old typing pattern at line(s) {display_lines}: {message}")
issue_text = (
f"⚠️ Old typing pattern at line(s) {display_lines}: {message}"
)
found_issues.append(issue_text)
return found_issues
@@ -1326,22 +1356,6 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
"""Detect files and functions/classes with suspicious adjective/adverb suffixes."""
issues: list[str] = []
# Common adjective/adverb suffixes that indicate potential duplication
SUSPICIOUS_SUFFIXES = {
"enhanced", "improved", "better", "new", "updated", "modified", "refactored",
"optimized", "fixed", "clean", "simple", "advanced", "basic", "complete",
"final", "latest", "current", "temp", "temporary", "backup", "old", "legacy",
"unified", "merged", "combined", "integrated", "consolidated", "extended",
"enriched", "augmented", "upgraded", "revised", "polished", "streamlined",
"simplified", "modernized", "normalized", "sanitized", "validated", "verified",
"corrected", "patched", "stable", "experimental", "alpha", "beta", "draft",
"preliminary", "prototype", "working", "test", "debug", "custom", "special",
"generic", "specific", "general", "detailed", "minimal", "full", "partial",
"quick", "fast", "slow", "smart", "intelligent", "auto", "manual", "secure",
"safe", "robust", "flexible", "dynamic", "static", "reactive", "async",
"sync", "parallel", "serial", "distributed", "centralized", "decentralized"
}
# Check file name against other files in the same directory
file_path_obj = Path(file_path)
if file_path_obj.parent.exists():
@@ -1350,17 +1364,23 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
# Check if current file has suspicious suffix
for suffix in SUSPICIOUS_SUFFIXES:
if file_stem.endswith(f"_{suffix}") or file_stem.endswith(f"-{suffix}"):
base_name = file_stem[:-len(suffix)-1]
potential_original = file_path_obj.parent / f"{base_name}{file_suffix}"
for separator in ("_", "-"):
suffix_token = f"{separator}{suffix}"
if not file_stem.endswith(suffix_token):
continue
base_name = file_stem[: -len(suffix_token)]
potential_original = file_path_obj.parent / f"{base_name}{file_suffix}"
if potential_original.exists() and potential_original != file_path_obj:
issues.append(
f"⚠️ File '{file_path_obj.name}' appears to be a suffixed duplicate of "
f"'{potential_original.name}'. Consider refactoring instead of creating "
f"variations with adjective suffixes."
message = FILE_SUFFIX_DUPLICATE_MSG.format(
current=file_path_obj.name,
original=potential_original.name,
)
issues.append(message)
break
else:
continue
break
# Check if any existing files are suffixed versions of current file
for existing_file in file_path_obj.parent.glob(f"{file_stem}_*{file_suffix}"):
@@ -1369,11 +1389,11 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
if existing_stem.startswith(f"{file_stem}_"):
potential_suffix = existing_stem[len(file_stem)+1:]
if potential_suffix in SUSPICIOUS_SUFFIXES:
issues.append(
f"⚠️ Creating '{file_path_obj.name}' when '{existing_file.name}' "
f"already exists suggests duplication. Consider consolidating or "
f"using a more descriptive name."
message = EXISTING_FILE_DUPLICATE_MSG.format(
current=file_path_obj.name,
existing=existing_file.name,
)
issues.append(message)
break
# Same check for dash-separated suffixes
@@ -1383,16 +1403,17 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
if existing_stem.startswith(f"{file_stem}-"):
potential_suffix = existing_stem[len(file_stem)+1:]
if potential_suffix in SUSPICIOUS_SUFFIXES:
issues.append(
f"⚠️ Creating '{file_path_obj.name}' when '{existing_file.name}' "
f"already exists suggests duplication. Consider consolidating or "
f"using a more descriptive name."
message = EXISTING_FILE_DUPLICATE_MSG.format(
current=file_path_obj.name,
existing=existing_file.name,
)
issues.append(message)
break
# Check function and class names in content
try:
tree = ast.parse(content)
# Dedent the content to handle code fragments with leading indentation
tree = ast.parse(textwrap.dedent(content))
class SuffixVisitor(ast.NodeVisitor):
def __init__(self):
@@ -1417,36 +1438,47 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
# Check for suspicious function name patterns
for func_name in visitor.function_names:
for suffix in SUSPICIOUS_SUFFIXES:
if func_name.endswith(f"_{suffix}"):
base_name = func_name[:-len(suffix)-1]
if base_name in visitor.function_names:
issues.append(
f"⚠️ Function '{func_name}' appears to be a suffixed duplicate of "
f"'{base_name}'. Consider refactoring instead of creating variations."
)
break
suffix_token = f"_{suffix}"
if not func_name.endswith(suffix_token):
continue
base_name = func_name[: -len(suffix_token)]
if base_name in visitor.function_names:
message = NAME_SUFFIX_DUPLICATE_MSG.format(
kind="Function",
name=func_name,
base=base_name,
)
issues.append(message)
break
# Check for suspicious class name patterns
for class_name in visitor.class_names:
for suffix in SUSPICIOUS_SUFFIXES:
# Convert to check both PascalCase and snake_case patterns
pascal_suffix = suffix.capitalize()
if class_name.endswith(pascal_suffix):
base_name = class_name[:-len(pascal_suffix)]
if base_name in visitor.class_names:
issues.append(
f"⚠️ Class '{class_name}' appears to be a suffixed duplicate of "
f"'{base_name}'. Consider refactoring instead of creating variations."
)
break
elif class_name.endswith(f"_{suffix}"):
base_name = class_name[:-len(suffix)-1]
if base_name in visitor.class_names:
issues.append(
f"⚠️ Class '{class_name}' appears to be a suffixed duplicate of "
f"'{base_name}'. Consider refactoring instead of creating variations."
snake_suffix = f"_{suffix}"
potential_matches = (
(pascal_suffix, class_name[: -len(pascal_suffix)]),
(snake_suffix, class_name[: -len(snake_suffix)]),
)
for token, base_name in potential_matches:
if (
token
and class_name.endswith(token)
and base_name in visitor.class_names
):
message = NAME_SUFFIX_DUPLICATE_MSG.format(
kind="Class",
name=class_name,
base=base_name,
)
issues.append(message)
break
else:
continue
break
except SyntaxError:
# If we can't parse the AST, skip function/class checks
@@ -1612,19 +1644,24 @@ def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject:
enable_type_checks = tool_name == "Write"
# Always run core quality checks (Any, type: ignore, old typing, duplicates) regardless of skip patterns
# Always run core checks (Any, type: ignore, typing, duplicates) before skipping
any_usage_issues = _detect_any_usage(content)
type_ignore_issues = _detect_type_ignore_usage(content)
old_typing_issues = _detect_old_typing_patterns(content)
suffix_duplication_issues = _detect_suffix_duplication(file_path, content)
precheck_issues = any_usage_issues + type_ignore_issues + old_typing_issues + suffix_duplication_issues
precheck_issues = (
any_usage_issues
+ type_ignore_issues
+ old_typing_issues
+ suffix_duplication_issues
)
# Run test quality checks if enabled and file is a test file
if run_test_checks:
test_quality_issues = run_test_quality_checks(content, file_path, config)
precheck_issues.extend(test_quality_issues)
# Skip detailed analysis for configured patterns, but not if it's a test file with test checks enabled
# Skip detailed analysis for configured patterns unless test checks should run
# Note: Core quality checks (Any, type: ignore, duplicates) always run above
should_skip_detailed = should_skip_file(file_path, config) and not run_test_checks
@@ -1766,8 +1803,12 @@ def is_test_file(file_path: str) -> bool:
return any(part in ("test", "tests", "testing") for part in path_parts)
def run_test_quality_checks(content: str, file_path: str, config: QualityConfig) -> list[str]:
"""Run Sourcery with specific test-related rules and return issues with enhanced guidance."""
def run_test_quality_checks(
content: str,
file_path: str,
config: QualityConfig,
) -> list[str]:
"""Run Sourcery's test rules and return guidance-enhanced issues."""
issues: list[str] = []
# Only run test quality checks for test files
@@ -1846,20 +1887,40 @@ def run_test_quality_checks(content: str, file_path: str, config: QualityConfig)
# Issues were found - parse the output
output = result.stdout + result.stderr
# Try to extract rule names from the output
# Sourcery output format typically includes rule names in brackets or after specific markers
# Try to extract rule names from the output. Sourcery usually includes
# rule identifiers in brackets or descriptive text.
for rule in test_rules:
if rule in output or rule.replace("-", " ") in output.lower():
base_guidance = generate_test_quality_guidance(rule, content, file_path, config)
external_context = get_external_context(rule, content, file_path, config)
base_guidance = generate_test_quality_guidance(
rule,
content,
file_path,
config,
)
external_context = get_external_context(
rule,
content,
file_path,
config,
)
if external_context:
base_guidance += f"\n\n{external_context}"
issues.append(base_guidance)
break # Only add one guidance message
else:
# If no specific rule found, provide general guidance
base_guidance = generate_test_quality_guidance("unknown", content, file_path, config)
external_context = get_external_context("unknown", content, file_path, config)
base_guidance = generate_test_quality_guidance(
"unknown",
content,
file_path,
config,
)
external_context = get_external_context(
"unknown",
content,
file_path,
config,
)
if external_context:
base_guidance += f"\n\n{external_context}"
issues.append(base_guidance)

View File

@@ -7,9 +7,10 @@ import ast
import difflib
import hashlib
import re
import textwrap
from collections import defaultdict
from dataclasses import dataclass
from typing import Any
from dataclasses import dataclass, field
from typing import TypedDict
COMMON_DUPLICATE_METHODS = {
"__init__",
@@ -19,6 +20,57 @@ COMMON_DUPLICATE_METHODS = {
"__aexit__",
}
# Test-specific patterns that commonly have legitimate duplication
TEST_FIXTURE_PATTERNS = {
"fixture",
"mock",
"stub",
"setup",
"teardown",
"data",
"sample",
}
# Common test assertion patterns
TEST_ASSERTION_PATTERNS = {
"assert",
"expect",
"should",
}
class DuplicateLocation(TypedDict):
"""Location information for a duplicate code block."""
name: str
type: str
lines: str
class Duplicate(TypedDict):
"""Duplicate detection result entry."""
type: str
similarity: float
description: str
locations: list[DuplicateLocation]
class DuplicateSummary(TypedDict, total=False):
"""Summary data accompanying duplicate detection."""
total_duplicates: int
blocks_analyzed: int
duplicate_lines: int
class DuplicateResults(TypedDict, total=False):
"""Structured results returned by duplicate detection."""
duplicates: list[Duplicate]
summary: DuplicateSummary
error: str
@dataclass
class CodeBlock:
@@ -31,11 +83,12 @@ class CodeBlock:
source: str
ast_node: ast.AST
complexity: int = 0
tokens: list[str] = None
tokens: list[str] = field(init=False)
decorators: list[str] = field(init=False)
def __post_init__(self):
if self.tokens is None:
self.tokens = self._tokenize()
def __post_init__(self) -> None:
self.tokens = self._tokenize()
self.decorators = self._extract_decorators()
def _tokenize(self) -> list[str]:
"""Extract meaningful tokens from source code."""
@@ -47,6 +100,40 @@ class CodeBlock:
# Extract identifiers, keywords, operators
return re.findall(r"\b\w+\b|[=<>!+\-*/]+", code)
def _extract_decorators(self) -> list[str]:
"""Extract decorator names from the AST node."""
decorators: list[str] = []
if isinstance(
self.ast_node,
(ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef),
):
for decorator in self.ast_node.decorator_list:
if isinstance(decorator, ast.Name):
decorators.append(decorator.id)
elif isinstance(decorator, ast.Attribute):
decorators.append(decorator.attr)
elif isinstance(decorator, ast.Call):
if isinstance(decorator.func, ast.Name):
decorators.append(decorator.func.id)
elif isinstance(decorator.func, ast.Attribute):
decorators.append(decorator.func.attr)
return decorators
def is_test_fixture(self) -> bool:
"""Check if this block is a pytest fixture."""
return "fixture" in self.decorators
def is_test_function(self) -> bool:
"""Check if this block is a test function."""
return self.name.startswith("test_") or (
self.type == "method" and self.name.startswith("test_")
)
def has_test_pattern_name(self) -> bool:
"""Check if name contains common test fixture patterns."""
name_lower = self.name.lower()
return any(pattern in name_lower for pattern in TEST_FIXTURE_PATTERNS)
@dataclass
class DuplicateGroup:
@@ -67,15 +154,16 @@ class InternalDuplicateDetector:
min_lines: int = 4,
min_tokens: int = 20,
):
self.similarity_threshold = similarity_threshold
self.min_lines = min_lines
self.min_tokens = min_tokens
self.similarity_threshold: float = similarity_threshold
self.min_lines: int = min_lines
self.min_tokens: int = min_tokens
self.duplicate_groups: list[DuplicateGroup] = []
def analyze_code(self, source_code: str) -> dict[str, Any]:
def analyze_code(self, source_code: str) -> DuplicateResults:
"""Analyze source code for internal duplicates."""
try:
tree = ast.parse(source_code)
# Dedent the content to handle code fragments with leading indentation
tree = ast.parse(textwrap.dedent(source_code))
except SyntaxError:
return {
"error": "Failed to parse code",
@@ -104,7 +192,7 @@ class InternalDuplicateDetector:
}
# Find duplicates
duplicate_groups = []
duplicate_groups: list[DuplicateGroup] = []
# 1. Check for exact duplicates (normalized)
exact_groups = self._find_exact_duplicates(blocks)
@@ -125,7 +213,7 @@ class InternalDuplicateDetector:
and not self._should_ignore_group(group)
]
results = [
results: list[Duplicate] = [
{
"type": group.pattern_type,
"similarity": group.similarity_score,
@@ -155,26 +243,25 @@ class InternalDuplicateDetector:
def _extract_code_blocks(self, tree: ast.AST, source: str) -> list[CodeBlock]:
"""Extract functions, methods, and classes from AST."""
blocks = []
blocks: list[CodeBlock] = []
lines = source.split("\n")
def create_block(
node: ast.AST,
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
block_type: str,
lines: list[str],
) -> CodeBlock | None:
try:
start = node.lineno - 1
end = node.end_lineno - 1 if hasattr(node, "end_lineno") else start
end_lineno = getattr(node, "end_lineno", None)
end = end_lineno - 1 if end_lineno is not None else start
source = "\n".join(lines[start : end + 1])
return CodeBlock(
name=node.name,
type=block_type,
start_line=node.lineno,
end_line=node.end_lineno
if hasattr(node, "end_lineno")
else node.lineno,
end_line=end_lineno if end_lineno is not None else node.lineno,
source=source,
ast_node=node,
complexity=calculate_complexity(node),
@@ -459,6 +546,7 @@ class InternalDuplicateDetector:
if not group.blocks:
return False
# Check for common dunder methods
if all(block.name in COMMON_DUPLICATE_METHODS for block in group.blocks):
max_lines = max(
block.end_line - block.start_line + 1 for block in group.blocks
@@ -469,6 +557,36 @@ class InternalDuplicateDetector:
if max_lines <= 12 and max_complexity <= 3:
return True
# Check for pytest fixtures - they legitimately have repetitive structure
if all(block.is_test_fixture() for block in group.blocks):
max_lines = max(
block.end_line - block.start_line + 1 for block in group.blocks
)
# Allow fixtures up to 15 lines with similar structure
if max_lines <= 15:
return True
# Check for test functions with fixture-like names (data builders, mocks, etc.)
if all(block.has_test_pattern_name() for block in group.blocks):
max_lines = max(
block.end_line - block.start_line + 1 for block in group.blocks
)
max_complexity = max(block.complexity for block in group.blocks)
# Allow test helpers that are simple and short
if max_lines <= 10 and max_complexity <= 4:
return True
# Check for simple test functions with arrange-act-assert pattern
if all(block.is_test_function() for block in group.blocks):
max_complexity = max(block.complexity for block in group.blocks)
max_lines = max(
block.end_line - block.start_line + 1 for block in group.blocks
)
# Simple tests (<=15 lines) often share similar control flow.
# Permit full similarity for those cases; duplication is acceptable.
if max_complexity <= 5 and max_lines <= 15:
return True
return False
@@ -476,7 +594,7 @@ def detect_internal_duplicates(
source_code: str,
threshold: float = 0.7,
min_lines: int = 4,
) -> dict[str, Any]:
) -> DuplicateResults:
"""Main function to detect internal duplicates in code."""
detector = InternalDuplicateDetector(
similarity_threshold=threshold,

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,10 @@
"""Fixture module used to verify Any detection in the guard."""
# ruff: noqa: ANN401 # These annotations intentionally use Any for the test harness.
from typing import Any
def process_data(data: Any) -> Any:
"""This should be blocked by the hook."""
return data
"""Return the provided value; the guard should block this in practice."""
return data

View File

@@ -1,169 +1,131 @@
#!/usr/bin/env python3
"""
Core hook validation tests - MUST ALL PASS
"""
import sys
from pathlib import Path
"""Core coverage tests for the code quality guard hooks."""
# Add hooks to path
sys.path.insert(0, str(Path(__file__).parent / "hooks"))
from __future__ import annotations
from code_quality_guard import pretooluse_hook, QualityConfig
from dataclasses import dataclass
from typing import Any
import pytest
from hooks.code_quality_guard import QualityConfig, pretooluse_hook
def test_core_blocking():
"""Test the core blocking functionality that MUST work."""
@pytest.fixture
def strict_config() -> QualityConfig:
"""Return a strict enforcement configuration for the guard."""
config = QualityConfig.from_env()
config.enforcement_mode = "strict"
return config
tests_passed = 0
tests_failed = 0
# Test 1: Any usage MUST be blocked
print("🧪 Test 1: Any usage blocking")
@dataclass(slots=True)
class BlockingScenario:
"""Parameters describing an expected blocking outcome."""
name: str
tool_name: str
tool_input: dict[str, Any]
reason_fragment: str
BLOCKING_SCENARIOS: tuple[BlockingScenario, ...] = (
BlockingScenario(
name="typing-any",
tool_name="Write",
tool_input={
"file_path": "/src/production.py",
"content": (
"from typing import Any\n"
"def bad(value: Any) -> Any:\n"
" return value\n"
),
},
reason_fragment="typing.Any usage",
),
BlockingScenario(
name="type-ignore",
tool_name="Write",
tool_input={
"file_path": "/src/production.py",
"content": (
"def bad() -> None:\n"
" value = call() # type: ignore\n"
" return value\n"
),
},
reason_fragment="type: ignore",
),
BlockingScenario(
name="legacy-typing",
tool_name="Write",
tool_input={
"file_path": "/src/production.py",
"content": (
"from typing import Optional, Union\n"
"def bad(value: Union[str, int]) -> Optional[str]:\n"
" return None\n"
),
},
reason_fragment="Old typing pattern",
),
BlockingScenario(
name="edit-tool-any",
tool_name="Edit",
tool_input={
"file_path": "/src/production.py",
"old_string": "def old():\n return 1\n",
"new_string": "def new(value: Any) -> Any:\n return value\n",
},
reason_fragment="typing.Any usage",
),
)
@pytest.mark.parametrize(
"scenario",
BLOCKING_SCENARIOS,
ids=lambda scenario: scenario.name,
)
def test_pretooluse_blocks_expected_patterns(
strict_config: QualityConfig,
scenario: BlockingScenario,
) -> None:
"""Verify the guard blocks known bad patterns."""
hook_data = {"tool_name": scenario.tool_name, "tool_input": scenario.tool_input}
result = pretooluse_hook(hook_data, strict_config)
assert result["permissionDecision"] == "deny"
assert scenario.reason_fragment in result.get("reason", "")
def test_pretooluse_allows_modern_code(strict_config: QualityConfig) -> None:
"""PreToolUse hook allows well-typed Python content."""
hook_data = {
"tool_name": "Write",
"tool_input": {
"file_path": "/src/production.py",
"content": "from typing import Any\ndef bad(x: Any) -> Any: return x"
}
"content": (
"def good(value: str | int) -> str | None:\n"
" return str(value) if value else None\n"
),
},
}
try:
result = pretooluse_hook(hook_data, config)
if result["permissionDecision"] == "deny" and "typing.Any usage" in result.get("reason", ""):
print("✅ PASS: Any usage properly blocked")
tests_passed += 1
else:
print(f"❌ FAIL: Any usage not blocked. Decision: {result['permissionDecision']}")
tests_failed += 1
except Exception as e:
print(f"❌ FAIL: Exception in Any test: {e}")
tests_failed += 1
result = pretooluse_hook(hook_data, strict_config)
# Test 2: type: ignore MUST be blocked
print("\n🧪 Test 2: type: ignore blocking")
hook_data = {
"tool_name": "Write",
"tool_input": {
"file_path": "/src/production.py",
"content": "def bad():\n x = call() # type: ignore\n return x"
}
}
assert result["permissionDecision"] == "allow"
try:
result = pretooluse_hook(hook_data, config)
if result["permissionDecision"] == "deny" and "type: ignore" in result.get("reason", ""):
print("✅ PASS: type: ignore properly blocked")
tests_passed += 1
else:
print(f"❌ FAIL: type: ignore not blocked. Decision: {result['permissionDecision']}")
tests_failed += 1
except Exception as e:
print(f"❌ FAIL: Exception in type: ignore test: {e}")
tests_failed += 1
# Test 3: Old typing patterns MUST be blocked
print("\n🧪 Test 3: Old typing patterns blocking")
hook_data = {
"tool_name": "Write",
"tool_input": {
"file_path": "/src/production.py",
"content": "from typing import Union, Optional\ndef bad(x: Union[str, int]) -> Optional[str]: return None"
}
}
try:
result = pretooluse_hook(hook_data, config)
if result["permissionDecision"] == "deny" and "Old typing pattern" in result.get("reason", ""):
print("✅ PASS: Old typing patterns properly blocked")
tests_passed += 1
else:
print(f"❌ FAIL: Old typing patterns not blocked. Decision: {result['permissionDecision']}")
tests_failed += 1
except Exception as e:
print(f"❌ FAIL: Exception in old typing test: {e}")
tests_failed += 1
# Test 4: Good code MUST be allowed
print("\n🧪 Test 4: Good code allowed")
hook_data = {
"tool_name": "Write",
"tool_input": {
"file_path": "/src/production.py",
"content": "def good(x: str | int) -> str | None:\n return str(x) if x else None"
}
}
try:
result = pretooluse_hook(hook_data, config)
if result["permissionDecision"] == "allow":
print("✅ PASS: Good code properly allowed")
tests_passed += 1
else:
print(f"❌ FAIL: Good code blocked. Decision: {result['permissionDecision']}")
print(f" Reason: {result.get('reason', 'No reason')}")
tests_failed += 1
except Exception as e:
print(f"❌ FAIL: Exception in good code test: {e}")
tests_failed += 1
# Test 5: Edit tool MUST also block
print("\n🧪 Test 5: Edit tool blocking")
hook_data = {
"tool_name": "Edit",
"tool_input": {
"file_path": "/src/production.py",
"old_string": "def old():",
"new_string": "def new(x: Any) -> Any:"
}
}
try:
result = pretooluse_hook(hook_data, config)
if result["permissionDecision"] == "deny" and "typing.Any usage" in result.get("reason", ""):
print("✅ PASS: Edit tool properly blocked")
tests_passed += 1
else:
print(f"❌ FAIL: Edit tool not blocked. Decision: {result['permissionDecision']}")
tests_failed += 1
except Exception as e:
print(f"❌ FAIL: Exception in Edit tool test: {e}")
tests_failed += 1
# Test 6: Non-Python files MUST be allowed
print("\n🧪 Test 6: Non-Python files allowed")
def test_pretooluse_allows_non_python_files(strict_config: QualityConfig) -> None:
"""Non-Python files should bypass quality restrictions."""
hook_data = {
"tool_name": "Write",
"tool_input": {
"file_path": "/src/config.json",
"content": '{"type": "Any", "ignore": true}'
}
"content": '{"type": "Any", "ignore": true}',
},
}
try:
result = pretooluse_hook(hook_data, config)
if result["permissionDecision"] == "allow":
print("✅ PASS: Non-Python files properly allowed")
tests_passed += 1
else:
print(f"❌ FAIL: Non-Python file blocked. Decision: {result['permissionDecision']}")
tests_failed += 1
except Exception as e:
print(f"❌ FAIL: Exception in non-Python test: {e}")
tests_failed += 1
result = pretooluse_hook(hook_data, strict_config)
print(f"\n📊 Results: {tests_passed} passed, {tests_failed} failed")
if tests_failed == 0:
print("🎉 ALL CORE TESTS PASSED! Hooks are working correctly.")
return True
else:
print(f"💥 {tests_failed} CRITICAL TESTS FAILED! Hooks are broken.")
return False
if __name__ == "__main__":
print("🚀 Running core hook validation tests...\n")
success = test_core_blocking()
sys.exit(0 if success else 1)
assert result["permissionDecision"] == "allow"

View File

@@ -1,9 +1,15 @@
def bad_function():
"""Function with type ignore."""
x = "string"
return x + 5 # type: ignore
"""Fixture module used to check handling of type: ignore annotations."""
def another_bad():
"""Another function with type ignore."""
from __future__ import annotations
def bad_function() -> int:
"""Return a value while suppressing a typing error."""
x = "string"
return x + 5 # type: ignore[arg-type]
def another_bad() -> int:
"""Return a value after an ignored assignment mismatch."""
y: int = "not an int" # type: ignore[assignment]
return y
return y

2
tests/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""Test package marker for Ruff namespace rules."""

View File

@@ -1,19 +1,20 @@
"""Comprehensive test suite covering all hook interaction scenarios."""
# ruff: noqa: SLF001
# pyright: reportPrivateUsage=false, reportPrivateImportUsage=false, reportPrivateLocalImportUsage=false, reportUnusedCallResult=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownLambdaType=false, reportUnknownMemberType=false
from __future__ import annotations
import json
import os
import subprocess
import sys
from collections.abc import Mapping
from pathlib import Path
from tempfile import gettempdir
import pytest
HOOKS_DIR = Path(__file__).parent.parent.parent / "hooks"
sys.path.insert(0, str(HOOKS_DIR))
import code_quality_guard as guard
from hooks import code_quality_guard as guard
class TestProjectStructureVariations:
@@ -26,14 +27,14 @@ class TestProjectStructureVariations:
root.mkdir()
(root / ".venv/bin").mkdir(parents=True)
(root / "pyproject.toml").touch()
test_file = root / "main.py"
test_file.write_text("# test")
# Should find project root
found_root = guard._find_project_root(str(test_file))
assert found_root == root
# Should create .tmp in root
tmp_dir = guard._get_project_tmp_dir(str(test_file))
assert tmp_dir == root / ".tmp"
@@ -49,13 +50,13 @@ class TestProjectStructureVariations:
(root / "src/package").mkdir(parents=True)
(root / ".venv/bin").mkdir(parents=True)
(root / "pyproject.toml").touch()
test_file = root / "src/package/module.py"
test_file.write_text("# test")
found_root = guard._find_project_root(str(test_file))
assert found_root == root
venv_bin = guard._get_project_venv_bin(str(test_file))
assert venv_bin == root / ".venv/bin"
finally:
@@ -70,19 +71,19 @@ class TestProjectStructureVariations:
# Outer project
(outer / ".venv/bin").mkdir(parents=True)
(outer / ".git").mkdir()
# Inner project
inner = outer / "subproject"
(inner / ".venv/bin").mkdir(parents=True)
(inner / "pyproject.toml").touch()
test_file = inner / "main.py"
test_file.write_text("# test")
# Should find inner project root
found_root = guard._find_project_root(str(test_file))
assert found_root == inner
# Should use inner venv
venv_bin = guard._get_project_venv_bin(str(test_file))
assert venv_bin == inner / ".venv/bin"
@@ -116,10 +117,10 @@ class TestProjectStructureVariations:
deep = root / "a/b/c/d/e/f"
deep.mkdir(parents=True)
(root / ".git").mkdir()
test_file = deep / "module.py"
test_file.write_text("# test")
found_root = guard._find_project_root(str(test_file))
assert found_root == root
finally:
@@ -137,13 +138,13 @@ class TestConfigurationInheritance:
try:
(root / "src").mkdir(parents=True)
(root / ".venv/bin").mkdir(parents=True)
config = {"reportUnknownMemberType": False}
(root / "pyrightconfig.json").write_text(json.dumps(config))
test_file = root / "src/mod.py"
test_file.write_text("# test")
found_root = guard._find_project_root(str(test_file))
assert found_root == root
assert (found_root / "pyrightconfig.json").exists()
@@ -158,10 +159,10 @@ class TestConfigurationInheritance:
try:
root.mkdir()
(root / "pyproject.toml").write_text("[tool.mypy]\n")
test_file = root / "main.py"
test_file.write_text("# test")
found_root = guard._find_project_root(str(test_file))
assert found_root == root
finally:
@@ -176,13 +177,13 @@ class TestConfigurationInheritance:
root.mkdir()
(root / "pyproject.toml").touch()
(root / ".gitignore").write_text("*.pyc\n__pycache__/\n")
test_file = root / "main.py"
test_file.write_text("# test")
tmp_dir = guard._get_project_tmp_dir(str(test_file))
assert tmp_dir.exists()
gitignore_content = (root / ".gitignore").read_text()
assert ".tmp/" in gitignore_content
finally:
@@ -198,12 +199,12 @@ class TestConfigurationInheritance:
(root / "pyproject.toml").touch()
original = "*.pyc\n.tmp/\n"
(root / ".gitignore").write_text(original)
test_file = root / "main.py"
test_file.write_text("# test")
_ = guard._get_project_tmp_dir(str(test_file))
# Should not have been modified
assert (root / ".gitignore").read_text() == original
finally:
@@ -266,26 +267,30 @@ class TestVirtualEnvironmentEdgeCases:
tool = root / ".venv/bin/basedpyright"
tool.write_text("#!/bin/bash\necho fake")
tool.chmod(0o755)
test_file = root / "main.py"
test_file.write_text("# test")
captured_env = {}
def capture_run(cmd: list[str], **kw: object) -> subprocess.CompletedProcess[str]:
if "env" in kw:
captured_env.update(dict(kw["env"]))
return subprocess.CompletedProcess(cmd, 0, "", "")
captured_env: dict[str, str] = {}
def capture_run(
cmd: list[str],
**kw: object,
) -> subprocess.CompletedProcess[str]:
env_obj = kw.get("env")
if isinstance(env_obj, Mapping):
captured_env.update({str(k): str(v) for k, v in env_obj.items()})
return subprocess.CompletedProcess(list(cmd), 0, stdout="", stderr="")
monkeypatch.setattr(guard.subprocess, "run", capture_run)
guard._run_type_checker(
"basedpyright",
str(test_file),
guard.QualityConfig(),
original_file_path=str(test_file),
)
# PYTHONPATH should not be set (or not include src)
if "PYTHONPATH" in captured_env:
assert "src" not in captured_env["PYTHONPATH"]
@@ -305,7 +310,7 @@ class TestTypeCheckerIntegration:
pyrefly_enabled=False,
sourcery_enabled=False,
)
issues = guard.run_type_checks("test.py", config)
assert issues == []
@@ -316,13 +321,13 @@ class TestTypeCheckerIntegration:
"""Missing tool returns warning, doesn't crash."""
monkeypatch.setattr(guard.Path, "exists", lambda _: False, raising=False)
monkeypatch.setattr(guard, "_ensure_tool_installed", lambda _: False)
success, message = guard._run_type_checker(
"basedpyright",
"test.py",
guard.QualityConfig(),
)
assert success is True
assert "not available" in message
@@ -332,18 +337,18 @@ class TestTypeCheckerIntegration:
) -> None:
"""Tool timeout is handled gracefully."""
monkeypatch.setattr(guard.Path, "exists", lambda _: True, raising=False)
def timeout_run(*_args: object, **_kw: object) -> None:
raise subprocess.TimeoutExpired(cmd=["tool"], timeout=30)
monkeypatch.setattr(guard.subprocess, "run", timeout_run)
success, message = guard._run_type_checker(
"basedpyright",
"test.py",
guard.QualityConfig(),
)
assert success is True
assert "timeout" in message.lower()
@@ -353,22 +358,26 @@ class TestTypeCheckerIntegration:
) -> None:
"""OS errors from tools are handled."""
monkeypatch.setattr(guard.Path, "exists", lambda _: True, raising=False)
def error_run(*_args: object, **_kw: object) -> None:
raise OSError("Permission denied")
message = "Permission denied"
raise OSError(message)
monkeypatch.setattr(guard.subprocess, "run", error_run)
success, message = guard._run_type_checker(
"basedpyright",
"test.py",
guard.QualityConfig(),
)
assert success is True
assert "execution error" in message.lower()
def test_unknown_tool_returns_warning(self, monkeypatch: pytest.MonkeyPatch) -> None:
def test_unknown_tool_returns_warning(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Unknown tool name returns warning."""
# Mock tool not existing
monkeypatch.setattr(guard.Path, "exists", lambda _: False, raising=False)
@@ -397,32 +406,36 @@ class TestWorkingDirectoryScenarios:
(root / "src").mkdir(parents=True)
(root / ".venv/bin").mkdir(parents=True)
(root / "pyrightconfig.json").touch()
tool = root / ".venv/bin/basedpyright"
tool.write_text("#!/bin/bash\npwd")
tool.chmod(0o755)
test_file = root / "src/mod.py"
test_file.write_text("# test")
captured_cwd = []
def capture_run(cmd: list[str], **kw: object) -> subprocess.CompletedProcess[str]:
if "cwd" in kw:
captured_cwd.append(str(kw["cwd"]))
return subprocess.CompletedProcess(cmd, 0, "", "")
captured_cwd: list[Path] = []
def capture_run(
cmd: list[str],
**kw: object,
) -> subprocess.CompletedProcess[str]:
cwd_obj = kw.get("cwd")
if cwd_obj is not None:
captured_cwd.append(Path(str(cwd_obj)))
return subprocess.CompletedProcess(list(cmd), 0, stdout="", stderr="")
monkeypatch.setattr(guard.subprocess, "run", capture_run)
guard._run_type_checker(
"basedpyright",
str(test_file),
guard.QualityConfig(),
original_file_path=str(test_file),
)
assert len(captured_cwd) > 0
assert Path(captured_cwd[0]) == root
assert captured_cwd
assert captured_cwd[0] == root
finally:
import shutil
if root.exists():
@@ -444,10 +457,11 @@ class TestErrorConditions:
) -> None:
"""Permission error creating .tmp is handled."""
def raise_permission(*_args: object, **_kw: object) -> None:
raise PermissionError("Cannot create directory")
message = "Cannot create directory"
raise PermissionError(message)
monkeypatch.setattr(Path, "mkdir", raise_permission)
# Should raise and be caught by caller
with pytest.raises(PermissionError):
guard._get_project_tmp_dir("/some/file.py")
@@ -460,7 +474,7 @@ class TestErrorConditions:
(root / "pyproject.toml").touch()
test_file = root / "empty.py"
test_file.write_text("")
# Should not crash
tmp_dir = guard._get_project_tmp_dir(str(test_file))
assert tmp_dir.exists()
@@ -479,13 +493,13 @@ class TestFileLocationVariations:
try:
(root / "tests").mkdir(parents=True)
(root / ".git").mkdir()
test_file = root / "tests/test_module.py"
test_file.write_text("# test")
found_root = guard._find_project_root(str(test_file))
assert found_root == root
# Test file detection
assert guard.is_test_file(str(test_file))
finally:
@@ -499,10 +513,10 @@ class TestFileLocationVariations:
try:
root.mkdir()
(root / ".git").mkdir()
test_file = root / "main.py"
test_file.write_text("# test")
found_root = guard._find_project_root(str(test_file))
assert found_root == root
finally:
@@ -521,12 +535,12 @@ class TestTempFileManagement:
(root / "src").mkdir(parents=True)
(root / ".venv/bin").mkdir(parents=True)
(root / "pyproject.toml").touch()
test_file = root / "src/mod.py"
test_file.write_text("def foo(): pass")
tmp_dir = root / ".tmp"
# Analyze code (should create and delete temp file)
config = guard.QualityConfig(
duplicate_enabled=False,
@@ -536,14 +550,14 @@ class TestTempFileManagement:
pyrefly_enabled=False,
sourcery_enabled=False,
)
guard.analyze_code_quality(
"def foo(): pass",
str(test_file),
config,
enable_type_checks=False,
)
# .tmp directory should exist but temp file should be gone
if tmp_dir.exists():
temp_files = list(tmp_dir.glob("hook_validation_*"))
@@ -559,15 +573,15 @@ class TestTempFileManagement:
try:
(root / "src").mkdir(parents=True)
(root / "pyproject.toml").touch()
test_file = root / "src/mod.py"
test_file.write_text("# test")
tmp_dir = guard._get_project_tmp_dir(str(test_file))
# Should be in project, not /tmp
assert str(tmp_dir).startswith(str(root))
assert not str(tmp_dir).startswith("/tmp")
assert not str(tmp_dir).startswith(gettempdir())
finally:
import shutil
if root.exists():

View File

@@ -3,15 +3,16 @@
import os
import pytest
from code_quality_guard import QualityConfig
from hooks import code_quality_guard as guard
class TestQualityConfig:
"""Test QualityConfig dataclass and environment loading."""
"""Test guard.QualityConfig dataclass and environment loading."""
def test_default_config(self):
"""Test default configuration values."""
config = QualityConfig()
config = guard.QualityConfig()
# Core settings
assert config.duplicate_threshold == 0.7
@@ -29,14 +30,16 @@ class TestQualityConfig:
assert config.show_success is False
# Skip patterns
assert "test_" in config.skip_patterns
assert "_test.py" in config.skip_patterns
assert "/tests/" in config.skip_patterns
assert "/fixtures/" in config.skip_patterns
assert config.skip_patterns is not None
skip_patterns = config.skip_patterns
assert "test_" in skip_patterns
assert "_test.py" in skip_patterns
assert "/tests/" in skip_patterns
assert "/fixtures/" in skip_patterns
def test_from_env_with_defaults(self):
"""Test loading config from environment with defaults."""
config = QualityConfig.from_env()
config = guard.QualityConfig.from_env()
# Should use defaults when env vars not set
assert config.duplicate_threshold == 0.7
@@ -61,7 +64,7 @@ class TestQualityConfig:
},
)
config = QualityConfig.from_env()
config = guard.QualityConfig.from_env()
assert config.duplicate_threshold == 0.8
assert config.duplicate_enabled is False
@@ -78,7 +81,7 @@ class TestQualityConfig:
def test_from_env_with_invalid_boolean(self):
"""Test loading config with invalid boolean values."""
os.environ["QUALITY_DUP_ENABLED"] = "invalid"
config = QualityConfig.from_env()
config = guard.QualityConfig.from_env()
# Should default to False for invalid boolean
assert config.duplicate_enabled is False
@@ -88,14 +91,14 @@ class TestQualityConfig:
os.environ["QUALITY_DUP_THRESHOLD"] = "not_a_float"
with pytest.raises(ValueError, match="could not convert string to float"):
QualityConfig.from_env()
_ = guard.QualityConfig.from_env()
def test_from_env_with_invalid_int(self):
"""Test loading config with invalid int values."""
os.environ["QUALITY_COMPLEXITY_THRESHOLD"] = "not_an_int"
with pytest.raises(ValueError, match="invalid literal"):
QualityConfig.from_env()
_ = guard.QualityConfig.from_env()
def test_enforcement_modes(self):
"""Test different enforcement modes."""
@@ -103,87 +106,85 @@ class TestQualityConfig:
for mode in modes:
os.environ["QUALITY_ENFORCEMENT"] = mode
config = QualityConfig.from_env()
config = guard.QualityConfig.from_env()
assert config.enforcement_mode == mode
def test_skip_patterns_initialization(self):
"""Test skip patterns initialization."""
config = QualityConfig(skip_patterns=None)
config = guard.QualityConfig(skip_patterns=None)
assert config.skip_patterns is not None
assert len(config.skip_patterns) > 0
custom_patterns = ["custom_test_", "/custom/"]
config = QualityConfig(skip_patterns=custom_patterns)
config = guard.QualityConfig(skip_patterns=custom_patterns)
assert config.skip_patterns == custom_patterns
def test_threshold_boundaries(self):
"""Test threshold boundary values."""
# Test minimum threshold
os.environ["QUALITY_DUP_THRESHOLD"] = "0.0"
config = QualityConfig.from_env()
config = guard.QualityConfig.from_env()
assert config.duplicate_threshold == 0.0
# Test maximum threshold
os.environ["QUALITY_DUP_THRESHOLD"] = "1.0"
config = QualityConfig.from_env()
config = guard.QualityConfig.from_env()
assert config.duplicate_threshold == 1.0
# Test complexity threshold
os.environ["QUALITY_COMPLEXITY_THRESHOLD"] = "1"
config = QualityConfig.from_env()
config = guard.QualityConfig.from_env()
assert config.complexity_threshold == 1
def test_config_combinations(self):
def test_config_combinations(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test various configuration combinations."""
test_cases = [
# All checks disabled
{
"env": {
test_cases: list[tuple[dict[str, str], dict[str, bool]]] = [
(
{
"QUALITY_DUP_ENABLED": "false",
"QUALITY_COMPLEXITY_ENABLED": "false",
"QUALITY_MODERN_ENABLED": "false",
},
"expected": {
{
"duplicate_enabled": False,
"complexity_enabled": False,
"modernization_enabled": False,
},
},
# Only duplicate checking
{
"env": {
),
(
{
"QUALITY_DUP_ENABLED": "true",
"QUALITY_COMPLEXITY_ENABLED": "false",
"QUALITY_MODERN_ENABLED": "false",
},
"expected": {
{
"duplicate_enabled": True,
"complexity_enabled": False,
"modernization_enabled": False,
},
},
# PostToolUse only
{
"env": {
),
(
{
"QUALITY_DUP_ENABLED": "false",
"QUALITY_STATE_TRACKING": "true",
"QUALITY_VERIFY_NAMING": "true",
},
"expected": {
{
"duplicate_enabled": False,
"state_tracking_enabled": True,
"verify_naming": True,
},
},
),
]
for test_case in test_cases:
os.environ.clear()
os.environ.update(test_case["env"])
config = QualityConfig.from_env()
for env_values, expected_values in test_cases:
with monkeypatch.context() as mp:
for key, value in env_values.items():
mp.setenv(key, value)
config = guard.QualityConfig.from_env()
for key, expected_value in test_case["expected"].items():
assert getattr(config, key) == expected_value
for key, expected_value in expected_values.items():
assert getattr(config, key) == expected_value
def test_case_insensitive_boolean(self):
"""Test case-insensitive boolean parsing."""
@@ -192,5 +193,5 @@ class TestQualityConfig:
for value, expected_bool in zip(test_values, expected, strict=False):
os.environ["QUALITY_DUP_ENABLED"] = value
config = QualityConfig.from_env()
config = guard.QualityConfig.from_env()
assert config.duplicate_enabled == expected_bool

View File

@@ -0,0 +1,309 @@
"""Fairness tests for async functions and fixtures in duplicate detection."""
from __future__ import annotations
from hooks.internal_duplicate_detector import (
Duplicate,
DuplicateResults,
detect_internal_duplicates,
)
def _run_detection(code: str, *, threshold: float) -> tuple[DuplicateResults, list[Duplicate]]:
"""Run duplicate detection and return typed results."""
result = detect_internal_duplicates(code, threshold=threshold)
duplicates = result.get("duplicates", []) or []
return result, duplicates
class TestAsyncFunctionFairness:
"""Verify async functions are treated fairly in duplicate detection."""
def test_async_and_sync_identical_logic(self) -> None:
"""Async and sync versions of same logic should be flagged as duplicates."""
code = """
def fetch_user(user_id: int) -> dict[str, str]:
response = requests.get(f"/api/users/{user_id}")
data = response.json()
return {"id": str(data["id"]), "name": data["name"]}
async def fetch_user_async(user_id: int) -> dict[str, str]:
response = await client.get(f"/api/users/{user_id}")
data = await response.json()
return {"id": str(data["id"]), "name": data["name"]}
"""
_, duplicates = _run_detection(code, threshold=0.7)
# Should detect structural similarity despite async/await differences
assert len(duplicates) >= 1
assert any(d["similarity"] > 0.7 for d in duplicates)
def test_async_context_managers_exemption(self) -> None:
"""Async context manager dunder methods should be exempted like sync ones."""
code = """
async def __aenter__(self):
self.conn = await connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.conn.close()
async def __aenter__(self):
self.cache = await connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.cache.close()
"""
_, duplicates = _run_detection(code, threshold=0.7)
# __aenter__ and __aexit__ should be exempted as boilerplate
# Even though they have similar structure
assert len(duplicates) == 0
def test_mixed_async_sync_functions_no_bias(self) -> None:
"""Detection should work equally for mixed async/sync functions."""
code = """
def process_sync(data: list[int]) -> int:
total = 0
for item in data:
if item > 0:
total += item * 2
return total
async def process_async(data: list[int]) -> int:
total = 0
for item in data:
if item > 0:
total += item * 2
return total
def calculate_sync(values: list[int]) -> int:
result = 0
for val in values:
if val > 0:
result += val * 2
return result
"""
_, duplicates = _run_detection(code, threshold=0.7)
# All three should be detected as similar (regardless of async)
assert len(duplicates) >= 1
found_functions: set[str] = set()
for dup in duplicates:
for loc in dup["locations"]:
found_functions.add(loc["name"])
# Should find all three functions in duplicate groups
assert len(found_functions) >= 2
class TestFixtureFairness:
"""Verify pytest fixtures and test patterns are treated fairly."""
def test_pytest_fixtures_with_similar_data(self) -> None:
"""Pytest fixtures with similar structure should be exempted."""
code = """
import pytest
@pytest.fixture
def user_data() -> dict[str, str | int]:
return {
"name": "Alice",
"age": 30,
"email": "alice@example.com"
}
@pytest.fixture
def admin_data() -> dict[str, str | int]:
return {
"name": "Bob",
"age": 35,
"email": "bob@example.com"
}
@pytest.fixture
def guest_data() -> dict[str, str | int]:
return {
"name": "Charlie",
"age": 25,
"email": "charlie@example.com"
}
"""
_, duplicates = _run_detection(code, threshold=0.7)
# Fixtures should be exempted from duplicate detection
assert len(duplicates) == 0
def test_mock_builders_exemption(self) -> None:
"""Mock/stub builder functions should be exempted if short and simple."""
code = """
def mock_user_response() -> dict[str, str]:
return {
"id": "123",
"name": "Test User",
"status": "active"
}
def mock_admin_response() -> dict[str, str]:
return {
"id": "456",
"name": "Admin User",
"status": "active"
}
def stub_guest_response() -> dict[str, str]:
return {
"id": "789",
"name": "Guest User",
"status": "pending"
}
"""
_, duplicates = _run_detection(code, threshold=0.7)
# Short mock builders should be exempted
assert len(duplicates) == 0
def test_simple_test_functions_with_aaa_pattern(self) -> None:
"""Simple test functions following arrange-act-assert should be lenient."""
code = """
def test_user_creation() -> None:
# Arrange
user_data = {"name": "Alice", "email": "alice@test.com"}
# Act
user = create_user(user_data)
# Assert
assert user.name == "Alice"
assert user.email == "alice@test.com"
def test_admin_creation() -> None:
# Arrange
admin_data = {"name": "Bob", "email": "bob@test.com"}
# Act
admin = create_user(admin_data)
# Assert
assert admin.name == "Bob"
assert admin.email == "bob@test.com"
def test_guest_creation() -> None:
# Arrange
guest_data = {"name": "Charlie", "email": "charlie@test.com"}
# Act
guest = create_user(guest_data)
# Assert
assert guest.name == "Charlie"
assert guest.email == "charlie@test.com"
"""
_, duplicates = _run_detection(code, threshold=0.7)
# Simple test functions with AAA pattern should be exempted if similarity < 95%
assert len(duplicates) == 0
def test_complex_fixtures_still_flagged(self) -> None:
"""Complex fixtures with substantial duplication should still be flagged."""
code = """
import pytest
@pytest.fixture
def complex_user_setup() -> dict[str, object]:
# Lots of complex setup logic
db = connect_database()
cache = setup_cache()
logger = configure_logging()
user = create_user(db, {
"name": "Alice",
"permissions": ["read", "write", "delete"],
"metadata": {"created": "2024-01-01"}
})
cache.warm_up(user)
logger.info(f"Created user {user.id}")
return {"user": user, "db": db, "cache": cache, "logger": logger}
@pytest.fixture
def complex_admin_setup() -> dict[str, object]:
# Lots of complex setup logic
db = connect_database()
cache = setup_cache()
logger = configure_logging()
user = create_user(db, {
"name": "Bob",
"permissions": ["read", "write", "delete"],
"metadata": {"created": "2024-01-02"}
})
cache.warm_up(user)
logger.info(f"Created user {user.id}")
return {"user": user, "db": db, "cache": cache, "logger": logger}
"""
_, duplicates = _run_detection(code, threshold=0.7)
# Complex fixtures exceeding 15 lines should be flagged
assert len(duplicates) >= 1
def test_setup_teardown_methods(self) -> None:
"""Test setup/teardown methods should be exempted if simple."""
code = """
def setup_database() -> None:
db = connect_test_db()
db.clear()
return db
def teardown_database(db: object) -> None:
db.clear()
db.close()
def setup_cache() -> None:
cache = connect_test_cache()
cache.clear()
return cache
def teardown_cache(cache: object) -> None:
cache.clear()
cache.close()
"""
_, duplicates = _run_detection(code, threshold=0.7)
# Setup/teardown functions with pattern names should be exempted
assert len(duplicates) == 0
def test_non_test_code_still_strictly_checked(self) -> None:
"""Non-test production code should still have strict duplicate detection."""
code = """
def calculate_user_total(users: list[dict[str, float]]) -> float:
total = 0.0
for user in users:
if user.get("active"):
total += user.get("amount", 0.0) * user.get("rate", 1.0)
return total
def calculate_product_total(products: list[dict[str, float]]) -> float:
total = 0.0
for product in products:
if product.get("active"):
total += product.get("amount", 0.0) * product.get("rate", 1.0)
return total
def calculate_order_total(orders: list[dict[str, float]]) -> float:
total = 0.0
for order in orders:
if order.get("active"):
total += order.get("amount", 0.0) * order.get("rate", 1.0)
return total
"""
_, duplicates = _run_detection(code, threshold=0.7)
# Production code should be strictly checked
assert len(duplicates) >= 1
assert any(d["similarity"] > 0.85 for d in duplicates)

View File

@@ -3,21 +3,19 @@
from __future__ import annotations
import json
from pathlib import Path
from types import SimpleNamespace
from typing import Iterator
from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
from code_quality_guard import QualityConfig, posttooluse_hook, pretooluse_hook
from code_quality_guard import (
QualityConfig,
posttooluse_hook,
pretooluse_hook,
)
if TYPE_CHECKING:
from collections.abc import Iterator
from pathlib import Path
@pytest.fixture()
@pytest.fixture
def multi_container_paths(tmp_path: Path) -> dict[str, Path]:
"""Create container/project directories used across tests."""
container_a = tmp_path / "container-a" / "project" / "src"
@@ -126,7 +124,9 @@ def test_pretooluse_handles_platform_metadata(
assert response["permissionDecision"] == "allow"
def test_state_tracking_isolation_between_containers(multi_container_paths: dict[str, Path]) -> None:
def test_state_tracking_isolation_between_containers(
multi_container_paths: dict[str, Path],
) -> None:
"""State tracking should stay isolated per container/project combination."""
config = QualityConfig(state_tracking_enabled=True)
config.skip_patterns = [] # Ensure state tracking runs even in pytest temp dirs.
@@ -173,11 +173,12 @@ def beta():
assert response_b_pre["permissionDecision"] == "allow"
# The first container writes fewer functions which should trigger a warning.
file_a.write_text("""\
file_a.write_text(
"""\
def alpha():
return 1
"""
""",
)
# The second container preserves the original content.
@@ -232,9 +233,13 @@ def beta():
path_one.parent.mkdir(parents=True, exist_ok=True)
path_two.parent.mkdir(parents=True, exist_ok=True)
cmd: list[str],
**kwargs: object,
) -> SimpleNamespace:
with patch("code_quality_guard.analyze_code_quality", return_value={}):
pretooluse_hook(
_pre_request(
path_one,
base_content,
container_id=shared_container,
project_id=shared_project,
user_id="collision-user",
),
config,
@@ -250,11 +255,13 @@ def beta():
config,
)
path_one.write_text("""\
path_one.write_text(
"""\
def alpha():
return 1
""".lstrip())
""".lstrip(),
)
path_two.write_text(base_content)
degraded_response = posttooluse_hook(
@@ -304,14 +311,7 @@ def test_cross_file_duplicate_project_root_detection(
captured: dict[str, list[str]] = {}
def fake_run(
cmd: list[str],
*,
check: bool,
capture_output: bool,
text: bool,
timeout: int,
) -> SimpleNamespace:
def fake_run(cmd: list[str], **_kwargs: object) -> SimpleNamespace:
captured["cmd"] = cmd
return SimpleNamespace(returncode=0, stdout=json.dumps({"duplicates": []}))
@@ -336,7 +336,9 @@ def test_cross_file_duplicate_project_root_detection(
assert response.get("decision") is None
def test_main_handles_permission_decisions_for_multiple_users(monkeypatch: pytest.MonkeyPatch) -> None:
def test_main_handles_permission_decisions_for_multiple_users(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""`main` should surface deny/ask decisions for different user contexts."""
from code_quality_guard import main
@@ -394,7 +396,7 @@ def test_main_handles_permission_decisions_for_multiple_users(monkeypatch: pytes
},
"permissionDecision": "allow",
},
]
],
)
input_iter: Iterator[dict[str, object]] = iter(hook_inputs)
@@ -402,7 +404,10 @@ def test_main_handles_permission_decisions_for_multiple_users(monkeypatch: pytes
def fake_json_load(_stream: object) -> dict[str, object]:
return next(input_iter)
def fake_pretooluse(_hook_data: dict[str, object], _config: QualityConfig) -> dict[str, object]:
def fake_pretooluse(
_hook_data: dict[str, object],
_config: QualityConfig,
) -> dict[str, object]:
return next(responses)
exit_calls: list[tuple[str, int]] = []
@@ -413,7 +418,7 @@ def test_main_handles_permission_decisions_for_multiple_users(monkeypatch: pytes
printed: list[str] = []
def fake_print(message: str) -> None: # noqa: D401 - simple passthrough
def fake_print(message: str) -> None:
printed.append(message)
monkeypatch.setattr("json.load", fake_json_load)

View File

@@ -10,7 +10,6 @@ from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from code_quality_guard import (
QualityConfig,
analyze_code_quality,
@@ -95,8 +94,13 @@ class TestHelperFunctions:
cmd = get_claude_quality_command(repo_root=tmp_path)
assert cmd == [str(executable), "-m", "quality.cli.main"]
def test_get_claude_quality_command_python_and_python3(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch, set_platform) -> None:
"""Prefer python when both python and python3 executables exist in the venv."""
def test_get_claude_quality_command_python_and_python3(
self,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
set_platform,
) -> None:
"""Prefer python when both python and python3 executables exist."""
set_platform("linux")
monkeypatch.setattr(shutil, "which", lambda _name: None)
@@ -108,7 +112,12 @@ class TestHelperFunctions:
assert cmd == [str(python_path), "-m", "quality.cli.main"]
assert python3_path.exists() # Sanity check that both executables were present
def test_get_claude_quality_command_cli_fallback(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch, set_platform) -> None:
def test_get_claude_quality_command_cli_fallback(
self,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
set_platform,
) -> None:
"""Fallback to claude-quality script when python executables are absent."""
set_platform("linux")
@@ -125,7 +134,7 @@ class TestHelperFunctions:
monkeypatch: pytest.MonkeyPatch,
set_platform,
) -> None:
"""Handle Windows environments where the claude-quality script lacks an .exe suffix."""
"""Handle Windows when the claude-quality script lacks an .exe suffix."""
set_platform("win32")
monkeypatch.setattr(shutil, "which", lambda _name: None)
@@ -141,7 +150,7 @@ class TestHelperFunctions:
monkeypatch: pytest.MonkeyPatch,
set_platform,
) -> None:
"""Fall back to python3 on POSIX and python on Windows when no venv executables exist."""
"""Fallback to python3 on POSIX and python on Windows when venv tools absent."""
set_platform("darwin")

View File

@@ -4,6 +4,10 @@ from unittest.mock import patch
from code_quality_guard import QualityConfig, pretooluse_hook
TEST_QUALITY_CONDITIONAL = (
"Test Quality: no-conditionals-in-tests - Conditional found in test"
)
class TestPreToolUseHook:
"""Test PreToolUse hook behavior."""
@@ -470,7 +474,7 @@ class TestPreToolUseHook:
"tool_input": {
"file_path": "example.py",
"content": (
"def example() -> None:\n" " value = unknown # type: ignore\n"
"def example() -> None:\n value = unknown # type: ignore\n"
),
},
}
@@ -492,7 +496,7 @@ class TestPreToolUseHook:
{
"old_string": "pass",
"new_string": (
"def helper() -> None:\n" " pass # type: ignore\n"
"def helper() -> None:\n pass # type: ignore\n"
),
},
{
@@ -545,7 +549,7 @@ class TestTestQualityChecks:
}
with patch("code_quality_guard.run_test_quality_checks") as mock_test_check:
mock_test_check.return_value = ["Test Quality: no-conditionals-in-tests - Conditional found in test"]
mock_test_check.return_value = [TEST_QUALITY_CONDITIONAL]
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
mock_analyze.return_value = {}
@@ -639,7 +643,7 @@ class TestTestQualityChecks:
}
with patch("code_quality_guard.run_test_quality_checks") as mock_test_check:
mock_test_check.return_value = ["Test Quality: no-conditionals-in-tests - Conditional found in test"]
mock_test_check.return_value = [TEST_QUALITY_CONDITIONAL]
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
mock_analyze.return_value = {}
@@ -659,14 +663,27 @@ class TestTestQualityChecks:
"tool_input": {
"file_path": "tests/test_example.py",
"edits": [
{"old_string": "a", "new_string": "def test_func1():\n assert True"},
{"old_string": "b", "new_string": "def test_func2():\n if False:\n pass"},
{
"old_string": "a",
"new_string": (
"def test_func1():\n"
" assert True"
),
},
{
"old_string": "b",
"new_string": (
"def test_func2():\n"
" if False:\n"
" pass"
),
},
],
},
}
with patch("code_quality_guard.run_test_quality_checks") as mock_test_check:
mock_test_check.return_value = ["Test Quality: no-conditionals-in-tests - Conditional found in test"]
mock_test_check.return_value = [TEST_QUALITY_CONDITIONAL]
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
mock_analyze.return_value = {}
@@ -695,7 +712,7 @@ class TestTestQualityChecks:
}
with patch("code_quality_guard.run_test_quality_checks") as mock_test_check:
mock_test_check.return_value = ["Test Quality: no-conditionals-in-tests - Conditional found in test"]
mock_test_check.return_value = [TEST_QUALITY_CONDITIONAL]
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
mock_analyze.return_value = {}

View File

@@ -1,25 +1,33 @@
# ruff: noqa: SLF001
"""Tests targeting internal helpers for code_quality_guard."""
from __future__ import annotations
# pyright: reportPrivateUsage=false, reportPrivateImportUsage=false, reportPrivateLocalImportUsage=false, reportUnknownArgumentType=false, reportUnknownLambdaType=false, reportUnknownMemberType=false, reportUnusedCallResult=false
import json
import subprocess
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING, cast
import code_quality_guard as guard
import pytest
from hooks import code_quality_guard as guard
if TYPE_CHECKING:
from pathlib import Path
@pytest.mark.parametrize(
("env_key", "value", "attr", "expected"),
(
[
("QUALITY_DUP_THRESHOLD", "0.9", "duplicate_threshold", 0.9),
("QUALITY_DUP_ENABLED", "false", "duplicate_enabled", False),
("QUALITY_COMPLEXITY_THRESHOLD", "7", "complexity_threshold", 7),
("QUALITY_ENFORCEMENT", "warn", "enforcement_mode", "warn"),
("QUALITY_STATE_TRACKING", "true", "state_tracking_enabled", True),
),
],
)
def test_quality_config_from_env_parsing(
monkeypatch: pytest.MonkeyPatch,
@@ -36,12 +44,12 @@ def test_quality_config_from_env_parsing(
@pytest.mark.parametrize(
("tool_exists", "install_behavior", "expected"),
(
[
(True, None, True),
(False, "success", True),
(False, "failure", False),
(False, "timeout", False),
),
],
)
def test_ensure_tool_installed(
monkeypatch: pytest.MonkeyPatch,
@@ -63,10 +71,12 @@ def test_ensure_tool_installed(
def fake_run(cmd: Iterable[str], **_: object) -> subprocess.CompletedProcess[bytes]:
if install_behavior is None:
raise AssertionError("uv install should not run when tool already exists")
message = "uv install should not run when tool already exists"
raise AssertionError(message)
if install_behavior == "timeout":
raise subprocess.TimeoutExpired(cmd=list(cmd), timeout=60)
return subprocess.CompletedProcess(list(cmd), 0 if install_behavior == "success" else 1)
exit_code = 0 if install_behavior == "success" else 1
return subprocess.CompletedProcess(list(cmd), exit_code)
monkeypatch.setattr(guard.subprocess, "run", fake_run)
@@ -75,12 +85,32 @@ def test_ensure_tool_installed(
@pytest.mark.parametrize(
("tool_name", "run_payload", "expected_success", "expected_fragment"),
(
("basedpyright", {"returncode": 0, "stdout": ""}, True, ""),
("basedpyright", {"returncode": 1, "stdout": ""}, False, "failed to parse"),
("sourcery", {"returncode": 0, "stdout": "3 issues detected"}, False, "3 code quality issue"),
("pyrefly", {"returncode": 1, "stdout": "pyrefly issue"}, False, "pyrefly issue"),
),
[
(
"basedpyright",
{"returncode": 0, "stdout": ""},
True,
"",
),
(
"basedpyright",
{"returncode": 1, "stdout": ""},
False,
"failed to parse",
),
(
"sourcery",
{"returncode": 0, "stdout": "3 issues detected"},
False,
"3 code quality issue",
),
(
"pyrefly",
{"returncode": 1, "stdout": "pyrefly issue"},
False,
"pyrefly issue",
),
],
)
def test_run_type_checker_known_tools(
monkeypatch: pytest.MonkeyPatch,
@@ -94,11 +124,27 @@ def test_run_type_checker_known_tools(
monkeypatch.setattr(guard.Path, "exists", lambda _path: True, raising=False)
def fake_run(cmd: Iterable[str], **_: object) -> subprocess.CompletedProcess[str]:
return subprocess.CompletedProcess(list(cmd), int(run_payload["returncode"]), run_payload.get("stdout", ""), "")
returncode_obj = run_payload.get("returncode", 0)
if isinstance(returncode_obj, bool):
exit_code = int(returncode_obj)
elif isinstance(returncode_obj, int):
exit_code = returncode_obj
elif isinstance(returncode_obj, str):
exit_code = int(returncode_obj)
else:
raise AssertionError(f"Unexpected returncode type: {type(returncode_obj)!r}")
stdout_obj = run_payload.get("stdout", "")
stdout = str(stdout_obj)
return subprocess.CompletedProcess(list(cmd), exit_code, stdout=stdout, stderr="")
monkeypatch.setattr(guard.subprocess, "run", fake_run)
success, message = guard._run_type_checker(tool_name, "tmp.py", guard.QualityConfig())
success, message = guard._run_type_checker(
tool_name,
"tmp.py",
guard.QualityConfig(),
)
assert success is expected_success
if expected_fragment:
assert expected_fragment in message
@@ -108,10 +154,10 @@ def test_run_type_checker_known_tools(
@pytest.mark.parametrize(
("exception", "expected_fragment"),
(
[
(subprocess.TimeoutExpired(cmd=["tool"], timeout=30), "timeout"),
(OSError("boom"), "execution error"),
),
],
)
def test_run_type_checker_runtime_exceptions(
monkeypatch: pytest.MonkeyPatch,
@@ -126,7 +172,11 @@ def test_run_type_checker_runtime_exceptions(
monkeypatch.setattr(guard.subprocess, "run", raise_exc)
success, message = guard._run_type_checker("sourcery", "tmp.py", guard.QualityConfig())
success, message = guard._run_type_checker(
"sourcery",
"tmp.py",
guard.QualityConfig(),
)
assert success is True
assert expected_fragment in message
@@ -137,7 +187,11 @@ def test_run_type_checker_tool_missing(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(guard.Path, "exists", lambda _path: False, raising=False)
monkeypatch.setattr(guard, "_ensure_tool_installed", lambda _name: False)
success, message = guard._run_type_checker("pyrefly", "tmp.py", guard.QualityConfig())
success, message = guard._run_type_checker(
"pyrefly",
"tmp.py",
guard.QualityConfig(),
)
assert success is True
assert "not available" in message
@@ -148,12 +202,19 @@ def test_run_type_checker_unknown_tool(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(guard.Path, "exists", lambda _path: True, raising=False)
success, message = guard._run_type_checker("unknown", "tmp.py", guard.QualityConfig())
success, message = guard._run_type_checker(
"unknown",
"tmp.py",
guard.QualityConfig(),
)
assert success is True
assert "Unknown tool" in message
def test_run_quality_analyses_invokes_cli(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
def test_run_quality_analyses_invokes_cli(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
"""_run_quality_analyses aggregates CLI outputs and duplicates."""
script_path = tmp_path / "module.py"
@@ -202,7 +263,8 @@ def test_run_quality_analyses_invokes_cli(monkeypatch: pytest.MonkeyPatch, tmp_p
},
)
else:
raise AssertionError(f"Unexpected command: {cmd}")
message = f"Unexpected command: {cmd}"
raise AssertionError(message)
return subprocess.CompletedProcess(list(cmd), 0, payload, "")
monkeypatch.setattr(guard.subprocess, "run", fake_run)
@@ -221,11 +283,11 @@ def test_run_quality_analyses_invokes_cli(monkeypatch: pytest.MonkeyPatch, tmp_p
@pytest.mark.parametrize(
("content", "expected"),
(
[
("from typing import Any\n\nAny\n", True),
("def broken(:\n Any\n", True),
("def clean() -> None:\n return None\n", False),
),
],
)
def test_detect_any_usage(content: str, expected: bool) -> None:
"""_detect_any_usage flags Any usage even on syntax errors."""
@@ -236,12 +298,12 @@ def test_detect_any_usage(content: str, expected: bool) -> None:
@pytest.mark.parametrize(
("mode", "forced", "expected_permission"),
(
[
("strict", None, "deny"),
("warn", None, "ask"),
("permissive", None, "allow"),
("strict", "allow", "allow"),
),
],
)
def test_handle_quality_issues_modes(
mode: str,
@@ -253,17 +315,29 @@ def test_handle_quality_issues_modes(
config = guard.QualityConfig(enforcement_mode=mode)
issues = ["Issue one", "Issue two"]
response = guard._handle_quality_issues("example.py", issues, config, forced_permission=forced)
assert response["permissionDecision"] == expected_permission
response = guard._handle_quality_issues(
"example.py",
issues,
config,
forced_permission=forced,
)
decision = cast(str, response["permissionDecision"])
assert decision == expected_permission
if forced is None:
assert any(issue in response.get("reason", "") for issue in issues)
reason = cast(str, response.get("reason", ""))
assert any(issue in reason for issue in issues)
def test_perform_quality_check_with_state_tracking(monkeypatch: pytest.MonkeyPatch) -> None:
def test_perform_quality_check_with_state_tracking(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""_perform_quality_check stores state and reports detected issues."""
tracked_calls: list[str] = []
monkeypatch.setattr(guard, "store_pre_state", lambda path, content: tracked_calls.append(path))
def record_state(path: str, _content: str) -> None:
tracked_calls.append(path)
monkeypatch.setattr(guard, "store_pre_state", record_state)
def fake_analyze(*_args: object, **_kwargs: object) -> guard.AnalysisResults:
return {
@@ -276,11 +350,18 @@ def test_perform_quality_check_with_state_tracking(monkeypatch: pytest.MonkeyPat
config = guard.QualityConfig(state_tracking_enabled=True)
has_issues, issues = guard._perform_quality_check("example.py", "def old(): pass", config)
has_issues, issues = guard._perform_quality_check(
"example.py",
"def old(): pass",
config,
)
assert tracked_calls == ["example.py"]
assert has_issues is True
assert any("Modernization" in issue or "modernization" in issue.lower() for issue in issues)
assert any(
"Modernization" in issue or "modernization" in issue.lower()
for issue in issues
)
def test_check_cross_file_duplicates_command(monkeypatch: pytest.MonkeyPatch) -> None:
@@ -296,7 +377,10 @@ def test_check_cross_file_duplicates_command(monkeypatch: pytest.MonkeyPatch) ->
monkeypatch.setattr(guard.subprocess, "run", fake_run)
issues = guard.check_cross_file_duplicates("/repo/example.py", guard.QualityConfig())
issues = guard.check_cross_file_duplicates(
"/repo/example.py",
guard.QualityConfig(),
)
assert issues
assert "duplicates" in captured_cmds[0]
@@ -314,9 +398,9 @@ def test_create_hook_response_includes_reason() -> None:
additional_context="context",
decision="block",
)
assert response["permissionDecision"] == "deny"
assert response["reason"] == "Testing"
assert response["systemMessage"] == "System"
assert response["hookSpecificOutput"]["additionalContext"] == "context"
assert response["decision"] == "block"
assert cast(str, response["permissionDecision"]) == "deny"
assert cast(str, response["reason"]) == "Testing"
assert cast(str, response["systemMessage"]) == "System"
hook_output = cast(dict[str, object], response["hookSpecificOutput"])
assert cast(str, hook_output["additionalContext"]) == "context"
assert cast(str, response["decision"]) == "block"

View File

@@ -2,25 +2,24 @@
from __future__ import annotations
# pyright: reportPrivateUsage=false, reportPrivateImportUsage=false, reportPrivateLocalImportUsage=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownLambdaType=false, reportUnknownMemberType=false, reportUnusedCallResult=false
# ruff: noqa: SLF001
import json
import os
import subprocess
import sys
from collections.abc import Mapping
from pathlib import Path
import pytest
# Add hooks directory to path
HOOKS_DIR = Path(__file__).parent.parent.parent / "hooks"
sys.path.insert(0, str(HOOKS_DIR))
import code_quality_guard as guard
from hooks import code_quality_guard as guard
class TestVenvDetection:
"""Test virtual environment detection."""
def test_finds_venv_from_file_path(self, tmp_path: Path) -> None:
def test_finds_venv_from_file_path(self) -> None:
"""Should find .venv by traversing up from file."""
# Use home directory to avoid /tmp check
root = Path.home() / f"test_proj_{os.getpid()}"
@@ -99,12 +98,17 @@ class TestPythonpathSetup:
tool.write_text("#!/bin/bash\necho fake")
tool.chmod(0o755)
captured_env = {}
captured_env: dict[str, str] = {}
def capture_run(cmd: list[str], **kw: object) -> subprocess.CompletedProcess[str]:
if "env" in kw:
captured_env.update(dict(kw["env"]))
return subprocess.CompletedProcess(cmd, 0, "", "")
def capture_run(
cmd: list[str],
**kwargs: object,
) -> subprocess.CompletedProcess[str]:
env_obj = kwargs.get("env")
if isinstance(env_obj, Mapping):
for key, value in env_obj.items():
captured_env[str(key)] = str(value)
return subprocess.CompletedProcess(list(cmd), 0, stdout="", stderr="")
monkeypatch.setattr(guard.subprocess, "run", capture_run)
@@ -137,10 +141,10 @@ class TestProjectRootAndTempFiles:
nested = root / "src/pkg/subpkg"
nested.mkdir(parents=True)
(root / ".git").mkdir()
test_file = nested / "module.py"
test_file.write_text("# test")
found_root = guard._find_project_root(str(test_file))
assert found_root == root
finally:
@@ -154,12 +158,12 @@ class TestProjectRootAndTempFiles:
try:
(root / "src").mkdir(parents=True)
(root / "pyproject.toml").touch()
test_file = root / "src/module.py"
test_file.write_text("# test")
tmp_dir = guard._get_project_tmp_dir(str(test_file))
assert tmp_dir.exists()
assert tmp_dir == root / ".tmp"
assert tmp_dir.parent == root
@@ -177,29 +181,33 @@ class TestProjectRootAndTempFiles:
tool = root / ".venv/bin/basedpyright"
tool.write_text("#!/bin/bash\necho fake")
tool.chmod(0o755)
# Create pyrightconfig.json
(root / "pyrightconfig.json").write_text('{"strict": []}')
captured_cwd = []
def capture_run(cmd: list[str], **kw: object) -> subprocess.CompletedProcess[str]:
if "cwd" in kw:
captured_cwd.append(Path(str(kw["cwd"])))
return subprocess.CompletedProcess(cmd, 0, "", "")
captured_cwd: list[Path] = []
def capture_run(
cmd: list[str],
**kwargs: object,
) -> subprocess.CompletedProcess[str]:
cwd_obj = kwargs.get("cwd")
if cwd_obj is not None:
captured_cwd.append(Path(str(cwd_obj)))
return subprocess.CompletedProcess(list(cmd), 0, stdout="", stderr="")
monkeypatch.setattr(guard.subprocess, "run", capture_run)
test_file = root / "src/mod.py"
test_file.write_text("# test")
guard._run_type_checker(
"basedpyright",
str(test_file),
guard.QualityConfig(),
original_file_path=str(test_file),
)
# Should have run from project root
assert len(captured_cwd) > 0
assert captured_cwd[0] == root

View File

@@ -1,4 +1,3 @@
#!/usr/bin/env python3
"""Comprehensive integration tests for code quality hooks.
This test suite validates that the hooks properly block forbidden code patterns
@@ -10,93 +9,103 @@ import subprocess
import sys
import tempfile
from pathlib import Path
from typing import Any
from typing import cast
import pytest
# Add hooks directory to path
HOOKS_DIR = Path(__file__).parent.parent / "hooks"
sys.path.insert(0, str(HOOKS_DIR))
from code_quality_guard import (
from hooks.code_quality_guard import (
JsonObject,
QualityConfig,
pretooluse_hook,
_detect_any_usage, # pyright: ignore[reportPrivateUsage]
_detect_old_typing_patterns, # pyright: ignore[reportPrivateUsage]
_detect_type_ignore_usage, # pyright: ignore[reportPrivateUsage]
posttooluse_hook,
_detect_any_usage,
_detect_type_ignore_usage,
_detect_old_typing_patterns,
pretooluse_hook,
)
HOOKS_DIR = Path(__file__).parent.parent / "hooks"
class TestHookIntegration:
"""Integration tests for the complete hook system."""
def setup_method(self):
config: QualityConfig
def __init__(self) -> None:
super().__init__()
self.config = QualityConfig.from_env()
self.config.enforcement_mode = "strict"
def setup_method(self) -> None:
"""Set up test environment."""
self.config = QualityConfig.from_env()
self.config.enforcement_mode = "strict"
def test_any_usage_blocked(self):
"""Test that typing.Any usage is blocked."""
content = '''from typing import Any
content = """from typing import Any
def bad_function(param: Any) -> Any:
return param'''
return param"""
hook_data = {
hook_data = cast(
JsonObject,
{
"tool_name": "Write",
"tool_input": {
"file_path": "/src/production_code.py", # Non-test file
"content": content
}
}
"content": content,
},
},
)
result = pretooluse_hook(hook_data, self.config)
assert result["permissionDecision"] == "deny"
assert "typing.Any usage" in result["reason"]
print(f"✅ Any usage properly blocked: {result['reason']}")
decision = cast(str, result["permissionDecision"])
reason = cast(str, result["reason"])
assert decision == "deny"
assert "typing.Any usage" in reason
def test_type_ignore_blocked(self):
"""Test that # type: ignore is blocked."""
content = '''def bad_function():
content = """def bad_function():
x = some_untyped_call() # type: ignore
return x'''
return x"""
hook_data = {
hook_data = cast(
JsonObject,
{
"tool_name": "Write",
"tool_input": {
"file_path": "/src/production_code.py",
"content": content
}
}
"tool_input": {"file_path": "/src/production_code.py", "content": content},
},
)
result = pretooluse_hook(hook_data, self.config)
assert result["permissionDecision"] == "deny"
assert "type: ignore" in result["reason"]
print(f"✅ type: ignore properly blocked: {result['reason']}")
decision = cast(str, result["permissionDecision"])
reason = cast(str, result["reason"])
assert decision == "deny"
assert "type: ignore" in reason
def test_old_typing_patterns_blocked(self):
"""Test that old typing patterns are blocked."""
content = '''from typing import Union, Optional, List, Dict
content = """from typing import Union, Optional, List, Dict
def bad_function(param: Union[str, int]) -> Optional[List[Dict[str, int]]]:
return None'''
return None"""
hook_data = {
hook_data = cast(
JsonObject,
{
"tool_name": "Write",
"tool_input": {
"file_path": "/src/production_code.py",
"content": content
}
}
"tool_input": {"file_path": "/src/production_code.py", "content": content},
},
)
result = pretooluse_hook(hook_data, self.config)
assert result["permissionDecision"] == "deny"
assert "Old typing pattern" in result["reason"]
print(f"✅ Old typing patterns properly blocked: {result['reason']}")
decision = cast(str, result["permissionDecision"])
reason = cast(str, result["reason"])
assert decision == "deny"
assert "Old typing pattern" in reason
def test_good_code_allowed(self):
"""Test that good code is allowed through."""
@@ -106,175 +115,199 @@ def bad_function(param: Union[str, int]) -> Optional[List[Dict[str, int]]]:
return None
return [{"value": 1}]'''
hook_data = {
hook_data = cast(
JsonObject,
{
"tool_name": "Write",
"tool_input": {
"file_path": "/src/production_code.py",
"content": content
}
}
"tool_input": {"file_path": "/src/production_code.py", "content": content},
},
)
result = pretooluse_hook(hook_data, self.config)
assert result["permissionDecision"] == "allow"
print("✅ Good code properly allowed")
decision = cast(str, result["permissionDecision"])
assert decision == "allow"
def test_test_file_conditionals_blocked(self):
"""Test that conditionals in test files are blocked."""
content = '''def test_something():
content = """def test_something():
for item in items:
if item.valid:
assert item.process()
else:
assert not item.process()'''
assert not item.process()"""
hook_data = {
hook_data = cast(
JsonObject,
{
"tool_name": "Write",
"tool_input": {
"file_path": "/tests/test_something.py", # Test file
"content": content
}
}
"content": content,
},
},
)
result = pretooluse_hook(hook_data, self.config)
assert result["permissionDecision"] == "deny"
assert ("Conditional Logic in Test Function" in result["reason"] or
"Loop Found in Test Function" in result["reason"])
print(f"✅ Test conditionals/loops properly blocked: {result['reason']}")
decision = cast(str, result["permissionDecision"])
reason = cast(str, result["reason"])
assert decision == "deny"
assert (
"Conditional Logic in Test Function" in reason
or "Loop Found in Test Function" in reason
)
def test_enforcement_modes(self):
"""Test different enforcement modes."""
content = '''from typing import Any
content = """from typing import Any
def bad_function(param: Any) -> Any:
return param'''
return param"""
hook_data = {
hook_data = cast(
JsonObject,
{
"tool_name": "Write",
"tool_input": {
"file_path": "/src/production_code.py",
"content": content
}
}
"tool_input": {"file_path": "/src/production_code.py", "content": content},
},
)
# Test strict mode
self.config.enforcement_mode = "strict"
result = pretooluse_hook(hook_data, self.config)
assert result["permissionDecision"] == "deny"
print("✅ Strict mode properly blocks")
assert cast(str, result["permissionDecision"]) == "deny"
# Test warn mode
self.config.enforcement_mode = "warn"
result = pretooluse_hook(hook_data, self.config)
assert result["permissionDecision"] == "ask"
print("✅ Warn mode properly asks")
assert cast(str, result["permissionDecision"]) == "ask"
# Test permissive mode
self.config.enforcement_mode = "permissive"
result = pretooluse_hook(hook_data, self.config)
assert result["permissionDecision"] == "allow"
print("✅ Permissive mode properly allows with warning")
assert cast(str, result["permissionDecision"]) == "allow"
def test_edit_tool_blocking(self):
"""Test that Edit tool is also properly blocked."""
hook_data = {
"tool_name": "Edit",
"tool_input": {
"file_path": "/src/production_code.py",
"old_string": "def old_func():",
"new_string": "def new_func(param: Any) -> Any:"
}
}
hook_data = cast(
JsonObject,
{
"tool_name": "Edit",
"tool_input": {
"file_path": "/src/production_code.py",
"old_string": "def old_func():",
"new_string": "def new_func(param: Any) -> Any:",
},
},
)
result = pretooluse_hook(hook_data, self.config)
assert result["permissionDecision"] == "deny"
assert "typing.Any usage" in result["reason"]
print(f"✅ Edit tool properly blocked: {result['reason']}")
decision = cast(str, result["permissionDecision"])
reason = cast(str, result["reason"])
assert decision == "deny"
assert "typing.Any usage" in reason
def test_multiedit_tool_blocking(self):
"""Test that MultiEdit tool is also properly blocked."""
hook_data = {
"tool_name": "MultiEdit",
"tool_input": {
"file_path": "/src/production_code.py",
"edits": [
{
"old_string": "def old_func():",
"new_string": "def new_func(param: Any) -> Any:"
}
]
}
}
hook_data = cast(
JsonObject,
{
"tool_name": "MultiEdit",
"tool_input": {
"file_path": "/src/production_code.py",
"edits": [
cast(
JsonObject,
{
"old_string": "def old_func():",
"new_string": "def new_func(param: Any) -> Any:",
},
),
],
},
},
)
result = pretooluse_hook(hook_data, self.config)
assert result["permissionDecision"] == "deny"
assert "typing.Any usage" in result["reason"]
print(f"✅ MultiEdit tool properly blocked: {result['reason']}")
decision = cast(str, result["permissionDecision"])
reason = cast(str, result["reason"])
assert decision == "deny"
assert "typing.Any usage" in reason
def test_non_python_files_allowed(self):
"""Test that non-Python files are allowed through."""
hook_data = {
"tool_name": "Write",
"tool_input": {
"file_path": "/src/config.json",
"content": '{"any": "value", "type": "ignore"}'
}
}
hook_data = cast(
JsonObject,
{
"tool_name": "Write",
"tool_input": cast(
JsonObject,
{
"file_path": "/src/config.json",
"content": json.dumps({"any": "value", "type": "ignore"}),
},
),
},
)
result = pretooluse_hook(hook_data, self.config)
assert result["permissionDecision"] == "allow"
print("✅ Non-Python files properly allowed")
assert cast(str, result["permissionDecision"]) == "allow"
def test_posttooluse_hook(self):
"""Test PostToolUse hook functionality."""
# Create a temp file with bad content
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write("from typing import Any\ndef bad(x: Any) -> Any: return x")
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
_ = f.write("from typing import Any\ndef bad(x: Any) -> Any: return x")
temp_path = f.name
try:
hook_data = {
"tool_name": "Write",
"tool_response": {
"file_path": temp_path
}
}
hook_data = cast(
JsonObject,
{
"tool_name": "Write",
"tool_response": cast(
JsonObject,
{"file_path": temp_path},
),
},
)
result = posttooluse_hook(hook_data, self.config)
# PostToolUse should detect issues in the written file
assert "decision" in result
print(f"✅ PostToolUse hook working: {result}")
finally:
Path(temp_path).unlink(missing_ok=True)
def test_command_line_execution(self):
"""Test that the hook works when executed via command line."""
test_input = json.dumps({
"tool_name": "Write",
"tool_input": {
"file_path": "/src/test.py",
"content": "from typing import Any\ndef bad(x: Any): pass"
}
})
test_input = json.dumps(
{
"tool_name": "Write",
"tool_input": {
"file_path": "/src/test.py",
"content": "from typing import Any\ndef bad(x: Any): pass",
},
},
)
# Test from hooks directory
result = subprocess.run(
["python", "code_quality_guard.py"],
script_path = HOOKS_DIR / "code_quality_guard.py"
result = subprocess.run( # noqa: S603 - executes trusted project script
[sys.executable, str(script_path)],
check=False,
input=test_input,
text=True,
capture_output=True,
cwd=HOOKS_DIR
cwd=HOOKS_DIR,
)
output = json.loads(result.stdout)
assert output["permissionDecision"] == "deny"
output = cast(JsonObject, json.loads(result.stdout))
assert cast(str, output["permissionDecision"]) == "deny"
assert result.returncode == 2 # Should exit with error code 2
print(f"✅ Command line execution properly blocks: {output['reason']}")
class TestDetectionFunctions:
@@ -296,7 +329,6 @@ class TestDetectionFunctions:
issues = _detect_any_usage(content)
has_issues = len(issues) > 0
assert has_issues == should_detect, f"Failed for: {content}"
print(f"✅ Any detection for '{content}': {has_issues}")
def test_type_ignore_detection_comprehensive(self):
"""Test comprehensive type: ignore detection."""
@@ -313,7 +345,6 @@ class TestDetectionFunctions:
issues = _detect_type_ignore_usage(content)
has_issues = len(issues) > 0
assert has_issues == should_detect, f"Failed for: {content}"
print(f"✅ Type ignore detection for '{content}': {has_issues}")
def test_old_typing_patterns_comprehensive(self):
"""Test comprehensive old typing patterns detection."""
@@ -334,62 +365,3 @@ class TestDetectionFunctions:
issues = _detect_old_typing_patterns(content)
has_issues = len(issues) > 0
assert has_issues == should_detect, f"Failed for: {content}"
print(f"✅ Old typing pattern detection for '{content}': {has_issues}")
def run_comprehensive_test():
"""Run all tests and provide a summary."""
print("🚀 Starting comprehensive hook testing...\n")
# Run the test classes
test_integration = TestHookIntegration()
test_integration.setup_method()
test_detection = TestDetectionFunctions()
tests = [
# Integration tests
(test_integration.test_any_usage_blocked, "Any usage blocking"),
(test_integration.test_type_ignore_blocked, "Type ignore blocking"),
(test_integration.test_old_typing_patterns_blocked, "Old typing patterns blocking"),
(test_integration.test_good_code_allowed, "Good code allowed"),
(test_integration.test_test_file_conditionals_blocked, "Test file conditionals blocked"),
(test_integration.test_enforcement_modes, "Enforcement modes"),
(test_integration.test_edit_tool_blocking, "Edit tool blocking"),
(test_integration.test_multiedit_tool_blocking, "MultiEdit tool blocking"),
(test_integration.test_non_python_files_allowed, "Non-Python files allowed"),
(test_integration.test_posttooluse_hook, "PostToolUse hook"),
(test_integration.test_command_line_execution, "Command line execution"),
# Detection function tests
(test_detection.test_any_detection_comprehensive, "Any detection comprehensive"),
(test_detection.test_type_ignore_detection_comprehensive, "Type ignore detection comprehensive"),
(test_detection.test_old_typing_patterns_comprehensive, "Old typing patterns comprehensive"),
]
passed = 0
failed = 0
for test_func, test_name in tests:
try:
print(f"\n🧪 Running: {test_name}")
test_func()
passed += 1
print(f"✅ PASSED: {test_name}")
except Exception as e:
failed += 1
print(f"❌ FAILED: {test_name} - {e}")
print(f"\n📊 Test Results: {passed} passed, {failed} failed")
if failed == 0:
print("🎉 ALL TESTS PASSED! Hooks are working correctly.")
else:
print(f"⚠️ {failed} tests failed. Hooks need fixes.")
return failed == 0
if __name__ == "__main__":
success = run_comprehensive_test()
sys.exit(0 if success else 1)