feat: enhance bash command guard with file locking and hook chaining

- Introduced file-based locking in `bash_command_guard.py` to prevent concurrent execution issues.
- Added configuration for lock timeout and polling interval in `bash_guard_constants.py`.
- Implemented a new `hook_chain.py` to unify hook execution, allowing for sequential processing of guards.
- Updated `claude-code-settings.json` to support the new hook chaining mechanism.
- Refactored subprocess lock handling to improve reliability and prevent deadlocks.

This update improves the robustness of the hook system by ensuring that bash commands are executed in a controlled manner, reducing the risk of concurrency-related errors.
This commit is contained in:
2025-10-26 02:14:12 +00:00
parent b4813e124d
commit bfb7773096
30 changed files with 24968 additions and 418 deletions

View File

@@ -11,6 +11,7 @@ import re
import subprocess
import sys
import tempfile
import time
from contextlib import contextmanager
from pathlib import Path
from shutil import which
@@ -21,6 +22,8 @@ try:
from .bash_guard_constants import (
DANGEROUS_SHELL_PATTERNS,
FORBIDDEN_PATTERNS,
LOCK_POLL_INTERVAL_SECONDS,
LOCK_TIMEOUT_SECONDS,
PYTHON_FILE_PATTERNS,
)
except ImportError:
@@ -28,6 +31,8 @@ except ImportError:
DANGEROUS_SHELL_PATTERNS = bash_guard_constants.DANGEROUS_SHELL_PATTERNS
FORBIDDEN_PATTERNS = bash_guard_constants.FORBIDDEN_PATTERNS
PYTHON_FILE_PATTERNS = bash_guard_constants.PYTHON_FILE_PATTERNS
LOCK_TIMEOUT_SECONDS = bash_guard_constants.LOCK_TIMEOUT_SECONDS
LOCK_POLL_INTERVAL_SECONDS = bash_guard_constants.LOCK_POLL_INTERVAL_SECONDS
class JsonObject(TypedDict, total=False):
@@ -51,31 +56,45 @@ def _get_lock_file() -> Path:
@contextmanager
def _subprocess_lock(timeout: float = 5.0):
def _subprocess_lock(timeout: float = LOCK_TIMEOUT_SECONDS):
"""Context manager for file-based subprocess locking.
Args:
timeout: Timeout in seconds for acquiring lock.
timeout: Maximum time in seconds to wait for the lock. Non-positive
values attempt a single non-blocking acquisition.
Yields:
True if lock was acquired, False if timeout occurred.
"""
lock_file = _get_lock_file()
deadline = (
time.monotonic() + timeout if timeout and timeout > 0 else None
)
acquired = False
# Open or create lock file
with open(lock_file, "a") as f:
try:
# Try to acquire exclusive lock (non-blocking)
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
yield True
except (IOError, OSError):
# Lock is held by another process, skip to avoid blocking
yield False
while True:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
acquired = True
break
except (IOError, OSError):
if deadline is None:
break
remaining = deadline - time.monotonic()
if remaining <= 0:
break
time.sleep(min(LOCK_POLL_INTERVAL_SECONDS, remaining))
yield acquired
finally:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except (IOError, OSError):
pass
if acquired:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except (IOError, OSError):
pass
def _contains_forbidden_pattern(text: str) -> tuple[bool, str | None]:
@@ -125,7 +144,7 @@ def _is_dangerous_shell_command(command: str) -> tuple[bool, str | None]:
r"\b(sed|awk|perl|ed|echo|printf|cat|tee|find|xargs|python|vim|nano|emacs)\b",
pattern,
)
tool_name = tool_match.group(1) if tool_match else "shell utility"
tool_name = tool_match[1] if tool_match else "shell utility"
return True, f"Use of {tool_name} to modify Python files"
return False, None
@@ -362,7 +381,7 @@ def _get_staged_python_files() -> list[str]:
try:
# Acquire file-based lock to prevent subprocess concurrency issues
with _subprocess_lock(timeout=5.0) as acquired:
with _subprocess_lock(timeout=LOCK_TIMEOUT_SECONDS) as acquired:
if not acquired:
return []
@@ -430,10 +449,7 @@ def stop_hook(_hook_data: dict[str, object]) -> JsonObject:
if not changed_files:
return _create_hook_response("Stop", decision="approve")
# Scan all changed Python files for violations
violations = _check_files_for_violations(changed_files)
if violations:
if violations := _check_files_for_violations(changed_files):
violation_text = "\n".join(f" {v}" for v in violations)
message = (
f"🚫 Final Validation Failed\n\n"

View File

@@ -4,6 +4,10 @@ This module contains patterns and constants used to detect forbidden code patter
and dangerous shell commands that could compromise type safety.
"""
# File locking configuration shared across hook modules.
LOCK_TIMEOUT_SECONDS = 45.0
LOCK_POLL_INTERVAL_SECONDS = 0.1
# Forbidden patterns that should never appear in Python code
FORBIDDEN_PATTERNS = [
r"\bfrom\s+typing\s+import\s+.*\bAny\b", # from typing import Any

View File

@@ -2,22 +2,22 @@
"hooks": {
"PreToolUse": [
{
"matcher": "Write|Edit|MultiEdit",
"matcher": "Write|Edit|MultiEdit|Bash",
"hooks": [
{
"type": "command",
"command": "cd $CLAUDE_PROJECT_DIR/hooks && (python code_quality_guard.py 2>/dev/null || python3 code_quality_guard.py)"
"command": "cd $CLAUDE_PROJECT_DIR/hooks && (python hook_chain.py --event pre 2>/dev/null || python3 hook_chain.py --event pre)"
}
]
}
],
"PostToolUse": [
{
"matcher": "Write|Edit|MultiEdit",
"matcher": "Write|Edit|MultiEdit|Bash",
"hooks": [
{
"type": "command",
"command": "cd $CLAUDE_PROJECT_DIR/hooks && (python code_quality_guard.py 2>/dev/null || python3 code_quality_guard.py)"
"command": "cd $CLAUDE_PROJECT_DIR/hooks && (python hook_chain.py --event post 2>/dev/null || python3 hook_chain.py --event post)"
}
]
}

View File

@@ -90,10 +90,8 @@ def _subprocess_lock(timeout: float = 10.0):
# Lock is held by another process, skip to avoid blocking
yield False
finally:
try:
with suppress(IOError, OSError):
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except (IOError, OSError):
pass
SUSPICIOUS_SUFFIXES: tuple[str, ...] = (
@@ -296,11 +294,8 @@ def generate_test_quality_guidance(
_config: "QualityConfig",
) -> str:
"""Return enriched guidance for test quality rule violations."""
function_name = "test_function"
match = re.search(r"def\s+(\w+)\s*\(", content)
if match:
function_name = match.group(1)
function_name = match[1] if match else "test_function"
# Extract a small snippet of the violating code
code_snippet = ""
if match:
@@ -491,12 +486,10 @@ def get_claude_quality_command(repo_root: Path | None = None) -> list[str]:
["claude-quality.exe", "claude-quality"] if is_windows else ["claude-quality"]
)
candidates: list[tuple[Path, list[str]]] = []
for name in python_names:
candidates.append(_module_candidate(scripts_dir / name))
for name in cli_names:
candidates.append(_cli_candidate(scripts_dir / name))
candidates: list[tuple[Path, list[str]]] = [
_module_candidate(scripts_dir / name) for name in python_names
]
candidates.extend(_cli_candidate(scripts_dir / name) for name in cli_names)
for candidate_path, command in candidates:
if candidate_path.exists():
return command
@@ -592,7 +585,8 @@ def _format_pyrefly_errors(output: str) -> str:
lines = output.strip().split("\n")
# Count ERROR lines to provide summary
error_count = sum(1 for line in lines if line.strip().startswith("ERROR"))
error_count = sum(bool(line.strip().startswith("ERROR"))
for line in lines)
if error_count == 0:
return output.strip()
@@ -616,9 +610,8 @@ def _format_sourcery_errors(output: str) -> str:
# Try to extract the number
import re
match = re.search(r"(\d+)\s+issue", line)
if match:
issue_count = int(match.group(1))
if match := re.search(r"(\d+)\s+issue", line):
issue_count = int(match[1])
break
# Format the output, removing redundant summary lines
@@ -722,8 +715,7 @@ def _run_type_checker(
project_root = venv_bin.parent.parent # Go from .venv/bin to project root
src_dir = project_root / "src"
if src_dir.exists() and src_dir.is_dir():
existing_pythonpath = env.get("PYTHONPATH", "")
if existing_pythonpath:
if existing_pythonpath := env.get("PYTHONPATH", ""):
env["PYTHONPATH"] = f"{src_dir}:{existing_pythonpath}"
else:
env["PYTHONPATH"] = str(src_dir)
@@ -1534,7 +1526,7 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
break
# Check function and class names in content
try:
with suppress(SyntaxError):
# Dedent the content to handle code fragments with leading indentation
tree = ast.parse(textwrap.dedent(content))
@@ -1603,10 +1595,6 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
continue
break
except SyntaxError:
# If we can't parse the AST, skip function/class checks
pass
return issues
@@ -1646,8 +1634,7 @@ def _handle_quality_issues(
if "\n" in issue:
lines = issue.split("\n")
formatted_issues.append(f"{lines[0]}")
for line in lines[1:]:
formatted_issues.append(f" {line}")
formatted_issues.extend(f" {line}" for line in lines[1:])
else:
formatted_issues.append(f"{issue}")
@@ -2021,13 +2008,12 @@ def run_test_quality_checks(
file_path,
config,
)
external_context = get_external_context(
if external_context := get_external_context(
rule,
content,
file_path,
config,
)
if external_context:
):
base_guidance += f"\n\n{external_context}"
issues.append(base_guidance)
break # Only add one guidance message
@@ -2039,20 +2025,19 @@ def run_test_quality_checks(
file_path,
config,
)
external_context = get_external_context(
if external_context := get_external_context(
"unknown",
content,
file_path,
config,
)
if external_context:
):
base_guidance += f"\n\n{external_context}"
issues.append(base_guidance)
elif result.returncode == 2:
# Error occurred
logging.debug("Sourcery error: %s", result.stderr)
# Don't block on Sourcery errors - just log them
# Exit code 0 means no issues - do nothing
# Exit code 0 means no issues - do nothing
except (subprocess.TimeoutExpired, OSError, json.JSONDecodeError) as e:
# If Sourcery fails, don't block the operation

159
hooks/hook_chain.py Normal file
View File

@@ -0,0 +1,159 @@
"""Unified hook runner that chains existing guards without parallel execution.
This entry point lets Claude Code invoke a single command per hook event while
still reusing the more specialized guards. It reacts to the hook payload to
decide which guard(s) to run and propagates their output/exit codes so Claude
continues to see the same responses. Post-tool Bash logging is handled here so
the old jq pipeline is no longer required.
"""
from __future__ import annotations
import argparse
import json
import subprocess
import sys
from pathlib import Path
from typing import Any
HookPayload = dict[str, Any]
def _load_payload() -> tuple[str, HookPayload | None]:
"""Read stdin payload once and return both raw text and parsed JSON."""
raw = sys.stdin.read()
if not raw:
return raw, None
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return raw, None
return (raw, parsed) if isinstance(parsed, dict) else (raw, None)
def _default_response(event: str) -> int:
"""Emit the minimal pass-through JSON for hooks that we skip."""
hook_event = "PreToolUse" if event == "pre" else "PostToolUse"
response: HookPayload = {"hookSpecificOutput": {"hookEventName": hook_event}}
if event == "pre":
response["hookSpecificOutput"]["permissionDecision"] = "allow"
sys.stdout.write(json.dumps(response))
sys.stdout.write("\n")
sys.stdout.flush()
return 0
def _run_guard(script_name: str, payload: str) -> int:
"""Execute a sibling guard script sequentially and relay its output."""
script_path = Path(__file__).with_name(script_name)
if not script_path.exists():
raise FileNotFoundError(f"Missing guard script: {script_path}")
proc = subprocess.run( # noqa: S603
[sys.executable, str(script_path)],
input=payload,
text=True,
capture_output=True,
check=False,
)
if proc.stdout:
sys.stdout.write(proc.stdout)
if proc.stderr:
sys.stderr.write(proc.stderr)
sys.stdout.flush()
sys.stderr.flush()
return proc.returncode
def _log_bash_command(payload: HookPayload) -> None:
"""Append successful Bash commands to Claude's standard log file."""
tool_input = payload.get("tool_input")
if not isinstance(tool_input, dict):
return
command = tool_input.get("command")
if not isinstance(command, str) or not command.strip():
return
description = tool_input.get("description")
if not isinstance(description, str) or not description.strip():
description = "No description"
log_path = Path.home() / ".claude" / "bash-command-log.txt"
try:
log_path.parent.mkdir(parents=True, exist_ok=True)
with log_path.open("a", encoding="utf-8") as handle:
handle.write(f"{command} - {description}\n")
except OSError:
# Logging is best-effort; ignore filesystem errors.
pass
def _pre_hook(payload: HookPayload | None, raw: str) -> int:
"""Handle PreToolUse events with sequential guard execution."""
if payload is None:
return _default_response("pre")
tool_name = str(payload.get("tool_name", ""))
if tool_name in {"Write", "Edit", "MultiEdit"}:
return _run_guard("code_quality_guard.py", raw)
if tool_name == "Bash":
return _run_guard("bash_command_guard.py", raw)
return _default_response("pre")
def _post_hook(payload: HookPayload | None, raw: str) -> int:
"""Handle PostToolUse events with sequential guard execution."""
if payload is None:
return _default_response("post")
tool_name = str(payload.get("tool_name", ""))
if tool_name in {"Write", "Edit", "MultiEdit"}:
return _run_guard("code_quality_guard.py", raw)
if tool_name == "Bash":
exit_code = _run_guard("bash_command_guard.py", raw)
if exit_code == 0:
_log_bash_command(payload)
return exit_code
return _default_response("post")
def main() -> None:
parser = argparse.ArgumentParser(description="Chain Claude Code hook guards")
parser.add_argument(
"--event",
choices={"pre", "post"},
required=True,
help="Hook event type to handle.",
)
args = parser.parse_args()
raw_payload, parsed_payload = _load_payload()
if args.event == "pre":
exit_code = _pre_hook(parsed_payload, raw_payload)
else:
exit_code = _post_hook(parsed_payload, raw_payload)
sys.exit(exit_code)
if __name__ == "__main__": # pragma: no cover - CLI entry
main()

View File

@@ -105,13 +105,11 @@ class EnhancedMessageFormatter:
loc_type: str = loc.get("type", "code")
location_summary.append(f"{name} ({loc_type}, lines {lines})")
# Determine refactoring strategy
# Convert locations to proper format for refactoring strategy
dict_locations: list[dict[str, str]] = []
for loc_item in locations:
if isinstance(loc_item, dict):
dict_locations.append(cast(dict[str, str], loc_item))
dict_locations: list[dict[str, str]] = [
cast(dict[str, str], loc_item)
for loc_item in locations
if isinstance(loc_item, dict)
]
strategy = EnhancedMessageFormatter._suggest_refactoring_strategy(
duplicate_type,
dict_locations,
@@ -124,35 +122,29 @@ class EnhancedMessageFormatter:
f"🔍 Duplicate Code Detected ({similarity:.0%} similar)",
"",
"📍 Locations:",
]
parts.extend(location_summary)
parts.extend(
[
*location_summary,
*[
"",
f"📊 Pattern Type: {duplicate_type}",
]
)
],
]
if include_refactoring and strategy:
parts.append("")
parts.append("💡 Refactoring Suggestion:")
parts.append(f" Strategy: {strategy.strategy_type}")
parts.append(f" {strategy.description}")
parts.append("")
parts.append("✅ Benefits:")
for benefit in strategy.benefits:
parts.append(f"{benefit}")
parts.extend(
(
"",
"💡 Refactoring Suggestion:",
f" Strategy: {strategy.strategy_type}",
f" {strategy.description}",
"",
"✅ Benefits:",
)
)
parts.extend(f"{benefit}" for benefit in strategy.benefits)
if strategy.example_before and strategy.example_after:
parts.append("")
parts.append("📝 Example:")
parts.append(" Before:")
for line in strategy.example_before.splitlines():
parts.append(f" {line}")
parts.extend(("", "📝 Example:", " Before:"))
parts.extend(f" {line}" for line in strategy.example_before.splitlines())
parts.append(" After:")
for line in strategy.example_after.splitlines():
parts.append(f" {line}")
parts.extend(f" {line}" for line in strategy.example_after.splitlines())
return "\n".join(parts)
@staticmethod
@@ -460,8 +452,7 @@ class EnhancedMessageFormatter:
int(line_num),
context_lines=2,
)
parts.append(context.code_snippet)
parts.append("")
parts.extend((context.code_snippet, ""))
except (ValueError, IndexError):
pass
@@ -643,17 +634,12 @@ class EnhancedMessageFormatter:
]
fixes_list = guidance["fixes"]
if isinstance(fixes_list, list):
for fix in fixes_list:
parts.append(f"{fix}")
parts.extend(f"{fix}" for fix in fixes_list)
if include_examples and guidance.get("example_before"):
parts.append("")
parts.append("💡 Example:")
parts.append(" ❌ Before:")
parts.extend(("", "💡 Example:", " ❌ Before:"))
example_before_str = guidance.get("example_before", "")
if isinstance(example_before_str, str):
for line in example_before_str.splitlines():
parts.append(f" {line}")
parts.extend(f" {line}" for line in example_before_str.splitlines())
parts.append(" ✅ After:")
example_after_str = guidance.get("example_after", "")
if isinstance(example_after_str, str):
@@ -661,8 +647,7 @@ class EnhancedMessageFormatter:
parts.append(f" {line}")
if code_snippet:
parts.append("")
parts.append("📍 Your Code:")
parts.extend(("", "📍 Your Code:"))
for line in code_snippet.splitlines()[:10]:
parts.append(f" {line}")

View File

@@ -58,9 +58,11 @@ class TypeInferenceHelper:
assignments: list[ast.expr] = []
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == variable_name:
assignments.append(node.value)
assignments.extend(
node.value
for target in node.targets
if isinstance(target, ast.Name) and target.id == variable_name
)
elif (
isinstance(node, ast.AnnAssign)
and isinstance(node.target, ast.Name)
@@ -207,12 +209,10 @@ class TypeInferenceHelper:
# Try to infer from usage within function
arg_name = arg.arg
suggested_type = TypeInferenceHelper._infer_param_from_usage(
if suggested_type := TypeInferenceHelper._infer_param_from_usage(
arg_name,
function_node,
)
if suggested_type:
):
suggestions.append(
TypeSuggestion(
element_name=arg_name,
@@ -289,8 +289,6 @@ class TypeInferenceHelper:
Returns list of (old_import, new_import, reason) tuples.
"""
suggestions = []
# Patterns to detect and replace
patterns = {
r"from typing import.*\bUnion\b": (
@@ -325,11 +323,11 @@ class TypeInferenceHelper:
),
}
for pattern, (old, new, example) in patterns.items():
if re.search(pattern, source_code):
suggestions.append((old, new, example))
return suggestions
return [
(old, new, example)
for pattern, (old, new, example) in patterns.items()
if re.search(pattern, source_code)
]
@staticmethod
def find_any_usage_with_context(source_code: str) -> list[dict[str, str | int]]:
@@ -367,20 +365,18 @@ class TypeInferenceHelper:
# Find function parameters with Any
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
for arg in node.args.args:
if arg.annotation and TypeInferenceHelper._contains_any(
arg.annotation
):
results.append(
{
"line": getattr(node, "lineno", 0),
"element": arg.arg,
"current": "Any",
"suggested": "Infer from usage",
"context": f"parameter in {node.name}",
}
)
results.extend(
{
"line": getattr(node, "lineno", 0),
"element": arg.arg,
"current": "Any",
"suggested": "Infer from usage",
"context": f"parameter in {node.name}",
}
for arg in node.args.args
if arg.annotation
and TypeInferenceHelper._contains_any(arg.annotation)
)
# Check return type
if node.returns and TypeInferenceHelper._contains_any(node.returns):
suggestion = TypeInferenceHelper.suggest_function_return_type(

File diff suppressed because it is too large Load Diff

View File

@@ -225,9 +225,6 @@ class ModernizationAnalyzer(ast.NodeVisitor):
def visit_BinOp(self, node: ast.BinOp) -> None:
"""Check for Union usage that could be modernized."""
if isinstance(node.op, ast.BitOr):
# This is already modern syntax (X | Y)
pass
self.generic_visit(node)
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
@@ -273,26 +270,22 @@ class ModernizationAnalyzer(ast.NodeVisitor):
"""Add issue for typing import that can be replaced with built-ins."""
modern_replacement = self.REPLACEABLE_TO_MODERN[typing_name]
if typing_name in ["List", "Dict", "Tuple", "Set", "FrozenSet"]:
if typing_name in {"List", "Dict", "Tuple", "Set", "FrozenSet"}:
description = (
f"Use built-in '{modern_replacement}' instead of "
f"'typing.{typing_name}' (Python 3.9+)"
)
severity = "warning"
elif typing_name == "Union":
description = (
"Use '|' union operator instead of 'typing.Union' (Python 3.10+)"
)
severity = "warning"
elif typing_name == "Optional":
description = "Use '| None' instead of 'typing.Optional' (Python 3.10+)"
severity = "warning"
else:
description = (
f"Use '{modern_replacement}' instead of 'typing.{typing_name}'"
)
severity = "warning"
severity = "warning"
self.issues.append(
ModernizationIssue(
file_path=self.file_path,
@@ -359,7 +352,7 @@ class ModernizationAnalyzer(ast.NodeVisitor):
"""Add issue for typing usage that can be modernized."""
if typing_name in self.REPLACEABLE_TYPING_IMPORTS:
modern_replacement = self.REPLACEABLE_TO_MODERN[typing_name]
if typing_name in ["List", "Dict", "Tuple", "Set", "FrozenSet"]:
if typing_name in {"List", "Dict", "Tuple", "Set", "FrozenSet"}:
old_pattern = f"{typing_name}[...]"
new_pattern = f"{modern_replacement.lower()}[...]"
description = (
@@ -696,7 +689,7 @@ class PydanticAnalyzer:
# Check if line contains any valid v2 methods
return any(f".{v2_method}(" in line for v2_method in self.V2_METHODS)
def _get_suggested_fix(self, pattern: str, line: str) -> str: # noqa: ARG002
def _get_suggested_fix(self, pattern: str, line: str) -> str: # noqa: ARG002
"""Get suggested fix for a Pydantic pattern."""
fixes = {
r"\.dict\(\)": line.replace(".dict()", ".model_dump()"),
@@ -706,11 +699,14 @@ class PydanticAnalyzer:
r"@root_validator": line.replace("@root_validator", "@model_validator"),
}
for fix_pattern, fix_line in fixes.items():
if re.search(fix_pattern, line):
return fix_line.strip()
return "See Pydantic v2 migration guide"
return next(
(
fix_line.strip()
for fix_pattern, fix_line in fixes.items()
if re.search(fix_pattern, line)
),
"See Pydantic v2 migration guide",
)
class ModernizationEngine:
@@ -753,8 +749,7 @@ class ModernizationEngine:
if file_path.suffix.lower() == ".py":
issues = self.analyze_file(file_path)
# Apply exception filtering
filtered_issues = self.exception_filter.filter_issues(
if filtered_issues := self.exception_filter.filter_issues(
"modernization",
issues,
get_file_path_fn=lambda issue: issue.file_path,
@@ -764,9 +759,7 @@ class ModernizationEngine:
issue.file_path,
issue.line_number,
),
)
if filtered_issues: # Only include files with remaining issues
):
results[file_path] = filtered_issues
return results
@@ -800,16 +793,16 @@ class ModernizationEngine:
by_type.setdefault(issue.issue_type, []).append(issue)
by_severity[issue.severity] += 1
# Top files with most issues
file_counts = {}
for file_path, issues in results.items():
if issues:
file_counts[file_path] = len(issues)
file_counts = {
file_path: len(issues)
for file_path, issues in results.items()
if issues
}
top_files = sorted(file_counts.items(), key=lambda x: x[1], reverse=True)[:10]
# Auto-fixable issues
auto_fixable = sum(1 for issue in all_issues if issue.can_auto_fix)
auto_fixable = sum(bool(issue.can_auto_fix)
for issue in all_issues)
return {
"total_files_analyzed": len(results),

View File

@@ -144,15 +144,15 @@ def duplicates(
results["duplicates"].append({"group_id": i, "analysis": detailed_analysis})
# Output results
if output_format == "json":
if output_format == "console":
_print_console_duplicates(results, verbose)
elif output_format == "csv":
_print_csv_duplicates(results, output)
elif output_format == "json":
if output:
json.dump(results, output, indent=2, default=str)
else:
click.echo(json.dumps(results, indent=2, default=str))
elif output_format == "console":
_print_console_duplicates(results, verbose)
elif output_format == "csv":
_print_csv_duplicates(results, output)
@cli.command()
@@ -211,13 +211,13 @@ def complexity(
overview = analyzer.get_project_complexity_overview(all_files)
# Output results
if output_format == "json":
if output_format == "console":
_print_console_complexity(overview, verbose)
elif output_format == "json":
if output:
json.dump(overview, output, indent=2, default=str)
else:
click.echo(json.dumps(overview, indent=2, default=str))
elif output_format == "console":
_print_console_complexity(overview, verbose)
@cli.command()
@@ -290,10 +290,11 @@ def modernization(
if pydantic_only:
filtered_results = {}
for file_path, issues in results.items():
pydantic_issues = [
issue for issue in issues if issue.issue_type == "pydantic_v1_pattern"
]
if pydantic_issues:
if pydantic_issues := [
issue
for issue in issues
if issue.issue_type == "pydantic_v1_pattern"
]:
filtered_results[file_path] = pydantic_issues
results = filtered_results
@@ -310,13 +311,13 @@ def modernization(
},
}
if output_format == "json":
if output_format == "console":
_print_console_modernization(final_results, verbose, include_type_hints)
elif output_format == "json":
if output:
json.dump(final_results, output, indent=2, default=str)
else:
click.echo(json.dumps(final_results, indent=2, default=str))
elif output_format == "console":
_print_console_modernization(final_results, verbose, include_type_hints)
@cli.command()
@@ -418,12 +419,11 @@ def full_analysis(
# Parse the AST and analyze
tree = ast.parse(content)
ast_analyzer.visit(tree)
smells = ast_analyzer.detect_code_smells()
if smells:
if smells := ast_analyzer.detect_code_smells():
all_smells.extend(
[{"file": str(file_path), "smell": smell} for smell in smells],
)
except (OSError, PermissionError, UnicodeDecodeError):
except (OSError, UnicodeDecodeError):
continue
results["code_smells"] = {"total_smells": len(all_smells), "details": all_smells}
@@ -432,13 +432,13 @@ def full_analysis(
results["quality_score"] = _calculate_overall_quality_score(results)
# Output results
if output_format == "json":
if output_format == "console":
_print_console_full_analysis(results, verbose)
elif output_format == "json":
if output:
json.dump(results, output, indent=2, default=str)
else:
click.echo(json.dumps(results, indent=2, default=str))
elif output_format == "console":
_print_console_full_analysis(results, verbose)
def _print_console_duplicates(results: dict[str, Any], verbose: bool) -> None:
@@ -478,7 +478,7 @@ def _print_console_duplicates(results: dict[str, Any], verbose: bool) -> None:
def _print_csv_duplicates(results: dict[str, Any], output: IO[str] | None) -> None:
"""Print duplicate results in CSV format."""
csv_output = output if output else sys.stdout
csv_output = output or sys.stdout
writer = csv.writer(csv_output)
writer.writerow(
[

View File

@@ -43,12 +43,10 @@ class ComplexityAnalyzer:
"""Analyze multiple files in parallel."""
raw_results = self.radon_analyzer.batch_analyze_files(file_paths, max_workers)
# Filter metrics based on configuration
filtered_results = {}
for path, metrics in raw_results.items():
filtered_results[path] = self._filter_metrics_by_config(metrics)
return filtered_results
return {
path: self._filter_metrics_by_config(metrics)
for path, metrics in raw_results.items()
}
def get_complexity_summary(self, metrics: ComplexityMetrics) -> dict[str, Any]:
"""Get a human-readable summary of complexity metrics."""

View File

@@ -36,7 +36,11 @@ class ComplexityCalculator:
# AST-based metrics
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
if (
isinstance(node, ast.FunctionDef)
or not isinstance(node, ast.ClassDef)
and isinstance(node, ast.AsyncFunctionDef)
):
metrics.function_count += 1
# Count parameters
metrics.parameters_count += len(node.args.args)
@@ -46,13 +50,6 @@ class ComplexityCalculator:
)
elif isinstance(node, ast.ClassDef):
metrics.class_count += 1
elif isinstance(node, ast.AsyncFunctionDef):
metrics.function_count += 1
metrics.parameters_count += len(node.args.args)
metrics.returns_count += len(
[n for n in ast.walk(node) if isinstance(n, ast.Return)],
)
# Calculate cyclomatic complexity
metrics.cyclomatic_complexity = self._calculate_cyclomatic_complexity(tree)
@@ -308,10 +305,8 @@ class ComplexityCalculator:
def _count_logical_lines(self, tree: ast.AST) -> int:
"""Count logical lines of code (AST nodes that represent statements)."""
count = 0
for node in ast.walk(tree):
if isinstance(
return sum(
isinstance(
node,
ast.Assign
| ast.AugAssign
@@ -327,22 +322,21 @@ class ComplexityCalculator:
| ast.Global
| ast.Nonlocal
| ast.Assert,
):
count += 1
return count
)
for node in ast.walk(tree)
)
def _count_variables(self, tree: ast.AST) -> int:
"""Count unique variable names."""
variables = set()
for node in ast.walk(tree):
if isinstance(node, ast.Name) and isinstance(
variables = {
node.id
for node in ast.walk(tree)
if isinstance(node, ast.Name)
and isinstance(
node.ctx,
(ast.Store, ast.Del),
):
variables.add(node.id)
)
}
return len(variables)
def _count_methods(self, tree: ast.AST) -> int:

View File

@@ -115,9 +115,7 @@ class ComplexityMetrics:
return "Moderate"
if score < 60:
return "High"
if score < 80:
return "Very High"
return "Extreme"
return "Very High" if score < 80 else "Extreme"
def get_priority_score(self) -> float:
"""Get priority score for refactoring (0-1, higher means higher priority)."""

View File

@@ -39,11 +39,11 @@ class RadonComplexityAnalyzer:
with open(file_path, encoding="utf-8") as f:
code = f.read()
return self.analyze_code(code, str(file_path))
except (OSError, PermissionError, UnicodeDecodeError):
except (OSError, UnicodeDecodeError):
# Return empty metrics for unreadable files
return ComplexityMetrics()
def _analyze_with_radon(self, code: str, filename: str) -> ComplexityMetrics: # noqa: ARG002
def _analyze_with_radon(self, code: str, filename: str) -> ComplexityMetrics: # noqa: ARG002
"""Analyze code using Radon library."""
metrics = ComplexityMetrics()
@@ -61,9 +61,7 @@ class RadonComplexityAnalyzer:
import radon.complexity
# Cyclomatic complexity
cc_results = radon.complexity.cc_visit(code)
if cc_results:
if cc_results := radon.complexity.cc_visit(code):
# Calculate average complexity from all functions/methods
total_complexity = sum(
getattr(block, "complexity", 0) for block in cc_results
@@ -149,14 +147,15 @@ class RadonComplexityAnalyzer:
[n for n in ast.walk(node) if isinstance(n, ast.Return)],
)
# Count variables
variables = set()
for node in ast.walk(tree):
if isinstance(node, ast.Name) and isinstance(
variables = {
node.id
for node in ast.walk(tree)
if isinstance(node, ast.Name)
and isinstance(
node.ctx,
ast.Store | ast.Del,
):
variables.add(node.id)
)
}
metrics.variables_count = len(variables)
except SyntaxError:
@@ -266,24 +265,10 @@ class RadonComplexityAnalyzer:
return "B" # Moderate
if complexity_score <= 20:
return "C" # High
if complexity_score <= 30:
return "D" # Very High
return "F" # Extreme
return "D" if complexity_score <= 30 else "F"
import radon.complexity
if RADON_AVAILABLE:
import radon.complexity
return str(radon.complexity.cc_rank(complexity_score))
# Fallback if radon not available
if complexity_score <= 5:
return "A"
if complexity_score <= 10:
return "B"
if complexity_score <= 20:
return "C"
if complexity_score <= 30:
return "D"
return "F"
return str(radon.complexity.cc_rank(complexity_score))
def batch_analyze_files(
self,
@@ -310,7 +295,7 @@ class RadonComplexityAnalyzer:
path = future_to_path[future]
try:
results[path] = future.result()
except (OSError, PermissionError, UnicodeDecodeError):
except (OSError, UnicodeDecodeError):
# Create empty metrics for failed files
results[path] = ComplexityMetrics()

View File

@@ -32,7 +32,7 @@ class ASTAnalyzer(ast.NodeVisitor):
return []
# Reset analyzer state
self.file_path = str(file_path)
self.file_path = file_path
self.content = content
self.content_lines = content.splitlines()
self.functions = []
@@ -49,13 +49,11 @@ class ASTAnalyzer(ast.NodeVisitor):
else:
self.visit(tree)
# Filter blocks by minimum size
filtered_blocks = []
for block in self.code_blocks:
if (block.end_line - block.start_line + 1) >= min_lines:
filtered_blocks.append(block)
return filtered_blocks
return [
block
for block in self.code_blocks
if (block.end_line - block.start_line + 1) >= min_lines
]
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
"""Visit function definitions with complexity analysis."""
@@ -264,20 +262,17 @@ class ASTAnalyzer(ast.NodeVisitor):
"""Detect common code smells."""
smells = []
# Long methods
long_methods = [f for f in self.functions if f.lines_count > 30]
if long_methods:
if long_methods := [f for f in self.functions if f.lines_count > 30]:
smells.append(
f"Long methods detected: {len(long_methods)} methods > 30 lines",
)
# Complex methods
complex_methods = [
if complex_methods := [
f
for f in self.functions
if f.complexity_metrics and f.complexity_metrics.cyclomatic_complexity > 10
]
if complex_methods:
if f.complexity_metrics
and f.complexity_metrics.cyclomatic_complexity > 10
]:
smells.append(
f"Complex methods detected: {len(complex_methods)} methods "
"with complexity > 10",
@@ -287,12 +282,12 @@ class ASTAnalyzer(ast.NodeVisitor):
for func in self.functions:
try:
tree = ast.parse(func.content)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and len(node.args.args) > 5:
smells.append(
f"Method with many parameters: {func.function_name} "
f"({len(node.args.args)} parameters)",
)
smells.extend(
f"Method with many parameters: {func.function_name} ({len(node.args.args)} parameters)"
for node in ast.walk(tree)
if isinstance(node, ast.FunctionDef)
and len(node.args.args) > 5
)
except Exception: # noqa: BLE001
logging.debug("Failed to analyze code smell for %s", self.file_path)

View File

@@ -73,22 +73,24 @@ class ExceptionFilter:
if self._is_globally_excluded(file_path):
return True, "File/directory globally excluded"
# Check exception rules
for rule in self.active_rules:
if self._rule_matches(
rule,
analysis_type,
issue_type,
file_path,
line_number,
line_content,
):
return (
return next(
(
(
True,
rule.reason or f"Matched exception rule: {rule.analysis_type}",
)
return False, None
for rule in self.active_rules
if self._rule_matches(
rule,
analysis_type,
issue_type,
file_path,
line_number,
line_content,
)
),
(False, None),
)
def _is_globally_excluded(self, file_path: str) -> bool:
"""Check if file is globally excluded."""
@@ -135,24 +137,22 @@ class ExceptionFilter:
# Check file patterns
if rule.file_patterns:
file_matches = False
for pattern in rule.file_patterns:
if fnmatch.fnmatch(file_path, pattern) or fnmatch.fnmatch(
file_matches = any(
fnmatch.fnmatch(file_path, pattern)
or fnmatch.fnmatch(
str(Path(file_path).name),
pattern,
):
file_matches = True
break
)
for pattern in rule.file_patterns
)
if not file_matches:
return False
# Check line patterns
if rule.line_patterns and line_content:
line_matches = False
for pattern in rule.line_patterns:
if re.search(pattern, line_content):
line_matches = True
break
line_matches = any(
re.search(pattern, line_content) for pattern in rule.line_patterns
)
if not line_matches:
return False

View File

@@ -259,7 +259,7 @@ class DuplicateDetectionEngine:
complexity = self.complexity_analyzer.analyze_code(block.content)
complexities.append(complexity.get_overall_score())
max_complexity = max(complexities) if complexities else 0.0
max_complexity = max(complexities, default=0.0)
match = DuplicateMatch(
blocks=group,
@@ -365,21 +365,26 @@ class DuplicateDetectionEngine:
has_class = any(isinstance(node, ast.ClassDef) for node in ast.walk(tree))
if has_function:
suggestions.append(
"Extract common function into a shared utility module",
)
suggestions.append(
"Consider creating a base function with configurable parameters",
suggestions.extend(
(
"Extract common function into a shared utility module",
"Consider creating a base function with configurable parameters",
)
)
elif has_class:
suggestions.append("Extract common class into a base class or mixin")
suggestions.append("Consider using inheritance or composition patterns")
else:
suggestions.append("Extract duplicate code into a reusable function")
suggestions.append(
"Consider creating a utility module for shared logic",
suggestions.extend(
(
"Extract common class into a base class or mixin",
"Consider using inheritance or composition patterns",
)
)
else:
suggestions.extend(
(
"Extract duplicate code into a reusable function",
"Consider creating a utility module for shared logic",
)
)
# Complexity-based suggestions
if duplicate_match.complexity_score > 60:
suggestions.append(
@@ -413,9 +418,7 @@ class DuplicateDetectionEngine:
return "Low (1-2 hours)"
if total_lines < 100:
return "Medium (0.5-1 day)"
if total_lines < 500:
return "High (1-3 days)"
return "Very High (1+ weeks)"
return "High (1-3 days)" if total_lines < 500 else "Very High (1+ weeks)"
def _assess_refactoring_risk(self, duplicate_match: DuplicateMatch) -> str:
"""Assess risk level of refactoring."""
@@ -437,9 +440,7 @@ class DuplicateDetectionEngine:
if not risk_factors:
return "Low"
if len(risk_factors) <= 2:
return "Medium"
return "High"
return "Medium" if len(risk_factors) <= 2 else "High"
def _get_content_preview(self, content: str, max_lines: int = 5) -> str:
"""Get a preview of code content."""

View File

@@ -161,19 +161,18 @@ class DuplicateMatcher:
if len(match.blocks) < 2:
return {"confidence": 0.0, "factors": []}
confidence_factors = []
total_confidence = 0.0
# Similarity-based confidence
similarity_confidence = match.similarity_score
confidence_factors.append(
confidence_factors = [
{
"factor": "similarity_score",
"value": match.similarity_score,
"weight": 0.4,
"contribution": similarity_confidence * 0.4,
},
)
}
]
total_confidence += similarity_confidence * 0.4
# Length-based confidence (longer matches are more reliable)
@@ -293,9 +292,7 @@ class DuplicateMatcher:
f"Merged cluster with {len(unique_blocks)} blocks "
f"(avg similarity: {avg_score:.3f})"
),
complexity_score=max(complexity_scores)
if complexity_scores
else 0.0,
complexity_score=max(complexity_scores, default=0.0),
priority_score=avg_score,
)
merged_matches.append(merged_match)
@@ -308,6 +305,4 @@ class DuplicateMatcher:
return "High"
if confidence >= 0.6:
return "Medium"
if confidence >= 0.4:
return "Low"
return "Very Low"
return "Low" if confidence >= 0.4 else "Very Low"

View File

@@ -182,7 +182,7 @@ class LSHDuplicateDetector:
self.code_blocks: dict[str, CodeBlock] = {}
if LSH_AVAILABLE:
self.lsh_index = MinHashLSH(threshold=float(threshold), num_perm=int(num_perm))
self.lsh_index = MinHashLSH(threshold=threshold, num_perm=num_perm)
def add_code_block(self, block: CodeBlock) -> None:
"""Add a code block to the LSH index."""
@@ -220,8 +220,7 @@ class LSHDuplicateDetector:
if candidate_id == block_id:
continue
candidate_block = self.code_blocks.get(candidate_id)
if candidate_block:
if candidate_block := self.code_blocks.get(candidate_id):
# Calculate exact similarity
similarity = query_minhash.jaccard(self.minhashes[candidate_id])
if similarity >= self.threshold:
@@ -243,13 +242,9 @@ class LSHDuplicateDetector:
if block_id in processed:
continue
similar_blocks = self.find_similar_blocks(block)
if similar_blocks:
if similar_blocks := self.find_similar_blocks(block):
# Create group with original block and similar blocks
group = [block]
group.extend([similar_block for similar_block, _ in similar_blocks])
group = [block, *[similar_block for similar_block, _ in similar_blocks]]
# Mark all blocks in group as processed
processed.add(block_id)
for similar_block, _ in similar_blocks:
@@ -364,7 +359,7 @@ class BandingLSH:
if len(sig1) != len(sig2):
return 0.0
matches = sum(1 for a, b in zip(sig1, sig2, strict=False) if a == b)
matches = sum(a == b for a, b in zip(sig1, sig2, strict=False))
return matches / len(sig1)
def get_statistics(self) -> dict[str, object]:

View File

@@ -93,9 +93,8 @@ class StructuralSimilarity(BaseSimilarityAlgorithm):
# Count methods in class
method_count = sum(
1
isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef))
for child in node.body
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef))
)
structure.append(f"{depth_prefix}class_methods:{method_count}")

View File

@@ -251,7 +251,7 @@ class TFIDFSimilarity(BaseSimilarityAlgorithm):
total_docs = len(documents)
for term in terms:
docs_containing_term = sum(1 for doc in documents if term in doc)
docs_containing_term = sum(term in doc for doc in documents)
idf[term] = math.log(
total_docs / (docs_containing_term + 1),
) # +1 for smoothing

View File

@@ -58,12 +58,12 @@ class FileFinder:
if root_path.is_file():
return [root_path] if self._is_python_file(root_path) else []
found_files = []
for file_path in root_path.rglob("*.py"):
if self._should_include_file(file_path) and self._is_python_file(file_path):
found_files.append(file_path)
return found_files
return [
file_path
for file_path in root_path.rglob("*.py")
if self._should_include_file(file_path)
and self._is_python_file(file_path)
]
def _should_include_file(self, file_path: Path) -> bool:
"""Check if a file should be included in analysis."""
@@ -77,29 +77,30 @@ class FileFinder:
):
return False
# Check include patterns
for pattern in self.path_config.include_patterns:
if fnmatch.fnmatch(path_str, pattern) or fnmatch.fnmatch(
file_path.name,
pattern,
):
# Check if it's a supported file type
return self._has_supported_extension(file_path)
return False
return next(
(
self._has_supported_extension(file_path)
for pattern in self.path_config.include_patterns
if fnmatch.fnmatch(path_str, pattern)
or fnmatch.fnmatch(
file_path.name,
pattern,
)
),
False,
)
def _has_supported_extension(self, file_path: Path) -> bool:
"""Check if file has a supported extension."""
suffix = file_path.suffix.lower()
for lang in self.language_config.languages:
if (
return any(
(
lang in self.language_config.file_extensions
and suffix in self.language_config.file_extensions[lang]
):
return True
return False
)
for lang in self.language_config.languages
)
def _is_python_file(self, file_path: Path) -> bool:
"""Check if file is a Python file."""
@@ -109,11 +110,14 @@ class FileFinder:
"""Determine the programming language of a file."""
suffix = file_path.suffix.lower()
for lang, extensions in self.language_config.file_extensions.items():
if suffix in extensions:
return lang
return None
return next(
(
lang
for lang, extensions in self.language_config.file_extensions.items()
if suffix in extensions
),
None,
)
def get_project_stats(self, root_path: Path) -> dict[str, Any]:
"""Get statistics about files in the project."""
@@ -173,15 +177,14 @@ class FileFinder:
# Apply include patterns
if include and include_patterns:
include = False
for pattern in include_patterns:
if fnmatch.fnmatch(path_str, pattern) or fnmatch.fnmatch(
include = any(
fnmatch.fnmatch(path_str, pattern)
or fnmatch.fnmatch(
file_path.name,
pattern,
):
include = True
break
)
for pattern in include_patterns
)
if include:
filtered.append(file_path)

View File

@@ -218,49 +218,43 @@ def reset_environment():
# Restore original environment
os.environ.clear()
os.environ.update(original_env)
os.environ |= original_env
@pytest.fixture
def set_env_strict():
"""Set environment for strict mode."""
os.environ.update(
{
"QUALITY_ENFORCEMENT": "strict",
"QUALITY_DUP_THRESHOLD": "0.7",
"QUALITY_COMPLEXITY_THRESHOLD": "10",
"QUALITY_DUP_ENABLED": "true",
"QUALITY_COMPLEXITY_ENABLED": "true",
"QUALITY_MODERN_ENABLED": "true",
"QUALITY_REQUIRE_TYPES": "true",
},
)
os.environ |= {
"QUALITY_ENFORCEMENT": "strict",
"QUALITY_DUP_THRESHOLD": "0.7",
"QUALITY_COMPLEXITY_THRESHOLD": "10",
"QUALITY_DUP_ENABLED": "true",
"QUALITY_COMPLEXITY_ENABLED": "true",
"QUALITY_MODERN_ENABLED": "true",
"QUALITY_REQUIRE_TYPES": "true",
}
@pytest.fixture
def set_env_permissive():
"""Set environment for permissive mode."""
os.environ.update(
{
"QUALITY_ENFORCEMENT": "permissive",
"QUALITY_DUP_THRESHOLD": "0.9",
"QUALITY_COMPLEXITY_THRESHOLD": "20",
"QUALITY_DUP_ENABLED": "true",
"QUALITY_COMPLEXITY_ENABLED": "true",
"QUALITY_MODERN_ENABLED": "false",
"QUALITY_REQUIRE_TYPES": "false",
},
)
os.environ |= {
"QUALITY_ENFORCEMENT": "permissive",
"QUALITY_DUP_THRESHOLD": "0.9",
"QUALITY_COMPLEXITY_THRESHOLD": "20",
"QUALITY_DUP_ENABLED": "true",
"QUALITY_COMPLEXITY_ENABLED": "true",
"QUALITY_MODERN_ENABLED": "false",
"QUALITY_REQUIRE_TYPES": "false",
}
@pytest.fixture
def set_env_posttooluse():
"""Set environment for PostToolUse features."""
os.environ.update(
{
"QUALITY_STATE_TRACKING": "true",
"QUALITY_CROSS_FILE_CHECK": "true",
"QUALITY_VERIFY_NAMING": "true",
"QUALITY_SHOW_SUCCESS": "true",
},
)
os.environ |= {
"QUALITY_STATE_TRACKING": "true",
"QUALITY_CROSS_FILE_CHECK": "true",
"QUALITY_VERIFY_NAMING": "true",
"QUALITY_SHOW_SUCCESS": "true",
}

View File

@@ -561,7 +561,7 @@ class TestTempFileManagement:
# .tmp directory should exist but temp file should be gone
if tmp_dir.exists():
temp_files = list(tmp_dir.glob("hook_validation_*"))
assert len(temp_files) == 0
assert not temp_files
finally:
import shutil
if root.exists():

View File

@@ -48,21 +48,19 @@ class TestQualityConfig:
def test_from_env_with_custom_values(self):
"""Test loading config from environment with custom values."""
os.environ.update(
{
"QUALITY_DUP_THRESHOLD": "0.8",
"QUALITY_DUP_ENABLED": "false",
"QUALITY_COMPLEXITY_THRESHOLD": "15",
"QUALITY_COMPLEXITY_ENABLED": "false",
"QUALITY_MODERN_ENABLED": "false",
"QUALITY_REQUIRE_TYPES": "false",
"QUALITY_ENFORCEMENT": "permissive",
"QUALITY_STATE_TRACKING": "true",
"QUALITY_CROSS_FILE_CHECK": "true",
"QUALITY_VERIFY_NAMING": "false",
"QUALITY_SHOW_SUCCESS": "true",
},
)
os.environ |= {
"QUALITY_DUP_THRESHOLD": "0.8",
"QUALITY_DUP_ENABLED": "false",
"QUALITY_COMPLEXITY_THRESHOLD": "15",
"QUALITY_COMPLEXITY_ENABLED": "false",
"QUALITY_MODERN_ENABLED": "false",
"QUALITY_REQUIRE_TYPES": "false",
"QUALITY_ENFORCEMENT": "permissive",
"QUALITY_STATE_TRACKING": "true",
"QUALITY_CROSS_FILE_CHECK": "true",
"QUALITY_VERIFY_NAMING": "false",
"QUALITY_SHOW_SUCCESS": "true",
}
config = guard.QualityConfig.from_env()

View File

@@ -23,10 +23,10 @@ class TestEdgeCases:
def test_massive_file_content(self):
"""Test handling of very large files."""
config = QualityConfig()
# Create a file with 10,000 lines
massive_content = "\n".join(f"# Line {i}" for i in range(10000))
massive_content += "\ndef func1():\n pass\n"
massive_content = (
"\n".join(f"# Line {i}" for i in range(10000))
+ "\ndef func1():\n pass\n"
)
hook_data = {
"tool_name": "Write",
"tool_input": {

View File

@@ -185,9 +185,7 @@ class TestHelperFunctions:
set_platform("linux")
def fake_which(name: str) -> str | None:
if name == "claude-quality":
return "/usr/local/bin/claude-quality"
return None
return "/usr/local/bin/claude-quality" if name == "claude-quality" else None
monkeypatch.setattr(shutil, "which", fake_which)

View File

@@ -146,7 +146,7 @@ class TestHookIntegration:
"QUALITY_COMPLEXITY_ENABLED": "true",
"QUALITY_MODERN_ENABLED": "false",
}
os.environ.update(env_overrides)
os.environ |= env_overrides
complex_code = """
def complex_func(a, b, c):

View File

@@ -63,9 +63,7 @@ def test_ensure_tool_installed(
suffix = str(path)
if suffix.endswith("basedpyright"):
return tool_exists
if suffix.endswith("uv"):
return not tool_exists
return False
return not tool_exists if suffix.endswith("uv") else False
monkeypatch.setattr(guard.Path, "exists", fake_exists, raising=False)

View File

@@ -209,7 +209,7 @@ class TestProjectRootAndTempFiles:
)
# Should have run from project root
assert len(captured_cwd) > 0
assert captured_cwd
assert captured_cwd[0] == root
finally:
import shutil