Add detection for typing.Any usage and enhance quality checks

- Implemented a new function to detect usage of typing.Any in code, which raises warnings during pre-tool use.
- Updated the pretooluse_hook to handle quality issues related to typing.Any and integrate it with existing quality checks.
- Modified the response structure to include permissionDecision instead of decision for consistency across hooks.
- Enhanced test coverage for typing.Any usage detection in both single and multi-edit scenarios.
- Adjusted existing tests to reflect changes in response structure and ensure proper validation of quality checks.
This commit is contained in:
2025-09-17 14:55:43 +00:00
parent 917b0de16c
commit f1b61a6ae7
6 changed files with 428 additions and 273 deletions

View File

@@ -5,6 +5,7 @@ Prevents writing duplicate, complex, or non-modernized code and verifies quality
after writes.
"""
import ast
import hashlib
import json
import logging
@@ -665,6 +666,64 @@ def verify_naming_conventions(file_path: str) -> list[str]:
return issues
def _detect_any_usage(content: str) -> list[str]:
"""Detect forbidden typing.Any usage in proposed content."""
class _AnyUsageVisitor(ast.NodeVisitor):
"""Collect line numbers where typing.Any is referenced."""
def __init__(self) -> None:
self.lines: set[int] = set()
def visit_Name(self, node: ast.Name) -> None:
if node.id == "Any":
self.lines.add(node.lineno)
self.generic_visit(node)
def visit_Attribute(self, node: ast.Attribute) -> None:
if node.attr == "Any":
self.lines.add(node.lineno)
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
for alias in node.names:
if alias.name == "Any" or alias.asname == "Any":
self.lines.add(node.lineno)
self.generic_visit(node)
def visit_Import(self, node: ast.Import) -> None:
for alias in node.names:
if alias.name == "Any" or alias.asname == "Any":
self.lines.add(node.lineno)
self.generic_visit(node)
lines_with_any: set[int] = set()
try:
tree = ast.parse(content)
except SyntaxError:
for index, line in enumerate(content.splitlines(), start=1):
code_portion = line.split("#", 1)[0]
if re.search(r"\bAny\b", code_portion):
lines_with_any.add(index)
else:
visitor = _AnyUsageVisitor()
visitor.visit(tree)
lines_with_any = visitor.lines
if not lines_with_any:
return []
sorted_lines = sorted(lines_with_any)
display_lines = ", ".join(str(num) for num in sorted_lines[:5])
if len(sorted_lines) > 5:
display_lines += ", …"
return [
"⚠️ Forbidden typing.Any usage at line(s) "
f"{display_lines}; replace with specific types",
]
def _perform_quality_check(
file_path: str,
content: str,
@@ -690,6 +749,8 @@ def _handle_quality_issues(
file_path: str,
issues: list[str],
config: QualityConfig,
*,
forced_permission: str | None = None,
) -> JsonObject:
"""Handle quality issues based on enforcement mode."""
# Prepare denial message
@@ -700,16 +761,20 @@ def _handle_quality_issues(
)
# Make decision based on enforcement mode
if forced_permission:
return _create_hook_response("PreToolUse", forced_permission, message)
if config.enforcement_mode == "strict":
return _create_hook_response("PreToolUse", "deny", message)
if config.enforcement_mode == "warn":
return _create_hook_response("PreToolUse", "ask", message)
# permissive
warning_message = f"⚠️ Quality Warning:\n{message}"
return _create_hook_response(
"PreToolUse",
"allow",
"",
f"⚠️ Quality Warning:\n{message}",
warning_message,
warning_message,
)
@@ -725,6 +790,8 @@ def _create_hook_response(
reason: str = "",
system_message: str = "",
additional_context: str = "",
*,
decision: str | None = None,
) -> JsonObject:
"""Create standardized hook response."""
hook_output: dict[str, object] = {
@@ -743,6 +810,15 @@ def _create_hook_response(
"hookSpecificOutput": hook_output,
}
if permission:
response["permissionDecision"] = permission
if decision:
response["decision"] = decision
if reason:
response["reason"] = reason
if system_message:
response["systemMessage"] = system_message
@@ -766,18 +842,23 @@ def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject:
content = ""
if tool_name == "Write":
content = str(tool_input.get("content", ""))
raw_content = tool_input.get("content", "")
content = "" if raw_content is None else str(raw_content)
elif tool_name == "Edit":
content = str(tool_input.get("new_string", ""))
new_string = tool_input.get("new_string", "")
content = "" if new_string is None else str(new_string)
elif tool_name == "MultiEdit":
edits = tool_input.get("edits", [])
if isinstance(edits, list):
edits_list = cast("list[object]", edits)
content = "\n".join(
str(cast("dict[str, object]", edit).get("new_string", ""))
for edit in edits_list
if isinstance(edit, dict)
)
parts: list[str] = []
for edit in edits_list:
if not isinstance(edit, dict):
continue
edit_dict = cast("dict[str, object]", edit)
new_str = edit_dict.get("new_string")
parts.append("") if new_str is None else parts.append(str(new_str))
content = "\n".join(parts)
# Only analyze Python files
if not file_path or not file_path.endswith(".py") or not content:
@@ -788,6 +869,7 @@ def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject:
return _create_hook_response("PreToolUse", "allow")
enable_type_checks = tool_name == "Write"
any_usage_issues = _detect_any_usage(content)
try:
has_issues, issues = _perform_quality_check(
@@ -797,15 +879,32 @@ def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject:
enable_type_checks=enable_type_checks,
)
if not has_issues:
all_issues = any_usage_issues + issues
if not all_issues:
return _create_hook_response("PreToolUse", "allow")
return _handle_quality_issues(file_path, issues, config)
if any_usage_issues:
return _handle_quality_issues(
file_path,
all_issues,
config,
forced_permission="ask",
)
return _handle_quality_issues(file_path, all_issues, config)
except Exception as e: # noqa: BLE001
if any_usage_issues:
return _handle_quality_issues(
file_path,
any_usage_issues,
config,
forced_permission="ask",
)
return _create_hook_response(
"PreToolUse",
"allow",
"",
f"Warning: Code quality check failed with error: {e}",
f"Warning: Code quality check failed with error: {e}",
)
@@ -816,7 +915,7 @@ def posttooluse_hook(
) -> JsonObject:
"""Handle PostToolUse hook - verify quality after write/edit."""
tool_name: str = str(hook_data.get("tool_name", ""))
tool_output = hook_data.get("tool_response", {})
tool_output = hook_data.get("tool_response", hook_data.get("tool_output", {}))
# Only process write/edit tools
if tool_name not in ["Write", "Edit", "MultiEdit"]:
@@ -866,13 +965,21 @@ def posttooluse_hook(
return _create_hook_response(
"PostToolUse",
"",
"",
message,
message,
message,
decision="block",
)
if config.show_success:
message = f"{Path(file_path).name} passed post-write verification"
return _create_hook_response("PostToolUse", "", "", message)
return _create_hook_response(
"PostToolUse",
"",
"",
message,
"",
decision="approve",
)
return _create_hook_response("PostToolUse")
@@ -898,7 +1005,7 @@ def main() -> None:
# Detect hook type: tool_response=PostToolUse, tool_input=PreToolUse
response: JsonObject
if "tool_response" in hook_data:
if "tool_response" in hook_data or "tool_output" in hook_data:
# PostToolUse hook
response = posttooluse_hook(hook_data, config)
else:

View File

@@ -87,7 +87,7 @@ class TestQualityConfig:
"""Test loading config with invalid float values."""
os.environ["QUALITY_DUP_THRESHOLD"] = "not_a_float"
with pytest.raises(ValueError, match="invalid literal"):
with pytest.raises(ValueError, match="could not convert string to float"):
QualityConfig.from_env()
def test_from_env_with_invalid_int(self):

View File

@@ -13,6 +13,10 @@ from code_quality_guard import (
)
def _perm(response: dict) -> str | None:
return response.get("hookSpecificOutput", {}).get("permissionDecision")
class TestEdgeCases:
"""Test edge cases and corner conditions."""
@@ -34,7 +38,7 @@ class TestEdgeCases:
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
mock_analyze.return_value = {}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert _perm(result) == "allow"
# Should still be called despite large file
mock_analyze.assert_called_once()
@@ -50,7 +54,7 @@ class TestEdgeCases:
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert _perm(result) == "allow"
def test_whitespace_only_content(self):
"""Test handling of whitespace-only content."""
@@ -64,7 +68,7 @@ class TestEdgeCases:
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert _perm(result) == "allow"
def test_malformed_python_syntax(self):
"""Test handling of syntax errors in Python code."""
@@ -86,11 +90,11 @@ def broken_func(
# Should gracefully handle syntax errors
result = pretooluse_hook(hook_data, config)
assert result["decision"] in ["allow", "deny", "ask"]
assert (
"error" in result.get("message", "").lower()
or result["decision"] == "allow"
)
decision = _perm(result)
assert decision in ["allow", "deny", "ask"]
if decision != "allow":
text = (result.get("reason") or "") + (result.get("systemMessage") or "")
assert "error" in text.lower()
def test_unicode_content(self):
"""Test handling of Unicode characters in code."""
@@ -115,7 +119,7 @@ def greet_世界():
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] in ["allow", "deny", "ask"]
assert _perm(result) in ["allow", "deny", "ask"]
def test_concurrent_hook_calls(self):
"""Test thread safety with concurrent calls."""
@@ -137,7 +141,7 @@ def greet_世界():
results.append(result)
# All should have the same decision
decisions = [r["decision"] for r in results]
decisions = [_perm(r) for r in results]
assert all(d == decisions[0] for d in decisions)
def test_missing_tool_input_fields(self):
@@ -150,7 +154,7 @@ def greet_世界():
"tool_input": {"content": "def test(): pass"},
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow" # Should handle gracefully
assert _perm(result) == "allow" # Should handle gracefully
# Missing content for Write
hook_data = {
@@ -158,7 +162,7 @@ def greet_世界():
"tool_input": {"file_path": "test.py"},
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow" # Should handle gracefully
assert _perm(result) == "allow" # Should handle gracefully
def test_circular_import_detection(self):
"""Test detection of circular imports."""
@@ -180,8 +184,7 @@ def func_c():
}
result = pretooluse_hook(hook_data, config)
# Should not crash on import analysis
assert result["decision"] in ["allow", "deny", "ask"]
assert _perm(result) in ["allow", "deny", "ask"]
def test_binary_file_path(self):
"""Test handling of binary file paths."""
@@ -195,7 +198,7 @@ def func_c():
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow" # Should skip non-Python files
assert _perm(result) == "allow" # Should skip non-Python files
def test_null_and_none_values(self):
"""Test handling of null/None values."""
@@ -210,7 +213,7 @@ def func_c():
},
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert _perm(result) == "allow"
# None as file_path
hook_data["tool_input"] = {
@@ -218,7 +221,7 @@ def func_c():
"content": "def test(): pass",
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert _perm(result) == "allow"
def test_path_traversal_attempts(self):
"""Test handling of path traversal attempts."""
@@ -239,8 +242,7 @@ def func_c():
},
}
result = pretooluse_hook(hook_data, config)
# Should handle without crashing
assert result["decision"] in ["allow", "deny", "ask"]
assert _perm(result) in ["allow", "deny", "ask"]
def test_extreme_thresholds(self):
"""Test with extreme threshold values."""
@@ -264,8 +266,7 @@ def func_c():
},
}
result = pretooluse_hook(hook_data, config)
# With threshold 0, everything should be flagged
assert result["decision"] == "deny"
assert _perm(result) == "deny"
# Maximum thresholds
config = QualityConfig(
@@ -282,8 +283,7 @@ def func_c():
},
}
result = pretooluse_hook(hook_data, config)
# With very high thresholds and permissive mode, should pass with warning
assert result["decision"] == "allow"
assert _perm(result) == "allow"
def test_subprocess_timeout(self):
"""Test handling of subprocess timeouts."""
@@ -345,8 +345,7 @@ def func_c():
mock_read.side_effect = PermissionError("Access denied")
result = posttooluse_hook(hook_data, config)
# Should handle permission errors gracefully
assert result["decision"] == "allow"
assert "decision" not in result
def test_deeply_nested_code_structure(self):
"""Test handling of deeply nested code."""
@@ -367,8 +366,7 @@ def func_c():
}
result = pretooluse_hook(hook_data, config)
# Should handle without stack overflow
assert result["decision"] in ["allow", "deny", "ask"]
assert _perm(result) in ["allow", "deny", "ask"]
def test_recursive_function_detection(self):
"""Test detection of recursive functions."""
@@ -392,8 +390,7 @@ def infinite_recursion():
}
result = pretooluse_hook(hook_data, config)
# Should handle recursive functions
assert result["decision"] in ["allow", "deny", "ask"]
assert _perm(result) in ["allow", "deny", "ask"]
def test_multifile_edit_paths(self):
"""Test MultiEdit with multiple file paths."""
@@ -413,7 +410,7 @@ def infinite_recursion():
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
mock_analyze.return_value = {}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert _perm(result) == "allow"
# Should concatenate all new_strings
call_args = mock_analyze.call_args[0][0]
assert "func1" in call_args
@@ -500,5 +497,4 @@ def cached_func():
}
result = pretooluse_hook(hook_data, config)
# Should handle modern Python syntax
assert result["decision"] in ["allow", "deny", "ask"]
assert _perm(result) in ["allow", "deny", "ask"]

View File

@@ -11,7 +11,7 @@ class TestHookIntegration:
"""Test complete hook integration scenarios."""
def test_main_entry_pretooluse(self):
"""Test main entry point detects PreToolUse."""
"""Ensure main dispatches to PreToolUse."""
from code_quality_guard import main
hook_input = {
@@ -22,19 +22,24 @@ class TestHookIntegration:
},
}
with patch("sys.stdin") as mock_stdin:
with patch("builtins.print"):
mock_stdin.read.return_value = json.dumps(hook_input)
mock_stdin.__iter__.return_value = [json.dumps(hook_input)]
with patch("sys.stdin") as mock_stdin, patch("builtins.print"):
mock_stdin.read.return_value = json.dumps(hook_input)
mock_stdin.__iter__.return_value = [json.dumps(hook_input)]
with patch("json.load", return_value=hook_input):
with patch("code_quality_guard.pretooluse_hook") as mock_pre:
mock_pre.return_value = {"decision": "allow"}
main()
mock_pre.assert_called_once()
with patch("json.load", return_value=hook_input), patch(
"code_quality_guard.pretooluse_hook",
return_value={
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow",
},
},
) as mock_pre:
main()
mock_pre.assert_called_once()
def test_main_entry_posttooluse(self):
"""Test main entry point detects PostToolUse."""
"""Ensure main dispatches to PostToolUse."""
from code_quality_guard import main
hook_input = {
@@ -45,39 +50,47 @@ class TestHookIntegration:
},
}
with patch("sys.stdin") as mock_stdin:
with patch("builtins.print"):
mock_stdin.read.return_value = json.dumps(hook_input)
mock_stdin.__iter__.return_value = [json.dumps(hook_input)]
with patch("sys.stdin") as mock_stdin, patch("builtins.print"):
mock_stdin.read.return_value = json.dumps(hook_input)
mock_stdin.__iter__.return_value = [json.dumps(hook_input)]
with patch("json.load", return_value=hook_input):
with patch("code_quality_guard.posttooluse_hook") as mock_post:
mock_post.return_value = {"decision": "allow"}
main()
mock_post.assert_called_once()
with patch("json.load", return_value=hook_input), patch(
"code_quality_guard.posttooluse_hook",
return_value={
"hookSpecificOutput": {
"hookEventName": "PostToolUse",
},
"decision": "approve",
},
) as mock_post:
main()
mock_post.assert_called_once()
def test_main_invalid_json(self):
"""Test main handles invalid JSON input."""
"""Invalid JSON falls back to allow."""
from code_quality_guard import main
with patch("sys.stdin"):
with patch("builtins.print") as mock_print:
with patch(
"json.load",
side_effect=json.JSONDecodeError("test", "test", 0),
):
main()
with patch("sys.stdin"), patch("builtins.print") as mock_print, patch(
"sys.stdout.write",
) as mock_write:
with patch(
"json.load",
side_effect=json.JSONDecodeError("test", "test", 0),
):
main()
# Should print allow decision
printed = mock_print.call_args[0][0]
response = json.loads(printed)
assert response["decision"] == "allow"
printed = (
mock_print.call_args[0][0]
if mock_print.call_args
else mock_write.call_args[0][0]
)
response = json.loads(printed)
assert response["hookSpecificOutput"]["permissionDecision"] == "allow"
def test_full_flow_clean_code(self, clean_code):
"""Test full flow with clean code."""
"""Clean code should pass both hook stages."""
from code_quality_guard import main
# PreToolUse
pre_input = {
"tool_name": "Write",
"tool_input": {
@@ -86,69 +99,66 @@ class TestHookIntegration:
},
}
with patch("sys.stdin"):
with patch("builtins.print") as mock_print:
with patch("json.load", return_value=pre_input):
main()
with patch("sys.stdin"), patch("builtins.print") as mock_print:
with patch("json.load", return_value=pre_input), patch(
"code_quality_guard.analyze_code_quality",
return_value={},
):
main()
printed = mock_print.call_args[0][0]
response = json.loads(printed)
assert response["decision"] == "allow"
response = json.loads(mock_print.call_args[0][0])
assert response["hookSpecificOutput"]["permissionDecision"] == "allow"
# Simulate file write
test_file = Path(f"{tempfile.gettempdir()}/clean.py")
test_file.write_text(clean_code)
# PostToolUse
post_input = {
"tool_name": "Write",
"tool_output": {
"file_path": f"{tempfile.gettempdir()}/clean.py",
"file_path": str(test_file),
"status": "success",
},
}
with patch("sys.stdin"):
with patch("builtins.print") as mock_print:
os.environ["QUALITY_SHOW_SUCCESS"] = "true"
try:
with patch("sys.stdin"), patch("builtins.print") as mock_print:
with patch("json.load", return_value=post_input):
os.environ["QUALITY_SHOW_SUCCESS"] = "true"
main()
printed = mock_print.call_args[0][0]
response = json.loads(printed)
assert response["decision"] == "allow"
assert "passed" in response.get("message", "").lower()
test_file.unlink(missing_ok=True)
response = json.loads(mock_print.call_args[0][0])
assert response.get("decision") == "approve"
assert "passed" in response.get("systemMessage", "").lower()
finally:
os.environ.pop("QUALITY_SHOW_SUCCESS", None)
test_file.unlink(missing_ok=True)
def test_environment_configuration_flow(self):
"""Test that environment variables are properly used."""
"""Environment settings change enforcement."""
from code_quality_guard import main
# Set strict environment
os.environ.update(
{
"QUALITY_ENFORCEMENT": "strict",
"QUALITY_COMPLEXITY_THRESHOLD": "5", # Very low threshold
"QUALITY_DUP_ENABLED": "false",
"QUALITY_COMPLEXITY_ENABLED": "true", # Keep complexity enabled
"QUALITY_MODERN_ENABLED": "false",
},
)
env_overrides = {
"QUALITY_ENFORCEMENT": "strict",
"QUALITY_COMPLEXITY_THRESHOLD": "5",
"QUALITY_DUP_ENABLED": "false",
"QUALITY_COMPLEXITY_ENABLED": "true",
"QUALITY_MODERN_ENABLED": "false",
}
os.environ.update(env_overrides)
complex_code = """
def complex_func(a, b, c):
if a:
if b:
if c:
return 1
def complex_func(a, b, c):
if a:
if b:
if c:
return 1
else:
return 2
else:
return 3
else:
return 2
else:
return 3
else:
return 4
"""
return 4
"""
hook_input = {
"tool_name": "Write",
@@ -158,38 +168,36 @@ def complex_func(a, b, c):
},
}
with patch("sys.stdin"):
with patch("builtins.print") as mock_print:
with patch("json.load", return_value=hook_input):
with patch(
"code_quality_guard.analyze_code_quality",
) as mock_analyze:
# Mock the complexity analysis result
mock_analyze.return_value = {
"complexity": {
"summary": {
"average_cyclomatic_complexity": 8,
}, # Above threshold
"distribution": {"High": 1},
},
}
try:
main()
msg = "Expected SystemExit"
raise AssertionError(msg)
except SystemExit as e:
assert e.code == 2, "Expected exit code 2 for deny" # noqa: PT017
try:
with patch("sys.stdin"), patch("builtins.print") as mock_print:
with patch("json.load", return_value=hook_input), patch(
"code_quality_guard.analyze_code_quality",
return_value={
"complexity": {
"summary": {"average_cyclomatic_complexity": 8},
"distribution": {"High": 1},
},
},
):
try:
main()
raise AssertionError("Expected SystemExit")
except SystemExit as exc:
assert exc.code == 2
printed = mock_print.call_args[0][0]
response = json.loads(printed)
# Should be denied due to low complexity threshold
assert response["decision"] == "deny"
response = json.loads(mock_print.call_args[0][0])
assert (
response["hookSpecificOutput"]["permissionDecision"]
== "deny"
)
finally:
for key in env_overrides:
os.environ.pop(key, None)
def test_skip_patterns_integration(self):
"""Test skip patterns work in integration."""
"""Skip patterns should bypass checks."""
from code_quality_guard import main
# Test file should be skipped
hook_input = {
"tool_name": "Write",
"tool_input": {
@@ -198,72 +206,62 @@ def complex_func(a, b, c):
},
}
with patch("sys.stdin"):
with patch("builtins.print") as mock_print:
with patch("json.load", return_value=hook_input):
main()
with patch("sys.stdin"), patch("builtins.print") as mock_print:
with patch("json.load", return_value=hook_input):
main()
printed = mock_print.call_args[0][0]
response = json.loads(printed)
assert response["decision"] == "allow"
response = json.loads(mock_print.call_args[0][0])
assert response["hookSpecificOutput"]["permissionDecision"] == "allow"
def test_state_tracking_flow(self, temp_python_file):
"""Test state tracking between pre and post."""
"""State tracking should flag regressions."""
from code_quality_guard import main
os.environ["QUALITY_STATE_TRACKING"] = "true"
try:
pre_input = {
"tool_name": "Write",
"tool_input": {
"file_path": str(temp_python_file),
"content": "def func1(): pass\ndef func2(): pass\ndef func3(): pass",
},
}
# PreToolUse - store state
initial_content = "def func1(): pass\ndef func2(): pass\ndef func3(): pass"
pre_input = {
"tool_name": "Write",
"tool_input": {
"file_path": str(temp_python_file),
"content": initial_content,
},
}
with patch("sys.stdin"):
with patch("builtins.print") as mock_print:
with patch("json.load", return_value=pre_input):
with patch("sys.stdin"), patch("builtins.print") as mock_print:
with patch("json.load", return_value=pre_input), patch(
"code_quality_guard.analyze_code_quality",
return_value={},
):
main()
# Simulate file modification (fewer functions)
modified_content = "def func1(): pass"
temp_python_file.write_text(modified_content)
temp_python_file.write_text("def func1(): pass")
# PostToolUse - check state
post_input = {
"tool_name": "Write",
"tool_output": {
"file_path": str(temp_python_file),
"status": "success",
},
}
post_input = {
"tool_name": "Write",
"tool_output": {
"file_path": str(temp_python_file),
"status": "success",
},
}
with patch("sys.stdin"):
with patch("builtins.print") as mock_print:
with patch("sys.stdin"), patch("builtins.print") as mock_print:
with patch("json.load", return_value=post_input):
main()
printed = mock_print.call_args[0][0]
response = json.loads(printed)
assert response["decision"] == "allow"
# Should detect function reduction
if "message" in response:
assert (
"reduced" in response["message"].lower()
or len(response["message"]) == 0
)
response = json.loads(mock_print.call_args[0][0])
assert response["decision"] == "block"
assert "reduced" in response["reason"].lower()
finally:
os.environ.pop("QUALITY_STATE_TRACKING", None)
def test_cross_tool_handling(self):
"""Test different tools are handled correctly."""
"""Supported tools should respond with allow."""
from code_quality_guard import main
tools = ["Write", "Edit", "MultiEdit", "Read", "Bash", "Task"]
for tool in tools:
if tool in ["Write", "Edit", "MultiEdit"]:
if tool in {"Write", "Edit", "MultiEdit"}:
hook_input = {
"tool_name": tool,
"tool_input": {
@@ -272,22 +270,20 @@ def complex_func(a, b, c):
},
}
else:
hook_input = {
"tool_name": tool,
"tool_input": {},
}
hook_input = {"tool_name": tool, "tool_input": {}}
with patch("sys.stdin"):
with patch("builtins.print") as mock_print:
with patch("json.load", return_value=hook_input):
main()
with patch("sys.stdin"), patch("builtins.print") as mock_print:
with patch("json.load", return_value=hook_input), patch(
"code_quality_guard.analyze_code_quality",
return_value={},
):
main()
printed = mock_print.call_args[0][0]
response = json.loads(printed)
assert response["decision"] == "allow"
response = json.loads(mock_print.call_args[0][0])
assert response["hookSpecificOutput"]["permissionDecision"] == "allow"
def test_enforcement_mode_progression(self, complex_code):
"""Test progression through enforcement modes."""
"""Strict/warn/permissive modes map to deny/ask/allow."""
from code_quality_guard import main
hook_input = {
@@ -298,33 +294,40 @@ def complex_func(a, b, c):
},
}
modes_and_decisions = [
scenarios = [
("strict", "deny"),
("warn", "ask"),
("permissive", "allow"),
]
for mode, expected_decision in modes_and_decisions:
for mode, expected in scenarios:
os.environ["QUALITY_ENFORCEMENT"] = mode
os.environ["QUALITY_COMPLEXITY_THRESHOLD"] = "10"
with patch("sys.stdin"):
with patch("builtins.print") as mock_print:
with patch("json.load", return_value=hook_input):
if expected_decision in ["deny", "ask"]:
# Expect SystemExit with code 2 for deny/ask decisions
try:
with patch("sys.stdin"), patch("builtins.print") as mock_print:
with patch("json.load", return_value=hook_input), patch(
"code_quality_guard.analyze_code_quality",
return_value={
"complexity": {
"summary": {"average_cyclomatic_complexity": 25},
"distribution": {"High": 1},
},
},
):
if expected in {"deny", "ask"}:
try:
main()
msg = f"Expected SystemExit for {mode} mode"
raise AssertionError(msg)
except SystemExit as e:
assert e.code == 2, ( # noqa: PT017
f"Expected exit code 2 for {mode} mode"
)
raise AssertionError("Expected SystemExit")
except SystemExit as exc:
assert exc.code == 2
else:
# Permissive mode should not exit
main()
printed = mock_print.call_args[0][0]
response = json.loads(printed)
assert response["decision"] == expected_decision
response = json.loads(mock_print.call_args[0][0])
assert (
response["hookSpecificOutput"]["permissionDecision"]
== expected
)
finally:
os.environ.pop("QUALITY_ENFORCEMENT", None)
os.environ.pop("QUALITY_COMPLEXITY_THRESHOLD", None)

View File

@@ -18,7 +18,8 @@ class TestPostToolUseHook:
}
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert result["hookSpecificOutput"]["hookEventName"] == "PostToolUse"
assert "decision" not in result
def test_file_path_extraction_dict(self):
"""Test file path extraction from dict output."""
@@ -34,7 +35,7 @@ class TestPostToolUseHook:
with patch("pathlib.Path.exists", return_value=True):
with patch("pathlib.Path.read_text", return_value="def test(): pass"):
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "decision" not in result
# Test with path key
hook_data["tool_output"] = {"path": test_file}
@@ -42,7 +43,7 @@ class TestPostToolUseHook:
with patch("pathlib.Path.exists", return_value=True):
with patch("pathlib.Path.read_text", return_value="def test(): pass"):
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "decision" not in result
def test_file_path_extraction_string(self):
"""Test file path extraction from string output."""
@@ -55,7 +56,7 @@ class TestPostToolUseHook:
with patch("pathlib.Path.exists", return_value=True):
with patch("pathlib.Path.read_text", return_value="def test(): pass"):
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "decision" not in result
def test_non_python_file_skipped(self):
"""Test that non-Python files are skipped."""
@@ -66,7 +67,7 @@ class TestPostToolUseHook:
}
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "decision" not in result
def test_nonexistent_file_skipped(self):
"""Test that nonexistent files are skipped."""
@@ -78,7 +79,7 @@ class TestPostToolUseHook:
with patch("pathlib.Path.exists", return_value=False):
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "decision" not in result
def test_state_tracking_degradation(self):
"""Test state tracking detects quality degradation."""
@@ -97,9 +98,10 @@ class TestPostToolUseHook:
]
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "Post-write quality notes" in result["message"]
assert "Reduced functions" in result["message"]
assert result["decision"] == "block"
reason_text = result["reason"].lower()
assert "post-write quality notes" in reason_text
assert "reduced functions" in reason_text
def test_cross_file_duplicates(self):
"""Test cross-file duplicate detection."""
@@ -117,8 +119,8 @@ class TestPostToolUseHook:
mock_check.return_value = ["⚠️ Cross-file duplication detected"]
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "Cross-file duplication" in result["message"]
assert result["decision"] == "block"
assert "cross-file duplication" in result["reason"].lower()
def test_naming_convention_violations(self, non_pep8_code):
"""Test naming convention verification."""
@@ -131,9 +133,9 @@ class TestPostToolUseHook:
with patch("pathlib.Path.exists", return_value=True):
with patch("pathlib.Path.read_text", return_value=non_pep8_code):
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "Non-PEP8 function names" in result["message"]
assert "Non-PEP8 class names" in result["message"]
assert result["decision"] == "block"
assert "non-pep8 function names" in result["reason"].lower()
assert "non-pep8 class names" in result["reason"].lower()
def test_show_success_message(self, clean_code):
"""Test success message when enabled."""
@@ -146,8 +148,8 @@ class TestPostToolUseHook:
with patch("pathlib.Path.exists", return_value=True):
with patch("pathlib.Path.read_text", return_value=clean_code):
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "passed post-write verification" in result["message"]
assert result["decision"] == "approve"
assert "passed post-write verification" in result["systemMessage"].lower()
def test_no_message_when_success_disabled(self, clean_code):
"""Test no message when show_success is disabled."""
@@ -160,8 +162,8 @@ class TestPostToolUseHook:
with patch("pathlib.Path.exists", return_value=True):
with patch("pathlib.Path.read_text", return_value=clean_code):
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "message" not in result
assert "decision" not in result
assert "systemMessage" not in result
def test_all_features_combined(self):
"""Test all PostToolUse features combined."""
@@ -190,10 +192,11 @@ class TestPostToolUseHook:
mock_naming.return_value = ["⚠️ Issue 3"]
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "Issue 1" in result["message"]
assert "Issue 2" in result["message"]
assert "Issue 3" in result["message"]
assert result["decision"] == "block"
reason_text = result["reason"].lower()
assert "issue 1" in reason_text
assert "issue 2" in reason_text
assert "issue 3" in reason_text
def test_edit_tool_output(self):
"""Test Edit tool output handling."""
@@ -209,7 +212,7 @@ class TestPostToolUseHook:
with patch("pathlib.Path.exists", return_value=True):
with patch("pathlib.Path.read_text", return_value="def test(): pass"):
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "decision" not in result
def test_multiedit_tool_output(self):
"""Test MultiEdit tool output handling."""
@@ -225,7 +228,7 @@ class TestPostToolUseHook:
with patch("pathlib.Path.exists", return_value=True):
with patch("pathlib.Path.read_text", return_value="def test(): pass"):
result = posttooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "decision" not in result
def test_features_disabled(self):
"""Test with all features disabled."""
@@ -257,5 +260,5 @@ class TestPostToolUseHook:
mock_cross.assert_not_called()
mock_naming.assert_not_called()
assert result["decision"] == "allow"
assert "message" not in result
assert "decision" not in result
assert "systemMessage" not in result

View File

@@ -17,7 +17,7 @@ class TestPreToolUseHook:
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert result["permissionDecision"] == "allow"
def test_non_python_file_allowed(self):
"""Test that non-Python files are always allowed."""
@@ -31,7 +31,7 @@ class TestPreToolUseHook:
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert result["permissionDecision"] == "allow"
def test_test_file_skipped(self):
"""Test that test files are skipped when configured."""
@@ -45,7 +45,7 @@ class TestPreToolUseHook:
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert result["permissionDecision"] == "allow"
def test_clean_code_allowed(self, clean_code):
"""Test that clean code is allowed."""
@@ -61,7 +61,7 @@ class TestPreToolUseHook:
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
mock_analyze.return_value = {}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert result["permissionDecision"] == "allow"
def test_complex_code_denied_strict(self, complex_code):
"""Test that complex code is denied in strict mode."""
@@ -83,8 +83,8 @@ class TestPreToolUseHook:
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "deny"
assert "quality check failed" in result["message"].lower()
assert result["permissionDecision"] == "deny"
assert "quality check failed" in result["reason"].lower()
def test_complex_code_ask_warn_mode(self, complex_code):
"""Test that complex code triggers ask in warn mode."""
@@ -106,7 +106,7 @@ class TestPreToolUseHook:
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "ask"
assert result["permissionDecision"] == "ask"
def test_complex_code_allowed_permissive(self, complex_code):
"""Test that complex code is allowed with warning in permissive mode."""
@@ -128,8 +128,8 @@ class TestPreToolUseHook:
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "warning" in result["message"].lower()
assert result["permissionDecision"] == "allow"
assert "warning" in result.get("reason", "").lower()
def test_duplicate_code_detection(self, duplicate_code):
"""Test internal duplicate detection."""
@@ -162,8 +162,8 @@ class TestPreToolUseHook:
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "deny"
assert "duplication" in result["message"].lower()
assert result["permissionDecision"] == "deny"
assert "duplication" in result["reason"].lower()
def test_edit_tool_handling(self):
"""Test Edit tool content extraction."""
@@ -180,7 +180,7 @@ class TestPreToolUseHook:
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
mock_analyze.return_value = {}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert result["permissionDecision"] == "allow"
# Verify new_string was analyzed
call_args = mock_analyze.call_args[0]
@@ -203,7 +203,7 @@ class TestPreToolUseHook:
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
mock_analyze.return_value = {}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert result["permissionDecision"] == "allow"
# Verify concatenated content was analyzed
call_args = mock_analyze.call_args[0]
@@ -245,8 +245,8 @@ class TestPreToolUseHook:
mock_analyze.side_effect = Exception("Analysis failed")
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert "error" in result["message"].lower()
assert result["permissionDecision"] == "allow"
assert "error" in result.get("reason", "").lower()
def test_custom_skip_patterns(self):
"""Test custom skip patterns."""
@@ -261,12 +261,12 @@ class TestPreToolUseHook:
},
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert result["permissionDecision"] == "allow"
# Test path pattern match
hook_data["tool_input"]["file_path"] = "/ignored/file.py"
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "allow"
assert result["permissionDecision"] == "allow"
def test_modernization_issues(self, old_style_code):
"""Test modernization issue detection."""
@@ -292,8 +292,8 @@ class TestPreToolUseHook:
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "deny"
assert "modernization" in result["message"].lower()
assert result["permissionDecision"] == "deny"
assert "modernization" in result["reason"].lower()
def test_type_hint_threshold(self):
"""Test type hint issue threshold."""
@@ -320,5 +320,51 @@ class TestPreToolUseHook:
}
result = pretooluse_hook(hook_data, config)
assert result["decision"] == "deny"
assert "type hints" in result["message"].lower()
assert result["permissionDecision"] == "deny"
assert "type hints" in result["reason"].lower()
def test_any_usage_denied(self):
"""Test that typing.Any usage triggers a denial."""
config = QualityConfig(enforcement_mode="strict")
hook_data = {
"tool_name": "Write",
"tool_input": {
"file_path": "example.py",
"content": "from typing import Any\n\n"
"def example(value: Any) -> None:\n pass\n",
},
}
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
mock_analyze.return_value = {}
result = pretooluse_hook(hook_data, config)
assert result["permissionDecision"] == "ask"
assert "any" in result["reason"].lower()
def test_any_usage_detected_in_multiedit(self):
"""Test that MultiEdit content is scanned for typing.Any usage."""
config = QualityConfig()
hook_data = {
"tool_name": "MultiEdit",
"tool_input": {
"file_path": "example.py",
"edits": [
{
"old_string": "pass",
"new_string": "from typing import Any\n",
},
{
"old_string": "pass",
"new_string": "def handler(arg: Any) -> str:\n return str(arg)\n",
},
],
},
}
with patch("code_quality_guard.analyze_code_quality") as mock_analyze:
mock_analyze.return_value = {}
result = pretooluse_hook(hook_data, config)
assert result["permissionDecision"] == "ask"
assert "any" in result["reason"].lower()