feat: enhance test coverage and improve code quality checks
- Updated test files to improve coverage for Any usage, type: ignore, and old typing patterns, ensuring that these patterns are properly blocked. - Refactored test structure for better clarity and maintainability, including the introduction of fixtures and improved assertions. - Enhanced error handling and messaging in the hook system to provide clearer feedback on violations. - Improved integration tests to validate hook behavior across various scenarios, ensuring robustness and reliability.
This commit is contained in:
450
hooks/bash_command_guard.py
Normal file
450
hooks/bash_command_guard.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""Shell command guard for Claude Code PreToolUse/PostToolUse hooks.
|
||||
|
||||
Prevents circumvention of type safety rules via shell commands that could inject
|
||||
'Any' types or type ignore comments into Python files.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from shutil import which
|
||||
from typing import TypedDict
|
||||
|
||||
from .bash_guard_constants import (
|
||||
DANGEROUS_SHELL_PATTERNS,
|
||||
FORBIDDEN_PATTERNS,
|
||||
PYTHON_FILE_PATTERNS,
|
||||
)
|
||||
|
||||
|
||||
class JsonObject(TypedDict, total=False):
|
||||
"""Type for JSON-like objects."""
|
||||
|
||||
hookEventName: str
|
||||
permissionDecision: str
|
||||
permissionDecisionReason: str
|
||||
decision: str
|
||||
reason: str
|
||||
systemMessage: str
|
||||
hookSpecificOutput: dict[str, object]
|
||||
|
||||
|
||||
def _contains_forbidden_pattern(text: str) -> tuple[bool, str | None]:
|
||||
"""Check if text contains any forbidden patterns.
|
||||
|
||||
Args:
|
||||
text: The text to check for forbidden patterns.
|
||||
|
||||
Returns:
|
||||
Tuple of (has_violation, matched_pattern_description)
|
||||
"""
|
||||
for pattern in FORBIDDEN_PATTERNS:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
if "Any" in pattern:
|
||||
return True, "typing.Any usage"
|
||||
if "type.*ignore" in pattern:
|
||||
return True, "type suppression comment"
|
||||
return False, None
|
||||
|
||||
|
||||
def _is_dangerous_shell_command(command: str) -> tuple[bool, str | None]:
|
||||
"""Check if shell command uses dangerous patterns.
|
||||
|
||||
Args:
|
||||
command: The shell command to analyze.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_dangerous, reason)
|
||||
"""
|
||||
# Check if command targets Python files
|
||||
targets_python = any(
|
||||
re.search(pattern, command) for pattern in PYTHON_FILE_PATTERNS
|
||||
)
|
||||
|
||||
if not targets_python:
|
||||
return False, None
|
||||
|
||||
# Check for dangerous shell patterns
|
||||
for pattern in DANGEROUS_SHELL_PATTERNS:
|
||||
if re.search(pattern, command):
|
||||
tool_match = re.search(
|
||||
r"\b(sed|awk|perl|ed|echo|printf|cat|tee|find|xargs|python|vim|nano|emacs)\b",
|
||||
pattern,
|
||||
)
|
||||
tool_name = tool_match.group(1) if tool_match else "shell utility"
|
||||
return True, f"Use of {tool_name} to modify Python files"
|
||||
|
||||
return False, None
|
||||
|
||||
|
||||
def _command_contains_forbidden_injection(command: str) -> tuple[bool, str | None]:
|
||||
"""Check if command attempts to inject forbidden patterns.
|
||||
|
||||
Args:
|
||||
command: The shell command to analyze.
|
||||
|
||||
Returns:
|
||||
Tuple of (has_injection, violation_description)
|
||||
"""
|
||||
# Check if the command itself contains forbidden patterns
|
||||
has_violation, violation_type = _contains_forbidden_pattern(command)
|
||||
|
||||
if has_violation:
|
||||
return True, violation_type
|
||||
|
||||
# Check for encoded or escaped patterns
|
||||
# Handle common escape sequences
|
||||
decoded_cmd = command.replace("\\n", "\n").replace("\\t", "\t")
|
||||
decoded_cmd = re.sub(r"\\\s", " ", decoded_cmd)
|
||||
|
||||
has_violation, violation_type = _contains_forbidden_pattern(decoded_cmd)
|
||||
if has_violation:
|
||||
return True, f"{violation_type} (escaped)"
|
||||
|
||||
return False, None
|
||||
|
||||
|
||||
def _analyze_bash_command(command: str) -> tuple[bool, list[str]]:
|
||||
"""Analyze bash command for safety violations.
|
||||
|
||||
Args:
|
||||
command: The bash command to analyze.
|
||||
|
||||
Returns:
|
||||
Tuple of (should_block, list_of_violations)
|
||||
"""
|
||||
violations: list[str] = []
|
||||
|
||||
# Check for forbidden pattern injection
|
||||
has_injection, injection_type = _command_contains_forbidden_injection(command)
|
||||
if has_injection:
|
||||
violations.append(f"⛔ Shell command attempts to inject {injection_type}")
|
||||
|
||||
# Check for dangerous shell patterns on Python files
|
||||
is_dangerous, danger_reason = _is_dangerous_shell_command(command)
|
||||
if is_dangerous:
|
||||
violations.append(
|
||||
f"⛔ {danger_reason} is forbidden - use Edit/Write tools instead",
|
||||
)
|
||||
|
||||
return len(violations) > 0, violations
|
||||
|
||||
|
||||
def _create_hook_response(
|
||||
event_name: str,
|
||||
permission: str = "",
|
||||
reason: str = "",
|
||||
system_message: str = "",
|
||||
*,
|
||||
decision: str | None = None,
|
||||
) -> JsonObject:
|
||||
"""Create standardized hook response.
|
||||
|
||||
Args:
|
||||
event_name: Name of the hook event (PreToolUse, PostToolUse, Stop).
|
||||
permission: Permission decision (allow, deny, ask).
|
||||
reason: Reason for the decision.
|
||||
system_message: System message to display.
|
||||
decision: Decision for PostToolUse/Stop hooks (approve, block).
|
||||
|
||||
Returns:
|
||||
JSON response object for the hook.
|
||||
"""
|
||||
hook_output: dict[str, object] = {
|
||||
"hookEventName": event_name,
|
||||
}
|
||||
|
||||
if permission:
|
||||
hook_output["permissionDecision"] = permission
|
||||
if reason:
|
||||
hook_output["permissionDecisionReason"] = reason
|
||||
|
||||
response: JsonObject = {
|
||||
"hookSpecificOutput": hook_output,
|
||||
}
|
||||
|
||||
if permission:
|
||||
response["permissionDecision"] = permission
|
||||
|
||||
if decision:
|
||||
response["decision"] = decision
|
||||
|
||||
if reason:
|
||||
response["reason"] = reason
|
||||
|
||||
if system_message:
|
||||
response["systemMessage"] = system_message
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def pretooluse_bash_hook(hook_data: dict[str, object]) -> JsonObject:
|
||||
"""Handle PreToolUse hook for Bash commands.
|
||||
|
||||
Args:
|
||||
hook_data: Hook input data containing tool_name and tool_input.
|
||||
|
||||
Returns:
|
||||
Hook response with permission decision.
|
||||
"""
|
||||
tool_name = str(hook_data.get("tool_name", ""))
|
||||
|
||||
# Only analyze Bash commands
|
||||
if tool_name != "Bash":
|
||||
return _create_hook_response("PreToolUse", "allow")
|
||||
|
||||
tool_input_raw = hook_data.get("tool_input", {})
|
||||
if not isinstance(tool_input_raw, dict):
|
||||
return _create_hook_response("PreToolUse", "allow")
|
||||
|
||||
tool_input: dict[str, object] = dict(tool_input_raw)
|
||||
command = str(tool_input.get("command", ""))
|
||||
|
||||
if not command:
|
||||
return _create_hook_response("PreToolUse", "allow")
|
||||
|
||||
# Analyze command for violations
|
||||
should_block, violations = _analyze_bash_command(command)
|
||||
|
||||
if not should_block:
|
||||
return _create_hook_response("PreToolUse", "allow")
|
||||
|
||||
# Build denial message
|
||||
violation_text = "\n".join(f" {v}" for v in violations)
|
||||
message = (
|
||||
f"🚫 Shell Command Blocked\n\n"
|
||||
f"Violations:\n{violation_text}\n\n"
|
||||
f"Command: {command[:200]}{'...' if len(command) > 200 else ''}\n\n"
|
||||
f"Use Edit/Write tools to modify Python files with proper type safety."
|
||||
)
|
||||
|
||||
return _create_hook_response(
|
||||
"PreToolUse",
|
||||
"deny",
|
||||
message,
|
||||
message,
|
||||
)
|
||||
|
||||
|
||||
def posttooluse_bash_hook(hook_data: dict[str, object]) -> JsonObject:
|
||||
"""Handle PostToolUse hook for Bash commands.
|
||||
|
||||
Args:
|
||||
hook_data: Hook output data containing tool_response.
|
||||
|
||||
Returns:
|
||||
Hook response with decision.
|
||||
"""
|
||||
tool_name = str(hook_data.get("tool_name", ""))
|
||||
|
||||
# Only analyze Bash commands
|
||||
if tool_name != "Bash":
|
||||
return _create_hook_response("PostToolUse")
|
||||
|
||||
# Extract command from hook data
|
||||
tool_input_raw = hook_data.get("tool_input", {})
|
||||
if not isinstance(tool_input_raw, dict):
|
||||
return _create_hook_response("PostToolUse")
|
||||
|
||||
tool_input: dict[str, object] = dict(tool_input_raw)
|
||||
command = str(tool_input.get("command", ""))
|
||||
|
||||
# Check if command modified any Python files
|
||||
# Look for file paths in the command
|
||||
python_files: list[str] = []
|
||||
for match in re.finditer(r"([^\s]+\.pyi?)\b", command):
|
||||
file_path = match.group(1)
|
||||
if Path(file_path).exists():
|
||||
python_files.append(file_path)
|
||||
|
||||
if not python_files:
|
||||
return _create_hook_response("PostToolUse")
|
||||
|
||||
# Scan modified files for violations
|
||||
violations: list[str] = []
|
||||
for file_path in python_files:
|
||||
try:
|
||||
with open(file_path, encoding="utf-8") as file_handle:
|
||||
content = file_handle.read()
|
||||
|
||||
has_violation, violation_type = _contains_forbidden_pattern(content)
|
||||
if has_violation:
|
||||
violations.append(
|
||||
f"⛔ File '{Path(file_path).name}' contains {violation_type}",
|
||||
)
|
||||
except (OSError, UnicodeDecodeError):
|
||||
# If we can't read the file, skip it
|
||||
continue
|
||||
|
||||
if violations:
|
||||
violation_text = "\n".join(f" {v}" for v in violations)
|
||||
message = (
|
||||
f"🚫 Post-Execution Violation Detected\n\n"
|
||||
f"Violations:\n{violation_text}\n\n"
|
||||
f"Shell command introduced forbidden patterns. "
|
||||
f"Please revert changes and use proper typing."
|
||||
)
|
||||
|
||||
return _create_hook_response(
|
||||
"PostToolUse",
|
||||
"",
|
||||
message,
|
||||
message,
|
||||
decision="block",
|
||||
)
|
||||
|
||||
return _create_hook_response("PostToolUse")
|
||||
|
||||
|
||||
def stop_hook(_hook_data: dict[str, object]) -> JsonObject:
|
||||
"""Handle Stop hook - final validation before completion.
|
||||
|
||||
Args:
|
||||
_hook_data: Stop hook data (unused).
|
||||
|
||||
Returns:
|
||||
Hook response with decision.
|
||||
"""
|
||||
# Get list of changed files from git
|
||||
try:
|
||||
git_path = which("git")
|
||||
if git_path is None:
|
||||
return _create_hook_response("Stop", decision="approve")
|
||||
|
||||
# Safe: invokes git with fixed arguments, no user input interpolation.
|
||||
result = subprocess.run( # noqa: S603
|
||||
[git_path, "diff", "--name-only", "--cached"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
# No git repo or no staged changes
|
||||
return _create_hook_response("Stop", decision="approve")
|
||||
|
||||
changed_files = [
|
||||
file_name.strip()
|
||||
for file_name in result.stdout.split("\n")
|
||||
if file_name.strip() and file_name.strip().endswith((".py", ".pyi"))
|
||||
]
|
||||
|
||||
if not changed_files:
|
||||
return _create_hook_response("Stop", decision="approve")
|
||||
|
||||
# Scan all changed Python files
|
||||
violations: list[str] = []
|
||||
for file_path in changed_files:
|
||||
if not Path(file_path).exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(file_path, encoding="utf-8") as file_handle:
|
||||
content = file_handle.read()
|
||||
|
||||
has_violation, violation_type = _contains_forbidden_pattern(content)
|
||||
if has_violation:
|
||||
violations.append(f"⛔ {file_path}: {violation_type}")
|
||||
except (OSError, UnicodeDecodeError):
|
||||
continue
|
||||
|
||||
if violations:
|
||||
violation_text = "\n".join(f" {v}" for v in violations)
|
||||
message = (
|
||||
f"🚫 Final Validation Failed\n\n"
|
||||
f"Violations:\n{violation_text}\n\n"
|
||||
f"Please remove forbidden patterns before completing."
|
||||
)
|
||||
|
||||
return _create_hook_response(
|
||||
"Stop",
|
||||
"",
|
||||
message,
|
||||
message,
|
||||
decision="block",
|
||||
)
|
||||
|
||||
return _create_hook_response("Stop", decision="approve")
|
||||
|
||||
except (OSError, subprocess.SubprocessError, TimeoutError) as exc:
|
||||
# If validation fails, allow but warn
|
||||
return _create_hook_response(
|
||||
"Stop",
|
||||
"",
|
||||
f"Warning: Final validation error: {exc}",
|
||||
f"Warning: Final validation error: {exc}",
|
||||
decision="approve",
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main hook entry point."""
|
||||
try:
|
||||
# Read hook input from stdin
|
||||
try:
|
||||
hook_data: dict[str, object] = json.load(sys.stdin)
|
||||
except json.JSONDecodeError:
|
||||
fallback_response: JsonObject = {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "allow",
|
||||
},
|
||||
}
|
||||
sys.stdout.write(json.dumps(fallback_response))
|
||||
return
|
||||
|
||||
# Detect hook type
|
||||
response: JsonObject
|
||||
|
||||
if "tool_response" in hook_data or "tool_output" in hook_data:
|
||||
# PostToolUse hook
|
||||
response = posttooluse_bash_hook(hook_data)
|
||||
elif hook_data.get("hookEventName") == "Stop":
|
||||
# Stop hook
|
||||
response = stop_hook(hook_data)
|
||||
else:
|
||||
# PreToolUse hook
|
||||
response = pretooluse_bash_hook(hook_data)
|
||||
|
||||
sys.stdout.write(json.dumps(response))
|
||||
|
||||
# Handle exit codes based on hook output
|
||||
hook_output_raw = response.get("hookSpecificOutput", {})
|
||||
if not hook_output_raw or not isinstance(hook_output_raw, dict):
|
||||
return
|
||||
|
||||
hook_output: dict[str, object] = hook_output_raw
|
||||
permission_decision = hook_output.get("permissionDecision")
|
||||
|
||||
if permission_decision == "deny":
|
||||
# Exit code 2: Blocking error
|
||||
reason = str(
|
||||
hook_output.get("permissionDecisionReason", "Permission denied"),
|
||||
)
|
||||
sys.stderr.write(reason)
|
||||
sys.exit(2)
|
||||
elif permission_decision == "ask":
|
||||
# Exit code 2 for ask decisions
|
||||
reason = str(
|
||||
hook_output.get("permissionDecisionReason", "Permission request"),
|
||||
)
|
||||
sys.stderr.write(reason)
|
||||
sys.exit(2)
|
||||
|
||||
# Check for Stop hook block decision
|
||||
if response.get("decision") == "block":
|
||||
reason = str(response.get("reason", "Validation failed"))
|
||||
sys.stderr.write(reason)
|
||||
sys.exit(2)
|
||||
|
||||
except (OSError, ValueError, subprocess.SubprocessError, TimeoutError) as exc:
|
||||
# Unexpected error - use exit code 1 (non-blocking error)
|
||||
sys.stderr.write(f"Hook error: {exc}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
72
hooks/bash_guard_constants.py
Normal file
72
hooks/bash_guard_constants.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Shared constants for bash command guard functionality.
|
||||
|
||||
This module contains patterns and constants used to detect forbidden code patterns
|
||||
and dangerous shell commands that could compromise type safety.
|
||||
"""
|
||||
|
||||
# Forbidden patterns that should never appear in Python code
|
||||
FORBIDDEN_PATTERNS = [
|
||||
r"\bfrom\s+typing\s+import\s+.*\bAny\b", # from typing import Any
|
||||
r"\bimport\s+typing\.Any\b", # import typing.Any
|
||||
r"\btyping\.Any\b", # typing.Any reference
|
||||
r"\b:\s*Any\b", # Type annotation with Any
|
||||
r"->\s*Any\b", # Return type Any
|
||||
r"#\s*type:\s*ignore", # type suppression comment
|
||||
]
|
||||
|
||||
# Shell command patterns that can modify files
|
||||
DANGEROUS_SHELL_PATTERNS = [
|
||||
# Direct file writes
|
||||
r"\becho\s+.*>",
|
||||
r"\bprintf\s+.*>",
|
||||
r"\bcat\s+>",
|
||||
r"\btee\s+",
|
||||
# Stream editors and text processors
|
||||
r"\bsed\s+",
|
||||
r"\bawk\s+",
|
||||
r"\bperl\s+",
|
||||
r"\bed\s+",
|
||||
# Mass operations
|
||||
r"\bfind\s+.*-exec",
|
||||
r"\bxargs\s+",
|
||||
r"\bgrep\s+.*\|\s*xargs",
|
||||
# Python execution with file operations
|
||||
r"\bpython\s+-c\s+.*open\(",
|
||||
r"\bpython\s+-c\s+.*write\(",
|
||||
r"\bpython3\s+-c\s+.*open\(",
|
||||
r"\bpython3\s+-c\s+.*write\(",
|
||||
# Editor batch modes
|
||||
r"\bvim\s+-c",
|
||||
r"\bnano\s+--tempfile",
|
||||
r"\bemacs\s+--batch",
|
||||
]
|
||||
|
||||
# Python file patterns to protect
|
||||
PYTHON_FILE_PATTERNS = [
|
||||
r"\.py\b",
|
||||
r"\.pyi\b",
|
||||
]
|
||||
|
||||
# Pattern descriptions for error messages
|
||||
FORBIDDEN_PATTERN_DESCRIPTIONS = {
|
||||
"Any": "typing.Any usage",
|
||||
"type.*ignore": "type suppression comment",
|
||||
}
|
||||
|
||||
# Tool names extracted from dangerous patterns
|
||||
DANGEROUS_TOOLS = [
|
||||
"sed",
|
||||
"awk",
|
||||
"perl",
|
||||
"ed",
|
||||
"echo",
|
||||
"printf",
|
||||
"cat",
|
||||
"tee",
|
||||
"find",
|
||||
"xargs",
|
||||
"python",
|
||||
"vim",
|
||||
"nano",
|
||||
"emacs",
|
||||
]
|
||||
@@ -13,6 +13,8 @@ import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from importlib import import_module
|
||||
import textwrap
|
||||
import tokenize
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
@@ -20,237 +22,261 @@ from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import TypedDict, cast
|
||||
from tempfile import NamedTemporaryFile, gettempdir
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
# Import internal duplicate detector
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from internal_duplicate_detector import detect_internal_duplicates
|
||||
# Import internal duplicate detector; fall back to local path when executed directly
|
||||
if TYPE_CHECKING:
|
||||
from .internal_duplicate_detector import (
|
||||
Duplicate,
|
||||
DuplicateResults,
|
||||
detect_internal_duplicates,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
from .internal_duplicate_detector import (
|
||||
Duplicate,
|
||||
DuplicateResults,
|
||||
detect_internal_duplicates,
|
||||
)
|
||||
except ImportError:
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
module = import_module("internal_duplicate_detector")
|
||||
Duplicate = module.Duplicate
|
||||
DuplicateResults = module.DuplicateResults
|
||||
detect_internal_duplicates = module.detect_internal_duplicates
|
||||
|
||||
SUSPICIOUS_SUFFIXES: tuple[str, ...] = (
|
||||
"enhanced",
|
||||
"improved",
|
||||
"better",
|
||||
"new",
|
||||
"updated",
|
||||
"modified",
|
||||
"refactored",
|
||||
"optimized",
|
||||
"fixed",
|
||||
"clean",
|
||||
"simple",
|
||||
"advanced",
|
||||
"basic",
|
||||
"complete",
|
||||
"final",
|
||||
"latest",
|
||||
"current",
|
||||
"temp",
|
||||
"temporary",
|
||||
"backup",
|
||||
"old",
|
||||
"legacy",
|
||||
"unified",
|
||||
"merged",
|
||||
"combined",
|
||||
"integrated",
|
||||
"consolidated",
|
||||
"extended",
|
||||
"enriched",
|
||||
"augmented",
|
||||
"upgraded",
|
||||
"revised",
|
||||
"polished",
|
||||
"streamlined",
|
||||
"simplified",
|
||||
"modernized",
|
||||
"normalized",
|
||||
"sanitized",
|
||||
"validated",
|
||||
"verified",
|
||||
"corrected",
|
||||
"patched",
|
||||
"stable",
|
||||
"experimental",
|
||||
"alpha",
|
||||
"beta",
|
||||
"draft",
|
||||
"preliminary",
|
||||
"prototype",
|
||||
"working",
|
||||
"test",
|
||||
"debug",
|
||||
"custom",
|
||||
"special",
|
||||
"generic",
|
||||
"specific",
|
||||
"general",
|
||||
"detailed",
|
||||
"minimal",
|
||||
"full",
|
||||
"partial",
|
||||
"quick",
|
||||
"fast",
|
||||
"slow",
|
||||
"smart",
|
||||
"intelligent",
|
||||
"auto",
|
||||
"manual",
|
||||
"secure",
|
||||
"safe",
|
||||
"robust",
|
||||
"flexible",
|
||||
"dynamic",
|
||||
"static",
|
||||
"reactive",
|
||||
"async",
|
||||
"sync",
|
||||
"parallel",
|
||||
"serial",
|
||||
"distributed",
|
||||
"centralized",
|
||||
"decentralized",
|
||||
)
|
||||
|
||||
FILE_SUFFIX_DUPLICATE_MSG = (
|
||||
"⚠️ File '{current}' appears to be a suffixed duplicate of '{original}'. "
|
||||
"Consider refactoring instead of creating variations with adjective "
|
||||
"suffixes."
|
||||
)
|
||||
|
||||
EXISTING_FILE_DUPLICATE_MSG = (
|
||||
"⚠️ Creating '{current}' when '{existing}' already exists suggests "
|
||||
"duplication. Consider consolidating or using a more descriptive name."
|
||||
)
|
||||
|
||||
NAME_SUFFIX_DUPLICATE_MSG = (
|
||||
"⚠️ {kind} '{name}' appears to be a suffixed duplicate of '{base}'. "
|
||||
"Consider refactoring instead of creating variations."
|
||||
)
|
||||
|
||||
|
||||
class QualityConfig:
|
||||
"""Configuration for quality checks."""
|
||||
def get_external_context(
|
||||
rule_id: str,
|
||||
content: str,
|
||||
_file_path: str,
|
||||
config: "QualityConfig",
|
||||
) -> str:
|
||||
"""Fetch additional guidance from optional integrations."""
|
||||
context_parts: list[str] = []
|
||||
|
||||
def __init__(self):
|
||||
self._config = config = QualityConfig.from_env()
|
||||
config.skip_patterns = ["test_", "_test.py", "/tests/", "/fixtures/"]
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._config, name, self._config[name])
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "QualityConfig":
|
||||
"""Load config from environment variables."""
|
||||
return cls()
|
||||
|
||||
|
||||
def get_external_context(rule_id: str, content: str, file_path: str, config: QualityConfig) -> str:
|
||||
"""Get additional context from external APIs for enhanced guidance."""
|
||||
context_parts = []
|
||||
|
||||
# Context7 integration for additional context analysis
|
||||
if config.context7_enabled and config.context7_api_key:
|
||||
try:
|
||||
# Note: This would integrate with actual Context7 API
|
||||
# For now, providing placeholder for the integration
|
||||
context7_context = _get_context7_analysis(rule_id, content, config.context7_api_key)
|
||||
context7_context = _get_context7_analysis(
|
||||
rule_id,
|
||||
content,
|
||||
config.context7_api_key,
|
||||
)
|
||||
except (OSError, RuntimeError, ValueError, TimeoutError) as exc:
|
||||
logging.debug("Context7 API call failed: %s", exc)
|
||||
else:
|
||||
if context7_context:
|
||||
context_parts.append(f"📊 Context7 Analysis: {context7_context}")
|
||||
except Exception as e:
|
||||
logging.debug("Context7 API call failed: %s", e)
|
||||
|
||||
# Firecrawl integration for web scraping additional examples
|
||||
if config.firecrawl_enabled and config.firecrawl_api_key:
|
||||
try:
|
||||
# Note: This would integrate with actual Firecrawl API
|
||||
# For now, providing placeholder for the integration
|
||||
firecrawl_examples = _get_firecrawl_examples(rule_id, config.firecrawl_api_key)
|
||||
firecrawl_examples = _get_firecrawl_examples(
|
||||
rule_id,
|
||||
config.firecrawl_api_key,
|
||||
)
|
||||
except (OSError, RuntimeError, ValueError, TimeoutError) as exc:
|
||||
logging.debug("Firecrawl API call failed: %s", exc)
|
||||
else:
|
||||
if firecrawl_examples:
|
||||
context_parts.append(f"🔗 Additional Examples: {firecrawl_examples}")
|
||||
except Exception as e:
|
||||
logging.debug("Firecrawl API call failed: %s", e)
|
||||
|
||||
return "\n\n".join(context_parts) if context_parts else ""
|
||||
return "\n\n".join(context_parts)
|
||||
|
||||
|
||||
def _get_context7_analysis(rule_id: str, content: str, api_key: str) -> str:
|
||||
def _get_context7_analysis(
|
||||
rule_id: str,
|
||||
_content: str,
|
||||
_api_key: str,
|
||||
) -> str:
|
||||
"""Placeholder for Context7 API integration."""
|
||||
# This would make actual API calls to Context7
|
||||
# For demonstration, returning contextual analysis based on rule type
|
||||
context_map = {
|
||||
"no-conditionals-in-tests": "Consider using data-driven tests with pytest.mark.parametrize for better test isolation.",
|
||||
"no-loop-in-tests": "Each iteration should be a separate test case for better failure isolation and debugging.",
|
||||
"raise-specific-error": "Specific exceptions improve error handling and make tests more maintainable.",
|
||||
"dont-import-test-modules": "Keep test utilities separate from production code to maintain clean architecture."
|
||||
"no-conditionals-in-tests": (
|
||||
"Use pytest.mark.parametrize instead of branching inside tests."
|
||||
),
|
||||
"no-loop-in-tests": (
|
||||
"Split loops into individual tests or parameterize the data for clarity."
|
||||
),
|
||||
"raise-specific-error": (
|
||||
"Raise precise exception types (ValueError, custom errors, etc.) "
|
||||
"so failures highlight the right behaviour."
|
||||
),
|
||||
"dont-import-test-modules": (
|
||||
"Production code should not depend on test helpers; "
|
||||
"extract shared logic into utility modules."
|
||||
),
|
||||
}
|
||||
return context_map.get(rule_id, "Additional context analysis would be provided here.")
|
||||
|
||||
return context_map.get(
|
||||
rule_id,
|
||||
"Context7 guidance will appear once the integration is available.",
|
||||
)
|
||||
|
||||
|
||||
def _get_firecrawl_examples(rule_id: str, api_key: str) -> str:
|
||||
def _get_firecrawl_examples(rule_id: str, _api_key: str) -> str:
|
||||
"""Placeholder for Firecrawl API integration."""
|
||||
# This would scrape web for additional examples and best practices
|
||||
# For demonstration, returning relevant examples based on rule type
|
||||
examples_map = {
|
||||
"no-conditionals-in-tests": "See pytest documentation on parameterized tests for multiple scenarios.",
|
||||
"no-loop-in-tests": "Check testing best practices guides for data-driven testing approaches.",
|
||||
"raise-specific-error": "Review Python exception hierarchy and custom exception patterns.",
|
||||
"dont-import-test-modules": "Explore clean architecture patterns for test organization."
|
||||
"no-conditionals-in-tests": (
|
||||
"Pytest parameterization tutorial: "
|
||||
"docs.pytest.org/en/latest/how-to/parametrize.html"
|
||||
),
|
||||
"no-loop-in-tests": (
|
||||
"Testing best practices: keep scenarios in separate "
|
||||
"tests for precise failures."
|
||||
),
|
||||
"raise-specific-error": (
|
||||
"Python docs on exceptions: docs.python.org/3/tutorial/errors.html"
|
||||
),
|
||||
"dont-import-test-modules": (
|
||||
"Clean architecture tip: production modules should not import from tests."
|
||||
),
|
||||
}
|
||||
return examples_map.get(rule_id, "Additional examples and patterns would be sourced from web resources.")
|
||||
|
||||
def generate_test_quality_guidance(rule_id: str, content: str, file_path: str, config: QualityConfig) -> str:
|
||||
"""Generate specific, actionable guidance for test quality rule violations."""
|
||||
return examples_map.get(
|
||||
rule_id,
|
||||
"Firecrawl examples will appear once the integration is available.",
|
||||
)
|
||||
|
||||
# Extract function name and context for better guidance
|
||||
|
||||
def generate_test_quality_guidance(
|
||||
rule_id: str,
|
||||
content: str,
|
||||
file_path: str,
|
||||
_config: "QualityConfig",
|
||||
) -> str:
|
||||
"""Return concise guidance for test quality rule violations."""
|
||||
function_name = "test_function"
|
||||
if "def " in content:
|
||||
# Try to extract the test function name
|
||||
import re
|
||||
match = re.search(r'def\s+(\w+)\s*\(', content)
|
||||
if match:
|
||||
function_name = match.group(1)
|
||||
match = re.search(r"def\s+(\w+)\s*\(", content)
|
||||
if match:
|
||||
function_name = match.group(1)
|
||||
|
||||
file_name = Path(file_path).name
|
||||
|
||||
guidance_map = {
|
||||
"no-conditionals-in-tests": {
|
||||
"title": "Conditional Logic in Test Function",
|
||||
"problem": f"Test function '{function_name}' contains conditional statements (if/elif/else).",
|
||||
"why": "Tests should be simple assertions that verify specific behavior. Conditionals make tests harder to understand and maintain.",
|
||||
"solutions": [
|
||||
"✅ Replace conditionals with parameterized test cases",
|
||||
"✅ Use pytest.mark.parametrize for multiple scenarios",
|
||||
"✅ Extract conditional logic into helper functions",
|
||||
"✅ Use assertion libraries like assertpy for complex conditions"
|
||||
],
|
||||
"examples": [
|
||||
"# ❌ Instead of this:",
|
||||
"def test_user_access():",
|
||||
" user = create_user()",
|
||||
" if user.is_admin:",
|
||||
" assert user.can_access_admin()",
|
||||
" else:",
|
||||
" assert not user.can_access_admin()",
|
||||
"",
|
||||
"# ✅ Do this:",
|
||||
"@pytest.mark.parametrize('is_admin,can_access', [",
|
||||
" (True, True),",
|
||||
" (False, False)",
|
||||
"])",
|
||||
"def test_user_access(is_admin, can_access):",
|
||||
" user = create_user(admin=is_admin)",
|
||||
" assert user.can_access_admin() == can_access"
|
||||
]
|
||||
},
|
||||
"no-loop-in-tests": {
|
||||
"title": "Loop Found in Test Function",
|
||||
"problem": f"Test function '{function_name}' contains loops (for/while).",
|
||||
"why": "Loops in tests often indicate testing multiple scenarios that should be separate test cases.",
|
||||
"solutions": [
|
||||
"✅ Break loop into individual test cases",
|
||||
"✅ Use parameterized tests for multiple data scenarios",
|
||||
"✅ Extract loop logic into data providers or fixtures",
|
||||
"✅ Use pytest fixtures for setup that requires iteration"
|
||||
],
|
||||
"examples": [
|
||||
"# ❌ Instead of this:",
|
||||
"def test_process_items():",
|
||||
" for item in test_items:",
|
||||
" result = process_item(item)",
|
||||
" assert result.success",
|
||||
"",
|
||||
"# ✅ Do this:",
|
||||
"@pytest.mark.parametrize('item', test_items)",
|
||||
"def test_process_item(item):",
|
||||
" result = process_item(item)",
|
||||
" assert result.success"
|
||||
]
|
||||
},
|
||||
"raise-specific-error": {
|
||||
"title": "Generic Exception in Test",
|
||||
"problem": f"Test function '{function_name}' raises generic Exception or BaseException.",
|
||||
"why": "Generic exceptions make it harder to identify specific test failures and handle different error conditions appropriately.",
|
||||
"solutions": [
|
||||
"✅ Use specific exception types (ValueError, TypeError, etc.)",
|
||||
"✅ Create custom exception classes for domain-specific errors",
|
||||
"✅ Use pytest.raises() with specific exception types",
|
||||
"✅ Test for expected exceptions, not generic ones"
|
||||
],
|
||||
"examples": [
|
||||
"# ❌ Instead of this:",
|
||||
"def test_invalid_input():",
|
||||
" with pytest.raises(Exception):",
|
||||
" process_invalid_data()",
|
||||
"",
|
||||
"# ✅ Do this:",
|
||||
"def test_invalid_input():",
|
||||
" with pytest.raises(ValueError, match='Invalid input'):",
|
||||
" process_invalid_data()",
|
||||
"",
|
||||
"# Or create custom exceptions:",
|
||||
"def test_business_logic():",
|
||||
" with pytest.raises(InsufficientFundsError):",
|
||||
" account.withdraw(1000)"
|
||||
]
|
||||
},
|
||||
"dont-import-test-modules": {
|
||||
"title": "Test Module Import in Non-Test Code",
|
||||
"problem": f"File '{Path(file_path).name}' imports test modules outside of test directories.",
|
||||
"why": "Test modules should only be imported by other test files. Production code should not depend on test utilities.",
|
||||
"solutions": [
|
||||
"✅ Move shared test utilities to a separate utils module",
|
||||
"✅ Create production versions of test helper functions",
|
||||
"✅ Use dependency injection instead of direct test imports",
|
||||
"✅ Extract common logic into a shared production module"
|
||||
],
|
||||
"examples": [
|
||||
"# ❌ Instead of this (in production code):",
|
||||
"from tests.test_helpers import create_test_data",
|
||||
"",
|
||||
"# ✅ Do this:",
|
||||
"# Create src/utils/test_data_factory.py",
|
||||
"from src.utils.test_data_factory import create_test_data",
|
||||
"",
|
||||
"# Or use fixtures in tests:",
|
||||
"@pytest.fixture",
|
||||
"def test_data():",
|
||||
" return create_production_data()"
|
||||
]
|
||||
}
|
||||
"no-conditionals-in-tests": (
|
||||
f"Test {function_name} contains conditional logic. "
|
||||
"Parameterize or split tests so each scenario is explicit."
|
||||
),
|
||||
"no-loop-in-tests": (
|
||||
f"Test {function_name} iterates over data. Break the loop into "
|
||||
"separate tests or use pytest parameterization."
|
||||
),
|
||||
"raise-specific-error": (
|
||||
f"Test {function_name} asserts generic exceptions. "
|
||||
"Assert specific types to document the expected behaviour."
|
||||
),
|
||||
"dont-import-test-modules": (
|
||||
f"File {file_name} imports from tests. Move shared helpers into a "
|
||||
"production module or provide them via fixtures."
|
||||
),
|
||||
}
|
||||
|
||||
guidance = guidance_map.get(rule_id, {
|
||||
"title": f"Test Quality Issue: {rule_id}",
|
||||
"problem": f"Rule '{rule_id}' violation detected in test code.",
|
||||
"why": "This rule helps maintain test quality and best practices.",
|
||||
"solutions": [
|
||||
"✅ Review the specific rule requirements",
|
||||
"✅ Refactor code to follow test best practices",
|
||||
"✅ Consult testing framework documentation"
|
||||
],
|
||||
"examples": [
|
||||
"# Review the rule documentation for specific guidance",
|
||||
"# Consider the test's purpose and refactor accordingly"
|
||||
]
|
||||
})
|
||||
|
||||
# Format the guidance message
|
||||
message = f"""
|
||||
🚫 {guidance['title']}
|
||||
|
||||
📋 Problem: {guidance['problem']}
|
||||
|
||||
❓ Why this matters: {guidance['why']}
|
||||
|
||||
🛠️ How to fix it:
|
||||
{chr(10).join(f" • {solution}" for solution in guidance['solutions'])}
|
||||
|
||||
💡 Example:
|
||||
{chr(10).join(f" {line}" for line in guidance['examples'])}
|
||||
|
||||
🔍 File: {Path(file_path).name}
|
||||
📍 Function: {function_name}
|
||||
"""
|
||||
|
||||
return message.strip()
|
||||
return guidance_map.get(
|
||||
rule_id,
|
||||
"Keep tests behaviour-focused: avoid conditionals, loops, generic exceptions, "
|
||||
"and production dependencies on test helpers.",
|
||||
)
|
||||
|
||||
|
||||
class ToolConfig(TypedDict):
|
||||
@@ -261,27 +287,6 @@ class ToolConfig(TypedDict):
|
||||
error_message: str | Callable[[subprocess.CompletedProcess[str]], str]
|
||||
|
||||
|
||||
class DuplicateLocation(TypedDict):
|
||||
"""Location information for a duplicate code block."""
|
||||
|
||||
name: str
|
||||
lines: str
|
||||
|
||||
|
||||
class Duplicate(TypedDict):
|
||||
"""Duplicate code detection result."""
|
||||
|
||||
similarity: float
|
||||
description: str
|
||||
locations: list[DuplicateLocation]
|
||||
|
||||
|
||||
class DuplicateResults(TypedDict):
|
||||
"""Results from duplicate detection analysis."""
|
||||
|
||||
duplicates: list[Duplicate]
|
||||
|
||||
|
||||
class ComplexitySummary(TypedDict):
|
||||
"""Summary of complexity analysis."""
|
||||
|
||||
@@ -398,13 +403,17 @@ class QualityConfig:
|
||||
pyrefly_enabled=os.getenv("QUALITY_PYREFLY_ENABLED", "true").lower()
|
||||
== "true",
|
||||
type_check_exit_code=int(os.getenv("QUALITY_TYPE_CHECK_EXIT_CODE", "2")),
|
||||
test_quality_enabled=os.getenv("QUALITY_TEST_QUALITY_ENABLED", "true").lower()
|
||||
== "true",
|
||||
context7_enabled=os.getenv("QUALITY_CONTEXT7_ENABLED", "false").lower()
|
||||
== "true",
|
||||
test_quality_enabled=(
|
||||
os.getenv("QUALITY_TEST_QUALITY_ENABLED", "true").lower()
|
||||
== "true"
|
||||
),
|
||||
context7_enabled=(
|
||||
os.getenv("QUALITY_CONTEXT7_ENABLED", "false").lower() == "true"
|
||||
),
|
||||
context7_api_key=os.getenv("QUALITY_CONTEXT7_API_KEY", ""),
|
||||
firecrawl_enabled=os.getenv("QUALITY_FIRECRAWL_ENABLED", "false").lower()
|
||||
== "true",
|
||||
firecrawl_enabled=(
|
||||
os.getenv("QUALITY_FIRECRAWL_ENABLED", "false").lower() == "true"
|
||||
),
|
||||
firecrawl_api_key=os.getenv("QUALITY_FIRECRAWL_API_KEY", ""),
|
||||
)
|
||||
|
||||
@@ -418,26 +427,29 @@ def should_skip_file(file_path: str, config: QualityConfig) -> bool:
|
||||
|
||||
def _module_candidate(path: Path) -> tuple[Path, list[str]]:
|
||||
"""Build a module invocation candidate for a python executable."""
|
||||
|
||||
return path, [str(path), "-m", "quality.cli.main"]
|
||||
|
||||
|
||||
def _cli_candidate(path: Path) -> tuple[Path, list[str]]:
|
||||
"""Build a direct CLI invocation candidate."""
|
||||
|
||||
return path, [str(path)]
|
||||
|
||||
|
||||
def get_claude_quality_command(repo_root: Path | None = None) -> list[str]:
|
||||
"""Return a path-resilient command for invoking claude-quality."""
|
||||
|
||||
repo_root = repo_root or Path(__file__).resolve().parent.parent
|
||||
platform_name = sys.platform
|
||||
is_windows = platform_name.startswith("win")
|
||||
|
||||
scripts_dir = repo_root / ".venv" / ("Scripts" if is_windows else "bin")
|
||||
python_names = ["python.exe", "python3.exe"] if is_windows else ["python", "python3"]
|
||||
cli_names = ["claude-quality.exe", "claude-quality"] if is_windows else ["claude-quality"]
|
||||
python_names = (
|
||||
["python.exe", "python3.exe"] if is_windows else ["python", "python3"]
|
||||
)
|
||||
cli_names = (
|
||||
["claude-quality.exe", "claude-quality"]
|
||||
if is_windows
|
||||
else ["claude-quality"]
|
||||
)
|
||||
|
||||
candidates: list[tuple[Path, list[str]]] = []
|
||||
for name in python_names:
|
||||
@@ -457,9 +469,11 @@ def get_claude_quality_command(repo_root: Path | None = None) -> list[str]:
|
||||
if shutil.which("claude-quality"):
|
||||
return ["claude-quality"]
|
||||
|
||||
raise RuntimeError(
|
||||
"'claude-quality' was not found on PATH. Please ensure it is installed and available."
|
||||
message = (
|
||||
"'claude-quality' was not found on PATH. Please ensure it is installed and "
|
||||
"available."
|
||||
)
|
||||
raise RuntimeError(message)
|
||||
|
||||
|
||||
def _get_project_venv_bin(file_path: str | None = None) -> Path:
|
||||
@@ -470,8 +484,16 @@ def _get_project_venv_bin(file_path: str | None = None) -> Path:
|
||||
If not provided, uses current working directory.
|
||||
"""
|
||||
# Start from the file's directory if provided, otherwise from cwd
|
||||
if file_path and not file_path.startswith("/tmp"):
|
||||
start_path = Path(file_path).resolve().parent
|
||||
temp_root = Path(gettempdir()).resolve()
|
||||
if file_path:
|
||||
resolved_path = Path(file_path).resolve()
|
||||
try:
|
||||
if resolved_path.is_relative_to(temp_root):
|
||||
start_path = Path.cwd()
|
||||
else:
|
||||
start_path = resolved_path.parent
|
||||
except ValueError:
|
||||
start_path = resolved_path.parent
|
||||
else:
|
||||
start_path = Path.cwd()
|
||||
|
||||
@@ -501,7 +523,7 @@ def _format_basedpyright_errors(json_output: str) -> str:
|
||||
# Group by severity and format
|
||||
errors = []
|
||||
for diag in diagnostics[:10]: # Limit to first 10 errors
|
||||
severity = diag.get("severity", "error")
|
||||
severity = diag.get("severity", "error").upper()
|
||||
message = diag.get("message", "Unknown error")
|
||||
rule = diag.get("rule", "")
|
||||
range_info = diag.get("range", {})
|
||||
@@ -509,7 +531,7 @@ def _format_basedpyright_errors(json_output: str) -> str:
|
||||
line = start.get("line", 0) + 1 # Convert 0-indexed to 1-indexed
|
||||
|
||||
rule_text = f" [{rule}]" if rule else ""
|
||||
errors.append(f" Line {line}: {message}{rule_text}")
|
||||
errors.append(f" [{severity}] Line {line}: {message}{rule_text}")
|
||||
|
||||
count = len(diagnostics)
|
||||
summary = f"Found {count} type error{'s' if count != 1 else ''}"
|
||||
@@ -569,7 +591,8 @@ def _format_sourcery_errors(output: str) -> str:
|
||||
formatted_lines.append(line)
|
||||
|
||||
if issue_count > 0:
|
||||
summary = f"Found {issue_count} code quality issue{'s' if issue_count != 1 else ''}"
|
||||
plural = "issue" if issue_count == 1 else "issues"
|
||||
summary = f"Found {issue_count} code quality {plural}"
|
||||
return f"{summary}:\n" + "\n".join(formatted_lines)
|
||||
|
||||
return output.strip()
|
||||
@@ -659,7 +682,7 @@ def _run_type_checker(
|
||||
else:
|
||||
env["PYTHONPATH"] = str(src_dir)
|
||||
|
||||
# Run type checker from project root so it finds pyrightconfig.json and other configs
|
||||
# Run type checker from project root to pick up project configuration files
|
||||
result = subprocess.run( # noqa: S603
|
||||
cmd,
|
||||
check=False,
|
||||
@@ -754,13 +777,11 @@ def _run_quality_analyses(
|
||||
|
||||
# First check for internal duplicates within the file
|
||||
if config.duplicate_enabled:
|
||||
internal_duplicates_raw = detect_internal_duplicates(
|
||||
internal_duplicates = detect_internal_duplicates(
|
||||
content,
|
||||
threshold=config.duplicate_threshold,
|
||||
min_lines=4,
|
||||
)
|
||||
# Cast after runtime validation - function returns compatible structure
|
||||
internal_duplicates = cast("DuplicateResults", internal_duplicates_raw)
|
||||
if internal_duplicates.get("duplicates"):
|
||||
results["internal_duplicates"] = internal_duplicates
|
||||
|
||||
@@ -857,7 +878,7 @@ def _find_project_root(file_path: str) -> Path:
|
||||
# Look for common project markers
|
||||
while current != current.parent:
|
||||
if any((current / marker).exists() for marker in [
|
||||
".git", "pyrightconfig.json", "pyproject.toml", ".venv", "setup.py"
|
||||
".git", "pyrightconfig.json", "pyproject.toml", ".venv", "setup.py",
|
||||
]):
|
||||
return current
|
||||
current = current.parent
|
||||
@@ -1226,7 +1247,8 @@ def _detect_any_usage(content: str) -> list[str]:
|
||||
|
||||
lines_with_any: set[int] = set()
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
# Dedent the content to handle code fragments with leading indentation
|
||||
tree = ast.parse(textwrap.dedent(content))
|
||||
except SyntaxError:
|
||||
for index, line in enumerate(content.splitlines(), start=1):
|
||||
code_portion = line.split("#", 1)[0]
|
||||
@@ -1247,23 +1269,24 @@ def _detect_any_usage(content: str) -> list[str]:
|
||||
|
||||
return [
|
||||
"⚠️ Forbidden typing.Any usage at line(s) "
|
||||
f"{display_lines}; replace with specific types"
|
||||
f"{display_lines}; replace with specific types",
|
||||
]
|
||||
|
||||
|
||||
def _detect_type_ignore_usage(content: str) -> list[str]:
|
||||
"""Detect forbidden # type: ignore usage in proposed content."""
|
||||
|
||||
pattern = re.compile(r"#\s*type:\s*ignore(?:\b|\[)", re.IGNORECASE)
|
||||
lines_with_type_ignore: set[int] = set()
|
||||
|
||||
try:
|
||||
# Dedent the content to handle code fragments with leading indentation
|
||||
dedented_content = textwrap.dedent(content)
|
||||
for token_type, token_string, start, _, _ in tokenize.generate_tokens(
|
||||
StringIO(content).readline,
|
||||
StringIO(dedented_content).readline,
|
||||
):
|
||||
if token_type == tokenize.COMMENT and pattern.search(token_string):
|
||||
lines_with_type_ignore.add(start[0])
|
||||
except tokenize.TokenError:
|
||||
except (tokenize.TokenError, IndentationError):
|
||||
for index, line in enumerate(content.splitlines(), start=1):
|
||||
if pattern.search(line):
|
||||
lines_with_type_ignore.add(index)
|
||||
@@ -1278,28 +1301,32 @@ def _detect_type_ignore_usage(content: str) -> list[str]:
|
||||
|
||||
return [
|
||||
"⚠️ Forbidden # type: ignore usage at line(s) "
|
||||
f"{display_lines}; remove the suppression and fix typing issues instead"
|
||||
f"{display_lines}; remove the suppression and fix typing issues instead",
|
||||
]
|
||||
|
||||
|
||||
def _detect_old_typing_patterns(content: str) -> list[str]:
|
||||
"""Detect old typing patterns that should use modern syntax."""
|
||||
issues: list[str] = []
|
||||
|
||||
# Old typing imports that should be replaced
|
||||
old_patterns = {
|
||||
r'\bfrom typing import.*\bUnion\b': 'Use | syntax instead of Union (e.g., str | int)',
|
||||
r'\bfrom typing import.*\bOptional\b': 'Use | None syntax instead of Optional (e.g., str | None)',
|
||||
r'\bfrom typing import.*\bList\b': 'Use list[T] instead of List[T]',
|
||||
r'\bfrom typing import.*\bDict\b': 'Use dict[K, V] instead of Dict[K, V]',
|
||||
r'\bfrom typing import.*\bSet\b': 'Use set[T] instead of Set[T]',
|
||||
r'\bfrom typing import.*\bTuple\b': 'Use tuple[T, ...] instead of Tuple[T, ...]',
|
||||
r'\bUnion\s*\[': 'Use | syntax instead of Union (e.g., str | int)',
|
||||
r'\bOptional\s*\[': 'Use | None syntax instead of Optional (e.g., str | None)',
|
||||
r'\bList\s*\[': 'Use list[T] instead of List[T]',
|
||||
r'\bDict\s*\[': 'Use dict[K, V] instead of Dict[K, V]',
|
||||
r'\bSet\s*\[': 'Use set[T] instead of Set[T]',
|
||||
r'\bTuple\s*\[': 'Use tuple[T, ...] instead of Tuple[T, ...]',
|
||||
r"\bfrom typing import.*\bUnion\b": (
|
||||
"Use | syntax instead of Union (e.g., str | int)"
|
||||
),
|
||||
r"\bfrom typing import.*\bOptional\b": (
|
||||
"Use | None syntax instead of Optional (e.g., str | None)"
|
||||
),
|
||||
r"\bfrom typing import.*\bList\b": "Use list[T] instead of List[T]",
|
||||
r"\bfrom typing import.*\bDict\b": "Use dict[K, V] instead of Dict[K, V]",
|
||||
r"\bfrom typing import.*\bSet\b": "Use set[T] instead of Set[T]",
|
||||
r"\bfrom typing import.*\bTuple\b": (
|
||||
"Use tuple[T, ...] instead of Tuple[T, ...]"
|
||||
),
|
||||
r"\bUnion\s*\[": "Use | syntax instead of Union (e.g., str | int)",
|
||||
r"\bOptional\s*\[": "Use | None syntax instead of Optional (e.g., str | None)",
|
||||
r"\bList\s*\[": "Use list[T] instead of List[T]",
|
||||
r"\bDict\s*\[": "Use dict[K, V] instead of Dict[K, V]",
|
||||
r"\bSet\s*\[": "Use set[T] instead of Set[T]",
|
||||
r"\bTuple\s*\[": "Use tuple[T, ...] instead of Tuple[T, ...]",
|
||||
}
|
||||
|
||||
lines = content.splitlines()
|
||||
@@ -1309,7 +1336,7 @@ def _detect_old_typing_patterns(content: str) -> list[str]:
|
||||
lines_with_pattern = []
|
||||
for i, line in enumerate(lines, 1):
|
||||
# Skip comments
|
||||
code_part = line.split('#')[0]
|
||||
code_part = line.split("#")[0]
|
||||
if re.search(pattern, code_part):
|
||||
lines_with_pattern.append(i)
|
||||
|
||||
@@ -1317,7 +1344,10 @@ def _detect_old_typing_patterns(content: str) -> list[str]:
|
||||
display_lines = ", ".join(str(num) for num in lines_with_pattern[:5])
|
||||
if len(lines_with_pattern) > 5:
|
||||
display_lines += ", …"
|
||||
found_issues.append(f"⚠️ Old typing pattern at line(s) {display_lines}: {message}")
|
||||
issue_text = (
|
||||
f"⚠️ Old typing pattern at line(s) {display_lines}: {message}"
|
||||
)
|
||||
found_issues.append(issue_text)
|
||||
|
||||
return found_issues
|
||||
|
||||
@@ -1326,22 +1356,6 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
|
||||
"""Detect files and functions/classes with suspicious adjective/adverb suffixes."""
|
||||
issues: list[str] = []
|
||||
|
||||
# Common adjective/adverb suffixes that indicate potential duplication
|
||||
SUSPICIOUS_SUFFIXES = {
|
||||
"enhanced", "improved", "better", "new", "updated", "modified", "refactored",
|
||||
"optimized", "fixed", "clean", "simple", "advanced", "basic", "complete",
|
||||
"final", "latest", "current", "temp", "temporary", "backup", "old", "legacy",
|
||||
"unified", "merged", "combined", "integrated", "consolidated", "extended",
|
||||
"enriched", "augmented", "upgraded", "revised", "polished", "streamlined",
|
||||
"simplified", "modernized", "normalized", "sanitized", "validated", "verified",
|
||||
"corrected", "patched", "stable", "experimental", "alpha", "beta", "draft",
|
||||
"preliminary", "prototype", "working", "test", "debug", "custom", "special",
|
||||
"generic", "specific", "general", "detailed", "minimal", "full", "partial",
|
||||
"quick", "fast", "slow", "smart", "intelligent", "auto", "manual", "secure",
|
||||
"safe", "robust", "flexible", "dynamic", "static", "reactive", "async",
|
||||
"sync", "parallel", "serial", "distributed", "centralized", "decentralized"
|
||||
}
|
||||
|
||||
# Check file name against other files in the same directory
|
||||
file_path_obj = Path(file_path)
|
||||
if file_path_obj.parent.exists():
|
||||
@@ -1350,17 +1364,23 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
|
||||
|
||||
# Check if current file has suspicious suffix
|
||||
for suffix in SUSPICIOUS_SUFFIXES:
|
||||
if file_stem.endswith(f"_{suffix}") or file_stem.endswith(f"-{suffix}"):
|
||||
base_name = file_stem[:-len(suffix)-1]
|
||||
potential_original = file_path_obj.parent / f"{base_name}{file_suffix}"
|
||||
for separator in ("_", "-"):
|
||||
suffix_token = f"{separator}{suffix}"
|
||||
if not file_stem.endswith(suffix_token):
|
||||
continue
|
||||
|
||||
base_name = file_stem[: -len(suffix_token)]
|
||||
potential_original = file_path_obj.parent / f"{base_name}{file_suffix}"
|
||||
if potential_original.exists() and potential_original != file_path_obj:
|
||||
issues.append(
|
||||
f"⚠️ File '{file_path_obj.name}' appears to be a suffixed duplicate of "
|
||||
f"'{potential_original.name}'. Consider refactoring instead of creating "
|
||||
f"variations with adjective suffixes."
|
||||
message = FILE_SUFFIX_DUPLICATE_MSG.format(
|
||||
current=file_path_obj.name,
|
||||
original=potential_original.name,
|
||||
)
|
||||
issues.append(message)
|
||||
break
|
||||
else:
|
||||
continue
|
||||
break
|
||||
|
||||
# Check if any existing files are suffixed versions of current file
|
||||
for existing_file in file_path_obj.parent.glob(f"{file_stem}_*{file_suffix}"):
|
||||
@@ -1369,11 +1389,11 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
|
||||
if existing_stem.startswith(f"{file_stem}_"):
|
||||
potential_suffix = existing_stem[len(file_stem)+1:]
|
||||
if potential_suffix in SUSPICIOUS_SUFFIXES:
|
||||
issues.append(
|
||||
f"⚠️ Creating '{file_path_obj.name}' when '{existing_file.name}' "
|
||||
f"already exists suggests duplication. Consider consolidating or "
|
||||
f"using a more descriptive name."
|
||||
message = EXISTING_FILE_DUPLICATE_MSG.format(
|
||||
current=file_path_obj.name,
|
||||
existing=existing_file.name,
|
||||
)
|
||||
issues.append(message)
|
||||
break
|
||||
|
||||
# Same check for dash-separated suffixes
|
||||
@@ -1383,16 +1403,17 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
|
||||
if existing_stem.startswith(f"{file_stem}-"):
|
||||
potential_suffix = existing_stem[len(file_stem)+1:]
|
||||
if potential_suffix in SUSPICIOUS_SUFFIXES:
|
||||
issues.append(
|
||||
f"⚠️ Creating '{file_path_obj.name}' when '{existing_file.name}' "
|
||||
f"already exists suggests duplication. Consider consolidating or "
|
||||
f"using a more descriptive name."
|
||||
message = EXISTING_FILE_DUPLICATE_MSG.format(
|
||||
current=file_path_obj.name,
|
||||
existing=existing_file.name,
|
||||
)
|
||||
issues.append(message)
|
||||
break
|
||||
|
||||
# Check function and class names in content
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
# Dedent the content to handle code fragments with leading indentation
|
||||
tree = ast.parse(textwrap.dedent(content))
|
||||
|
||||
class SuffixVisitor(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
@@ -1417,36 +1438,47 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
|
||||
# Check for suspicious function name patterns
|
||||
for func_name in visitor.function_names:
|
||||
for suffix in SUSPICIOUS_SUFFIXES:
|
||||
if func_name.endswith(f"_{suffix}"):
|
||||
base_name = func_name[:-len(suffix)-1]
|
||||
if base_name in visitor.function_names:
|
||||
issues.append(
|
||||
f"⚠️ Function '{func_name}' appears to be a suffixed duplicate of "
|
||||
f"'{base_name}'. Consider refactoring instead of creating variations."
|
||||
)
|
||||
break
|
||||
suffix_token = f"_{suffix}"
|
||||
if not func_name.endswith(suffix_token):
|
||||
continue
|
||||
|
||||
base_name = func_name[: -len(suffix_token)]
|
||||
if base_name in visitor.function_names:
|
||||
message = NAME_SUFFIX_DUPLICATE_MSG.format(
|
||||
kind="Function",
|
||||
name=func_name,
|
||||
base=base_name,
|
||||
)
|
||||
issues.append(message)
|
||||
break
|
||||
|
||||
# Check for suspicious class name patterns
|
||||
for class_name in visitor.class_names:
|
||||
for suffix in SUSPICIOUS_SUFFIXES:
|
||||
# Convert to check both PascalCase and snake_case patterns
|
||||
pascal_suffix = suffix.capitalize()
|
||||
if class_name.endswith(pascal_suffix):
|
||||
base_name = class_name[:-len(pascal_suffix)]
|
||||
if base_name in visitor.class_names:
|
||||
issues.append(
|
||||
f"⚠️ Class '{class_name}' appears to be a suffixed duplicate of "
|
||||
f"'{base_name}'. Consider refactoring instead of creating variations."
|
||||
)
|
||||
break
|
||||
elif class_name.endswith(f"_{suffix}"):
|
||||
base_name = class_name[:-len(suffix)-1]
|
||||
if base_name in visitor.class_names:
|
||||
issues.append(
|
||||
f"⚠️ Class '{class_name}' appears to be a suffixed duplicate of "
|
||||
f"'{base_name}'. Consider refactoring instead of creating variations."
|
||||
snake_suffix = f"_{suffix}"
|
||||
|
||||
potential_matches = (
|
||||
(pascal_suffix, class_name[: -len(pascal_suffix)]),
|
||||
(snake_suffix, class_name[: -len(snake_suffix)]),
|
||||
)
|
||||
|
||||
for token, base_name in potential_matches:
|
||||
if (
|
||||
token
|
||||
and class_name.endswith(token)
|
||||
and base_name in visitor.class_names
|
||||
):
|
||||
message = NAME_SUFFIX_DUPLICATE_MSG.format(
|
||||
kind="Class",
|
||||
name=class_name,
|
||||
base=base_name,
|
||||
)
|
||||
issues.append(message)
|
||||
break
|
||||
else:
|
||||
continue
|
||||
break
|
||||
|
||||
except SyntaxError:
|
||||
# If we can't parse the AST, skip function/class checks
|
||||
@@ -1612,19 +1644,24 @@ def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject:
|
||||
|
||||
enable_type_checks = tool_name == "Write"
|
||||
|
||||
# Always run core quality checks (Any, type: ignore, old typing, duplicates) regardless of skip patterns
|
||||
# Always run core checks (Any, type: ignore, typing, duplicates) before skipping
|
||||
any_usage_issues = _detect_any_usage(content)
|
||||
type_ignore_issues = _detect_type_ignore_usage(content)
|
||||
old_typing_issues = _detect_old_typing_patterns(content)
|
||||
suffix_duplication_issues = _detect_suffix_duplication(file_path, content)
|
||||
precheck_issues = any_usage_issues + type_ignore_issues + old_typing_issues + suffix_duplication_issues
|
||||
precheck_issues = (
|
||||
any_usage_issues
|
||||
+ type_ignore_issues
|
||||
+ old_typing_issues
|
||||
+ suffix_duplication_issues
|
||||
)
|
||||
|
||||
# Run test quality checks if enabled and file is a test file
|
||||
if run_test_checks:
|
||||
test_quality_issues = run_test_quality_checks(content, file_path, config)
|
||||
precheck_issues.extend(test_quality_issues)
|
||||
|
||||
# Skip detailed analysis for configured patterns, but not if it's a test file with test checks enabled
|
||||
# Skip detailed analysis for configured patterns unless test checks should run
|
||||
# Note: Core quality checks (Any, type: ignore, duplicates) always run above
|
||||
should_skip_detailed = should_skip_file(file_path, config) and not run_test_checks
|
||||
|
||||
@@ -1766,8 +1803,12 @@ def is_test_file(file_path: str) -> bool:
|
||||
return any(part in ("test", "tests", "testing") for part in path_parts)
|
||||
|
||||
|
||||
def run_test_quality_checks(content: str, file_path: str, config: QualityConfig) -> list[str]:
|
||||
"""Run Sourcery with specific test-related rules and return issues with enhanced guidance."""
|
||||
def run_test_quality_checks(
|
||||
content: str,
|
||||
file_path: str,
|
||||
config: QualityConfig,
|
||||
) -> list[str]:
|
||||
"""Run Sourcery's test rules and return guidance-enhanced issues."""
|
||||
issues: list[str] = []
|
||||
|
||||
# Only run test quality checks for test files
|
||||
@@ -1846,20 +1887,40 @@ def run_test_quality_checks(content: str, file_path: str, config: QualityConfig)
|
||||
# Issues were found - parse the output
|
||||
output = result.stdout + result.stderr
|
||||
|
||||
# Try to extract rule names from the output
|
||||
# Sourcery output format typically includes rule names in brackets or after specific markers
|
||||
# Try to extract rule names from the output. Sourcery usually includes
|
||||
# rule identifiers in brackets or descriptive text.
|
||||
for rule in test_rules:
|
||||
if rule in output or rule.replace("-", " ") in output.lower():
|
||||
base_guidance = generate_test_quality_guidance(rule, content, file_path, config)
|
||||
external_context = get_external_context(rule, content, file_path, config)
|
||||
base_guidance = generate_test_quality_guidance(
|
||||
rule,
|
||||
content,
|
||||
file_path,
|
||||
config,
|
||||
)
|
||||
external_context = get_external_context(
|
||||
rule,
|
||||
content,
|
||||
file_path,
|
||||
config,
|
||||
)
|
||||
if external_context:
|
||||
base_guidance += f"\n\n{external_context}"
|
||||
issues.append(base_guidance)
|
||||
break # Only add one guidance message
|
||||
else:
|
||||
# If no specific rule found, provide general guidance
|
||||
base_guidance = generate_test_quality_guidance("unknown", content, file_path, config)
|
||||
external_context = get_external_context("unknown", content, file_path, config)
|
||||
base_guidance = generate_test_quality_guidance(
|
||||
"unknown",
|
||||
content,
|
||||
file_path,
|
||||
config,
|
||||
)
|
||||
external_context = get_external_context(
|
||||
"unknown",
|
||||
content,
|
||||
file_path,
|
||||
config,
|
||||
)
|
||||
if external_context:
|
||||
base_guidance += f"\n\n{external_context}"
|
||||
issues.append(base_guidance)
|
||||
|
||||
@@ -7,9 +7,10 @@ import ast
|
||||
import difflib
|
||||
import hashlib
|
||||
import re
|
||||
import textwrap
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict
|
||||
|
||||
COMMON_DUPLICATE_METHODS = {
|
||||
"__init__",
|
||||
@@ -19,6 +20,57 @@ COMMON_DUPLICATE_METHODS = {
|
||||
"__aexit__",
|
||||
}
|
||||
|
||||
# Test-specific patterns that commonly have legitimate duplication
|
||||
TEST_FIXTURE_PATTERNS = {
|
||||
"fixture",
|
||||
"mock",
|
||||
"stub",
|
||||
"setup",
|
||||
"teardown",
|
||||
"data",
|
||||
"sample",
|
||||
}
|
||||
|
||||
# Common test assertion patterns
|
||||
TEST_ASSERTION_PATTERNS = {
|
||||
"assert",
|
||||
"expect",
|
||||
"should",
|
||||
}
|
||||
|
||||
|
||||
class DuplicateLocation(TypedDict):
|
||||
"""Location information for a duplicate code block."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
lines: str
|
||||
|
||||
|
||||
class Duplicate(TypedDict):
|
||||
"""Duplicate detection result entry."""
|
||||
|
||||
type: str
|
||||
similarity: float
|
||||
description: str
|
||||
locations: list[DuplicateLocation]
|
||||
|
||||
|
||||
class DuplicateSummary(TypedDict, total=False):
|
||||
"""Summary data accompanying duplicate detection."""
|
||||
|
||||
total_duplicates: int
|
||||
blocks_analyzed: int
|
||||
duplicate_lines: int
|
||||
|
||||
|
||||
class DuplicateResults(TypedDict, total=False):
|
||||
"""Structured results returned by duplicate detection."""
|
||||
|
||||
duplicates: list[Duplicate]
|
||||
summary: DuplicateSummary
|
||||
error: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeBlock:
|
||||
@@ -31,11 +83,12 @@ class CodeBlock:
|
||||
source: str
|
||||
ast_node: ast.AST
|
||||
complexity: int = 0
|
||||
tokens: list[str] = None
|
||||
tokens: list[str] = field(init=False)
|
||||
decorators: list[str] = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokens is None:
|
||||
self.tokens = self._tokenize()
|
||||
def __post_init__(self) -> None:
|
||||
self.tokens = self._tokenize()
|
||||
self.decorators = self._extract_decorators()
|
||||
|
||||
def _tokenize(self) -> list[str]:
|
||||
"""Extract meaningful tokens from source code."""
|
||||
@@ -47,6 +100,40 @@ class CodeBlock:
|
||||
# Extract identifiers, keywords, operators
|
||||
return re.findall(r"\b\w+\b|[=<>!+\-*/]+", code)
|
||||
|
||||
def _extract_decorators(self) -> list[str]:
|
||||
"""Extract decorator names from the AST node."""
|
||||
decorators: list[str] = []
|
||||
if isinstance(
|
||||
self.ast_node,
|
||||
(ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef),
|
||||
):
|
||||
for decorator in self.ast_node.decorator_list:
|
||||
if isinstance(decorator, ast.Name):
|
||||
decorators.append(decorator.id)
|
||||
elif isinstance(decorator, ast.Attribute):
|
||||
decorators.append(decorator.attr)
|
||||
elif isinstance(decorator, ast.Call):
|
||||
if isinstance(decorator.func, ast.Name):
|
||||
decorators.append(decorator.func.id)
|
||||
elif isinstance(decorator.func, ast.Attribute):
|
||||
decorators.append(decorator.func.attr)
|
||||
return decorators
|
||||
|
||||
def is_test_fixture(self) -> bool:
|
||||
"""Check if this block is a pytest fixture."""
|
||||
return "fixture" in self.decorators
|
||||
|
||||
def is_test_function(self) -> bool:
|
||||
"""Check if this block is a test function."""
|
||||
return self.name.startswith("test_") or (
|
||||
self.type == "method" and self.name.startswith("test_")
|
||||
)
|
||||
|
||||
def has_test_pattern_name(self) -> bool:
|
||||
"""Check if name contains common test fixture patterns."""
|
||||
name_lower = self.name.lower()
|
||||
return any(pattern in name_lower for pattern in TEST_FIXTURE_PATTERNS)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DuplicateGroup:
|
||||
@@ -67,15 +154,16 @@ class InternalDuplicateDetector:
|
||||
min_lines: int = 4,
|
||||
min_tokens: int = 20,
|
||||
):
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.min_lines = min_lines
|
||||
self.min_tokens = min_tokens
|
||||
self.similarity_threshold: float = similarity_threshold
|
||||
self.min_lines: int = min_lines
|
||||
self.min_tokens: int = min_tokens
|
||||
self.duplicate_groups: list[DuplicateGroup] = []
|
||||
|
||||
def analyze_code(self, source_code: str) -> dict[str, Any]:
|
||||
def analyze_code(self, source_code: str) -> DuplicateResults:
|
||||
"""Analyze source code for internal duplicates."""
|
||||
try:
|
||||
tree = ast.parse(source_code)
|
||||
# Dedent the content to handle code fragments with leading indentation
|
||||
tree = ast.parse(textwrap.dedent(source_code))
|
||||
except SyntaxError:
|
||||
return {
|
||||
"error": "Failed to parse code",
|
||||
@@ -104,7 +192,7 @@ class InternalDuplicateDetector:
|
||||
}
|
||||
|
||||
# Find duplicates
|
||||
duplicate_groups = []
|
||||
duplicate_groups: list[DuplicateGroup] = []
|
||||
|
||||
# 1. Check for exact duplicates (normalized)
|
||||
exact_groups = self._find_exact_duplicates(blocks)
|
||||
@@ -125,7 +213,7 @@ class InternalDuplicateDetector:
|
||||
and not self._should_ignore_group(group)
|
||||
]
|
||||
|
||||
results = [
|
||||
results: list[Duplicate] = [
|
||||
{
|
||||
"type": group.pattern_type,
|
||||
"similarity": group.similarity_score,
|
||||
@@ -155,26 +243,25 @@ class InternalDuplicateDetector:
|
||||
|
||||
def _extract_code_blocks(self, tree: ast.AST, source: str) -> list[CodeBlock]:
|
||||
"""Extract functions, methods, and classes from AST."""
|
||||
blocks = []
|
||||
blocks: list[CodeBlock] = []
|
||||
lines = source.split("\n")
|
||||
|
||||
def create_block(
|
||||
node: ast.AST,
|
||||
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
|
||||
block_type: str,
|
||||
lines: list[str],
|
||||
) -> CodeBlock | None:
|
||||
try:
|
||||
start = node.lineno - 1
|
||||
end = node.end_lineno - 1 if hasattr(node, "end_lineno") else start
|
||||
end_lineno = getattr(node, "end_lineno", None)
|
||||
end = end_lineno - 1 if end_lineno is not None else start
|
||||
source = "\n".join(lines[start : end + 1])
|
||||
|
||||
return CodeBlock(
|
||||
name=node.name,
|
||||
type=block_type,
|
||||
start_line=node.lineno,
|
||||
end_line=node.end_lineno
|
||||
if hasattr(node, "end_lineno")
|
||||
else node.lineno,
|
||||
end_line=end_lineno if end_lineno is not None else node.lineno,
|
||||
source=source,
|
||||
ast_node=node,
|
||||
complexity=calculate_complexity(node),
|
||||
@@ -459,6 +546,7 @@ class InternalDuplicateDetector:
|
||||
if not group.blocks:
|
||||
return False
|
||||
|
||||
# Check for common dunder methods
|
||||
if all(block.name in COMMON_DUPLICATE_METHODS for block in group.blocks):
|
||||
max_lines = max(
|
||||
block.end_line - block.start_line + 1 for block in group.blocks
|
||||
@@ -469,6 +557,36 @@ class InternalDuplicateDetector:
|
||||
if max_lines <= 12 and max_complexity <= 3:
|
||||
return True
|
||||
|
||||
# Check for pytest fixtures - they legitimately have repetitive structure
|
||||
if all(block.is_test_fixture() for block in group.blocks):
|
||||
max_lines = max(
|
||||
block.end_line - block.start_line + 1 for block in group.blocks
|
||||
)
|
||||
# Allow fixtures up to 15 lines with similar structure
|
||||
if max_lines <= 15:
|
||||
return True
|
||||
|
||||
# Check for test functions with fixture-like names (data builders, mocks, etc.)
|
||||
if all(block.has_test_pattern_name() for block in group.blocks):
|
||||
max_lines = max(
|
||||
block.end_line - block.start_line + 1 for block in group.blocks
|
||||
)
|
||||
max_complexity = max(block.complexity for block in group.blocks)
|
||||
# Allow test helpers that are simple and short
|
||||
if max_lines <= 10 and max_complexity <= 4:
|
||||
return True
|
||||
|
||||
# Check for simple test functions with arrange-act-assert pattern
|
||||
if all(block.is_test_function() for block in group.blocks):
|
||||
max_complexity = max(block.complexity for block in group.blocks)
|
||||
max_lines = max(
|
||||
block.end_line - block.start_line + 1 for block in group.blocks
|
||||
)
|
||||
# Simple tests (<=15 lines) often share similar control flow.
|
||||
# Permit full similarity for those cases; duplication is acceptable.
|
||||
if max_complexity <= 5 and max_lines <= 15:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -476,7 +594,7 @@ def detect_internal_duplicates(
|
||||
source_code: str,
|
||||
threshold: float = 0.7,
|
||||
min_lines: int = 4,
|
||||
) -> dict[str, Any]:
|
||||
) -> DuplicateResults:
|
||||
"""Main function to detect internal duplicates in code."""
|
||||
detector = InternalDuplicateDetector(
|
||||
similarity_threshold=threshold,
|
||||
|
||||
15123
logs/status_line.json
15123
logs/status_line.json
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,10 @@
|
||||
"""Fixture module used to verify Any detection in the guard."""
|
||||
|
||||
# ruff: noqa: ANN401 # These annotations intentionally use Any for the test harness.
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def process_data(data: Any) -> Any:
|
||||
"""This should be blocked by the hook."""
|
||||
return data
|
||||
"""Return the provided value; the guard should block this in practice."""
|
||||
return data
|
||||
|
||||
@@ -1,169 +1,131 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Core hook validation tests - MUST ALL PASS
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
"""Core coverage tests for the code quality guard hooks."""
|
||||
|
||||
# Add hooks to path
|
||||
sys.path.insert(0, str(Path(__file__).parent / "hooks"))
|
||||
from __future__ import annotations
|
||||
|
||||
from code_quality_guard import pretooluse_hook, QualityConfig
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from hooks.code_quality_guard import QualityConfig, pretooluse_hook
|
||||
|
||||
|
||||
def test_core_blocking():
|
||||
"""Test the core blocking functionality that MUST work."""
|
||||
@pytest.fixture
|
||||
def strict_config() -> QualityConfig:
|
||||
"""Return a strict enforcement configuration for the guard."""
|
||||
config = QualityConfig.from_env()
|
||||
config.enforcement_mode = "strict"
|
||||
return config
|
||||
|
||||
tests_passed = 0
|
||||
tests_failed = 0
|
||||
|
||||
# Test 1: Any usage MUST be blocked
|
||||
print("🧪 Test 1: Any usage blocking")
|
||||
@dataclass(slots=True)
|
||||
class BlockingScenario:
|
||||
"""Parameters describing an expected blocking outcome."""
|
||||
|
||||
name: str
|
||||
tool_name: str
|
||||
tool_input: dict[str, Any]
|
||||
reason_fragment: str
|
||||
|
||||
|
||||
BLOCKING_SCENARIOS: tuple[BlockingScenario, ...] = (
|
||||
BlockingScenario(
|
||||
name="typing-any",
|
||||
tool_name="Write",
|
||||
tool_input={
|
||||
"file_path": "/src/production.py",
|
||||
"content": (
|
||||
"from typing import Any\n"
|
||||
"def bad(value: Any) -> Any:\n"
|
||||
" return value\n"
|
||||
),
|
||||
},
|
||||
reason_fragment="typing.Any usage",
|
||||
),
|
||||
BlockingScenario(
|
||||
name="type-ignore",
|
||||
tool_name="Write",
|
||||
tool_input={
|
||||
"file_path": "/src/production.py",
|
||||
"content": (
|
||||
"def bad() -> None:\n"
|
||||
" value = call() # type: ignore\n"
|
||||
" return value\n"
|
||||
),
|
||||
},
|
||||
reason_fragment="type: ignore",
|
||||
),
|
||||
BlockingScenario(
|
||||
name="legacy-typing",
|
||||
tool_name="Write",
|
||||
tool_input={
|
||||
"file_path": "/src/production.py",
|
||||
"content": (
|
||||
"from typing import Optional, Union\n"
|
||||
"def bad(value: Union[str, int]) -> Optional[str]:\n"
|
||||
" return None\n"
|
||||
),
|
||||
},
|
||||
reason_fragment="Old typing pattern",
|
||||
),
|
||||
BlockingScenario(
|
||||
name="edit-tool-any",
|
||||
tool_name="Edit",
|
||||
tool_input={
|
||||
"file_path": "/src/production.py",
|
||||
"old_string": "def old():\n return 1\n",
|
||||
"new_string": "def new(value: Any) -> Any:\n return value\n",
|
||||
},
|
||||
reason_fragment="typing.Any usage",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
BLOCKING_SCENARIOS,
|
||||
ids=lambda scenario: scenario.name,
|
||||
)
|
||||
def test_pretooluse_blocks_expected_patterns(
|
||||
strict_config: QualityConfig,
|
||||
scenario: BlockingScenario,
|
||||
) -> None:
|
||||
"""Verify the guard blocks known bad patterns."""
|
||||
hook_data = {"tool_name": scenario.tool_name, "tool_input": scenario.tool_input}
|
||||
result = pretooluse_hook(hook_data, strict_config)
|
||||
|
||||
assert result["permissionDecision"] == "deny"
|
||||
assert scenario.reason_fragment in result.get("reason", "")
|
||||
|
||||
|
||||
def test_pretooluse_allows_modern_code(strict_config: QualityConfig) -> None:
|
||||
"""PreToolUse hook allows well-typed Python content."""
|
||||
hook_data = {
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production.py",
|
||||
"content": "from typing import Any\ndef bad(x: Any) -> Any: return x"
|
||||
}
|
||||
"content": (
|
||||
"def good(value: str | int) -> str | None:\n"
|
||||
" return str(value) if value else None\n"
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
result = pretooluse_hook(hook_data, config)
|
||||
if result["permissionDecision"] == "deny" and "typing.Any usage" in result.get("reason", ""):
|
||||
print("✅ PASS: Any usage properly blocked")
|
||||
tests_passed += 1
|
||||
else:
|
||||
print(f"❌ FAIL: Any usage not blocked. Decision: {result['permissionDecision']}")
|
||||
tests_failed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ FAIL: Exception in Any test: {e}")
|
||||
tests_failed += 1
|
||||
result = pretooluse_hook(hook_data, strict_config)
|
||||
|
||||
# Test 2: type: ignore MUST be blocked
|
||||
print("\n🧪 Test 2: type: ignore blocking")
|
||||
hook_data = {
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production.py",
|
||||
"content": "def bad():\n x = call() # type: ignore\n return x"
|
||||
}
|
||||
}
|
||||
assert result["permissionDecision"] == "allow"
|
||||
|
||||
try:
|
||||
result = pretooluse_hook(hook_data, config)
|
||||
if result["permissionDecision"] == "deny" and "type: ignore" in result.get("reason", ""):
|
||||
print("✅ PASS: type: ignore properly blocked")
|
||||
tests_passed += 1
|
||||
else:
|
||||
print(f"❌ FAIL: type: ignore not blocked. Decision: {result['permissionDecision']}")
|
||||
tests_failed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ FAIL: Exception in type: ignore test: {e}")
|
||||
tests_failed += 1
|
||||
|
||||
# Test 3: Old typing patterns MUST be blocked
|
||||
print("\n🧪 Test 3: Old typing patterns blocking")
|
||||
hook_data = {
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production.py",
|
||||
"content": "from typing import Union, Optional\ndef bad(x: Union[str, int]) -> Optional[str]: return None"
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
result = pretooluse_hook(hook_data, config)
|
||||
if result["permissionDecision"] == "deny" and "Old typing pattern" in result.get("reason", ""):
|
||||
print("✅ PASS: Old typing patterns properly blocked")
|
||||
tests_passed += 1
|
||||
else:
|
||||
print(f"❌ FAIL: Old typing patterns not blocked. Decision: {result['permissionDecision']}")
|
||||
tests_failed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ FAIL: Exception in old typing test: {e}")
|
||||
tests_failed += 1
|
||||
|
||||
# Test 4: Good code MUST be allowed
|
||||
print("\n🧪 Test 4: Good code allowed")
|
||||
hook_data = {
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production.py",
|
||||
"content": "def good(x: str | int) -> str | None:\n return str(x) if x else None"
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
result = pretooluse_hook(hook_data, config)
|
||||
if result["permissionDecision"] == "allow":
|
||||
print("✅ PASS: Good code properly allowed")
|
||||
tests_passed += 1
|
||||
else:
|
||||
print(f"❌ FAIL: Good code blocked. Decision: {result['permissionDecision']}")
|
||||
print(f" Reason: {result.get('reason', 'No reason')}")
|
||||
tests_failed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ FAIL: Exception in good code test: {e}")
|
||||
tests_failed += 1
|
||||
|
||||
# Test 5: Edit tool MUST also block
|
||||
print("\n🧪 Test 5: Edit tool blocking")
|
||||
hook_data = {
|
||||
"tool_name": "Edit",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production.py",
|
||||
"old_string": "def old():",
|
||||
"new_string": "def new(x: Any) -> Any:"
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
result = pretooluse_hook(hook_data, config)
|
||||
if result["permissionDecision"] == "deny" and "typing.Any usage" in result.get("reason", ""):
|
||||
print("✅ PASS: Edit tool properly blocked")
|
||||
tests_passed += 1
|
||||
else:
|
||||
print(f"❌ FAIL: Edit tool not blocked. Decision: {result['permissionDecision']}")
|
||||
tests_failed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ FAIL: Exception in Edit tool test: {e}")
|
||||
tests_failed += 1
|
||||
|
||||
# Test 6: Non-Python files MUST be allowed
|
||||
print("\n🧪 Test 6: Non-Python files allowed")
|
||||
def test_pretooluse_allows_non_python_files(strict_config: QualityConfig) -> None:
|
||||
"""Non-Python files should bypass quality restrictions."""
|
||||
hook_data = {
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/src/config.json",
|
||||
"content": '{"type": "Any", "ignore": true}'
|
||||
}
|
||||
"content": '{"type": "Any", "ignore": true}',
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
result = pretooluse_hook(hook_data, config)
|
||||
if result["permissionDecision"] == "allow":
|
||||
print("✅ PASS: Non-Python files properly allowed")
|
||||
tests_passed += 1
|
||||
else:
|
||||
print(f"❌ FAIL: Non-Python file blocked. Decision: {result['permissionDecision']}")
|
||||
tests_failed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ FAIL: Exception in non-Python test: {e}")
|
||||
tests_failed += 1
|
||||
result = pretooluse_hook(hook_data, strict_config)
|
||||
|
||||
print(f"\n📊 Results: {tests_passed} passed, {tests_failed} failed")
|
||||
|
||||
if tests_failed == 0:
|
||||
print("🎉 ALL CORE TESTS PASSED! Hooks are working correctly.")
|
||||
return True
|
||||
else:
|
||||
print(f"💥 {tests_failed} CRITICAL TESTS FAILED! Hooks are broken.")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 Running core hook validation tests...\n")
|
||||
success = test_core_blocking()
|
||||
sys.exit(0 if success else 1)
|
||||
assert result["permissionDecision"] == "allow"
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
def bad_function():
|
||||
"""Function with type ignore."""
|
||||
x = "string"
|
||||
return x + 5 # type: ignore
|
||||
"""Fixture module used to check handling of type: ignore annotations."""
|
||||
|
||||
def another_bad():
|
||||
"""Another function with type ignore."""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def bad_function() -> int:
|
||||
"""Return a value while suppressing a typing error."""
|
||||
x = "string"
|
||||
return x + 5 # type: ignore[arg-type]
|
||||
|
||||
|
||||
def another_bad() -> int:
|
||||
"""Return a value after an ignored assignment mismatch."""
|
||||
y: int = "not an int" # type: ignore[assignment]
|
||||
return y
|
||||
return y
|
||||
|
||||
2
tests/__init__.py
Normal file
2
tests/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Test package marker for Ruff namespace rules."""
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
"""Comprehensive test suite covering all hook interaction scenarios."""
|
||||
|
||||
# ruff: noqa: SLF001
|
||||
# pyright: reportPrivateUsage=false, reportPrivateImportUsage=false, reportPrivateLocalImportUsage=false, reportUnusedCallResult=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownLambdaType=false, reportUnknownMemberType=false
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from tempfile import gettempdir
|
||||
|
||||
import pytest
|
||||
|
||||
HOOKS_DIR = Path(__file__).parent.parent.parent / "hooks"
|
||||
sys.path.insert(0, str(HOOKS_DIR))
|
||||
|
||||
import code_quality_guard as guard
|
||||
from hooks import code_quality_guard as guard
|
||||
|
||||
|
||||
class TestProjectStructureVariations:
|
||||
@@ -26,14 +27,14 @@ class TestProjectStructureVariations:
|
||||
root.mkdir()
|
||||
(root / ".venv/bin").mkdir(parents=True)
|
||||
(root / "pyproject.toml").touch()
|
||||
|
||||
|
||||
test_file = root / "main.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
# Should find project root
|
||||
found_root = guard._find_project_root(str(test_file))
|
||||
assert found_root == root
|
||||
|
||||
|
||||
# Should create .tmp in root
|
||||
tmp_dir = guard._get_project_tmp_dir(str(test_file))
|
||||
assert tmp_dir == root / ".tmp"
|
||||
@@ -49,13 +50,13 @@ class TestProjectStructureVariations:
|
||||
(root / "src/package").mkdir(parents=True)
|
||||
(root / ".venv/bin").mkdir(parents=True)
|
||||
(root / "pyproject.toml").touch()
|
||||
|
||||
|
||||
test_file = root / "src/package/module.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
found_root = guard._find_project_root(str(test_file))
|
||||
assert found_root == root
|
||||
|
||||
|
||||
venv_bin = guard._get_project_venv_bin(str(test_file))
|
||||
assert venv_bin == root / ".venv/bin"
|
||||
finally:
|
||||
@@ -70,19 +71,19 @@ class TestProjectStructureVariations:
|
||||
# Outer project
|
||||
(outer / ".venv/bin").mkdir(parents=True)
|
||||
(outer / ".git").mkdir()
|
||||
|
||||
|
||||
# Inner project
|
||||
inner = outer / "subproject"
|
||||
(inner / ".venv/bin").mkdir(parents=True)
|
||||
(inner / "pyproject.toml").touch()
|
||||
|
||||
|
||||
test_file = inner / "main.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
# Should find inner project root
|
||||
found_root = guard._find_project_root(str(test_file))
|
||||
assert found_root == inner
|
||||
|
||||
|
||||
# Should use inner venv
|
||||
venv_bin = guard._get_project_venv_bin(str(test_file))
|
||||
assert venv_bin == inner / ".venv/bin"
|
||||
@@ -116,10 +117,10 @@ class TestProjectStructureVariations:
|
||||
deep = root / "a/b/c/d/e/f"
|
||||
deep.mkdir(parents=True)
|
||||
(root / ".git").mkdir()
|
||||
|
||||
|
||||
test_file = deep / "module.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
found_root = guard._find_project_root(str(test_file))
|
||||
assert found_root == root
|
||||
finally:
|
||||
@@ -137,13 +138,13 @@ class TestConfigurationInheritance:
|
||||
try:
|
||||
(root / "src").mkdir(parents=True)
|
||||
(root / ".venv/bin").mkdir(parents=True)
|
||||
|
||||
|
||||
config = {"reportUnknownMemberType": False}
|
||||
(root / "pyrightconfig.json").write_text(json.dumps(config))
|
||||
|
||||
|
||||
test_file = root / "src/mod.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
found_root = guard._find_project_root(str(test_file))
|
||||
assert found_root == root
|
||||
assert (found_root / "pyrightconfig.json").exists()
|
||||
@@ -158,10 +159,10 @@ class TestConfigurationInheritance:
|
||||
try:
|
||||
root.mkdir()
|
||||
(root / "pyproject.toml").write_text("[tool.mypy]\n")
|
||||
|
||||
|
||||
test_file = root / "main.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
found_root = guard._find_project_root(str(test_file))
|
||||
assert found_root == root
|
||||
finally:
|
||||
@@ -176,13 +177,13 @@ class TestConfigurationInheritance:
|
||||
root.mkdir()
|
||||
(root / "pyproject.toml").touch()
|
||||
(root / ".gitignore").write_text("*.pyc\n__pycache__/\n")
|
||||
|
||||
|
||||
test_file = root / "main.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
tmp_dir = guard._get_project_tmp_dir(str(test_file))
|
||||
assert tmp_dir.exists()
|
||||
|
||||
|
||||
gitignore_content = (root / ".gitignore").read_text()
|
||||
assert ".tmp/" in gitignore_content
|
||||
finally:
|
||||
@@ -198,12 +199,12 @@ class TestConfigurationInheritance:
|
||||
(root / "pyproject.toml").touch()
|
||||
original = "*.pyc\n.tmp/\n"
|
||||
(root / ".gitignore").write_text(original)
|
||||
|
||||
|
||||
test_file = root / "main.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
_ = guard._get_project_tmp_dir(str(test_file))
|
||||
|
||||
|
||||
# Should not have been modified
|
||||
assert (root / ".gitignore").read_text() == original
|
||||
finally:
|
||||
@@ -266,26 +267,30 @@ class TestVirtualEnvironmentEdgeCases:
|
||||
tool = root / ".venv/bin/basedpyright"
|
||||
tool.write_text("#!/bin/bash\necho fake")
|
||||
tool.chmod(0o755)
|
||||
|
||||
|
||||
test_file = root / "main.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
captured_env = {}
|
||||
|
||||
def capture_run(cmd: list[str], **kw: object) -> subprocess.CompletedProcess[str]:
|
||||
if "env" in kw:
|
||||
captured_env.update(dict(kw["env"]))
|
||||
return subprocess.CompletedProcess(cmd, 0, "", "")
|
||||
|
||||
|
||||
captured_env: dict[str, str] = {}
|
||||
|
||||
def capture_run(
|
||||
cmd: list[str],
|
||||
**kw: object,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
env_obj = kw.get("env")
|
||||
if isinstance(env_obj, Mapping):
|
||||
captured_env.update({str(k): str(v) for k, v in env_obj.items()})
|
||||
return subprocess.CompletedProcess(list(cmd), 0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(guard.subprocess, "run", capture_run)
|
||||
|
||||
|
||||
guard._run_type_checker(
|
||||
"basedpyright",
|
||||
str(test_file),
|
||||
guard.QualityConfig(),
|
||||
original_file_path=str(test_file),
|
||||
)
|
||||
|
||||
|
||||
# PYTHONPATH should not be set (or not include src)
|
||||
if "PYTHONPATH" in captured_env:
|
||||
assert "src" not in captured_env["PYTHONPATH"]
|
||||
@@ -305,7 +310,7 @@ class TestTypeCheckerIntegration:
|
||||
pyrefly_enabled=False,
|
||||
sourcery_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
issues = guard.run_type_checks("test.py", config)
|
||||
assert issues == []
|
||||
|
||||
@@ -316,13 +321,13 @@ class TestTypeCheckerIntegration:
|
||||
"""Missing tool returns warning, doesn't crash."""
|
||||
monkeypatch.setattr(guard.Path, "exists", lambda _: False, raising=False)
|
||||
monkeypatch.setattr(guard, "_ensure_tool_installed", lambda _: False)
|
||||
|
||||
|
||||
success, message = guard._run_type_checker(
|
||||
"basedpyright",
|
||||
"test.py",
|
||||
guard.QualityConfig(),
|
||||
)
|
||||
|
||||
|
||||
assert success is True
|
||||
assert "not available" in message
|
||||
|
||||
@@ -332,18 +337,18 @@ class TestTypeCheckerIntegration:
|
||||
) -> None:
|
||||
"""Tool timeout is handled gracefully."""
|
||||
monkeypatch.setattr(guard.Path, "exists", lambda _: True, raising=False)
|
||||
|
||||
|
||||
def timeout_run(*_args: object, **_kw: object) -> None:
|
||||
raise subprocess.TimeoutExpired(cmd=["tool"], timeout=30)
|
||||
|
||||
|
||||
monkeypatch.setattr(guard.subprocess, "run", timeout_run)
|
||||
|
||||
|
||||
success, message = guard._run_type_checker(
|
||||
"basedpyright",
|
||||
"test.py",
|
||||
guard.QualityConfig(),
|
||||
)
|
||||
|
||||
|
||||
assert success is True
|
||||
assert "timeout" in message.lower()
|
||||
|
||||
@@ -353,22 +358,26 @@ class TestTypeCheckerIntegration:
|
||||
) -> None:
|
||||
"""OS errors from tools are handled."""
|
||||
monkeypatch.setattr(guard.Path, "exists", lambda _: True, raising=False)
|
||||
|
||||
|
||||
def error_run(*_args: object, **_kw: object) -> None:
|
||||
raise OSError("Permission denied")
|
||||
|
||||
message = "Permission denied"
|
||||
raise OSError(message)
|
||||
|
||||
monkeypatch.setattr(guard.subprocess, "run", error_run)
|
||||
|
||||
|
||||
success, message = guard._run_type_checker(
|
||||
"basedpyright",
|
||||
"test.py",
|
||||
guard.QualityConfig(),
|
||||
)
|
||||
|
||||
|
||||
assert success is True
|
||||
assert "execution error" in message.lower()
|
||||
|
||||
def test_unknown_tool_returns_warning(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_unknown_tool_returns_warning(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Unknown tool name returns warning."""
|
||||
# Mock tool not existing
|
||||
monkeypatch.setattr(guard.Path, "exists", lambda _: False, raising=False)
|
||||
@@ -397,32 +406,36 @@ class TestWorkingDirectoryScenarios:
|
||||
(root / "src").mkdir(parents=True)
|
||||
(root / ".venv/bin").mkdir(parents=True)
|
||||
(root / "pyrightconfig.json").touch()
|
||||
|
||||
|
||||
tool = root / ".venv/bin/basedpyright"
|
||||
tool.write_text("#!/bin/bash\npwd")
|
||||
tool.chmod(0o755)
|
||||
|
||||
|
||||
test_file = root / "src/mod.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
captured_cwd = []
|
||||
|
||||
def capture_run(cmd: list[str], **kw: object) -> subprocess.CompletedProcess[str]:
|
||||
if "cwd" in kw:
|
||||
captured_cwd.append(str(kw["cwd"]))
|
||||
return subprocess.CompletedProcess(cmd, 0, "", "")
|
||||
|
||||
|
||||
captured_cwd: list[Path] = []
|
||||
|
||||
def capture_run(
|
||||
cmd: list[str],
|
||||
**kw: object,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
cwd_obj = kw.get("cwd")
|
||||
if cwd_obj is not None:
|
||||
captured_cwd.append(Path(str(cwd_obj)))
|
||||
return subprocess.CompletedProcess(list(cmd), 0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(guard.subprocess, "run", capture_run)
|
||||
|
||||
|
||||
guard._run_type_checker(
|
||||
"basedpyright",
|
||||
str(test_file),
|
||||
guard.QualityConfig(),
|
||||
original_file_path=str(test_file),
|
||||
)
|
||||
|
||||
assert len(captured_cwd) > 0
|
||||
assert Path(captured_cwd[0]) == root
|
||||
|
||||
assert captured_cwd
|
||||
assert captured_cwd[0] == root
|
||||
finally:
|
||||
import shutil
|
||||
if root.exists():
|
||||
@@ -444,10 +457,11 @@ class TestErrorConditions:
|
||||
) -> None:
|
||||
"""Permission error creating .tmp is handled."""
|
||||
def raise_permission(*_args: object, **_kw: object) -> None:
|
||||
raise PermissionError("Cannot create directory")
|
||||
|
||||
message = "Cannot create directory"
|
||||
raise PermissionError(message)
|
||||
|
||||
monkeypatch.setattr(Path, "mkdir", raise_permission)
|
||||
|
||||
|
||||
# Should raise and be caught by caller
|
||||
with pytest.raises(PermissionError):
|
||||
guard._get_project_tmp_dir("/some/file.py")
|
||||
@@ -460,7 +474,7 @@ class TestErrorConditions:
|
||||
(root / "pyproject.toml").touch()
|
||||
test_file = root / "empty.py"
|
||||
test_file.write_text("")
|
||||
|
||||
|
||||
# Should not crash
|
||||
tmp_dir = guard._get_project_tmp_dir(str(test_file))
|
||||
assert tmp_dir.exists()
|
||||
@@ -479,13 +493,13 @@ class TestFileLocationVariations:
|
||||
try:
|
||||
(root / "tests").mkdir(parents=True)
|
||||
(root / ".git").mkdir()
|
||||
|
||||
|
||||
test_file = root / "tests/test_module.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
found_root = guard._find_project_root(str(test_file))
|
||||
assert found_root == root
|
||||
|
||||
|
||||
# Test file detection
|
||||
assert guard.is_test_file(str(test_file))
|
||||
finally:
|
||||
@@ -499,10 +513,10 @@ class TestFileLocationVariations:
|
||||
try:
|
||||
root.mkdir()
|
||||
(root / ".git").mkdir()
|
||||
|
||||
|
||||
test_file = root / "main.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
found_root = guard._find_project_root(str(test_file))
|
||||
assert found_root == root
|
||||
finally:
|
||||
@@ -521,12 +535,12 @@ class TestTempFileManagement:
|
||||
(root / "src").mkdir(parents=True)
|
||||
(root / ".venv/bin").mkdir(parents=True)
|
||||
(root / "pyproject.toml").touch()
|
||||
|
||||
|
||||
test_file = root / "src/mod.py"
|
||||
test_file.write_text("def foo(): pass")
|
||||
|
||||
|
||||
tmp_dir = root / ".tmp"
|
||||
|
||||
|
||||
# Analyze code (should create and delete temp file)
|
||||
config = guard.QualityConfig(
|
||||
duplicate_enabled=False,
|
||||
@@ -536,14 +550,14 @@ class TestTempFileManagement:
|
||||
pyrefly_enabled=False,
|
||||
sourcery_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
guard.analyze_code_quality(
|
||||
"def foo(): pass",
|
||||
str(test_file),
|
||||
config,
|
||||
enable_type_checks=False,
|
||||
)
|
||||
|
||||
|
||||
# .tmp directory should exist but temp file should be gone
|
||||
if tmp_dir.exists():
|
||||
temp_files = list(tmp_dir.glob("hook_validation_*"))
|
||||
@@ -559,15 +573,15 @@ class TestTempFileManagement:
|
||||
try:
|
||||
(root / "src").mkdir(parents=True)
|
||||
(root / "pyproject.toml").touch()
|
||||
|
||||
|
||||
test_file = root / "src/mod.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
tmp_dir = guard._get_project_tmp_dir(str(test_file))
|
||||
|
||||
|
||||
# Should be in project, not /tmp
|
||||
assert str(tmp_dir).startswith(str(root))
|
||||
assert not str(tmp_dir).startswith("/tmp")
|
||||
assert not str(tmp_dir).startswith(gettempdir())
|
||||
finally:
|
||||
import shutil
|
||||
if root.exists():
|
||||
|
||||
@@ -3,15 +3,16 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from code_quality_guard import QualityConfig
|
||||
|
||||
from hooks import code_quality_guard as guard
|
||||
|
||||
|
||||
class TestQualityConfig:
|
||||
"""Test QualityConfig dataclass and environment loading."""
|
||||
"""Test guard.QualityConfig dataclass and environment loading."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
config = QualityConfig()
|
||||
config = guard.QualityConfig()
|
||||
|
||||
# Core settings
|
||||
assert config.duplicate_threshold == 0.7
|
||||
@@ -29,14 +30,16 @@ class TestQualityConfig:
|
||||
assert config.show_success is False
|
||||
|
||||
# Skip patterns
|
||||
assert "test_" in config.skip_patterns
|
||||
assert "_test.py" in config.skip_patterns
|
||||
assert "/tests/" in config.skip_patterns
|
||||
assert "/fixtures/" in config.skip_patterns
|
||||
assert config.skip_patterns is not None
|
||||
skip_patterns = config.skip_patterns
|
||||
assert "test_" in skip_patterns
|
||||
assert "_test.py" in skip_patterns
|
||||
assert "/tests/" in skip_patterns
|
||||
assert "/fixtures/" in skip_patterns
|
||||
|
||||
def test_from_env_with_defaults(self):
|
||||
"""Test loading config from environment with defaults."""
|
||||
config = QualityConfig.from_env()
|
||||
config = guard.QualityConfig.from_env()
|
||||
|
||||
# Should use defaults when env vars not set
|
||||
assert config.duplicate_threshold == 0.7
|
||||
@@ -61,7 +64,7 @@ class TestQualityConfig:
|
||||
},
|
||||
)
|
||||
|
||||
config = QualityConfig.from_env()
|
||||
config = guard.QualityConfig.from_env()
|
||||
|
||||
assert config.duplicate_threshold == 0.8
|
||||
assert config.duplicate_enabled is False
|
||||
@@ -78,7 +81,7 @@ class TestQualityConfig:
|
||||
def test_from_env_with_invalid_boolean(self):
|
||||
"""Test loading config with invalid boolean values."""
|
||||
os.environ["QUALITY_DUP_ENABLED"] = "invalid"
|
||||
config = QualityConfig.from_env()
|
||||
config = guard.QualityConfig.from_env()
|
||||
|
||||
# Should default to False for invalid boolean
|
||||
assert config.duplicate_enabled is False
|
||||
@@ -88,14 +91,14 @@ class TestQualityConfig:
|
||||
os.environ["QUALITY_DUP_THRESHOLD"] = "not_a_float"
|
||||
|
||||
with pytest.raises(ValueError, match="could not convert string to float"):
|
||||
QualityConfig.from_env()
|
||||
_ = guard.QualityConfig.from_env()
|
||||
|
||||
def test_from_env_with_invalid_int(self):
|
||||
"""Test loading config with invalid int values."""
|
||||
os.environ["QUALITY_COMPLEXITY_THRESHOLD"] = "not_an_int"
|
||||
|
||||
with pytest.raises(ValueError, match="invalid literal"):
|
||||
QualityConfig.from_env()
|
||||
_ = guard.QualityConfig.from_env()
|
||||
|
||||
def test_enforcement_modes(self):
|
||||
"""Test different enforcement modes."""
|
||||
@@ -103,87 +106,85 @@ class TestQualityConfig:
|
||||
|
||||
for mode in modes:
|
||||
os.environ["QUALITY_ENFORCEMENT"] = mode
|
||||
config = QualityConfig.from_env()
|
||||
config = guard.QualityConfig.from_env()
|
||||
assert config.enforcement_mode == mode
|
||||
|
||||
def test_skip_patterns_initialization(self):
|
||||
"""Test skip patterns initialization."""
|
||||
config = QualityConfig(skip_patterns=None)
|
||||
config = guard.QualityConfig(skip_patterns=None)
|
||||
assert config.skip_patterns is not None
|
||||
assert len(config.skip_patterns) > 0
|
||||
|
||||
custom_patterns = ["custom_test_", "/custom/"]
|
||||
config = QualityConfig(skip_patterns=custom_patterns)
|
||||
config = guard.QualityConfig(skip_patterns=custom_patterns)
|
||||
assert config.skip_patterns == custom_patterns
|
||||
|
||||
def test_threshold_boundaries(self):
|
||||
"""Test threshold boundary values."""
|
||||
# Test minimum threshold
|
||||
os.environ["QUALITY_DUP_THRESHOLD"] = "0.0"
|
||||
config = QualityConfig.from_env()
|
||||
config = guard.QualityConfig.from_env()
|
||||
assert config.duplicate_threshold == 0.0
|
||||
|
||||
# Test maximum threshold
|
||||
os.environ["QUALITY_DUP_THRESHOLD"] = "1.0"
|
||||
config = QualityConfig.from_env()
|
||||
config = guard.QualityConfig.from_env()
|
||||
assert config.duplicate_threshold == 1.0
|
||||
|
||||
# Test complexity threshold
|
||||
os.environ["QUALITY_COMPLEXITY_THRESHOLD"] = "1"
|
||||
config = QualityConfig.from_env()
|
||||
config = guard.QualityConfig.from_env()
|
||||
assert config.complexity_threshold == 1
|
||||
|
||||
def test_config_combinations(self):
|
||||
def test_config_combinations(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test various configuration combinations."""
|
||||
test_cases = [
|
||||
# All checks disabled
|
||||
{
|
||||
"env": {
|
||||
test_cases: list[tuple[dict[str, str], dict[str, bool]]] = [
|
||||
(
|
||||
{
|
||||
"QUALITY_DUP_ENABLED": "false",
|
||||
"QUALITY_COMPLEXITY_ENABLED": "false",
|
||||
"QUALITY_MODERN_ENABLED": "false",
|
||||
},
|
||||
"expected": {
|
||||
{
|
||||
"duplicate_enabled": False,
|
||||
"complexity_enabled": False,
|
||||
"modernization_enabled": False,
|
||||
},
|
||||
},
|
||||
# Only duplicate checking
|
||||
{
|
||||
"env": {
|
||||
),
|
||||
(
|
||||
{
|
||||
"QUALITY_DUP_ENABLED": "true",
|
||||
"QUALITY_COMPLEXITY_ENABLED": "false",
|
||||
"QUALITY_MODERN_ENABLED": "false",
|
||||
},
|
||||
"expected": {
|
||||
{
|
||||
"duplicate_enabled": True,
|
||||
"complexity_enabled": False,
|
||||
"modernization_enabled": False,
|
||||
},
|
||||
},
|
||||
# PostToolUse only
|
||||
{
|
||||
"env": {
|
||||
),
|
||||
(
|
||||
{
|
||||
"QUALITY_DUP_ENABLED": "false",
|
||||
"QUALITY_STATE_TRACKING": "true",
|
||||
"QUALITY_VERIFY_NAMING": "true",
|
||||
},
|
||||
"expected": {
|
||||
{
|
||||
"duplicate_enabled": False,
|
||||
"state_tracking_enabled": True,
|
||||
"verify_naming": True,
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
os.environ.clear()
|
||||
os.environ.update(test_case["env"])
|
||||
config = QualityConfig.from_env()
|
||||
for env_values, expected_values in test_cases:
|
||||
with monkeypatch.context() as mp:
|
||||
for key, value in env_values.items():
|
||||
mp.setenv(key, value)
|
||||
config = guard.QualityConfig.from_env()
|
||||
|
||||
for key, expected_value in test_case["expected"].items():
|
||||
assert getattr(config, key) == expected_value
|
||||
for key, expected_value in expected_values.items():
|
||||
assert getattr(config, key) == expected_value
|
||||
|
||||
def test_case_insensitive_boolean(self):
|
||||
"""Test case-insensitive boolean parsing."""
|
||||
@@ -192,5 +193,5 @@ class TestQualityConfig:
|
||||
|
||||
for value, expected_bool in zip(test_values, expected, strict=False):
|
||||
os.environ["QUALITY_DUP_ENABLED"] = value
|
||||
config = QualityConfig.from_env()
|
||||
config = guard.QualityConfig.from_env()
|
||||
assert config.duplicate_enabled == expected_bool
|
||||
|
||||
309
tests/hooks/test_duplicate_fairness.py
Normal file
309
tests/hooks/test_duplicate_fairness.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""Fairness tests for async functions and fixtures in duplicate detection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from hooks.internal_duplicate_detector import (
|
||||
Duplicate,
|
||||
DuplicateResults,
|
||||
detect_internal_duplicates,
|
||||
)
|
||||
|
||||
|
||||
def _run_detection(code: str, *, threshold: float) -> tuple[DuplicateResults, list[Duplicate]]:
|
||||
"""Run duplicate detection and return typed results."""
|
||||
|
||||
result = detect_internal_duplicates(code, threshold=threshold)
|
||||
duplicates = result.get("duplicates", []) or []
|
||||
return result, duplicates
|
||||
|
||||
|
||||
class TestAsyncFunctionFairness:
|
||||
"""Verify async functions are treated fairly in duplicate detection."""
|
||||
|
||||
def test_async_and_sync_identical_logic(self) -> None:
|
||||
"""Async and sync versions of same logic should be flagged as duplicates."""
|
||||
code = """
|
||||
def fetch_user(user_id: int) -> dict[str, str]:
|
||||
response = requests.get(f"/api/users/{user_id}")
|
||||
data = response.json()
|
||||
return {"id": str(data["id"]), "name": data["name"]}
|
||||
|
||||
async def fetch_user_async(user_id: int) -> dict[str, str]:
|
||||
response = await client.get(f"/api/users/{user_id}")
|
||||
data = await response.json()
|
||||
return {"id": str(data["id"]), "name": data["name"]}
|
||||
"""
|
||||
_, duplicates = _run_detection(code, threshold=0.7)
|
||||
|
||||
# Should detect structural similarity despite async/await differences
|
||||
assert len(duplicates) >= 1
|
||||
assert any(d["similarity"] > 0.7 for d in duplicates)
|
||||
|
||||
def test_async_context_managers_exemption(self) -> None:
|
||||
"""Async context manager dunder methods should be exempted like sync ones."""
|
||||
code = """
|
||||
async def __aenter__(self):
|
||||
self.conn = await connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.conn.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
self.cache = await connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.cache.close()
|
||||
"""
|
||||
_, duplicates = _run_detection(code, threshold=0.7)
|
||||
|
||||
# __aenter__ and __aexit__ should be exempted as boilerplate
|
||||
# Even though they have similar structure
|
||||
assert len(duplicates) == 0
|
||||
|
||||
def test_mixed_async_sync_functions_no_bias(self) -> None:
|
||||
"""Detection should work equally for mixed async/sync functions."""
|
||||
code = """
|
||||
def process_sync(data: list[int]) -> int:
|
||||
total = 0
|
||||
for item in data:
|
||||
if item > 0:
|
||||
total += item * 2
|
||||
return total
|
||||
|
||||
async def process_async(data: list[int]) -> int:
|
||||
total = 0
|
||||
for item in data:
|
||||
if item > 0:
|
||||
total += item * 2
|
||||
return total
|
||||
|
||||
def calculate_sync(values: list[int]) -> int:
|
||||
result = 0
|
||||
for val in values:
|
||||
if val > 0:
|
||||
result += val * 2
|
||||
return result
|
||||
"""
|
||||
_, duplicates = _run_detection(code, threshold=0.7)
|
||||
|
||||
# All three should be detected as similar (regardless of async)
|
||||
assert len(duplicates) >= 1
|
||||
found_functions: set[str] = set()
|
||||
for dup in duplicates:
|
||||
for loc in dup["locations"]:
|
||||
found_functions.add(loc["name"])
|
||||
|
||||
# Should find all three functions in duplicate groups
|
||||
assert len(found_functions) >= 2
|
||||
|
||||
|
||||
class TestFixtureFairness:
|
||||
"""Verify pytest fixtures and test patterns are treated fairly."""
|
||||
|
||||
def test_pytest_fixtures_with_similar_data(self) -> None:
|
||||
"""Pytest fixtures with similar structure should be exempted."""
|
||||
code = """
|
||||
import pytest
|
||||
|
||||
@pytest.fixture
|
||||
def user_data() -> dict[str, str | int]:
|
||||
return {
|
||||
"name": "Alice",
|
||||
"age": 30,
|
||||
"email": "alice@example.com"
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def admin_data() -> dict[str, str | int]:
|
||||
return {
|
||||
"name": "Bob",
|
||||
"age": 35,
|
||||
"email": "bob@example.com"
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def guest_data() -> dict[str, str | int]:
|
||||
return {
|
||||
"name": "Charlie",
|
||||
"age": 25,
|
||||
"email": "charlie@example.com"
|
||||
}
|
||||
"""
|
||||
_, duplicates = _run_detection(code, threshold=0.7)
|
||||
|
||||
# Fixtures should be exempted from duplicate detection
|
||||
assert len(duplicates) == 0
|
||||
|
||||
def test_mock_builders_exemption(self) -> None:
|
||||
"""Mock/stub builder functions should be exempted if short and simple."""
|
||||
code = """
|
||||
def mock_user_response() -> dict[str, str]:
|
||||
return {
|
||||
"id": "123",
|
||||
"name": "Test User",
|
||||
"status": "active"
|
||||
}
|
||||
|
||||
def mock_admin_response() -> dict[str, str]:
|
||||
return {
|
||||
"id": "456",
|
||||
"name": "Admin User",
|
||||
"status": "active"
|
||||
}
|
||||
|
||||
def stub_guest_response() -> dict[str, str]:
|
||||
return {
|
||||
"id": "789",
|
||||
"name": "Guest User",
|
||||
"status": "pending"
|
||||
}
|
||||
"""
|
||||
_, duplicates = _run_detection(code, threshold=0.7)
|
||||
|
||||
# Short mock builders should be exempted
|
||||
assert len(duplicates) == 0
|
||||
|
||||
def test_simple_test_functions_with_aaa_pattern(self) -> None:
|
||||
"""Simple test functions following arrange-act-assert should be lenient."""
|
||||
code = """
|
||||
def test_user_creation() -> None:
|
||||
# Arrange
|
||||
user_data = {"name": "Alice", "email": "alice@test.com"}
|
||||
|
||||
# Act
|
||||
user = create_user(user_data)
|
||||
|
||||
# Assert
|
||||
assert user.name == "Alice"
|
||||
assert user.email == "alice@test.com"
|
||||
|
||||
def test_admin_creation() -> None:
|
||||
# Arrange
|
||||
admin_data = {"name": "Bob", "email": "bob@test.com"}
|
||||
|
||||
# Act
|
||||
admin = create_user(admin_data)
|
||||
|
||||
# Assert
|
||||
assert admin.name == "Bob"
|
||||
assert admin.email == "bob@test.com"
|
||||
|
||||
def test_guest_creation() -> None:
|
||||
# Arrange
|
||||
guest_data = {"name": "Charlie", "email": "charlie@test.com"}
|
||||
|
||||
# Act
|
||||
guest = create_user(guest_data)
|
||||
|
||||
# Assert
|
||||
assert guest.name == "Charlie"
|
||||
assert guest.email == "charlie@test.com"
|
||||
"""
|
||||
_, duplicates = _run_detection(code, threshold=0.7)
|
||||
|
||||
# Simple test functions with AAA pattern should be exempted if similarity < 95%
|
||||
assert len(duplicates) == 0
|
||||
|
||||
def test_complex_fixtures_still_flagged(self) -> None:
|
||||
"""Complex fixtures with substantial duplication should still be flagged."""
|
||||
code = """
|
||||
import pytest
|
||||
|
||||
@pytest.fixture
|
||||
def complex_user_setup() -> dict[str, object]:
|
||||
# Lots of complex setup logic
|
||||
db = connect_database()
|
||||
cache = setup_cache()
|
||||
logger = configure_logging()
|
||||
|
||||
user = create_user(db, {
|
||||
"name": "Alice",
|
||||
"permissions": ["read", "write", "delete"],
|
||||
"metadata": {"created": "2024-01-01"}
|
||||
})
|
||||
|
||||
cache.warm_up(user)
|
||||
logger.info(f"Created user {user.id}")
|
||||
|
||||
return {"user": user, "db": db, "cache": cache, "logger": logger}
|
||||
|
||||
@pytest.fixture
|
||||
def complex_admin_setup() -> dict[str, object]:
|
||||
# Lots of complex setup logic
|
||||
db = connect_database()
|
||||
cache = setup_cache()
|
||||
logger = configure_logging()
|
||||
|
||||
user = create_user(db, {
|
||||
"name": "Bob",
|
||||
"permissions": ["read", "write", "delete"],
|
||||
"metadata": {"created": "2024-01-02"}
|
||||
})
|
||||
|
||||
cache.warm_up(user)
|
||||
logger.info(f"Created user {user.id}")
|
||||
|
||||
return {"user": user, "db": db, "cache": cache, "logger": logger}
|
||||
"""
|
||||
_, duplicates = _run_detection(code, threshold=0.7)
|
||||
|
||||
# Complex fixtures exceeding 15 lines should be flagged
|
||||
assert len(duplicates) >= 1
|
||||
|
||||
def test_setup_teardown_methods(self) -> None:
|
||||
"""Test setup/teardown methods should be exempted if simple."""
|
||||
code = """
|
||||
def setup_database() -> None:
|
||||
db = connect_test_db()
|
||||
db.clear()
|
||||
return db
|
||||
|
||||
def teardown_database(db: object) -> None:
|
||||
db.clear()
|
||||
db.close()
|
||||
|
||||
def setup_cache() -> None:
|
||||
cache = connect_test_cache()
|
||||
cache.clear()
|
||||
return cache
|
||||
|
||||
def teardown_cache(cache: object) -> None:
|
||||
cache.clear()
|
||||
cache.close()
|
||||
"""
|
||||
_, duplicates = _run_detection(code, threshold=0.7)
|
||||
|
||||
# Setup/teardown functions with pattern names should be exempted
|
||||
assert len(duplicates) == 0
|
||||
|
||||
def test_non_test_code_still_strictly_checked(self) -> None:
|
||||
"""Non-test production code should still have strict duplicate detection."""
|
||||
code = """
|
||||
def calculate_user_total(users: list[dict[str, float]]) -> float:
|
||||
total = 0.0
|
||||
for user in users:
|
||||
if user.get("active"):
|
||||
total += user.get("amount", 0.0) * user.get("rate", 1.0)
|
||||
return total
|
||||
|
||||
def calculate_product_total(products: list[dict[str, float]]) -> float:
|
||||
total = 0.0
|
||||
for product in products:
|
||||
if product.get("active"):
|
||||
total += product.get("amount", 0.0) * product.get("rate", 1.0)
|
||||
return total
|
||||
|
||||
def calculate_order_total(orders: list[dict[str, float]]) -> float:
|
||||
total = 0.0
|
||||
for order in orders:
|
||||
if order.get("active"):
|
||||
total += order.get("amount", 0.0) * order.get("rate", 1.0)
|
||||
return total
|
||||
"""
|
||||
_, duplicates = _run_detection(code, threshold=0.7)
|
||||
|
||||
# Production code should be strictly checked
|
||||
assert len(duplicates) >= 1
|
||||
assert any(d["similarity"] > 0.85 for d in duplicates)
|
||||
@@ -3,21 +3,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from code_quality_guard import QualityConfig, posttooluse_hook, pretooluse_hook
|
||||
|
||||
from code_quality_guard import (
|
||||
QualityConfig,
|
||||
posttooluse_hook,
|
||||
pretooluse_hook,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def multi_container_paths(tmp_path: Path) -> dict[str, Path]:
|
||||
"""Create container/project directories used across tests."""
|
||||
container_a = tmp_path / "container-a" / "project" / "src"
|
||||
@@ -126,7 +124,9 @@ def test_pretooluse_handles_platform_metadata(
|
||||
assert response["permissionDecision"] == "allow"
|
||||
|
||||
|
||||
def test_state_tracking_isolation_between_containers(multi_container_paths: dict[str, Path]) -> None:
|
||||
def test_state_tracking_isolation_between_containers(
|
||||
multi_container_paths: dict[str, Path],
|
||||
) -> None:
|
||||
"""State tracking should stay isolated per container/project combination."""
|
||||
config = QualityConfig(state_tracking_enabled=True)
|
||||
config.skip_patterns = [] # Ensure state tracking runs even in pytest temp dirs.
|
||||
@@ -173,11 +173,12 @@ def beta():
|
||||
assert response_b_pre["permissionDecision"] == "allow"
|
||||
|
||||
# The first container writes fewer functions which should trigger a warning.
|
||||
file_a.write_text("""\
|
||||
file_a.write_text(
|
||||
"""\
|
||||
|
||||
def alpha():
|
||||
return 1
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
# The second container preserves the original content.
|
||||
@@ -232,9 +233,13 @@ def beta():
|
||||
path_one.parent.mkdir(parents=True, exist_ok=True)
|
||||
path_two.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cmd: list[str],
|
||||
**kwargs: object,
|
||||
) -> SimpleNamespace:
|
||||
with patch("code_quality_guard.analyze_code_quality", return_value={}):
|
||||
pretooluse_hook(
|
||||
_pre_request(
|
||||
path_one,
|
||||
base_content,
|
||||
container_id=shared_container,
|
||||
project_id=shared_project,
|
||||
user_id="collision-user",
|
||||
),
|
||||
config,
|
||||
@@ -250,11 +255,13 @@ def beta():
|
||||
config,
|
||||
)
|
||||
|
||||
path_one.write_text("""\
|
||||
path_one.write_text(
|
||||
"""\
|
||||
|
||||
def alpha():
|
||||
return 1
|
||||
""".lstrip())
|
||||
""".lstrip(),
|
||||
)
|
||||
path_two.write_text(base_content)
|
||||
|
||||
degraded_response = posttooluse_hook(
|
||||
@@ -304,14 +311,7 @@ def test_cross_file_duplicate_project_root_detection(
|
||||
|
||||
captured: dict[str, list[str]] = {}
|
||||
|
||||
def fake_run(
|
||||
cmd: list[str],
|
||||
*,
|
||||
check: bool,
|
||||
capture_output: bool,
|
||||
text: bool,
|
||||
timeout: int,
|
||||
) -> SimpleNamespace:
|
||||
def fake_run(cmd: list[str], **_kwargs: object) -> SimpleNamespace:
|
||||
captured["cmd"] = cmd
|
||||
return SimpleNamespace(returncode=0, stdout=json.dumps({"duplicates": []}))
|
||||
|
||||
@@ -336,7 +336,9 @@ def test_cross_file_duplicate_project_root_detection(
|
||||
assert response.get("decision") is None
|
||||
|
||||
|
||||
def test_main_handles_permission_decisions_for_multiple_users(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_main_handles_permission_decisions_for_multiple_users(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""`main` should surface deny/ask decisions for different user contexts."""
|
||||
from code_quality_guard import main
|
||||
|
||||
@@ -394,7 +396,7 @@ def test_main_handles_permission_decisions_for_multiple_users(monkeypatch: pytes
|
||||
},
|
||||
"permissionDecision": "allow",
|
||||
},
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
input_iter: Iterator[dict[str, object]] = iter(hook_inputs)
|
||||
@@ -402,7 +404,10 @@ def test_main_handles_permission_decisions_for_multiple_users(monkeypatch: pytes
|
||||
def fake_json_load(_stream: object) -> dict[str, object]:
|
||||
return next(input_iter)
|
||||
|
||||
def fake_pretooluse(_hook_data: dict[str, object], _config: QualityConfig) -> dict[str, object]:
|
||||
def fake_pretooluse(
|
||||
_hook_data: dict[str, object],
|
||||
_config: QualityConfig,
|
||||
) -> dict[str, object]:
|
||||
return next(responses)
|
||||
|
||||
exit_calls: list[tuple[str, int]] = []
|
||||
@@ -413,7 +418,7 @@ def test_main_handles_permission_decisions_for_multiple_users(monkeypatch: pytes
|
||||
|
||||
printed: list[str] = []
|
||||
|
||||
def fake_print(message: str) -> None: # noqa: D401 - simple passthrough
|
||||
def fake_print(message: str) -> None:
|
||||
printed.append(message)
|
||||
|
||||
monkeypatch.setattr("json.load", fake_json_load)
|
||||
|
||||
@@ -10,7 +10,6 @@ from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from code_quality_guard import (
|
||||
QualityConfig,
|
||||
analyze_code_quality,
|
||||
@@ -95,8 +94,13 @@ class TestHelperFunctions:
|
||||
cmd = get_claude_quality_command(repo_root=tmp_path)
|
||||
assert cmd == [str(executable), "-m", "quality.cli.main"]
|
||||
|
||||
def test_get_claude_quality_command_python_and_python3(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch, set_platform) -> None:
|
||||
"""Prefer python when both python and python3 executables exist in the venv."""
|
||||
def test_get_claude_quality_command_python_and_python3(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
set_platform,
|
||||
) -> None:
|
||||
"""Prefer python when both python and python3 executables exist."""
|
||||
|
||||
set_platform("linux")
|
||||
monkeypatch.setattr(shutil, "which", lambda _name: None)
|
||||
@@ -108,7 +112,12 @@ class TestHelperFunctions:
|
||||
assert cmd == [str(python_path), "-m", "quality.cli.main"]
|
||||
assert python3_path.exists() # Sanity check that both executables were present
|
||||
|
||||
def test_get_claude_quality_command_cli_fallback(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch, set_platform) -> None:
|
||||
def test_get_claude_quality_command_cli_fallback(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
set_platform,
|
||||
) -> None:
|
||||
"""Fallback to claude-quality script when python executables are absent."""
|
||||
|
||||
set_platform("linux")
|
||||
@@ -125,7 +134,7 @@ class TestHelperFunctions:
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
set_platform,
|
||||
) -> None:
|
||||
"""Handle Windows environments where the claude-quality script lacks an .exe suffix."""
|
||||
"""Handle Windows when the claude-quality script lacks an .exe suffix."""
|
||||
|
||||
set_platform("win32")
|
||||
monkeypatch.setattr(shutil, "which", lambda _name: None)
|
||||
@@ -141,7 +150,7 @@ class TestHelperFunctions:
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
set_platform,
|
||||
) -> None:
|
||||
"""Fall back to python3 on POSIX and python on Windows when no venv executables exist."""
|
||||
"""Fallback to python3 on POSIX and python on Windows when venv tools absent."""
|
||||
|
||||
set_platform("darwin")
|
||||
|
||||
|
||||
@@ -4,6 +4,10 @@ from unittest.mock import patch
|
||||
|
||||
from code_quality_guard import QualityConfig, pretooluse_hook
|
||||
|
||||
TEST_QUALITY_CONDITIONAL = (
|
||||
"Test Quality: no-conditionals-in-tests - Conditional found in test"
|
||||
)
|
||||
|
||||
|
||||
class TestPreToolUseHook:
|
||||
"""Test PreToolUse hook behavior."""
|
||||
@@ -470,7 +474,7 @@ class TestPreToolUseHook:
|
||||
"tool_input": {
|
||||
"file_path": "example.py",
|
||||
"content": (
|
||||
"def example() -> None:\n" " value = unknown # type: ignore\n"
|
||||
"def example() -> None:\n value = unknown # type: ignore\n"
|
||||
),
|
||||
},
|
||||
}
|
||||
@@ -492,7 +496,7 @@ class TestPreToolUseHook:
|
||||
{
|
||||
"old_string": "pass",
|
||||
"new_string": (
|
||||
"def helper() -> None:\n" " pass # type: ignore\n"
|
||||
"def helper() -> None:\n pass # type: ignore\n"
|
||||
),
|
||||
},
|
||||
{
|
||||
@@ -545,7 +549,7 @@ class TestTestQualityChecks:
|
||||
}
|
||||
|
||||
with patch("code_quality_guard.run_test_quality_checks") as mock_test_check:
|
||||
mock_test_check.return_value = ["Test Quality: no-conditionals-in-tests - Conditional found in test"]
|
||||
mock_test_check.return_value = [TEST_QUALITY_CONDITIONAL]
|
||||
|
||||
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
|
||||
mock_analyze.return_value = {}
|
||||
@@ -639,7 +643,7 @@ class TestTestQualityChecks:
|
||||
}
|
||||
|
||||
with patch("code_quality_guard.run_test_quality_checks") as mock_test_check:
|
||||
mock_test_check.return_value = ["Test Quality: no-conditionals-in-tests - Conditional found in test"]
|
||||
mock_test_check.return_value = [TEST_QUALITY_CONDITIONAL]
|
||||
|
||||
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
|
||||
mock_analyze.return_value = {}
|
||||
@@ -659,14 +663,27 @@ class TestTestQualityChecks:
|
||||
"tool_input": {
|
||||
"file_path": "tests/test_example.py",
|
||||
"edits": [
|
||||
{"old_string": "a", "new_string": "def test_func1():\n assert True"},
|
||||
{"old_string": "b", "new_string": "def test_func2():\n if False:\n pass"},
|
||||
{
|
||||
"old_string": "a",
|
||||
"new_string": (
|
||||
"def test_func1():\n"
|
||||
" assert True"
|
||||
),
|
||||
},
|
||||
{
|
||||
"old_string": "b",
|
||||
"new_string": (
|
||||
"def test_func2():\n"
|
||||
" if False:\n"
|
||||
" pass"
|
||||
),
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
with patch("code_quality_guard.run_test_quality_checks") as mock_test_check:
|
||||
mock_test_check.return_value = ["Test Quality: no-conditionals-in-tests - Conditional found in test"]
|
||||
mock_test_check.return_value = [TEST_QUALITY_CONDITIONAL]
|
||||
|
||||
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
|
||||
mock_analyze.return_value = {}
|
||||
@@ -695,7 +712,7 @@ class TestTestQualityChecks:
|
||||
}
|
||||
|
||||
with patch("code_quality_guard.run_test_quality_checks") as mock_test_check:
|
||||
mock_test_check.return_value = ["Test Quality: no-conditionals-in-tests - Conditional found in test"]
|
||||
mock_test_check.return_value = [TEST_QUALITY_CONDITIONAL]
|
||||
|
||||
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
|
||||
mock_analyze.return_value = {}
|
||||
|
||||
@@ -1,25 +1,33 @@
|
||||
# ruff: noqa: SLF001
|
||||
|
||||
"""Tests targeting internal helpers for code_quality_guard."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# pyright: reportPrivateUsage=false, reportPrivateImportUsage=false, reportPrivateLocalImportUsage=false, reportUnknownArgumentType=false, reportUnknownLambdaType=false, reportUnknownMemberType=false, reportUnusedCallResult=false
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import code_quality_guard as guard
|
||||
import pytest
|
||||
|
||||
from hooks import code_quality_guard as guard
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("env_key", "value", "attr", "expected"),
|
||||
(
|
||||
[
|
||||
("QUALITY_DUP_THRESHOLD", "0.9", "duplicate_threshold", 0.9),
|
||||
("QUALITY_DUP_ENABLED", "false", "duplicate_enabled", False),
|
||||
("QUALITY_COMPLEXITY_THRESHOLD", "7", "complexity_threshold", 7),
|
||||
("QUALITY_ENFORCEMENT", "warn", "enforcement_mode", "warn"),
|
||||
("QUALITY_STATE_TRACKING", "true", "state_tracking_enabled", True),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_quality_config_from_env_parsing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@@ -36,12 +44,12 @@ def test_quality_config_from_env_parsing(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tool_exists", "install_behavior", "expected"),
|
||||
(
|
||||
[
|
||||
(True, None, True),
|
||||
(False, "success", True),
|
||||
(False, "failure", False),
|
||||
(False, "timeout", False),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_ensure_tool_installed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@@ -63,10 +71,12 @@ def test_ensure_tool_installed(
|
||||
|
||||
def fake_run(cmd: Iterable[str], **_: object) -> subprocess.CompletedProcess[bytes]:
|
||||
if install_behavior is None:
|
||||
raise AssertionError("uv install should not run when tool already exists")
|
||||
message = "uv install should not run when tool already exists"
|
||||
raise AssertionError(message)
|
||||
if install_behavior == "timeout":
|
||||
raise subprocess.TimeoutExpired(cmd=list(cmd), timeout=60)
|
||||
return subprocess.CompletedProcess(list(cmd), 0 if install_behavior == "success" else 1)
|
||||
exit_code = 0 if install_behavior == "success" else 1
|
||||
return subprocess.CompletedProcess(list(cmd), exit_code)
|
||||
|
||||
monkeypatch.setattr(guard.subprocess, "run", fake_run)
|
||||
|
||||
@@ -75,12 +85,32 @@ def test_ensure_tool_installed(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tool_name", "run_payload", "expected_success", "expected_fragment"),
|
||||
(
|
||||
("basedpyright", {"returncode": 0, "stdout": ""}, True, ""),
|
||||
("basedpyright", {"returncode": 1, "stdout": ""}, False, "failed to parse"),
|
||||
("sourcery", {"returncode": 0, "stdout": "3 issues detected"}, False, "3 code quality issue"),
|
||||
("pyrefly", {"returncode": 1, "stdout": "pyrefly issue"}, False, "pyrefly issue"),
|
||||
),
|
||||
[
|
||||
(
|
||||
"basedpyright",
|
||||
{"returncode": 0, "stdout": ""},
|
||||
True,
|
||||
"",
|
||||
),
|
||||
(
|
||||
"basedpyright",
|
||||
{"returncode": 1, "stdout": ""},
|
||||
False,
|
||||
"failed to parse",
|
||||
),
|
||||
(
|
||||
"sourcery",
|
||||
{"returncode": 0, "stdout": "3 issues detected"},
|
||||
False,
|
||||
"3 code quality issue",
|
||||
),
|
||||
(
|
||||
"pyrefly",
|
||||
{"returncode": 1, "stdout": "pyrefly issue"},
|
||||
False,
|
||||
"pyrefly issue",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_run_type_checker_known_tools(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@@ -94,11 +124,27 @@ def test_run_type_checker_known_tools(
|
||||
monkeypatch.setattr(guard.Path, "exists", lambda _path: True, raising=False)
|
||||
|
||||
def fake_run(cmd: Iterable[str], **_: object) -> subprocess.CompletedProcess[str]:
|
||||
return subprocess.CompletedProcess(list(cmd), int(run_payload["returncode"]), run_payload.get("stdout", ""), "")
|
||||
returncode_obj = run_payload.get("returncode", 0)
|
||||
if isinstance(returncode_obj, bool):
|
||||
exit_code = int(returncode_obj)
|
||||
elif isinstance(returncode_obj, int):
|
||||
exit_code = returncode_obj
|
||||
elif isinstance(returncode_obj, str):
|
||||
exit_code = int(returncode_obj)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected returncode type: {type(returncode_obj)!r}")
|
||||
|
||||
stdout_obj = run_payload.get("stdout", "")
|
||||
stdout = str(stdout_obj)
|
||||
return subprocess.CompletedProcess(list(cmd), exit_code, stdout=stdout, stderr="")
|
||||
|
||||
monkeypatch.setattr(guard.subprocess, "run", fake_run)
|
||||
|
||||
success, message = guard._run_type_checker(tool_name, "tmp.py", guard.QualityConfig())
|
||||
success, message = guard._run_type_checker(
|
||||
tool_name,
|
||||
"tmp.py",
|
||||
guard.QualityConfig(),
|
||||
)
|
||||
assert success is expected_success
|
||||
if expected_fragment:
|
||||
assert expected_fragment in message
|
||||
@@ -108,10 +154,10 @@ def test_run_type_checker_known_tools(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception", "expected_fragment"),
|
||||
(
|
||||
[
|
||||
(subprocess.TimeoutExpired(cmd=["tool"], timeout=30), "timeout"),
|
||||
(OSError("boom"), "execution error"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_run_type_checker_runtime_exceptions(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@@ -126,7 +172,11 @@ def test_run_type_checker_runtime_exceptions(
|
||||
|
||||
monkeypatch.setattr(guard.subprocess, "run", raise_exc)
|
||||
|
||||
success, message = guard._run_type_checker("sourcery", "tmp.py", guard.QualityConfig())
|
||||
success, message = guard._run_type_checker(
|
||||
"sourcery",
|
||||
"tmp.py",
|
||||
guard.QualityConfig(),
|
||||
)
|
||||
assert success is True
|
||||
assert expected_fragment in message
|
||||
|
||||
@@ -137,7 +187,11 @@ def test_run_type_checker_tool_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(guard.Path, "exists", lambda _path: False, raising=False)
|
||||
monkeypatch.setattr(guard, "_ensure_tool_installed", lambda _name: False)
|
||||
|
||||
success, message = guard._run_type_checker("pyrefly", "tmp.py", guard.QualityConfig())
|
||||
success, message = guard._run_type_checker(
|
||||
"pyrefly",
|
||||
"tmp.py",
|
||||
guard.QualityConfig(),
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert "not available" in message
|
||||
@@ -148,12 +202,19 @@ def test_run_type_checker_unknown_tool(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
monkeypatch.setattr(guard.Path, "exists", lambda _path: True, raising=False)
|
||||
|
||||
success, message = guard._run_type_checker("unknown", "tmp.py", guard.QualityConfig())
|
||||
success, message = guard._run_type_checker(
|
||||
"unknown",
|
||||
"tmp.py",
|
||||
guard.QualityConfig(),
|
||||
)
|
||||
assert success is True
|
||||
assert "Unknown tool" in message
|
||||
|
||||
|
||||
def test_run_quality_analyses_invokes_cli(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
def test_run_quality_analyses_invokes_cli(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""_run_quality_analyses aggregates CLI outputs and duplicates."""
|
||||
|
||||
script_path = tmp_path / "module.py"
|
||||
@@ -202,7 +263,8 @@ def test_run_quality_analyses_invokes_cli(monkeypatch: pytest.MonkeyPatch, tmp_p
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected command: {cmd}")
|
||||
message = f"Unexpected command: {cmd}"
|
||||
raise AssertionError(message)
|
||||
return subprocess.CompletedProcess(list(cmd), 0, payload, "")
|
||||
|
||||
monkeypatch.setattr(guard.subprocess, "run", fake_run)
|
||||
@@ -221,11 +283,11 @@ def test_run_quality_analyses_invokes_cli(monkeypatch: pytest.MonkeyPatch, tmp_p
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("content", "expected"),
|
||||
(
|
||||
[
|
||||
("from typing import Any\n\nAny\n", True),
|
||||
("def broken(:\n Any\n", True),
|
||||
("def clean() -> None:\n return None\n", False),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_detect_any_usage(content: str, expected: bool) -> None:
|
||||
"""_detect_any_usage flags Any usage even on syntax errors."""
|
||||
@@ -236,12 +298,12 @@ def test_detect_any_usage(content: str, expected: bool) -> None:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "forced", "expected_permission"),
|
||||
(
|
||||
[
|
||||
("strict", None, "deny"),
|
||||
("warn", None, "ask"),
|
||||
("permissive", None, "allow"),
|
||||
("strict", "allow", "allow"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_handle_quality_issues_modes(
|
||||
mode: str,
|
||||
@@ -253,17 +315,29 @@ def test_handle_quality_issues_modes(
|
||||
config = guard.QualityConfig(enforcement_mode=mode)
|
||||
issues = ["Issue one", "Issue two"]
|
||||
|
||||
response = guard._handle_quality_issues("example.py", issues, config, forced_permission=forced)
|
||||
assert response["permissionDecision"] == expected_permission
|
||||
response = guard._handle_quality_issues(
|
||||
"example.py",
|
||||
issues,
|
||||
config,
|
||||
forced_permission=forced,
|
||||
)
|
||||
decision = cast(str, response["permissionDecision"])
|
||||
assert decision == expected_permission
|
||||
if forced is None:
|
||||
assert any(issue in response.get("reason", "") for issue in issues)
|
||||
reason = cast(str, response.get("reason", ""))
|
||||
assert any(issue in reason for issue in issues)
|
||||
|
||||
|
||||
def test_perform_quality_check_with_state_tracking(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_perform_quality_check_with_state_tracking(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""_perform_quality_check stores state and reports detected issues."""
|
||||
|
||||
tracked_calls: list[str] = []
|
||||
monkeypatch.setattr(guard, "store_pre_state", lambda path, content: tracked_calls.append(path))
|
||||
def record_state(path: str, _content: str) -> None:
|
||||
tracked_calls.append(path)
|
||||
|
||||
monkeypatch.setattr(guard, "store_pre_state", record_state)
|
||||
|
||||
def fake_analyze(*_args: object, **_kwargs: object) -> guard.AnalysisResults:
|
||||
return {
|
||||
@@ -276,11 +350,18 @@ def test_perform_quality_check_with_state_tracking(monkeypatch: pytest.MonkeyPat
|
||||
|
||||
config = guard.QualityConfig(state_tracking_enabled=True)
|
||||
|
||||
has_issues, issues = guard._perform_quality_check("example.py", "def old(): pass", config)
|
||||
has_issues, issues = guard._perform_quality_check(
|
||||
"example.py",
|
||||
"def old(): pass",
|
||||
config,
|
||||
)
|
||||
|
||||
assert tracked_calls == ["example.py"]
|
||||
assert has_issues is True
|
||||
assert any("Modernization" in issue or "modernization" in issue.lower() for issue in issues)
|
||||
assert any(
|
||||
"Modernization" in issue or "modernization" in issue.lower()
|
||||
for issue in issues
|
||||
)
|
||||
|
||||
|
||||
def test_check_cross_file_duplicates_command(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@@ -296,7 +377,10 @@ def test_check_cross_file_duplicates_command(monkeypatch: pytest.MonkeyPatch) ->
|
||||
|
||||
monkeypatch.setattr(guard.subprocess, "run", fake_run)
|
||||
|
||||
issues = guard.check_cross_file_duplicates("/repo/example.py", guard.QualityConfig())
|
||||
issues = guard.check_cross_file_duplicates(
|
||||
"/repo/example.py",
|
||||
guard.QualityConfig(),
|
||||
)
|
||||
|
||||
assert issues
|
||||
assert "duplicates" in captured_cmds[0]
|
||||
@@ -314,9 +398,9 @@ def test_create_hook_response_includes_reason() -> None:
|
||||
additional_context="context",
|
||||
decision="block",
|
||||
)
|
||||
assert response["permissionDecision"] == "deny"
|
||||
assert response["reason"] == "Testing"
|
||||
assert response["systemMessage"] == "System"
|
||||
assert response["hookSpecificOutput"]["additionalContext"] == "context"
|
||||
assert response["decision"] == "block"
|
||||
|
||||
assert cast(str, response["permissionDecision"]) == "deny"
|
||||
assert cast(str, response["reason"]) == "Testing"
|
||||
assert cast(str, response["systemMessage"]) == "System"
|
||||
hook_output = cast(dict[str, object], response["hookSpecificOutput"])
|
||||
assert cast(str, hook_output["additionalContext"]) == "context"
|
||||
assert cast(str, response["decision"]) == "block"
|
||||
|
||||
@@ -2,25 +2,24 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# pyright: reportPrivateUsage=false, reportPrivateImportUsage=false, reportPrivateLocalImportUsage=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownLambdaType=false, reportUnknownMemberType=false, reportUnusedCallResult=false
|
||||
|
||||
# ruff: noqa: SLF001
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Add hooks directory to path
|
||||
HOOKS_DIR = Path(__file__).parent.parent.parent / "hooks"
|
||||
sys.path.insert(0, str(HOOKS_DIR))
|
||||
|
||||
import code_quality_guard as guard
|
||||
from hooks import code_quality_guard as guard
|
||||
|
||||
|
||||
class TestVenvDetection:
|
||||
"""Test virtual environment detection."""
|
||||
|
||||
def test_finds_venv_from_file_path(self, tmp_path: Path) -> None:
|
||||
def test_finds_venv_from_file_path(self) -> None:
|
||||
"""Should find .venv by traversing up from file."""
|
||||
# Use home directory to avoid /tmp check
|
||||
root = Path.home() / f"test_proj_{os.getpid()}"
|
||||
@@ -99,12 +98,17 @@ class TestPythonpathSetup:
|
||||
tool.write_text("#!/bin/bash\necho fake")
|
||||
tool.chmod(0o755)
|
||||
|
||||
captured_env = {}
|
||||
captured_env: dict[str, str] = {}
|
||||
|
||||
def capture_run(cmd: list[str], **kw: object) -> subprocess.CompletedProcess[str]:
|
||||
if "env" in kw:
|
||||
captured_env.update(dict(kw["env"]))
|
||||
return subprocess.CompletedProcess(cmd, 0, "", "")
|
||||
def capture_run(
|
||||
cmd: list[str],
|
||||
**kwargs: object,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
env_obj = kwargs.get("env")
|
||||
if isinstance(env_obj, Mapping):
|
||||
for key, value in env_obj.items():
|
||||
captured_env[str(key)] = str(value)
|
||||
return subprocess.CompletedProcess(list(cmd), 0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(guard.subprocess, "run", capture_run)
|
||||
|
||||
@@ -137,10 +141,10 @@ class TestProjectRootAndTempFiles:
|
||||
nested = root / "src/pkg/subpkg"
|
||||
nested.mkdir(parents=True)
|
||||
(root / ".git").mkdir()
|
||||
|
||||
|
||||
test_file = nested / "module.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
found_root = guard._find_project_root(str(test_file))
|
||||
assert found_root == root
|
||||
finally:
|
||||
@@ -154,12 +158,12 @@ class TestProjectRootAndTempFiles:
|
||||
try:
|
||||
(root / "src").mkdir(parents=True)
|
||||
(root / "pyproject.toml").touch()
|
||||
|
||||
|
||||
test_file = root / "src/module.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
tmp_dir = guard._get_project_tmp_dir(str(test_file))
|
||||
|
||||
|
||||
assert tmp_dir.exists()
|
||||
assert tmp_dir == root / ".tmp"
|
||||
assert tmp_dir.parent == root
|
||||
@@ -177,29 +181,33 @@ class TestProjectRootAndTempFiles:
|
||||
tool = root / ".venv/bin/basedpyright"
|
||||
tool.write_text("#!/bin/bash\necho fake")
|
||||
tool.chmod(0o755)
|
||||
|
||||
|
||||
# Create pyrightconfig.json
|
||||
(root / "pyrightconfig.json").write_text('{"strict": []}')
|
||||
|
||||
captured_cwd = []
|
||||
|
||||
def capture_run(cmd: list[str], **kw: object) -> subprocess.CompletedProcess[str]:
|
||||
if "cwd" in kw:
|
||||
captured_cwd.append(Path(str(kw["cwd"])))
|
||||
return subprocess.CompletedProcess(cmd, 0, "", "")
|
||||
|
||||
|
||||
captured_cwd: list[Path] = []
|
||||
|
||||
def capture_run(
|
||||
cmd: list[str],
|
||||
**kwargs: object,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
cwd_obj = kwargs.get("cwd")
|
||||
if cwd_obj is not None:
|
||||
captured_cwd.append(Path(str(cwd_obj)))
|
||||
return subprocess.CompletedProcess(list(cmd), 0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(guard.subprocess, "run", capture_run)
|
||||
|
||||
|
||||
test_file = root / "src/mod.py"
|
||||
test_file.write_text("# test")
|
||||
|
||||
|
||||
guard._run_type_checker(
|
||||
"basedpyright",
|
||||
str(test_file),
|
||||
guard.QualityConfig(),
|
||||
original_file_path=str(test_file),
|
||||
)
|
||||
|
||||
|
||||
# Should have run from project root
|
||||
assert len(captured_cwd) > 0
|
||||
assert captured_cwd[0] == root
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Comprehensive integration tests for code quality hooks.
|
||||
|
||||
This test suite validates that the hooks properly block forbidden code patterns
|
||||
@@ -10,93 +9,103 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
# Add hooks directory to path
|
||||
HOOKS_DIR = Path(__file__).parent.parent / "hooks"
|
||||
sys.path.insert(0, str(HOOKS_DIR))
|
||||
|
||||
from code_quality_guard import (
|
||||
from hooks.code_quality_guard import (
|
||||
JsonObject,
|
||||
QualityConfig,
|
||||
pretooluse_hook,
|
||||
_detect_any_usage, # pyright: ignore[reportPrivateUsage]
|
||||
_detect_old_typing_patterns, # pyright: ignore[reportPrivateUsage]
|
||||
_detect_type_ignore_usage, # pyright: ignore[reportPrivateUsage]
|
||||
posttooluse_hook,
|
||||
_detect_any_usage,
|
||||
_detect_type_ignore_usage,
|
||||
_detect_old_typing_patterns,
|
||||
pretooluse_hook,
|
||||
)
|
||||
|
||||
HOOKS_DIR = Path(__file__).parent.parent / "hooks"
|
||||
|
||||
|
||||
class TestHookIntegration:
|
||||
"""Integration tests for the complete hook system."""
|
||||
|
||||
def setup_method(self):
|
||||
config: QualityConfig
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.config = QualityConfig.from_env()
|
||||
self.config.enforcement_mode = "strict"
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Set up test environment."""
|
||||
self.config = QualityConfig.from_env()
|
||||
self.config.enforcement_mode = "strict"
|
||||
|
||||
def test_any_usage_blocked(self):
|
||||
"""Test that typing.Any usage is blocked."""
|
||||
content = '''from typing import Any
|
||||
content = """from typing import Any
|
||||
|
||||
def bad_function(param: Any) -> Any:
|
||||
return param'''
|
||||
return param"""
|
||||
|
||||
hook_data = {
|
||||
hook_data = cast(
|
||||
JsonObject,
|
||||
{
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production_code.py", # Non-test file
|
||||
"content": content
|
||||
}
|
||||
}
|
||||
"content": content,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result = pretooluse_hook(hook_data, self.config)
|
||||
|
||||
assert result["permissionDecision"] == "deny"
|
||||
assert "typing.Any usage" in result["reason"]
|
||||
print(f"✅ Any usage properly blocked: {result['reason']}")
|
||||
decision = cast(str, result["permissionDecision"])
|
||||
reason = cast(str, result["reason"])
|
||||
assert decision == "deny"
|
||||
assert "typing.Any usage" in reason
|
||||
|
||||
def test_type_ignore_blocked(self):
|
||||
"""Test that # type: ignore is blocked."""
|
||||
content = '''def bad_function():
|
||||
content = """def bad_function():
|
||||
x = some_untyped_call() # type: ignore
|
||||
return x'''
|
||||
return x"""
|
||||
|
||||
hook_data = {
|
||||
hook_data = cast(
|
||||
JsonObject,
|
||||
{
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production_code.py",
|
||||
"content": content
|
||||
}
|
||||
}
|
||||
"tool_input": {"file_path": "/src/production_code.py", "content": content},
|
||||
},
|
||||
)
|
||||
|
||||
result = pretooluse_hook(hook_data, self.config)
|
||||
|
||||
assert result["permissionDecision"] == "deny"
|
||||
assert "type: ignore" in result["reason"]
|
||||
print(f"✅ type: ignore properly blocked: {result['reason']}")
|
||||
decision = cast(str, result["permissionDecision"])
|
||||
reason = cast(str, result["reason"])
|
||||
assert decision == "deny"
|
||||
assert "type: ignore" in reason
|
||||
|
||||
def test_old_typing_patterns_blocked(self):
|
||||
"""Test that old typing patterns are blocked."""
|
||||
content = '''from typing import Union, Optional, List, Dict
|
||||
content = """from typing import Union, Optional, List, Dict
|
||||
|
||||
def bad_function(param: Union[str, int]) -> Optional[List[Dict[str, int]]]:
|
||||
return None'''
|
||||
return None"""
|
||||
|
||||
hook_data = {
|
||||
hook_data = cast(
|
||||
JsonObject,
|
||||
{
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production_code.py",
|
||||
"content": content
|
||||
}
|
||||
}
|
||||
"tool_input": {"file_path": "/src/production_code.py", "content": content},
|
||||
},
|
||||
)
|
||||
|
||||
result = pretooluse_hook(hook_data, self.config)
|
||||
|
||||
assert result["permissionDecision"] == "deny"
|
||||
assert "Old typing pattern" in result["reason"]
|
||||
print(f"✅ Old typing patterns properly blocked: {result['reason']}")
|
||||
decision = cast(str, result["permissionDecision"])
|
||||
reason = cast(str, result["reason"])
|
||||
assert decision == "deny"
|
||||
assert "Old typing pattern" in reason
|
||||
|
||||
def test_good_code_allowed(self):
|
||||
"""Test that good code is allowed through."""
|
||||
@@ -106,175 +115,199 @@ def bad_function(param: Union[str, int]) -> Optional[List[Dict[str, int]]]:
|
||||
return None
|
||||
return [{"value": 1}]'''
|
||||
|
||||
hook_data = {
|
||||
hook_data = cast(
|
||||
JsonObject,
|
||||
{
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production_code.py",
|
||||
"content": content
|
||||
}
|
||||
}
|
||||
"tool_input": {"file_path": "/src/production_code.py", "content": content},
|
||||
},
|
||||
)
|
||||
|
||||
result = pretooluse_hook(hook_data, self.config)
|
||||
|
||||
assert result["permissionDecision"] == "allow"
|
||||
print("✅ Good code properly allowed")
|
||||
decision = cast(str, result["permissionDecision"])
|
||||
assert decision == "allow"
|
||||
|
||||
def test_test_file_conditionals_blocked(self):
|
||||
"""Test that conditionals in test files are blocked."""
|
||||
content = '''def test_something():
|
||||
content = """def test_something():
|
||||
for item in items:
|
||||
if item.valid:
|
||||
assert item.process()
|
||||
else:
|
||||
assert not item.process()'''
|
||||
assert not item.process()"""
|
||||
|
||||
hook_data = {
|
||||
hook_data = cast(
|
||||
JsonObject,
|
||||
{
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/tests/test_something.py", # Test file
|
||||
"content": content
|
||||
}
|
||||
}
|
||||
"content": content,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result = pretooluse_hook(hook_data, self.config)
|
||||
|
||||
assert result["permissionDecision"] == "deny"
|
||||
assert ("Conditional Logic in Test Function" in result["reason"] or
|
||||
"Loop Found in Test Function" in result["reason"])
|
||||
print(f"✅ Test conditionals/loops properly blocked: {result['reason']}")
|
||||
decision = cast(str, result["permissionDecision"])
|
||||
reason = cast(str, result["reason"])
|
||||
assert decision == "deny"
|
||||
assert (
|
||||
"Conditional Logic in Test Function" in reason
|
||||
or "Loop Found in Test Function" in reason
|
||||
)
|
||||
|
||||
def test_enforcement_modes(self):
|
||||
"""Test different enforcement modes."""
|
||||
content = '''from typing import Any
|
||||
content = """from typing import Any
|
||||
def bad_function(param: Any) -> Any:
|
||||
return param'''
|
||||
return param"""
|
||||
|
||||
hook_data = {
|
||||
hook_data = cast(
|
||||
JsonObject,
|
||||
{
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production_code.py",
|
||||
"content": content
|
||||
}
|
||||
}
|
||||
"tool_input": {"file_path": "/src/production_code.py", "content": content},
|
||||
},
|
||||
)
|
||||
|
||||
# Test strict mode
|
||||
self.config.enforcement_mode = "strict"
|
||||
result = pretooluse_hook(hook_data, self.config)
|
||||
assert result["permissionDecision"] == "deny"
|
||||
print("✅ Strict mode properly blocks")
|
||||
assert cast(str, result["permissionDecision"]) == "deny"
|
||||
|
||||
# Test warn mode
|
||||
self.config.enforcement_mode = "warn"
|
||||
result = pretooluse_hook(hook_data, self.config)
|
||||
assert result["permissionDecision"] == "ask"
|
||||
print("✅ Warn mode properly asks")
|
||||
assert cast(str, result["permissionDecision"]) == "ask"
|
||||
|
||||
# Test permissive mode
|
||||
self.config.enforcement_mode = "permissive"
|
||||
result = pretooluse_hook(hook_data, self.config)
|
||||
assert result["permissionDecision"] == "allow"
|
||||
print("✅ Permissive mode properly allows with warning")
|
||||
assert cast(str, result["permissionDecision"]) == "allow"
|
||||
|
||||
def test_edit_tool_blocking(self):
|
||||
"""Test that Edit tool is also properly blocked."""
|
||||
hook_data = {
|
||||
"tool_name": "Edit",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production_code.py",
|
||||
"old_string": "def old_func():",
|
||||
"new_string": "def new_func(param: Any) -> Any:"
|
||||
}
|
||||
}
|
||||
hook_data = cast(
|
||||
JsonObject,
|
||||
{
|
||||
"tool_name": "Edit",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production_code.py",
|
||||
"old_string": "def old_func():",
|
||||
"new_string": "def new_func(param: Any) -> Any:",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result = pretooluse_hook(hook_data, self.config)
|
||||
|
||||
assert result["permissionDecision"] == "deny"
|
||||
assert "typing.Any usage" in result["reason"]
|
||||
print(f"✅ Edit tool properly blocked: {result['reason']}")
|
||||
decision = cast(str, result["permissionDecision"])
|
||||
reason = cast(str, result["reason"])
|
||||
assert decision == "deny"
|
||||
assert "typing.Any usage" in reason
|
||||
|
||||
def test_multiedit_tool_blocking(self):
|
||||
"""Test that MultiEdit tool is also properly blocked."""
|
||||
hook_data = {
|
||||
"tool_name": "MultiEdit",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production_code.py",
|
||||
"edits": [
|
||||
{
|
||||
"old_string": "def old_func():",
|
||||
"new_string": "def new_func(param: Any) -> Any:"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
hook_data = cast(
|
||||
JsonObject,
|
||||
{
|
||||
"tool_name": "MultiEdit",
|
||||
"tool_input": {
|
||||
"file_path": "/src/production_code.py",
|
||||
"edits": [
|
||||
cast(
|
||||
JsonObject,
|
||||
{
|
||||
"old_string": "def old_func():",
|
||||
"new_string": "def new_func(param: Any) -> Any:",
|
||||
},
|
||||
),
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result = pretooluse_hook(hook_data, self.config)
|
||||
|
||||
assert result["permissionDecision"] == "deny"
|
||||
assert "typing.Any usage" in result["reason"]
|
||||
print(f"✅ MultiEdit tool properly blocked: {result['reason']}")
|
||||
decision = cast(str, result["permissionDecision"])
|
||||
reason = cast(str, result["reason"])
|
||||
assert decision == "deny"
|
||||
assert "typing.Any usage" in reason
|
||||
|
||||
def test_non_python_files_allowed(self):
|
||||
"""Test that non-Python files are allowed through."""
|
||||
hook_data = {
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/src/config.json",
|
||||
"content": '{"any": "value", "type": "ignore"}'
|
||||
}
|
||||
}
|
||||
|
||||
hook_data = cast(
|
||||
JsonObject,
|
||||
{
|
||||
"tool_name": "Write",
|
||||
"tool_input": cast(
|
||||
JsonObject,
|
||||
{
|
||||
"file_path": "/src/config.json",
|
||||
"content": json.dumps({"any": "value", "type": "ignore"}),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
result = pretooluse_hook(hook_data, self.config)
|
||||
|
||||
assert result["permissionDecision"] == "allow"
|
||||
print("✅ Non-Python files properly allowed")
|
||||
assert cast(str, result["permissionDecision"]) == "allow"
|
||||
|
||||
def test_posttooluse_hook(self):
|
||||
"""Test PostToolUse hook functionality."""
|
||||
# Create a temp file with bad content
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
||||
f.write("from typing import Any\ndef bad(x: Any) -> Any: return x")
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
_ = f.write("from typing import Any\ndef bad(x: Any) -> Any: return x")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
hook_data = {
|
||||
"tool_name": "Write",
|
||||
"tool_response": {
|
||||
"file_path": temp_path
|
||||
}
|
||||
}
|
||||
hook_data = cast(
|
||||
JsonObject,
|
||||
{
|
||||
"tool_name": "Write",
|
||||
"tool_response": cast(
|
||||
JsonObject,
|
||||
{"file_path": temp_path},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
result = posttooluse_hook(hook_data, self.config)
|
||||
|
||||
# PostToolUse should detect issues in the written file
|
||||
assert "decision" in result
|
||||
print(f"✅ PostToolUse hook working: {result}")
|
||||
|
||||
finally:
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
def test_command_line_execution(self):
|
||||
"""Test that the hook works when executed via command line."""
|
||||
test_input = json.dumps({
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/src/test.py",
|
||||
"content": "from typing import Any\ndef bad(x: Any): pass"
|
||||
}
|
||||
})
|
||||
test_input = json.dumps(
|
||||
{
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
"file_path": "/src/test.py",
|
||||
"content": "from typing import Any\ndef bad(x: Any): pass",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Test from hooks directory
|
||||
result = subprocess.run(
|
||||
["python", "code_quality_guard.py"],
|
||||
script_path = HOOKS_DIR / "code_quality_guard.py"
|
||||
result = subprocess.run( # noqa: S603 - executes trusted project script
|
||||
[sys.executable, str(script_path)],
|
||||
check=False,
|
||||
input=test_input,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
cwd=HOOKS_DIR
|
||||
cwd=HOOKS_DIR,
|
||||
)
|
||||
|
||||
output = json.loads(result.stdout)
|
||||
assert output["permissionDecision"] == "deny"
|
||||
output = cast(JsonObject, json.loads(result.stdout))
|
||||
assert cast(str, output["permissionDecision"]) == "deny"
|
||||
assert result.returncode == 2 # Should exit with error code 2
|
||||
print(f"✅ Command line execution properly blocks: {output['reason']}")
|
||||
|
||||
|
||||
class TestDetectionFunctions:
|
||||
@@ -296,7 +329,6 @@ class TestDetectionFunctions:
|
||||
issues = _detect_any_usage(content)
|
||||
has_issues = len(issues) > 0
|
||||
assert has_issues == should_detect, f"Failed for: {content}"
|
||||
print(f"✅ Any detection for '{content}': {has_issues}")
|
||||
|
||||
def test_type_ignore_detection_comprehensive(self):
|
||||
"""Test comprehensive type: ignore detection."""
|
||||
@@ -313,7 +345,6 @@ class TestDetectionFunctions:
|
||||
issues = _detect_type_ignore_usage(content)
|
||||
has_issues = len(issues) > 0
|
||||
assert has_issues == should_detect, f"Failed for: {content}"
|
||||
print(f"✅ Type ignore detection for '{content}': {has_issues}")
|
||||
|
||||
def test_old_typing_patterns_comprehensive(self):
|
||||
"""Test comprehensive old typing patterns detection."""
|
||||
@@ -334,62 +365,3 @@ class TestDetectionFunctions:
|
||||
issues = _detect_old_typing_patterns(content)
|
||||
has_issues = len(issues) > 0
|
||||
assert has_issues == should_detect, f"Failed for: {content}"
|
||||
print(f"✅ Old typing pattern detection for '{content}': {has_issues}")
|
||||
|
||||
|
||||
def run_comprehensive_test():
|
||||
"""Run all tests and provide a summary."""
|
||||
print("🚀 Starting comprehensive hook testing...\n")
|
||||
|
||||
# Run the test classes
|
||||
test_integration = TestHookIntegration()
|
||||
test_integration.setup_method()
|
||||
|
||||
test_detection = TestDetectionFunctions()
|
||||
|
||||
tests = [
|
||||
# Integration tests
|
||||
(test_integration.test_any_usage_blocked, "Any usage blocking"),
|
||||
(test_integration.test_type_ignore_blocked, "Type ignore blocking"),
|
||||
(test_integration.test_old_typing_patterns_blocked, "Old typing patterns blocking"),
|
||||
(test_integration.test_good_code_allowed, "Good code allowed"),
|
||||
(test_integration.test_test_file_conditionals_blocked, "Test file conditionals blocked"),
|
||||
(test_integration.test_enforcement_modes, "Enforcement modes"),
|
||||
(test_integration.test_edit_tool_blocking, "Edit tool blocking"),
|
||||
(test_integration.test_multiedit_tool_blocking, "MultiEdit tool blocking"),
|
||||
(test_integration.test_non_python_files_allowed, "Non-Python files allowed"),
|
||||
(test_integration.test_posttooluse_hook, "PostToolUse hook"),
|
||||
(test_integration.test_command_line_execution, "Command line execution"),
|
||||
|
||||
# Detection function tests
|
||||
(test_detection.test_any_detection_comprehensive, "Any detection comprehensive"),
|
||||
(test_detection.test_type_ignore_detection_comprehensive, "Type ignore detection comprehensive"),
|
||||
(test_detection.test_old_typing_patterns_comprehensive, "Old typing patterns comprehensive"),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test_func, test_name in tests:
|
||||
try:
|
||||
print(f"\n🧪 Running: {test_name}")
|
||||
test_func()
|
||||
passed += 1
|
||||
print(f"✅ PASSED: {test_name}")
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
print(f"❌ FAILED: {test_name} - {e}")
|
||||
|
||||
print(f"\n📊 Test Results: {passed} passed, {failed} failed")
|
||||
|
||||
if failed == 0:
|
||||
print("🎉 ALL TESTS PASSED! Hooks are working correctly.")
|
||||
else:
|
||||
print(f"⚠️ {failed} tests failed. Hooks need fixes.")
|
||||
|
||||
return failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_comprehensive_test()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user