feat: enhance bash command guard with file locking and hook chaining
- Introduced file-based locking in `bash_command_guard.py` to prevent concurrent execution issues. - Added configuration for lock timeout and polling interval in `bash_guard_constants.py`. - Implemented a new `hook_chain.py` to unify hook execution, allowing for sequential processing of guards. - Updated `claude-code-settings.json` to support the new hook chaining mechanism. - Refactored subprocess lock handling to improve reliability and prevent deadlocks. This update improves the robustness of the hook system by ensuring that bash commands are executed in a controlled manner, reducing the risk of concurrency-related errors.
This commit is contained in:
@@ -11,6 +11,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from shutil import which
|
||||
@@ -21,6 +22,8 @@ try:
|
||||
from .bash_guard_constants import (
|
||||
DANGEROUS_SHELL_PATTERNS,
|
||||
FORBIDDEN_PATTERNS,
|
||||
LOCK_POLL_INTERVAL_SECONDS,
|
||||
LOCK_TIMEOUT_SECONDS,
|
||||
PYTHON_FILE_PATTERNS,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -28,6 +31,8 @@ except ImportError:
|
||||
DANGEROUS_SHELL_PATTERNS = bash_guard_constants.DANGEROUS_SHELL_PATTERNS
|
||||
FORBIDDEN_PATTERNS = bash_guard_constants.FORBIDDEN_PATTERNS
|
||||
PYTHON_FILE_PATTERNS = bash_guard_constants.PYTHON_FILE_PATTERNS
|
||||
LOCK_TIMEOUT_SECONDS = bash_guard_constants.LOCK_TIMEOUT_SECONDS
|
||||
LOCK_POLL_INTERVAL_SECONDS = bash_guard_constants.LOCK_POLL_INTERVAL_SECONDS
|
||||
|
||||
|
||||
class JsonObject(TypedDict, total=False):
|
||||
@@ -51,31 +56,45 @@ def _get_lock_file() -> Path:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _subprocess_lock(timeout: float = 5.0):
|
||||
def _subprocess_lock(timeout: float = LOCK_TIMEOUT_SECONDS):
|
||||
"""Context manager for file-based subprocess locking.
|
||||
|
||||
Args:
|
||||
timeout: Timeout in seconds for acquiring lock.
|
||||
timeout: Maximum time in seconds to wait for the lock. Non-positive
|
||||
values attempt a single non-blocking acquisition.
|
||||
|
||||
Yields:
|
||||
True if lock was acquired, False if timeout occurred.
|
||||
"""
|
||||
lock_file = _get_lock_file()
|
||||
deadline = (
|
||||
time.monotonic() + timeout if timeout and timeout > 0 else None
|
||||
)
|
||||
acquired = False
|
||||
|
||||
# Open or create lock file
|
||||
with open(lock_file, "a") as f:
|
||||
try:
|
||||
# Try to acquire exclusive lock (non-blocking)
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
yield True
|
||||
except (IOError, OSError):
|
||||
# Lock is held by another process, skip to avoid blocking
|
||||
yield False
|
||||
while True:
|
||||
try:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
acquired = True
|
||||
break
|
||||
except (IOError, OSError):
|
||||
if deadline is None:
|
||||
break
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
time.sleep(min(LOCK_POLL_INTERVAL_SECONDS, remaining))
|
||||
|
||||
yield acquired
|
||||
finally:
|
||||
try:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
||||
except (IOError, OSError):
|
||||
pass
|
||||
if acquired:
|
||||
try:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
||||
except (IOError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
def _contains_forbidden_pattern(text: str) -> tuple[bool, str | None]:
|
||||
@@ -125,7 +144,7 @@ def _is_dangerous_shell_command(command: str) -> tuple[bool, str | None]:
|
||||
r"\b(sed|awk|perl|ed|echo|printf|cat|tee|find|xargs|python|vim|nano|emacs)\b",
|
||||
pattern,
|
||||
)
|
||||
tool_name = tool_match.group(1) if tool_match else "shell utility"
|
||||
tool_name = tool_match[1] if tool_match else "shell utility"
|
||||
return True, f"Use of {tool_name} to modify Python files"
|
||||
|
||||
return False, None
|
||||
@@ -362,7 +381,7 @@ def _get_staged_python_files() -> list[str]:
|
||||
|
||||
try:
|
||||
# Acquire file-based lock to prevent subprocess concurrency issues
|
||||
with _subprocess_lock(timeout=5.0) as acquired:
|
||||
with _subprocess_lock(timeout=LOCK_TIMEOUT_SECONDS) as acquired:
|
||||
if not acquired:
|
||||
return []
|
||||
|
||||
@@ -430,10 +449,7 @@ def stop_hook(_hook_data: dict[str, object]) -> JsonObject:
|
||||
if not changed_files:
|
||||
return _create_hook_response("Stop", decision="approve")
|
||||
|
||||
# Scan all changed Python files for violations
|
||||
violations = _check_files_for_violations(changed_files)
|
||||
|
||||
if violations:
|
||||
if violations := _check_files_for_violations(changed_files):
|
||||
violation_text = "\n".join(f" {v}" for v in violations)
|
||||
message = (
|
||||
f"🚫 Final Validation Failed\n\n"
|
||||
|
||||
@@ -4,6 +4,10 @@ This module contains patterns and constants used to detect forbidden code patter
|
||||
and dangerous shell commands that could compromise type safety.
|
||||
"""
|
||||
|
||||
# File locking configuration shared across hook modules.
|
||||
LOCK_TIMEOUT_SECONDS = 45.0
|
||||
LOCK_POLL_INTERVAL_SECONDS = 0.1
|
||||
|
||||
# Forbidden patterns that should never appear in Python code
|
||||
FORBIDDEN_PATTERNS = [
|
||||
r"\bfrom\s+typing\s+import\s+.*\bAny\b", # from typing import Any
|
||||
|
||||
@@ -2,22 +2,22 @@
|
||||
"hooks": {
|
||||
"PreToolUse": [
|
||||
{
|
||||
"matcher": "Write|Edit|MultiEdit",
|
||||
"matcher": "Write|Edit|MultiEdit|Bash",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "cd $CLAUDE_PROJECT_DIR/hooks && (python code_quality_guard.py 2>/dev/null || python3 code_quality_guard.py)"
|
||||
"command": "cd $CLAUDE_PROJECT_DIR/hooks && (python hook_chain.py --event pre 2>/dev/null || python3 hook_chain.py --event pre)"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"PostToolUse": [
|
||||
{
|
||||
"matcher": "Write|Edit|MultiEdit",
|
||||
"matcher": "Write|Edit|MultiEdit|Bash",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "cd $CLAUDE_PROJECT_DIR/hooks && (python code_quality_guard.py 2>/dev/null || python3 code_quality_guard.py)"
|
||||
"command": "cd $CLAUDE_PROJECT_DIR/hooks && (python hook_chain.py --event post 2>/dev/null || python3 hook_chain.py --event post)"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -90,10 +90,8 @@ def _subprocess_lock(timeout: float = 10.0):
|
||||
# Lock is held by another process, skip to avoid blocking
|
||||
yield False
|
||||
finally:
|
||||
try:
|
||||
with suppress(IOError, OSError):
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
||||
except (IOError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
SUSPICIOUS_SUFFIXES: tuple[str, ...] = (
|
||||
@@ -296,11 +294,8 @@ def generate_test_quality_guidance(
|
||||
_config: "QualityConfig",
|
||||
) -> str:
|
||||
"""Return enriched guidance for test quality rule violations."""
|
||||
function_name = "test_function"
|
||||
match = re.search(r"def\s+(\w+)\s*\(", content)
|
||||
if match:
|
||||
function_name = match.group(1)
|
||||
|
||||
function_name = match[1] if match else "test_function"
|
||||
# Extract a small snippet of the violating code
|
||||
code_snippet = ""
|
||||
if match:
|
||||
@@ -491,12 +486,10 @@ def get_claude_quality_command(repo_root: Path | None = None) -> list[str]:
|
||||
["claude-quality.exe", "claude-quality"] if is_windows else ["claude-quality"]
|
||||
)
|
||||
|
||||
candidates: list[tuple[Path, list[str]]] = []
|
||||
for name in python_names:
|
||||
candidates.append(_module_candidate(scripts_dir / name))
|
||||
for name in cli_names:
|
||||
candidates.append(_cli_candidate(scripts_dir / name))
|
||||
|
||||
candidates: list[tuple[Path, list[str]]] = [
|
||||
_module_candidate(scripts_dir / name) for name in python_names
|
||||
]
|
||||
candidates.extend(_cli_candidate(scripts_dir / name) for name in cli_names)
|
||||
for candidate_path, command in candidates:
|
||||
if candidate_path.exists():
|
||||
return command
|
||||
@@ -592,7 +585,8 @@ def _format_pyrefly_errors(output: str) -> str:
|
||||
lines = output.strip().split("\n")
|
||||
|
||||
# Count ERROR lines to provide summary
|
||||
error_count = sum(1 for line in lines if line.strip().startswith("ERROR"))
|
||||
error_count = sum(bool(line.strip().startswith("ERROR"))
|
||||
for line in lines)
|
||||
|
||||
if error_count == 0:
|
||||
return output.strip()
|
||||
@@ -616,9 +610,8 @@ def _format_sourcery_errors(output: str) -> str:
|
||||
# Try to extract the number
|
||||
import re
|
||||
|
||||
match = re.search(r"(\d+)\s+issue", line)
|
||||
if match:
|
||||
issue_count = int(match.group(1))
|
||||
if match := re.search(r"(\d+)\s+issue", line):
|
||||
issue_count = int(match[1])
|
||||
break
|
||||
|
||||
# Format the output, removing redundant summary lines
|
||||
@@ -722,8 +715,7 @@ def _run_type_checker(
|
||||
project_root = venv_bin.parent.parent # Go from .venv/bin to project root
|
||||
src_dir = project_root / "src"
|
||||
if src_dir.exists() and src_dir.is_dir():
|
||||
existing_pythonpath = env.get("PYTHONPATH", "")
|
||||
if existing_pythonpath:
|
||||
if existing_pythonpath := env.get("PYTHONPATH", ""):
|
||||
env["PYTHONPATH"] = f"{src_dir}:{existing_pythonpath}"
|
||||
else:
|
||||
env["PYTHONPATH"] = str(src_dir)
|
||||
@@ -1534,7 +1526,7 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
|
||||
break
|
||||
|
||||
# Check function and class names in content
|
||||
try:
|
||||
with suppress(SyntaxError):
|
||||
# Dedent the content to handle code fragments with leading indentation
|
||||
tree = ast.parse(textwrap.dedent(content))
|
||||
|
||||
@@ -1603,10 +1595,6 @@ def _detect_suffix_duplication(file_path: str, content: str) -> list[str]:
|
||||
continue
|
||||
break
|
||||
|
||||
except SyntaxError:
|
||||
# If we can't parse the AST, skip function/class checks
|
||||
pass
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
@@ -1646,8 +1634,7 @@ def _handle_quality_issues(
|
||||
if "\n" in issue:
|
||||
lines = issue.split("\n")
|
||||
formatted_issues.append(f"• {lines[0]}")
|
||||
for line in lines[1:]:
|
||||
formatted_issues.append(f" {line}")
|
||||
formatted_issues.extend(f" {line}" for line in lines[1:])
|
||||
else:
|
||||
formatted_issues.append(f"• {issue}")
|
||||
|
||||
@@ -2021,13 +2008,12 @@ def run_test_quality_checks(
|
||||
file_path,
|
||||
config,
|
||||
)
|
||||
external_context = get_external_context(
|
||||
if external_context := get_external_context(
|
||||
rule,
|
||||
content,
|
||||
file_path,
|
||||
config,
|
||||
)
|
||||
if external_context:
|
||||
):
|
||||
base_guidance += f"\n\n{external_context}"
|
||||
issues.append(base_guidance)
|
||||
break # Only add one guidance message
|
||||
@@ -2039,20 +2025,19 @@ def run_test_quality_checks(
|
||||
file_path,
|
||||
config,
|
||||
)
|
||||
external_context = get_external_context(
|
||||
if external_context := get_external_context(
|
||||
"unknown",
|
||||
content,
|
||||
file_path,
|
||||
config,
|
||||
)
|
||||
if external_context:
|
||||
):
|
||||
base_guidance += f"\n\n{external_context}"
|
||||
issues.append(base_guidance)
|
||||
elif result.returncode == 2:
|
||||
# Error occurred
|
||||
logging.debug("Sourcery error: %s", result.stderr)
|
||||
# Don't block on Sourcery errors - just log them
|
||||
# Exit code 0 means no issues - do nothing
|
||||
# Exit code 0 means no issues - do nothing
|
||||
|
||||
except (subprocess.TimeoutExpired, OSError, json.JSONDecodeError) as e:
|
||||
# If Sourcery fails, don't block the operation
|
||||
|
||||
159
hooks/hook_chain.py
Normal file
159
hooks/hook_chain.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Unified hook runner that chains existing guards without parallel execution.
|
||||
|
||||
This entry point lets Claude Code invoke a single command per hook event while
|
||||
still reusing the more specialized guards. It reacts to the hook payload to
|
||||
decide which guard(s) to run and propagates their output/exit codes so Claude
|
||||
continues to see the same responses. Post-tool Bash logging is handled here so
|
||||
the old jq pipeline is no longer required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
HookPayload = dict[str, Any]
|
||||
|
||||
|
||||
def _load_payload() -> tuple[str, HookPayload | None]:
|
||||
"""Read stdin payload once and return both raw text and parsed JSON."""
|
||||
|
||||
raw = sys.stdin.read()
|
||||
if not raw:
|
||||
return raw, None
|
||||
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return raw, None
|
||||
|
||||
return (raw, parsed) if isinstance(parsed, dict) else (raw, None)
|
||||
|
||||
|
||||
def _default_response(event: str) -> int:
|
||||
"""Emit the minimal pass-through JSON for hooks that we skip."""
|
||||
|
||||
hook_event = "PreToolUse" if event == "pre" else "PostToolUse"
|
||||
response: HookPayload = {"hookSpecificOutput": {"hookEventName": hook_event}}
|
||||
|
||||
if event == "pre":
|
||||
response["hookSpecificOutput"]["permissionDecision"] = "allow"
|
||||
|
||||
sys.stdout.write(json.dumps(response))
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
return 0
|
||||
|
||||
|
||||
def _run_guard(script_name: str, payload: str) -> int:
|
||||
"""Execute a sibling guard script sequentially and relay its output."""
|
||||
|
||||
script_path = Path(__file__).with_name(script_name)
|
||||
if not script_path.exists():
|
||||
raise FileNotFoundError(f"Missing guard script: {script_path}")
|
||||
|
||||
proc = subprocess.run( # noqa: S603
|
||||
[sys.executable, str(script_path)],
|
||||
input=payload,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if proc.stdout:
|
||||
sys.stdout.write(proc.stdout)
|
||||
if proc.stderr:
|
||||
sys.stderr.write(proc.stderr)
|
||||
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
return proc.returncode
|
||||
|
||||
|
||||
def _log_bash_command(payload: HookPayload) -> None:
|
||||
"""Append successful Bash commands to Claude's standard log file."""
|
||||
|
||||
tool_input = payload.get("tool_input")
|
||||
if not isinstance(tool_input, dict):
|
||||
return
|
||||
|
||||
command = tool_input.get("command")
|
||||
if not isinstance(command, str) or not command.strip():
|
||||
return
|
||||
|
||||
description = tool_input.get("description")
|
||||
if not isinstance(description, str) or not description.strip():
|
||||
description = "No description"
|
||||
|
||||
log_path = Path.home() / ".claude" / "bash-command-log.txt"
|
||||
try:
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with log_path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(f"{command} - {description}\n")
|
||||
except OSError:
|
||||
# Logging is best-effort; ignore filesystem errors.
|
||||
pass
|
||||
|
||||
|
||||
def _pre_hook(payload: HookPayload | None, raw: str) -> int:
|
||||
"""Handle PreToolUse events with sequential guard execution."""
|
||||
|
||||
if payload is None:
|
||||
return _default_response("pre")
|
||||
|
||||
tool_name = str(payload.get("tool_name", ""))
|
||||
|
||||
if tool_name in {"Write", "Edit", "MultiEdit"}:
|
||||
return _run_guard("code_quality_guard.py", raw)
|
||||
if tool_name == "Bash":
|
||||
return _run_guard("bash_command_guard.py", raw)
|
||||
|
||||
return _default_response("pre")
|
||||
|
||||
|
||||
def _post_hook(payload: HookPayload | None, raw: str) -> int:
|
||||
"""Handle PostToolUse events with sequential guard execution."""
|
||||
|
||||
if payload is None:
|
||||
return _default_response("post")
|
||||
|
||||
tool_name = str(payload.get("tool_name", ""))
|
||||
|
||||
if tool_name in {"Write", "Edit", "MultiEdit"}:
|
||||
return _run_guard("code_quality_guard.py", raw)
|
||||
if tool_name == "Bash":
|
||||
exit_code = _run_guard("bash_command_guard.py", raw)
|
||||
if exit_code == 0:
|
||||
_log_bash_command(payload)
|
||||
return exit_code
|
||||
|
||||
return _default_response("post")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Chain Claude Code hook guards")
|
||||
parser.add_argument(
|
||||
"--event",
|
||||
choices={"pre", "post"},
|
||||
required=True,
|
||||
help="Hook event type to handle.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
raw_payload, parsed_payload = _load_payload()
|
||||
|
||||
if args.event == "pre":
|
||||
exit_code = _pre_hook(parsed_payload, raw_payload)
|
||||
else:
|
||||
exit_code = _post_hook(parsed_payload, raw_payload)
|
||||
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover - CLI entry
|
||||
main()
|
||||
@@ -105,13 +105,11 @@ class EnhancedMessageFormatter:
|
||||
loc_type: str = loc.get("type", "code")
|
||||
location_summary.append(f" • {name} ({loc_type}, lines {lines})")
|
||||
|
||||
# Determine refactoring strategy
|
||||
# Convert locations to proper format for refactoring strategy
|
||||
dict_locations: list[dict[str, str]] = []
|
||||
for loc_item in locations:
|
||||
if isinstance(loc_item, dict):
|
||||
dict_locations.append(cast(dict[str, str], loc_item))
|
||||
|
||||
dict_locations: list[dict[str, str]] = [
|
||||
cast(dict[str, str], loc_item)
|
||||
for loc_item in locations
|
||||
if isinstance(loc_item, dict)
|
||||
]
|
||||
strategy = EnhancedMessageFormatter._suggest_refactoring_strategy(
|
||||
duplicate_type,
|
||||
dict_locations,
|
||||
@@ -124,35 +122,29 @@ class EnhancedMessageFormatter:
|
||||
f"🔍 Duplicate Code Detected ({similarity:.0%} similar)",
|
||||
"",
|
||||
"📍 Locations:",
|
||||
]
|
||||
parts.extend(location_summary)
|
||||
parts.extend(
|
||||
[
|
||||
*location_summary,
|
||||
*[
|
||||
"",
|
||||
f"📊 Pattern Type: {duplicate_type}",
|
||||
]
|
||||
)
|
||||
|
||||
],
|
||||
]
|
||||
if include_refactoring and strategy:
|
||||
parts.append("")
|
||||
parts.append("💡 Refactoring Suggestion:")
|
||||
parts.append(f" Strategy: {strategy.strategy_type}")
|
||||
parts.append(f" {strategy.description}")
|
||||
parts.append("")
|
||||
parts.append("✅ Benefits:")
|
||||
for benefit in strategy.benefits:
|
||||
parts.append(f" • {benefit}")
|
||||
|
||||
parts.extend(
|
||||
(
|
||||
"",
|
||||
"💡 Refactoring Suggestion:",
|
||||
f" Strategy: {strategy.strategy_type}",
|
||||
f" {strategy.description}",
|
||||
"",
|
||||
"✅ Benefits:",
|
||||
)
|
||||
)
|
||||
parts.extend(f" • {benefit}" for benefit in strategy.benefits)
|
||||
if strategy.example_before and strategy.example_after:
|
||||
parts.append("")
|
||||
parts.append("📝 Example:")
|
||||
parts.append(" Before:")
|
||||
for line in strategy.example_before.splitlines():
|
||||
parts.append(f" {line}")
|
||||
parts.extend(("", "📝 Example:", " Before:"))
|
||||
parts.extend(f" {line}" for line in strategy.example_before.splitlines())
|
||||
parts.append(" After:")
|
||||
for line in strategy.example_after.splitlines():
|
||||
parts.append(f" {line}")
|
||||
|
||||
parts.extend(f" {line}" for line in strategy.example_after.splitlines())
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
@@ -460,8 +452,7 @@ class EnhancedMessageFormatter:
|
||||
int(line_num),
|
||||
context_lines=2,
|
||||
)
|
||||
parts.append(context.code_snippet)
|
||||
parts.append("")
|
||||
parts.extend((context.code_snippet, ""))
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
@@ -643,17 +634,12 @@ class EnhancedMessageFormatter:
|
||||
]
|
||||
fixes_list = guidance["fixes"]
|
||||
if isinstance(fixes_list, list):
|
||||
for fix in fixes_list:
|
||||
parts.append(f" • {fix}")
|
||||
|
||||
parts.extend(f" • {fix}" for fix in fixes_list)
|
||||
if include_examples and guidance.get("example_before"):
|
||||
parts.append("")
|
||||
parts.append("💡 Example:")
|
||||
parts.append(" ❌ Before:")
|
||||
parts.extend(("", "💡 Example:", " ❌ Before:"))
|
||||
example_before_str = guidance.get("example_before", "")
|
||||
if isinstance(example_before_str, str):
|
||||
for line in example_before_str.splitlines():
|
||||
parts.append(f" {line}")
|
||||
parts.extend(f" {line}" for line in example_before_str.splitlines())
|
||||
parts.append(" ✅ After:")
|
||||
example_after_str = guidance.get("example_after", "")
|
||||
if isinstance(example_after_str, str):
|
||||
@@ -661,8 +647,7 @@ class EnhancedMessageFormatter:
|
||||
parts.append(f" {line}")
|
||||
|
||||
if code_snippet:
|
||||
parts.append("")
|
||||
parts.append("📍 Your Code:")
|
||||
parts.extend(("", "📍 Your Code:"))
|
||||
for line in code_snippet.splitlines()[:10]:
|
||||
parts.append(f" {line}")
|
||||
|
||||
|
||||
@@ -58,9 +58,11 @@ class TypeInferenceHelper:
|
||||
assignments: list[ast.expr] = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == variable_name:
|
||||
assignments.append(node.value)
|
||||
assignments.extend(
|
||||
node.value
|
||||
for target in node.targets
|
||||
if isinstance(target, ast.Name) and target.id == variable_name
|
||||
)
|
||||
elif (
|
||||
isinstance(node, ast.AnnAssign)
|
||||
and isinstance(node.target, ast.Name)
|
||||
@@ -207,12 +209,10 @@ class TypeInferenceHelper:
|
||||
|
||||
# Try to infer from usage within function
|
||||
arg_name = arg.arg
|
||||
suggested_type = TypeInferenceHelper._infer_param_from_usage(
|
||||
if suggested_type := TypeInferenceHelper._infer_param_from_usage(
|
||||
arg_name,
|
||||
function_node,
|
||||
)
|
||||
|
||||
if suggested_type:
|
||||
):
|
||||
suggestions.append(
|
||||
TypeSuggestion(
|
||||
element_name=arg_name,
|
||||
@@ -289,8 +289,6 @@ class TypeInferenceHelper:
|
||||
|
||||
Returns list of (old_import, new_import, reason) tuples.
|
||||
"""
|
||||
suggestions = []
|
||||
|
||||
# Patterns to detect and replace
|
||||
patterns = {
|
||||
r"from typing import.*\bUnion\b": (
|
||||
@@ -325,11 +323,11 @@ class TypeInferenceHelper:
|
||||
),
|
||||
}
|
||||
|
||||
for pattern, (old, new, example) in patterns.items():
|
||||
if re.search(pattern, source_code):
|
||||
suggestions.append((old, new, example))
|
||||
|
||||
return suggestions
|
||||
return [
|
||||
(old, new, example)
|
||||
for pattern, (old, new, example) in patterns.items()
|
||||
if re.search(pattern, source_code)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def find_any_usage_with_context(source_code: str) -> list[dict[str, str | int]]:
|
||||
@@ -367,20 +365,18 @@ class TypeInferenceHelper:
|
||||
|
||||
# Find function parameters with Any
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
for arg in node.args.args:
|
||||
if arg.annotation and TypeInferenceHelper._contains_any(
|
||||
arg.annotation
|
||||
):
|
||||
results.append(
|
||||
{
|
||||
"line": getattr(node, "lineno", 0),
|
||||
"element": arg.arg,
|
||||
"current": "Any",
|
||||
"suggested": "Infer from usage",
|
||||
"context": f"parameter in {node.name}",
|
||||
}
|
||||
)
|
||||
|
||||
results.extend(
|
||||
{
|
||||
"line": getattr(node, "lineno", 0),
|
||||
"element": arg.arg,
|
||||
"current": "Any",
|
||||
"suggested": "Infer from usage",
|
||||
"context": f"parameter in {node.name}",
|
||||
}
|
||||
for arg in node.args.args
|
||||
if arg.annotation
|
||||
and TypeInferenceHelper._contains_any(arg.annotation)
|
||||
)
|
||||
# Check return type
|
||||
if node.returns and TypeInferenceHelper._contains_any(node.returns):
|
||||
suggestion = TypeInferenceHelper.suggest_function_return_type(
|
||||
|
||||
24461
logs/status_line.json
24461
logs/status_line.json
File diff suppressed because it is too large
Load Diff
@@ -225,9 +225,6 @@ class ModernizationAnalyzer(ast.NodeVisitor):
|
||||
|
||||
def visit_BinOp(self, node: ast.BinOp) -> None:
|
||||
"""Check for Union usage that could be modernized."""
|
||||
if isinstance(node.op, ast.BitOr):
|
||||
# This is already modern syntax (X | Y)
|
||||
pass
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
@@ -273,26 +270,22 @@ class ModernizationAnalyzer(ast.NodeVisitor):
|
||||
"""Add issue for typing import that can be replaced with built-ins."""
|
||||
modern_replacement = self.REPLACEABLE_TO_MODERN[typing_name]
|
||||
|
||||
if typing_name in ["List", "Dict", "Tuple", "Set", "FrozenSet"]:
|
||||
if typing_name in {"List", "Dict", "Tuple", "Set", "FrozenSet"}:
|
||||
description = (
|
||||
f"Use built-in '{modern_replacement}' instead of "
|
||||
f"'typing.{typing_name}' (Python 3.9+)"
|
||||
)
|
||||
severity = "warning"
|
||||
elif typing_name == "Union":
|
||||
description = (
|
||||
"Use '|' union operator instead of 'typing.Union' (Python 3.10+)"
|
||||
)
|
||||
severity = "warning"
|
||||
elif typing_name == "Optional":
|
||||
description = "Use '| None' instead of 'typing.Optional' (Python 3.10+)"
|
||||
severity = "warning"
|
||||
else:
|
||||
description = (
|
||||
f"Use '{modern_replacement}' instead of 'typing.{typing_name}'"
|
||||
)
|
||||
severity = "warning"
|
||||
|
||||
severity = "warning"
|
||||
self.issues.append(
|
||||
ModernizationIssue(
|
||||
file_path=self.file_path,
|
||||
@@ -359,7 +352,7 @@ class ModernizationAnalyzer(ast.NodeVisitor):
|
||||
"""Add issue for typing usage that can be modernized."""
|
||||
if typing_name in self.REPLACEABLE_TYPING_IMPORTS:
|
||||
modern_replacement = self.REPLACEABLE_TO_MODERN[typing_name]
|
||||
if typing_name in ["List", "Dict", "Tuple", "Set", "FrozenSet"]:
|
||||
if typing_name in {"List", "Dict", "Tuple", "Set", "FrozenSet"}:
|
||||
old_pattern = f"{typing_name}[...]"
|
||||
new_pattern = f"{modern_replacement.lower()}[...]"
|
||||
description = (
|
||||
@@ -696,7 +689,7 @@ class PydanticAnalyzer:
|
||||
# Check if line contains any valid v2 methods
|
||||
return any(f".{v2_method}(" in line for v2_method in self.V2_METHODS)
|
||||
|
||||
def _get_suggested_fix(self, pattern: str, line: str) -> str: # noqa: ARG002
|
||||
def _get_suggested_fix(self, pattern: str, line: str) -> str: # noqa: ARG002
|
||||
"""Get suggested fix for a Pydantic pattern."""
|
||||
fixes = {
|
||||
r"\.dict\(\)": line.replace(".dict()", ".model_dump()"),
|
||||
@@ -706,11 +699,14 @@ class PydanticAnalyzer:
|
||||
r"@root_validator": line.replace("@root_validator", "@model_validator"),
|
||||
}
|
||||
|
||||
for fix_pattern, fix_line in fixes.items():
|
||||
if re.search(fix_pattern, line):
|
||||
return fix_line.strip()
|
||||
|
||||
return "See Pydantic v2 migration guide"
|
||||
return next(
|
||||
(
|
||||
fix_line.strip()
|
||||
for fix_pattern, fix_line in fixes.items()
|
||||
if re.search(fix_pattern, line)
|
||||
),
|
||||
"See Pydantic v2 migration guide",
|
||||
)
|
||||
|
||||
|
||||
class ModernizationEngine:
|
||||
@@ -753,8 +749,7 @@ class ModernizationEngine:
|
||||
if file_path.suffix.lower() == ".py":
|
||||
issues = self.analyze_file(file_path)
|
||||
|
||||
# Apply exception filtering
|
||||
filtered_issues = self.exception_filter.filter_issues(
|
||||
if filtered_issues := self.exception_filter.filter_issues(
|
||||
"modernization",
|
||||
issues,
|
||||
get_file_path_fn=lambda issue: issue.file_path,
|
||||
@@ -764,9 +759,7 @@ class ModernizationEngine:
|
||||
issue.file_path,
|
||||
issue.line_number,
|
||||
),
|
||||
)
|
||||
|
||||
if filtered_issues: # Only include files with remaining issues
|
||||
):
|
||||
results[file_path] = filtered_issues
|
||||
|
||||
return results
|
||||
@@ -800,16 +793,16 @@ class ModernizationEngine:
|
||||
by_type.setdefault(issue.issue_type, []).append(issue)
|
||||
by_severity[issue.severity] += 1
|
||||
|
||||
# Top files with most issues
|
||||
file_counts = {}
|
||||
for file_path, issues in results.items():
|
||||
if issues:
|
||||
file_counts[file_path] = len(issues)
|
||||
|
||||
file_counts = {
|
||||
file_path: len(issues)
|
||||
for file_path, issues in results.items()
|
||||
if issues
|
||||
}
|
||||
top_files = sorted(file_counts.items(), key=lambda x: x[1], reverse=True)[:10]
|
||||
|
||||
# Auto-fixable issues
|
||||
auto_fixable = sum(1 for issue in all_issues if issue.can_auto_fix)
|
||||
auto_fixable = sum(bool(issue.can_auto_fix)
|
||||
for issue in all_issues)
|
||||
|
||||
return {
|
||||
"total_files_analyzed": len(results),
|
||||
|
||||
@@ -144,15 +144,15 @@ def duplicates(
|
||||
results["duplicates"].append({"group_id": i, "analysis": detailed_analysis})
|
||||
|
||||
# Output results
|
||||
if output_format == "json":
|
||||
if output_format == "console":
|
||||
_print_console_duplicates(results, verbose)
|
||||
elif output_format == "csv":
|
||||
_print_csv_duplicates(results, output)
|
||||
elif output_format == "json":
|
||||
if output:
|
||||
json.dump(results, output, indent=2, default=str)
|
||||
else:
|
||||
click.echo(json.dumps(results, indent=2, default=str))
|
||||
elif output_format == "console":
|
||||
_print_console_duplicates(results, verbose)
|
||||
elif output_format == "csv":
|
||||
_print_csv_duplicates(results, output)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -211,13 +211,13 @@ def complexity(
|
||||
overview = analyzer.get_project_complexity_overview(all_files)
|
||||
|
||||
# Output results
|
||||
if output_format == "json":
|
||||
if output_format == "console":
|
||||
_print_console_complexity(overview, verbose)
|
||||
elif output_format == "json":
|
||||
if output:
|
||||
json.dump(overview, output, indent=2, default=str)
|
||||
else:
|
||||
click.echo(json.dumps(overview, indent=2, default=str))
|
||||
elif output_format == "console":
|
||||
_print_console_complexity(overview, verbose)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -290,10 +290,11 @@ def modernization(
|
||||
if pydantic_only:
|
||||
filtered_results = {}
|
||||
for file_path, issues in results.items():
|
||||
pydantic_issues = [
|
||||
issue for issue in issues if issue.issue_type == "pydantic_v1_pattern"
|
||||
]
|
||||
if pydantic_issues:
|
||||
if pydantic_issues := [
|
||||
issue
|
||||
for issue in issues
|
||||
if issue.issue_type == "pydantic_v1_pattern"
|
||||
]:
|
||||
filtered_results[file_path] = pydantic_issues
|
||||
results = filtered_results
|
||||
|
||||
@@ -310,13 +311,13 @@ def modernization(
|
||||
},
|
||||
}
|
||||
|
||||
if output_format == "json":
|
||||
if output_format == "console":
|
||||
_print_console_modernization(final_results, verbose, include_type_hints)
|
||||
elif output_format == "json":
|
||||
if output:
|
||||
json.dump(final_results, output, indent=2, default=str)
|
||||
else:
|
||||
click.echo(json.dumps(final_results, indent=2, default=str))
|
||||
elif output_format == "console":
|
||||
_print_console_modernization(final_results, verbose, include_type_hints)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -418,12 +419,11 @@ def full_analysis(
|
||||
# Parse the AST and analyze
|
||||
tree = ast.parse(content)
|
||||
ast_analyzer.visit(tree)
|
||||
smells = ast_analyzer.detect_code_smells()
|
||||
if smells:
|
||||
if smells := ast_analyzer.detect_code_smells():
|
||||
all_smells.extend(
|
||||
[{"file": str(file_path), "smell": smell} for smell in smells],
|
||||
)
|
||||
except (OSError, PermissionError, UnicodeDecodeError):
|
||||
except (OSError, UnicodeDecodeError):
|
||||
continue
|
||||
|
||||
results["code_smells"] = {"total_smells": len(all_smells), "details": all_smells}
|
||||
@@ -432,13 +432,13 @@ def full_analysis(
|
||||
results["quality_score"] = _calculate_overall_quality_score(results)
|
||||
|
||||
# Output results
|
||||
if output_format == "json":
|
||||
if output_format == "console":
|
||||
_print_console_full_analysis(results, verbose)
|
||||
elif output_format == "json":
|
||||
if output:
|
||||
json.dump(results, output, indent=2, default=str)
|
||||
else:
|
||||
click.echo(json.dumps(results, indent=2, default=str))
|
||||
elif output_format == "console":
|
||||
_print_console_full_analysis(results, verbose)
|
||||
|
||||
|
||||
def _print_console_duplicates(results: dict[str, Any], verbose: bool) -> None:
|
||||
@@ -478,7 +478,7 @@ def _print_console_duplicates(results: dict[str, Any], verbose: bool) -> None:
|
||||
|
||||
def _print_csv_duplicates(results: dict[str, Any], output: IO[str] | None) -> None:
|
||||
"""Print duplicate results in CSV format."""
|
||||
csv_output = output if output else sys.stdout
|
||||
csv_output = output or sys.stdout
|
||||
writer = csv.writer(csv_output)
|
||||
writer.writerow(
|
||||
[
|
||||
|
||||
@@ -43,12 +43,10 @@ class ComplexityAnalyzer:
|
||||
"""Analyze multiple files in parallel."""
|
||||
raw_results = self.radon_analyzer.batch_analyze_files(file_paths, max_workers)
|
||||
|
||||
# Filter metrics based on configuration
|
||||
filtered_results = {}
|
||||
for path, metrics in raw_results.items():
|
||||
filtered_results[path] = self._filter_metrics_by_config(metrics)
|
||||
|
||||
return filtered_results
|
||||
return {
|
||||
path: self._filter_metrics_by_config(metrics)
|
||||
for path, metrics in raw_results.items()
|
||||
}
|
||||
|
||||
def get_complexity_summary(self, metrics: ComplexityMetrics) -> dict[str, Any]:
|
||||
"""Get a human-readable summary of complexity metrics."""
|
||||
|
||||
@@ -36,7 +36,11 @@ class ComplexityCalculator:
|
||||
|
||||
# AST-based metrics
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
if (
|
||||
isinstance(node, ast.FunctionDef)
|
||||
or not isinstance(node, ast.ClassDef)
|
||||
and isinstance(node, ast.AsyncFunctionDef)
|
||||
):
|
||||
metrics.function_count += 1
|
||||
# Count parameters
|
||||
metrics.parameters_count += len(node.args.args)
|
||||
@@ -46,13 +50,6 @@ class ComplexityCalculator:
|
||||
)
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
metrics.class_count += 1
|
||||
elif isinstance(node, ast.AsyncFunctionDef):
|
||||
metrics.function_count += 1
|
||||
metrics.parameters_count += len(node.args.args)
|
||||
metrics.returns_count += len(
|
||||
[n for n in ast.walk(node) if isinstance(n, ast.Return)],
|
||||
)
|
||||
|
||||
# Calculate cyclomatic complexity
|
||||
metrics.cyclomatic_complexity = self._calculate_cyclomatic_complexity(tree)
|
||||
|
||||
@@ -308,10 +305,8 @@ class ComplexityCalculator:
|
||||
|
||||
def _count_logical_lines(self, tree: ast.AST) -> int:
|
||||
"""Count logical lines of code (AST nodes that represent statements)."""
|
||||
count = 0
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(
|
||||
return sum(
|
||||
isinstance(
|
||||
node,
|
||||
ast.Assign
|
||||
| ast.AugAssign
|
||||
@@ -327,22 +322,21 @@ class ComplexityCalculator:
|
||||
| ast.Global
|
||||
| ast.Nonlocal
|
||||
| ast.Assert,
|
||||
):
|
||||
count += 1
|
||||
|
||||
return count
|
||||
)
|
||||
for node in ast.walk(tree)
|
||||
)
|
||||
|
||||
def _count_variables(self, tree: ast.AST) -> int:
|
||||
"""Count unique variable names."""
|
||||
variables = set()
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Name) and isinstance(
|
||||
variables = {
|
||||
node.id
|
||||
for node in ast.walk(tree)
|
||||
if isinstance(node, ast.Name)
|
||||
and isinstance(
|
||||
node.ctx,
|
||||
(ast.Store, ast.Del),
|
||||
):
|
||||
variables.add(node.id)
|
||||
|
||||
)
|
||||
}
|
||||
return len(variables)
|
||||
|
||||
def _count_methods(self, tree: ast.AST) -> int:
|
||||
|
||||
@@ -115,9 +115,7 @@ class ComplexityMetrics:
|
||||
return "Moderate"
|
||||
if score < 60:
|
||||
return "High"
|
||||
if score < 80:
|
||||
return "Very High"
|
||||
return "Extreme"
|
||||
return "Very High" if score < 80 else "Extreme"
|
||||
|
||||
def get_priority_score(self) -> float:
|
||||
"""Get priority score for refactoring (0-1, higher means higher priority)."""
|
||||
|
||||
@@ -39,11 +39,11 @@ class RadonComplexityAnalyzer:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
code = f.read()
|
||||
return self.analyze_code(code, str(file_path))
|
||||
except (OSError, PermissionError, UnicodeDecodeError):
|
||||
except (OSError, UnicodeDecodeError):
|
||||
# Return empty metrics for unreadable files
|
||||
return ComplexityMetrics()
|
||||
|
||||
def _analyze_with_radon(self, code: str, filename: str) -> ComplexityMetrics: # noqa: ARG002
|
||||
def _analyze_with_radon(self, code: str, filename: str) -> ComplexityMetrics: # noqa: ARG002
|
||||
"""Analyze code using Radon library."""
|
||||
metrics = ComplexityMetrics()
|
||||
|
||||
@@ -61,9 +61,7 @@ class RadonComplexityAnalyzer:
|
||||
|
||||
import radon.complexity
|
||||
|
||||
# Cyclomatic complexity
|
||||
cc_results = radon.complexity.cc_visit(code)
|
||||
if cc_results:
|
||||
if cc_results := radon.complexity.cc_visit(code):
|
||||
# Calculate average complexity from all functions/methods
|
||||
total_complexity = sum(
|
||||
getattr(block, "complexity", 0) for block in cc_results
|
||||
@@ -149,14 +147,15 @@ class RadonComplexityAnalyzer:
|
||||
[n for n in ast.walk(node) if isinstance(n, ast.Return)],
|
||||
)
|
||||
|
||||
# Count variables
|
||||
variables = set()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Name) and isinstance(
|
||||
variables = {
|
||||
node.id
|
||||
for node in ast.walk(tree)
|
||||
if isinstance(node, ast.Name)
|
||||
and isinstance(
|
||||
node.ctx,
|
||||
ast.Store | ast.Del,
|
||||
):
|
||||
variables.add(node.id)
|
||||
)
|
||||
}
|
||||
metrics.variables_count = len(variables)
|
||||
|
||||
except SyntaxError:
|
||||
@@ -266,24 +265,10 @@ class RadonComplexityAnalyzer:
|
||||
return "B" # Moderate
|
||||
if complexity_score <= 20:
|
||||
return "C" # High
|
||||
if complexity_score <= 30:
|
||||
return "D" # Very High
|
||||
return "F" # Extreme
|
||||
return "D" if complexity_score <= 30 else "F"
|
||||
import radon.complexity
|
||||
|
||||
if RADON_AVAILABLE:
|
||||
import radon.complexity
|
||||
|
||||
return str(radon.complexity.cc_rank(complexity_score))
|
||||
# Fallback if radon not available
|
||||
if complexity_score <= 5:
|
||||
return "A"
|
||||
if complexity_score <= 10:
|
||||
return "B"
|
||||
if complexity_score <= 20:
|
||||
return "C"
|
||||
if complexity_score <= 30:
|
||||
return "D"
|
||||
return "F"
|
||||
return str(radon.complexity.cc_rank(complexity_score))
|
||||
|
||||
def batch_analyze_files(
|
||||
self,
|
||||
@@ -310,7 +295,7 @@ class RadonComplexityAnalyzer:
|
||||
path = future_to_path[future]
|
||||
try:
|
||||
results[path] = future.result()
|
||||
except (OSError, PermissionError, UnicodeDecodeError):
|
||||
except (OSError, UnicodeDecodeError):
|
||||
# Create empty metrics for failed files
|
||||
results[path] = ComplexityMetrics()
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class ASTAnalyzer(ast.NodeVisitor):
|
||||
return []
|
||||
|
||||
# Reset analyzer state
|
||||
self.file_path = str(file_path)
|
||||
self.file_path = file_path
|
||||
self.content = content
|
||||
self.content_lines = content.splitlines()
|
||||
self.functions = []
|
||||
@@ -49,13 +49,11 @@ class ASTAnalyzer(ast.NodeVisitor):
|
||||
else:
|
||||
self.visit(tree)
|
||||
|
||||
# Filter blocks by minimum size
|
||||
filtered_blocks = []
|
||||
for block in self.code_blocks:
|
||||
if (block.end_line - block.start_line + 1) >= min_lines:
|
||||
filtered_blocks.append(block)
|
||||
|
||||
return filtered_blocks
|
||||
return [
|
||||
block
|
||||
for block in self.code_blocks
|
||||
if (block.end_line - block.start_line + 1) >= min_lines
|
||||
]
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
"""Visit function definitions with complexity analysis."""
|
||||
@@ -264,20 +262,17 @@ class ASTAnalyzer(ast.NodeVisitor):
|
||||
"""Detect common code smells."""
|
||||
smells = []
|
||||
|
||||
# Long methods
|
||||
long_methods = [f for f in self.functions if f.lines_count > 30]
|
||||
if long_methods:
|
||||
if long_methods := [f for f in self.functions if f.lines_count > 30]:
|
||||
smells.append(
|
||||
f"Long methods detected: {len(long_methods)} methods > 30 lines",
|
||||
)
|
||||
|
||||
# Complex methods
|
||||
complex_methods = [
|
||||
if complex_methods := [
|
||||
f
|
||||
for f in self.functions
|
||||
if f.complexity_metrics and f.complexity_metrics.cyclomatic_complexity > 10
|
||||
]
|
||||
if complex_methods:
|
||||
if f.complexity_metrics
|
||||
and f.complexity_metrics.cyclomatic_complexity > 10
|
||||
]:
|
||||
smells.append(
|
||||
f"Complex methods detected: {len(complex_methods)} methods "
|
||||
"with complexity > 10",
|
||||
@@ -287,12 +282,12 @@ class ASTAnalyzer(ast.NodeVisitor):
|
||||
for func in self.functions:
|
||||
try:
|
||||
tree = ast.parse(func.content)
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef) and len(node.args.args) > 5:
|
||||
smells.append(
|
||||
f"Method with many parameters: {func.function_name} "
|
||||
f"({len(node.args.args)} parameters)",
|
||||
)
|
||||
smells.extend(
|
||||
f"Method with many parameters: {func.function_name} ({len(node.args.args)} parameters)"
|
||||
for node in ast.walk(tree)
|
||||
if isinstance(node, ast.FunctionDef)
|
||||
and len(node.args.args) > 5
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logging.debug("Failed to analyze code smell for %s", self.file_path)
|
||||
|
||||
|
||||
@@ -73,22 +73,24 @@ class ExceptionFilter:
|
||||
if self._is_globally_excluded(file_path):
|
||||
return True, "File/directory globally excluded"
|
||||
|
||||
# Check exception rules
|
||||
for rule in self.active_rules:
|
||||
if self._rule_matches(
|
||||
rule,
|
||||
analysis_type,
|
||||
issue_type,
|
||||
file_path,
|
||||
line_number,
|
||||
line_content,
|
||||
):
|
||||
return (
|
||||
return next(
|
||||
(
|
||||
(
|
||||
True,
|
||||
rule.reason or f"Matched exception rule: {rule.analysis_type}",
|
||||
)
|
||||
|
||||
return False, None
|
||||
for rule in self.active_rules
|
||||
if self._rule_matches(
|
||||
rule,
|
||||
analysis_type,
|
||||
issue_type,
|
||||
file_path,
|
||||
line_number,
|
||||
line_content,
|
||||
)
|
||||
),
|
||||
(False, None),
|
||||
)
|
||||
|
||||
def _is_globally_excluded(self, file_path: str) -> bool:
|
||||
"""Check if file is globally excluded."""
|
||||
@@ -135,24 +137,22 @@ class ExceptionFilter:
|
||||
|
||||
# Check file patterns
|
||||
if rule.file_patterns:
|
||||
file_matches = False
|
||||
for pattern in rule.file_patterns:
|
||||
if fnmatch.fnmatch(file_path, pattern) or fnmatch.fnmatch(
|
||||
file_matches = any(
|
||||
fnmatch.fnmatch(file_path, pattern)
|
||||
or fnmatch.fnmatch(
|
||||
str(Path(file_path).name),
|
||||
pattern,
|
||||
):
|
||||
file_matches = True
|
||||
break
|
||||
)
|
||||
for pattern in rule.file_patterns
|
||||
)
|
||||
if not file_matches:
|
||||
return False
|
||||
|
||||
# Check line patterns
|
||||
if rule.line_patterns and line_content:
|
||||
line_matches = False
|
||||
for pattern in rule.line_patterns:
|
||||
if re.search(pattern, line_content):
|
||||
line_matches = True
|
||||
break
|
||||
line_matches = any(
|
||||
re.search(pattern, line_content) for pattern in rule.line_patterns
|
||||
)
|
||||
if not line_matches:
|
||||
return False
|
||||
|
||||
|
||||
@@ -259,7 +259,7 @@ class DuplicateDetectionEngine:
|
||||
complexity = self.complexity_analyzer.analyze_code(block.content)
|
||||
complexities.append(complexity.get_overall_score())
|
||||
|
||||
max_complexity = max(complexities) if complexities else 0.0
|
||||
max_complexity = max(complexities, default=0.0)
|
||||
|
||||
match = DuplicateMatch(
|
||||
blocks=group,
|
||||
@@ -365,21 +365,26 @@ class DuplicateDetectionEngine:
|
||||
has_class = any(isinstance(node, ast.ClassDef) for node in ast.walk(tree))
|
||||
|
||||
if has_function:
|
||||
suggestions.append(
|
||||
"Extract common function into a shared utility module",
|
||||
)
|
||||
suggestions.append(
|
||||
"Consider creating a base function with configurable parameters",
|
||||
suggestions.extend(
|
||||
(
|
||||
"Extract common function into a shared utility module",
|
||||
"Consider creating a base function with configurable parameters",
|
||||
)
|
||||
)
|
||||
elif has_class:
|
||||
suggestions.append("Extract common class into a base class or mixin")
|
||||
suggestions.append("Consider using inheritance or composition patterns")
|
||||
else:
|
||||
suggestions.append("Extract duplicate code into a reusable function")
|
||||
suggestions.append(
|
||||
"Consider creating a utility module for shared logic",
|
||||
suggestions.extend(
|
||||
(
|
||||
"Extract common class into a base class or mixin",
|
||||
"Consider using inheritance or composition patterns",
|
||||
)
|
||||
)
|
||||
else:
|
||||
suggestions.extend(
|
||||
(
|
||||
"Extract duplicate code into a reusable function",
|
||||
"Consider creating a utility module for shared logic",
|
||||
)
|
||||
)
|
||||
|
||||
# Complexity-based suggestions
|
||||
if duplicate_match.complexity_score > 60:
|
||||
suggestions.append(
|
||||
@@ -413,9 +418,7 @@ class DuplicateDetectionEngine:
|
||||
return "Low (1-2 hours)"
|
||||
if total_lines < 100:
|
||||
return "Medium (0.5-1 day)"
|
||||
if total_lines < 500:
|
||||
return "High (1-3 days)"
|
||||
return "Very High (1+ weeks)"
|
||||
return "High (1-3 days)" if total_lines < 500 else "Very High (1+ weeks)"
|
||||
|
||||
def _assess_refactoring_risk(self, duplicate_match: DuplicateMatch) -> str:
|
||||
"""Assess risk level of refactoring."""
|
||||
@@ -437,9 +440,7 @@ class DuplicateDetectionEngine:
|
||||
|
||||
if not risk_factors:
|
||||
return "Low"
|
||||
if len(risk_factors) <= 2:
|
||||
return "Medium"
|
||||
return "High"
|
||||
return "Medium" if len(risk_factors) <= 2 else "High"
|
||||
|
||||
def _get_content_preview(self, content: str, max_lines: int = 5) -> str:
|
||||
"""Get a preview of code content."""
|
||||
|
||||
@@ -161,19 +161,18 @@ class DuplicateMatcher:
|
||||
if len(match.blocks) < 2:
|
||||
return {"confidence": 0.0, "factors": []}
|
||||
|
||||
confidence_factors = []
|
||||
total_confidence = 0.0
|
||||
|
||||
# Similarity-based confidence
|
||||
similarity_confidence = match.similarity_score
|
||||
confidence_factors.append(
|
||||
confidence_factors = [
|
||||
{
|
||||
"factor": "similarity_score",
|
||||
"value": match.similarity_score,
|
||||
"weight": 0.4,
|
||||
"contribution": similarity_confidence * 0.4,
|
||||
},
|
||||
)
|
||||
}
|
||||
]
|
||||
total_confidence += similarity_confidence * 0.4
|
||||
|
||||
# Length-based confidence (longer matches are more reliable)
|
||||
@@ -293,9 +292,7 @@ class DuplicateMatcher:
|
||||
f"Merged cluster with {len(unique_blocks)} blocks "
|
||||
f"(avg similarity: {avg_score:.3f})"
|
||||
),
|
||||
complexity_score=max(complexity_scores)
|
||||
if complexity_scores
|
||||
else 0.0,
|
||||
complexity_score=max(complexity_scores, default=0.0),
|
||||
priority_score=avg_score,
|
||||
)
|
||||
merged_matches.append(merged_match)
|
||||
@@ -308,6 +305,4 @@ class DuplicateMatcher:
|
||||
return "High"
|
||||
if confidence >= 0.6:
|
||||
return "Medium"
|
||||
if confidence >= 0.4:
|
||||
return "Low"
|
||||
return "Very Low"
|
||||
return "Low" if confidence >= 0.4 else "Very Low"
|
||||
|
||||
@@ -182,7 +182,7 @@ class LSHDuplicateDetector:
|
||||
self.code_blocks: dict[str, CodeBlock] = {}
|
||||
|
||||
if LSH_AVAILABLE:
|
||||
self.lsh_index = MinHashLSH(threshold=float(threshold), num_perm=int(num_perm))
|
||||
self.lsh_index = MinHashLSH(threshold=threshold, num_perm=num_perm)
|
||||
|
||||
def add_code_block(self, block: CodeBlock) -> None:
|
||||
"""Add a code block to the LSH index."""
|
||||
@@ -220,8 +220,7 @@ class LSHDuplicateDetector:
|
||||
if candidate_id == block_id:
|
||||
continue
|
||||
|
||||
candidate_block = self.code_blocks.get(candidate_id)
|
||||
if candidate_block:
|
||||
if candidate_block := self.code_blocks.get(candidate_id):
|
||||
# Calculate exact similarity
|
||||
similarity = query_minhash.jaccard(self.minhashes[candidate_id])
|
||||
if similarity >= self.threshold:
|
||||
@@ -243,13 +242,9 @@ class LSHDuplicateDetector:
|
||||
if block_id in processed:
|
||||
continue
|
||||
|
||||
similar_blocks = self.find_similar_blocks(block)
|
||||
|
||||
if similar_blocks:
|
||||
if similar_blocks := self.find_similar_blocks(block):
|
||||
# Create group with original block and similar blocks
|
||||
group = [block]
|
||||
group.extend([similar_block for similar_block, _ in similar_blocks])
|
||||
|
||||
group = [block, *[similar_block for similar_block, _ in similar_blocks]]
|
||||
# Mark all blocks in group as processed
|
||||
processed.add(block_id)
|
||||
for similar_block, _ in similar_blocks:
|
||||
@@ -364,7 +359,7 @@ class BandingLSH:
|
||||
if len(sig1) != len(sig2):
|
||||
return 0.0
|
||||
|
||||
matches = sum(1 for a, b in zip(sig1, sig2, strict=False) if a == b)
|
||||
matches = sum(a == b for a, b in zip(sig1, sig2, strict=False))
|
||||
return matches / len(sig1)
|
||||
|
||||
def get_statistics(self) -> dict[str, object]:
|
||||
|
||||
@@ -93,9 +93,8 @@ class StructuralSimilarity(BaseSimilarityAlgorithm):
|
||||
|
||||
# Count methods in class
|
||||
method_count = sum(
|
||||
1
|
||||
isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
for child in node.body
|
||||
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
)
|
||||
structure.append(f"{depth_prefix}class_methods:{method_count}")
|
||||
|
||||
|
||||
@@ -251,7 +251,7 @@ class TFIDFSimilarity(BaseSimilarityAlgorithm):
|
||||
total_docs = len(documents)
|
||||
|
||||
for term in terms:
|
||||
docs_containing_term = sum(1 for doc in documents if term in doc)
|
||||
docs_containing_term = sum(term in doc for doc in documents)
|
||||
idf[term] = math.log(
|
||||
total_docs / (docs_containing_term + 1),
|
||||
) # +1 for smoothing
|
||||
|
||||
@@ -58,12 +58,12 @@ class FileFinder:
|
||||
if root_path.is_file():
|
||||
return [root_path] if self._is_python_file(root_path) else []
|
||||
|
||||
found_files = []
|
||||
for file_path in root_path.rglob("*.py"):
|
||||
if self._should_include_file(file_path) and self._is_python_file(file_path):
|
||||
found_files.append(file_path)
|
||||
|
||||
return found_files
|
||||
return [
|
||||
file_path
|
||||
for file_path in root_path.rglob("*.py")
|
||||
if self._should_include_file(file_path)
|
||||
and self._is_python_file(file_path)
|
||||
]
|
||||
|
||||
def _should_include_file(self, file_path: Path) -> bool:
|
||||
"""Check if a file should be included in analysis."""
|
||||
@@ -77,29 +77,30 @@ class FileFinder:
|
||||
):
|
||||
return False
|
||||
|
||||
# Check include patterns
|
||||
for pattern in self.path_config.include_patterns:
|
||||
if fnmatch.fnmatch(path_str, pattern) or fnmatch.fnmatch(
|
||||
file_path.name,
|
||||
pattern,
|
||||
):
|
||||
# Check if it's a supported file type
|
||||
return self._has_supported_extension(file_path)
|
||||
|
||||
return False
|
||||
return next(
|
||||
(
|
||||
self._has_supported_extension(file_path)
|
||||
for pattern in self.path_config.include_patterns
|
||||
if fnmatch.fnmatch(path_str, pattern)
|
||||
or fnmatch.fnmatch(
|
||||
file_path.name,
|
||||
pattern,
|
||||
)
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
def _has_supported_extension(self, file_path: Path) -> bool:
|
||||
"""Check if file has a supported extension."""
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
for lang in self.language_config.languages:
|
||||
if (
|
||||
return any(
|
||||
(
|
||||
lang in self.language_config.file_extensions
|
||||
and suffix in self.language_config.file_extensions[lang]
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
)
|
||||
for lang in self.language_config.languages
|
||||
)
|
||||
|
||||
def _is_python_file(self, file_path: Path) -> bool:
|
||||
"""Check if file is a Python file."""
|
||||
@@ -109,11 +110,14 @@ class FileFinder:
|
||||
"""Determine the programming language of a file."""
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
for lang, extensions in self.language_config.file_extensions.items():
|
||||
if suffix in extensions:
|
||||
return lang
|
||||
|
||||
return None
|
||||
return next(
|
||||
(
|
||||
lang
|
||||
for lang, extensions in self.language_config.file_extensions.items()
|
||||
if suffix in extensions
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def get_project_stats(self, root_path: Path) -> dict[str, Any]:
|
||||
"""Get statistics about files in the project."""
|
||||
@@ -173,15 +177,14 @@ class FileFinder:
|
||||
|
||||
# Apply include patterns
|
||||
if include and include_patterns:
|
||||
include = False
|
||||
for pattern in include_patterns:
|
||||
if fnmatch.fnmatch(path_str, pattern) or fnmatch.fnmatch(
|
||||
include = any(
|
||||
fnmatch.fnmatch(path_str, pattern)
|
||||
or fnmatch.fnmatch(
|
||||
file_path.name,
|
||||
pattern,
|
||||
):
|
||||
include = True
|
||||
break
|
||||
|
||||
)
|
||||
for pattern in include_patterns
|
||||
)
|
||||
if include:
|
||||
filtered.append(file_path)
|
||||
|
||||
|
||||
@@ -218,49 +218,43 @@ def reset_environment():
|
||||
|
||||
# Restore original environment
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
os.environ |= original_env
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_env_strict():
|
||||
"""Set environment for strict mode."""
|
||||
os.environ.update(
|
||||
{
|
||||
"QUALITY_ENFORCEMENT": "strict",
|
||||
"QUALITY_DUP_THRESHOLD": "0.7",
|
||||
"QUALITY_COMPLEXITY_THRESHOLD": "10",
|
||||
"QUALITY_DUP_ENABLED": "true",
|
||||
"QUALITY_COMPLEXITY_ENABLED": "true",
|
||||
"QUALITY_MODERN_ENABLED": "true",
|
||||
"QUALITY_REQUIRE_TYPES": "true",
|
||||
},
|
||||
)
|
||||
os.environ |= {
|
||||
"QUALITY_ENFORCEMENT": "strict",
|
||||
"QUALITY_DUP_THRESHOLD": "0.7",
|
||||
"QUALITY_COMPLEXITY_THRESHOLD": "10",
|
||||
"QUALITY_DUP_ENABLED": "true",
|
||||
"QUALITY_COMPLEXITY_ENABLED": "true",
|
||||
"QUALITY_MODERN_ENABLED": "true",
|
||||
"QUALITY_REQUIRE_TYPES": "true",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_env_permissive():
|
||||
"""Set environment for permissive mode."""
|
||||
os.environ.update(
|
||||
{
|
||||
"QUALITY_ENFORCEMENT": "permissive",
|
||||
"QUALITY_DUP_THRESHOLD": "0.9",
|
||||
"QUALITY_COMPLEXITY_THRESHOLD": "20",
|
||||
"QUALITY_DUP_ENABLED": "true",
|
||||
"QUALITY_COMPLEXITY_ENABLED": "true",
|
||||
"QUALITY_MODERN_ENABLED": "false",
|
||||
"QUALITY_REQUIRE_TYPES": "false",
|
||||
},
|
||||
)
|
||||
os.environ |= {
|
||||
"QUALITY_ENFORCEMENT": "permissive",
|
||||
"QUALITY_DUP_THRESHOLD": "0.9",
|
||||
"QUALITY_COMPLEXITY_THRESHOLD": "20",
|
||||
"QUALITY_DUP_ENABLED": "true",
|
||||
"QUALITY_COMPLEXITY_ENABLED": "true",
|
||||
"QUALITY_MODERN_ENABLED": "false",
|
||||
"QUALITY_REQUIRE_TYPES": "false",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_env_posttooluse():
|
||||
"""Set environment for PostToolUse features."""
|
||||
os.environ.update(
|
||||
{
|
||||
"QUALITY_STATE_TRACKING": "true",
|
||||
"QUALITY_CROSS_FILE_CHECK": "true",
|
||||
"QUALITY_VERIFY_NAMING": "true",
|
||||
"QUALITY_SHOW_SUCCESS": "true",
|
||||
},
|
||||
)
|
||||
os.environ |= {
|
||||
"QUALITY_STATE_TRACKING": "true",
|
||||
"QUALITY_CROSS_FILE_CHECK": "true",
|
||||
"QUALITY_VERIFY_NAMING": "true",
|
||||
"QUALITY_SHOW_SUCCESS": "true",
|
||||
}
|
||||
|
||||
@@ -561,7 +561,7 @@ class TestTempFileManagement:
|
||||
# .tmp directory should exist but temp file should be gone
|
||||
if tmp_dir.exists():
|
||||
temp_files = list(tmp_dir.glob("hook_validation_*"))
|
||||
assert len(temp_files) == 0
|
||||
assert not temp_files
|
||||
finally:
|
||||
import shutil
|
||||
if root.exists():
|
||||
|
||||
@@ -48,21 +48,19 @@ class TestQualityConfig:
|
||||
|
||||
def test_from_env_with_custom_values(self):
|
||||
"""Test loading config from environment with custom values."""
|
||||
os.environ.update(
|
||||
{
|
||||
"QUALITY_DUP_THRESHOLD": "0.8",
|
||||
"QUALITY_DUP_ENABLED": "false",
|
||||
"QUALITY_COMPLEXITY_THRESHOLD": "15",
|
||||
"QUALITY_COMPLEXITY_ENABLED": "false",
|
||||
"QUALITY_MODERN_ENABLED": "false",
|
||||
"QUALITY_REQUIRE_TYPES": "false",
|
||||
"QUALITY_ENFORCEMENT": "permissive",
|
||||
"QUALITY_STATE_TRACKING": "true",
|
||||
"QUALITY_CROSS_FILE_CHECK": "true",
|
||||
"QUALITY_VERIFY_NAMING": "false",
|
||||
"QUALITY_SHOW_SUCCESS": "true",
|
||||
},
|
||||
)
|
||||
os.environ |= {
|
||||
"QUALITY_DUP_THRESHOLD": "0.8",
|
||||
"QUALITY_DUP_ENABLED": "false",
|
||||
"QUALITY_COMPLEXITY_THRESHOLD": "15",
|
||||
"QUALITY_COMPLEXITY_ENABLED": "false",
|
||||
"QUALITY_MODERN_ENABLED": "false",
|
||||
"QUALITY_REQUIRE_TYPES": "false",
|
||||
"QUALITY_ENFORCEMENT": "permissive",
|
||||
"QUALITY_STATE_TRACKING": "true",
|
||||
"QUALITY_CROSS_FILE_CHECK": "true",
|
||||
"QUALITY_VERIFY_NAMING": "false",
|
||||
"QUALITY_SHOW_SUCCESS": "true",
|
||||
}
|
||||
|
||||
config = guard.QualityConfig.from_env()
|
||||
|
||||
|
||||
@@ -23,10 +23,10 @@ class TestEdgeCases:
|
||||
def test_massive_file_content(self):
|
||||
"""Test handling of very large files."""
|
||||
config = QualityConfig()
|
||||
# Create a file with 10,000 lines
|
||||
massive_content = "\n".join(f"# Line {i}" for i in range(10000))
|
||||
massive_content += "\ndef func1():\n pass\n"
|
||||
|
||||
massive_content = (
|
||||
"\n".join(f"# Line {i}" for i in range(10000))
|
||||
+ "\ndef func1():\n pass\n"
|
||||
)
|
||||
hook_data = {
|
||||
"tool_name": "Write",
|
||||
"tool_input": {
|
||||
|
||||
@@ -185,9 +185,7 @@ class TestHelperFunctions:
|
||||
set_platform("linux")
|
||||
|
||||
def fake_which(name: str) -> str | None:
|
||||
if name == "claude-quality":
|
||||
return "/usr/local/bin/claude-quality"
|
||||
return None
|
||||
return "/usr/local/bin/claude-quality" if name == "claude-quality" else None
|
||||
|
||||
monkeypatch.setattr(shutil, "which", fake_which)
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ class TestHookIntegration:
|
||||
"QUALITY_COMPLEXITY_ENABLED": "true",
|
||||
"QUALITY_MODERN_ENABLED": "false",
|
||||
}
|
||||
os.environ.update(env_overrides)
|
||||
os.environ |= env_overrides
|
||||
|
||||
complex_code = """
|
||||
def complex_func(a, b, c):
|
||||
|
||||
@@ -63,9 +63,7 @@ def test_ensure_tool_installed(
|
||||
suffix = str(path)
|
||||
if suffix.endswith("basedpyright"):
|
||||
return tool_exists
|
||||
if suffix.endswith("uv"):
|
||||
return not tool_exists
|
||||
return False
|
||||
return not tool_exists if suffix.endswith("uv") else False
|
||||
|
||||
monkeypatch.setattr(guard.Path, "exists", fake_exists, raising=False)
|
||||
|
||||
|
||||
@@ -209,7 +209,7 @@ class TestProjectRootAndTempFiles:
|
||||
)
|
||||
|
||||
# Should have run from project root
|
||||
assert len(captured_cwd) > 0
|
||||
assert captured_cwd
|
||||
assert captured_cwd[0] == root
|
||||
finally:
|
||||
import shutil
|
||||
|
||||
Reference in New Issue
Block a user