From f1b61a6ae73305277410833a98f542733eb49504 Mon Sep 17 00:00:00 2001 From: Travis Vasceannie Date: Wed, 17 Sep 2025 14:55:43 +0000 Subject: [PATCH] Add detection for typing.Any usage and enhance quality checks - Implemented a new function to detect usage of typing.Any in code, which raises warnings during pre-tool use. - Updated the pretooluse_hook to handle quality issues related to typing.Any and integrate it with existing quality checks. - Modified the response structure to include permissionDecision instead of decision for consistency across hooks. - Enhanced test coverage for typing.Any usage detection in both single and multi-edit scenarios. - Adjusted existing tests to reflect changes in response structure and ensure proper validation of quality checks. --- hooks/code_quality_guard.py | 139 +++++++++++-- tests/hooks/test_config.py | 2 +- tests/hooks/test_edge_cases.py | 60 +++--- tests/hooks/test_integration.py | 357 ++++++++++++++++---------------- tests/hooks/test_posttooluse.py | 55 ++--- tests/hooks/test_pretooluse.py | 88 ++++++-- 6 files changed, 428 insertions(+), 273 deletions(-) diff --git a/hooks/code_quality_guard.py b/hooks/code_quality_guard.py index 2fbd14a..d42e187 100644 --- a/hooks/code_quality_guard.py +++ b/hooks/code_quality_guard.py @@ -5,6 +5,7 @@ Prevents writing duplicate, complex, or non-modernized code and verifies quality after writes. """ +import ast import hashlib import json import logging @@ -665,6 +666,64 @@ def verify_naming_conventions(file_path: str) -> list[str]: return issues +def _detect_any_usage(content: str) -> list[str]: + """Detect forbidden typing.Any usage in proposed content.""" + + class _AnyUsageVisitor(ast.NodeVisitor): + """Collect line numbers where typing.Any is referenced.""" + + def __init__(self) -> None: + self.lines: set[int] = set() + + def visit_Name(self, node: ast.Name) -> None: + if node.id == "Any": + self.lines.add(node.lineno) + self.generic_visit(node) + + def visit_Attribute(self, node: ast.Attribute) -> None: + if node.attr == "Any": + self.lines.add(node.lineno) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + for alias in node.names: + if alias.name == "Any" or alias.asname == "Any": + self.lines.add(node.lineno) + self.generic_visit(node) + + def visit_Import(self, node: ast.Import) -> None: + for alias in node.names: + if alias.name == "Any" or alias.asname == "Any": + self.lines.add(node.lineno) + self.generic_visit(node) + + lines_with_any: set[int] = set() + try: + tree = ast.parse(content) + except SyntaxError: + for index, line in enumerate(content.splitlines(), start=1): + code_portion = line.split("#", 1)[0] + if re.search(r"\bAny\b", code_portion): + lines_with_any.add(index) + else: + visitor = _AnyUsageVisitor() + visitor.visit(tree) + lines_with_any = visitor.lines + + if not lines_with_any: + return [] + + sorted_lines = sorted(lines_with_any) + display_lines = ", ".join(str(num) for num in sorted_lines[:5]) + if len(sorted_lines) > 5: + display_lines += ", …" + + return [ + "⚠️ Forbidden typing.Any usage at line(s) " + f"{display_lines}; replace with specific types", + ] + + def _perform_quality_check( file_path: str, content: str, @@ -690,6 +749,8 @@ def _handle_quality_issues( file_path: str, issues: list[str], config: QualityConfig, + *, + forced_permission: str | None = None, ) -> JsonObject: """Handle quality issues based on enforcement mode.""" # Prepare denial message @@ -700,16 +761,20 @@ def _handle_quality_issues( ) # Make decision based on enforcement mode + if forced_permission: + return _create_hook_response("PreToolUse", forced_permission, message) + if config.enforcement_mode == "strict": return _create_hook_response("PreToolUse", "deny", message) if config.enforcement_mode == "warn": return _create_hook_response("PreToolUse", "ask", message) # permissive + warning_message = f"⚠️ Quality Warning:\n{message}" return _create_hook_response( "PreToolUse", "allow", - "", - f"⚠️ Quality Warning:\n{message}", + warning_message, + warning_message, ) @@ -725,6 +790,8 @@ def _create_hook_response( reason: str = "", system_message: str = "", additional_context: str = "", + *, + decision: str | None = None, ) -> JsonObject: """Create standardized hook response.""" hook_output: dict[str, object] = { @@ -743,6 +810,15 @@ def _create_hook_response( "hookSpecificOutput": hook_output, } + if permission: + response["permissionDecision"] = permission + + if decision: + response["decision"] = decision + + if reason: + response["reason"] = reason + if system_message: response["systemMessage"] = system_message @@ -766,18 +842,23 @@ def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject: content = "" if tool_name == "Write": - content = str(tool_input.get("content", "")) + raw_content = tool_input.get("content", "") + content = "" if raw_content is None else str(raw_content) elif tool_name == "Edit": - content = str(tool_input.get("new_string", "")) + new_string = tool_input.get("new_string", "") + content = "" if new_string is None else str(new_string) elif tool_name == "MultiEdit": edits = tool_input.get("edits", []) if isinstance(edits, list): edits_list = cast("list[object]", edits) - content = "\n".join( - str(cast("dict[str, object]", edit).get("new_string", "")) - for edit in edits_list - if isinstance(edit, dict) - ) + parts: list[str] = [] + for edit in edits_list: + if not isinstance(edit, dict): + continue + edit_dict = cast("dict[str, object]", edit) + new_str = edit_dict.get("new_string") + parts.append("") if new_str is None else parts.append(str(new_str)) + content = "\n".join(parts) # Only analyze Python files if not file_path or not file_path.endswith(".py") or not content: @@ -788,6 +869,7 @@ def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject: return _create_hook_response("PreToolUse", "allow") enable_type_checks = tool_name == "Write" + any_usage_issues = _detect_any_usage(content) try: has_issues, issues = _perform_quality_check( @@ -797,15 +879,32 @@ def pretooluse_hook(hook_data: JsonObject, config: QualityConfig) -> JsonObject: enable_type_checks=enable_type_checks, ) - if not has_issues: + all_issues = any_usage_issues + issues + + if not all_issues: return _create_hook_response("PreToolUse", "allow") - return _handle_quality_issues(file_path, issues, config) + if any_usage_issues: + return _handle_quality_issues( + file_path, + all_issues, + config, + forced_permission="ask", + ) + + return _handle_quality_issues(file_path, all_issues, config) except Exception as e: # noqa: BLE001 + if any_usage_issues: + return _handle_quality_issues( + file_path, + any_usage_issues, + config, + forced_permission="ask", + ) return _create_hook_response( "PreToolUse", "allow", - "", + f"Warning: Code quality check failed with error: {e}", f"Warning: Code quality check failed with error: {e}", ) @@ -816,7 +915,7 @@ def posttooluse_hook( ) -> JsonObject: """Handle PostToolUse hook - verify quality after write/edit.""" tool_name: str = str(hook_data.get("tool_name", "")) - tool_output = hook_data.get("tool_response", {}) + tool_output = hook_data.get("tool_response", hook_data.get("tool_output", {})) # Only process write/edit tools if tool_name not in ["Write", "Edit", "MultiEdit"]: @@ -866,13 +965,21 @@ def posttooluse_hook( return _create_hook_response( "PostToolUse", "", - "", message, message, + message, + decision="block", ) if config.show_success: message = f"✅ {Path(file_path).name} passed post-write verification" - return _create_hook_response("PostToolUse", "", "", message) + return _create_hook_response( + "PostToolUse", + "", + "", + message, + "", + decision="approve", + ) return _create_hook_response("PostToolUse") @@ -898,7 +1005,7 @@ def main() -> None: # Detect hook type: tool_response=PostToolUse, tool_input=PreToolUse response: JsonObject - if "tool_response" in hook_data: + if "tool_response" in hook_data or "tool_output" in hook_data: # PostToolUse hook response = posttooluse_hook(hook_data, config) else: diff --git a/tests/hooks/test_config.py b/tests/hooks/test_config.py index 960fb64..6668c6f 100644 --- a/tests/hooks/test_config.py +++ b/tests/hooks/test_config.py @@ -87,7 +87,7 @@ class TestQualityConfig: """Test loading config with invalid float values.""" os.environ["QUALITY_DUP_THRESHOLD"] = "not_a_float" - with pytest.raises(ValueError, match="invalid literal"): + with pytest.raises(ValueError, match="could not convert string to float"): QualityConfig.from_env() def test_from_env_with_invalid_int(self): diff --git a/tests/hooks/test_edge_cases.py b/tests/hooks/test_edge_cases.py index cda718f..9524d36 100644 --- a/tests/hooks/test_edge_cases.py +++ b/tests/hooks/test_edge_cases.py @@ -13,6 +13,10 @@ from code_quality_guard import ( ) +def _perm(response: dict) -> str | None: + return response.get("hookSpecificOutput", {}).get("permissionDecision") + + class TestEdgeCases: """Test edge cases and corner conditions.""" @@ -34,7 +38,7 @@ class TestEdgeCases: with patch("code_quality_guard.analyze_code_quality") as mock_analyze: mock_analyze.return_value = {} result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert _perm(result) == "allow" # Should still be called despite large file mock_analyze.assert_called_once() @@ -50,7 +54,7 @@ class TestEdgeCases: } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert _perm(result) == "allow" def test_whitespace_only_content(self): """Test handling of whitespace-only content.""" @@ -64,7 +68,7 @@ class TestEdgeCases: } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert _perm(result) == "allow" def test_malformed_python_syntax(self): """Test handling of syntax errors in Python code.""" @@ -86,11 +90,11 @@ def broken_func( # Should gracefully handle syntax errors result = pretooluse_hook(hook_data, config) - assert result["decision"] in ["allow", "deny", "ask"] - assert ( - "error" in result.get("message", "").lower() - or result["decision"] == "allow" - ) + decision = _perm(result) + assert decision in ["allow", "deny", "ask"] + if decision != "allow": + text = (result.get("reason") or "") + (result.get("systemMessage") or "") + assert "error" in text.lower() def test_unicode_content(self): """Test handling of Unicode characters in code.""" @@ -115,7 +119,7 @@ def greet_世界(): } result = pretooluse_hook(hook_data, config) - assert result["decision"] in ["allow", "deny", "ask"] + assert _perm(result) in ["allow", "deny", "ask"] def test_concurrent_hook_calls(self): """Test thread safety with concurrent calls.""" @@ -137,7 +141,7 @@ def greet_世界(): results.append(result) # All should have the same decision - decisions = [r["decision"] for r in results] + decisions = [_perm(r) for r in results] assert all(d == decisions[0] for d in decisions) def test_missing_tool_input_fields(self): @@ -150,7 +154,7 @@ def greet_世界(): "tool_input": {"content": "def test(): pass"}, } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" # Should handle gracefully + assert _perm(result) == "allow" # Should handle gracefully # Missing content for Write hook_data = { @@ -158,7 +162,7 @@ def greet_世界(): "tool_input": {"file_path": "test.py"}, } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" # Should handle gracefully + assert _perm(result) == "allow" # Should handle gracefully def test_circular_import_detection(self): """Test detection of circular imports.""" @@ -180,8 +184,7 @@ def func_c(): } result = pretooluse_hook(hook_data, config) - # Should not crash on import analysis - assert result["decision"] in ["allow", "deny", "ask"] + assert _perm(result) in ["allow", "deny", "ask"] def test_binary_file_path(self): """Test handling of binary file paths.""" @@ -195,7 +198,7 @@ def func_c(): } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" # Should skip non-Python files + assert _perm(result) == "allow" # Should skip non-Python files def test_null_and_none_values(self): """Test handling of null/None values.""" @@ -210,7 +213,7 @@ def func_c(): }, } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert _perm(result) == "allow" # None as file_path hook_data["tool_input"] = { @@ -218,7 +221,7 @@ def func_c(): "content": "def test(): pass", } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert _perm(result) == "allow" def test_path_traversal_attempts(self): """Test handling of path traversal attempts.""" @@ -239,8 +242,7 @@ def func_c(): }, } result = pretooluse_hook(hook_data, config) - # Should handle without crashing - assert result["decision"] in ["allow", "deny", "ask"] + assert _perm(result) in ["allow", "deny", "ask"] def test_extreme_thresholds(self): """Test with extreme threshold values.""" @@ -264,8 +266,7 @@ def func_c(): }, } result = pretooluse_hook(hook_data, config) - # With threshold 0, everything should be flagged - assert result["decision"] == "deny" + assert _perm(result) == "deny" # Maximum thresholds config = QualityConfig( @@ -282,8 +283,7 @@ def func_c(): }, } result = pretooluse_hook(hook_data, config) - # With very high thresholds and permissive mode, should pass with warning - assert result["decision"] == "allow" + assert _perm(result) == "allow" def test_subprocess_timeout(self): """Test handling of subprocess timeouts.""" @@ -345,8 +345,7 @@ def func_c(): mock_read.side_effect = PermissionError("Access denied") result = posttooluse_hook(hook_data, config) - # Should handle permission errors gracefully - assert result["decision"] == "allow" + assert "decision" not in result def test_deeply_nested_code_structure(self): """Test handling of deeply nested code.""" @@ -367,8 +366,7 @@ def func_c(): } result = pretooluse_hook(hook_data, config) - # Should handle without stack overflow - assert result["decision"] in ["allow", "deny", "ask"] + assert _perm(result) in ["allow", "deny", "ask"] def test_recursive_function_detection(self): """Test detection of recursive functions.""" @@ -392,8 +390,7 @@ def infinite_recursion(): } result = pretooluse_hook(hook_data, config) - # Should handle recursive functions - assert result["decision"] in ["allow", "deny", "ask"] + assert _perm(result) in ["allow", "deny", "ask"] def test_multifile_edit_paths(self): """Test MultiEdit with multiple file paths.""" @@ -413,7 +410,7 @@ def infinite_recursion(): with patch("code_quality_guard.analyze_code_quality") as mock_analyze: mock_analyze.return_value = {} result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert _perm(result) == "allow" # Should concatenate all new_strings call_args = mock_analyze.call_args[0][0] assert "func1" in call_args @@ -500,5 +497,4 @@ def cached_func(): } result = pretooluse_hook(hook_data, config) - # Should handle modern Python syntax - assert result["decision"] in ["allow", "deny", "ask"] + assert _perm(result) in ["allow", "deny", "ask"] diff --git a/tests/hooks/test_integration.py b/tests/hooks/test_integration.py index 90f3ad0..fb0dd6f 100644 --- a/tests/hooks/test_integration.py +++ b/tests/hooks/test_integration.py @@ -11,7 +11,7 @@ class TestHookIntegration: """Test complete hook integration scenarios.""" def test_main_entry_pretooluse(self): - """Test main entry point detects PreToolUse.""" + """Ensure main dispatches to PreToolUse.""" from code_quality_guard import main hook_input = { @@ -22,19 +22,24 @@ class TestHookIntegration: }, } - with patch("sys.stdin") as mock_stdin: - with patch("builtins.print"): - mock_stdin.read.return_value = json.dumps(hook_input) - mock_stdin.__iter__.return_value = [json.dumps(hook_input)] + with patch("sys.stdin") as mock_stdin, patch("builtins.print"): + mock_stdin.read.return_value = json.dumps(hook_input) + mock_stdin.__iter__.return_value = [json.dumps(hook_input)] - with patch("json.load", return_value=hook_input): - with patch("code_quality_guard.pretooluse_hook") as mock_pre: - mock_pre.return_value = {"decision": "allow"} - main() - mock_pre.assert_called_once() + with patch("json.load", return_value=hook_input), patch( + "code_quality_guard.pretooluse_hook", + return_value={ + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "allow", + }, + }, + ) as mock_pre: + main() + mock_pre.assert_called_once() def test_main_entry_posttooluse(self): - """Test main entry point detects PostToolUse.""" + """Ensure main dispatches to PostToolUse.""" from code_quality_guard import main hook_input = { @@ -45,39 +50,47 @@ class TestHookIntegration: }, } - with patch("sys.stdin") as mock_stdin: - with patch("builtins.print"): - mock_stdin.read.return_value = json.dumps(hook_input) - mock_stdin.__iter__.return_value = [json.dumps(hook_input)] + with patch("sys.stdin") as mock_stdin, patch("builtins.print"): + mock_stdin.read.return_value = json.dumps(hook_input) + mock_stdin.__iter__.return_value = [json.dumps(hook_input)] - with patch("json.load", return_value=hook_input): - with patch("code_quality_guard.posttooluse_hook") as mock_post: - mock_post.return_value = {"decision": "allow"} - main() - mock_post.assert_called_once() + with patch("json.load", return_value=hook_input), patch( + "code_quality_guard.posttooluse_hook", + return_value={ + "hookSpecificOutput": { + "hookEventName": "PostToolUse", + }, + "decision": "approve", + }, + ) as mock_post: + main() + mock_post.assert_called_once() def test_main_invalid_json(self): - """Test main handles invalid JSON input.""" + """Invalid JSON falls back to allow.""" from code_quality_guard import main - with patch("sys.stdin"): - with patch("builtins.print") as mock_print: - with patch( - "json.load", - side_effect=json.JSONDecodeError("test", "test", 0), - ): - main() + with patch("sys.stdin"), patch("builtins.print") as mock_print, patch( + "sys.stdout.write", + ) as mock_write: + with patch( + "json.load", + side_effect=json.JSONDecodeError("test", "test", 0), + ): + main() - # Should print allow decision - printed = mock_print.call_args[0][0] - response = json.loads(printed) - assert response["decision"] == "allow" + printed = ( + mock_print.call_args[0][0] + if mock_print.call_args + else mock_write.call_args[0][0] + ) + response = json.loads(printed) + assert response["hookSpecificOutput"]["permissionDecision"] == "allow" def test_full_flow_clean_code(self, clean_code): - """Test full flow with clean code.""" + """Clean code should pass both hook stages.""" from code_quality_guard import main - # PreToolUse pre_input = { "tool_name": "Write", "tool_input": { @@ -86,69 +99,66 @@ class TestHookIntegration: }, } - with patch("sys.stdin"): - with patch("builtins.print") as mock_print: - with patch("json.load", return_value=pre_input): - main() + with patch("sys.stdin"), patch("builtins.print") as mock_print: + with patch("json.load", return_value=pre_input), patch( + "code_quality_guard.analyze_code_quality", + return_value={}, + ): + main() - printed = mock_print.call_args[0][0] - response = json.loads(printed) - assert response["decision"] == "allow" + response = json.loads(mock_print.call_args[0][0]) + assert response["hookSpecificOutput"]["permissionDecision"] == "allow" - # Simulate file write test_file = Path(f"{tempfile.gettempdir()}/clean.py") test_file.write_text(clean_code) - # PostToolUse post_input = { "tool_name": "Write", "tool_output": { - "file_path": f"{tempfile.gettempdir()}/clean.py", + "file_path": str(test_file), "status": "success", }, } - with patch("sys.stdin"): - with patch("builtins.print") as mock_print: + os.environ["QUALITY_SHOW_SUCCESS"] = "true" + try: + with patch("sys.stdin"), patch("builtins.print") as mock_print: with patch("json.load", return_value=post_input): - os.environ["QUALITY_SHOW_SUCCESS"] = "true" main() - printed = mock_print.call_args[0][0] - response = json.loads(printed) - assert response["decision"] == "allow" - assert "passed" in response.get("message", "").lower() - - test_file.unlink(missing_ok=True) + response = json.loads(mock_print.call_args[0][0]) + assert response.get("decision") == "approve" + assert "passed" in response.get("systemMessage", "").lower() + finally: + os.environ.pop("QUALITY_SHOW_SUCCESS", None) + test_file.unlink(missing_ok=True) def test_environment_configuration_flow(self): - """Test that environment variables are properly used.""" + """Environment settings change enforcement.""" from code_quality_guard import main - # Set strict environment - os.environ.update( - { - "QUALITY_ENFORCEMENT": "strict", - "QUALITY_COMPLEXITY_THRESHOLD": "5", # Very low threshold - "QUALITY_DUP_ENABLED": "false", - "QUALITY_COMPLEXITY_ENABLED": "true", # Keep complexity enabled - "QUALITY_MODERN_ENABLED": "false", - }, - ) + env_overrides = { + "QUALITY_ENFORCEMENT": "strict", + "QUALITY_COMPLEXITY_THRESHOLD": "5", + "QUALITY_DUP_ENABLED": "false", + "QUALITY_COMPLEXITY_ENABLED": "true", + "QUALITY_MODERN_ENABLED": "false", + } + os.environ.update(env_overrides) complex_code = """ -def complex_func(a, b, c): - if a: - if b: - if c: - return 1 + def complex_func(a, b, c): + if a: + if b: + if c: + return 1 + else: + return 2 + else: + return 3 else: - return 2 - else: - return 3 - else: - return 4 -""" + return 4 + """ hook_input = { "tool_name": "Write", @@ -158,38 +168,36 @@ def complex_func(a, b, c): }, } - with patch("sys.stdin"): - with patch("builtins.print") as mock_print: - with patch("json.load", return_value=hook_input): - with patch( - "code_quality_guard.analyze_code_quality", - ) as mock_analyze: - # Mock the complexity analysis result - mock_analyze.return_value = { - "complexity": { - "summary": { - "average_cyclomatic_complexity": 8, - }, # Above threshold - "distribution": {"High": 1}, - }, - } - try: - main() - msg = "Expected SystemExit" - raise AssertionError(msg) - except SystemExit as e: - assert e.code == 2, "Expected exit code 2 for deny" # noqa: PT017 + try: + with patch("sys.stdin"), patch("builtins.print") as mock_print: + with patch("json.load", return_value=hook_input), patch( + "code_quality_guard.analyze_code_quality", + return_value={ + "complexity": { + "summary": {"average_cyclomatic_complexity": 8}, + "distribution": {"High": 1}, + }, + }, + ): + try: + main() + raise AssertionError("Expected SystemExit") + except SystemExit as exc: + assert exc.code == 2 - printed = mock_print.call_args[0][0] - response = json.loads(printed) - # Should be denied due to low complexity threshold - assert response["decision"] == "deny" + response = json.loads(mock_print.call_args[0][0]) + assert ( + response["hookSpecificOutput"]["permissionDecision"] + == "deny" + ) + finally: + for key in env_overrides: + os.environ.pop(key, None) def test_skip_patterns_integration(self): - """Test skip patterns work in integration.""" + """Skip patterns should bypass checks.""" from code_quality_guard import main - # Test file should be skipped hook_input = { "tool_name": "Write", "tool_input": { @@ -198,72 +206,62 @@ def complex_func(a, b, c): }, } - with patch("sys.stdin"): - with patch("builtins.print") as mock_print: - with patch("json.load", return_value=hook_input): - main() + with patch("sys.stdin"), patch("builtins.print") as mock_print: + with patch("json.load", return_value=hook_input): + main() - printed = mock_print.call_args[0][0] - response = json.loads(printed) - assert response["decision"] == "allow" + response = json.loads(mock_print.call_args[0][0]) + assert response["hookSpecificOutput"]["permissionDecision"] == "allow" def test_state_tracking_flow(self, temp_python_file): - """Test state tracking between pre and post.""" + """State tracking should flag regressions.""" from code_quality_guard import main os.environ["QUALITY_STATE_TRACKING"] = "true" + try: + pre_input = { + "tool_name": "Write", + "tool_input": { + "file_path": str(temp_python_file), + "content": "def func1(): pass\ndef func2(): pass\ndef func3(): pass", + }, + } - # PreToolUse - store state - initial_content = "def func1(): pass\ndef func2(): pass\ndef func3(): pass" - pre_input = { - "tool_name": "Write", - "tool_input": { - "file_path": str(temp_python_file), - "content": initial_content, - }, - } - - with patch("sys.stdin"): - with patch("builtins.print") as mock_print: - with patch("json.load", return_value=pre_input): + with patch("sys.stdin"), patch("builtins.print") as mock_print: + with patch("json.load", return_value=pre_input), patch( + "code_quality_guard.analyze_code_quality", + return_value={}, + ): main() - # Simulate file modification (fewer functions) - modified_content = "def func1(): pass" - temp_python_file.write_text(modified_content) + temp_python_file.write_text("def func1(): pass") - # PostToolUse - check state - post_input = { - "tool_name": "Write", - "tool_output": { - "file_path": str(temp_python_file), - "status": "success", - }, - } + post_input = { + "tool_name": "Write", + "tool_output": { + "file_path": str(temp_python_file), + "status": "success", + }, + } - with patch("sys.stdin"): - with patch("builtins.print") as mock_print: + with patch("sys.stdin"), patch("builtins.print") as mock_print: with patch("json.load", return_value=post_input): main() - printed = mock_print.call_args[0][0] - response = json.loads(printed) - assert response["decision"] == "allow" - # Should detect function reduction - if "message" in response: - assert ( - "reduced" in response["message"].lower() - or len(response["message"]) == 0 - ) + response = json.loads(mock_print.call_args[0][0]) + assert response["decision"] == "block" + assert "reduced" in response["reason"].lower() + finally: + os.environ.pop("QUALITY_STATE_TRACKING", None) def test_cross_tool_handling(self): - """Test different tools are handled correctly.""" + """Supported tools should respond with allow.""" from code_quality_guard import main tools = ["Write", "Edit", "MultiEdit", "Read", "Bash", "Task"] for tool in tools: - if tool in ["Write", "Edit", "MultiEdit"]: + if tool in {"Write", "Edit", "MultiEdit"}: hook_input = { "tool_name": tool, "tool_input": { @@ -272,22 +270,20 @@ def complex_func(a, b, c): }, } else: - hook_input = { - "tool_name": tool, - "tool_input": {}, - } + hook_input = {"tool_name": tool, "tool_input": {}} - with patch("sys.stdin"): - with patch("builtins.print") as mock_print: - with patch("json.load", return_value=hook_input): - main() + with patch("sys.stdin"), patch("builtins.print") as mock_print: + with patch("json.load", return_value=hook_input), patch( + "code_quality_guard.analyze_code_quality", + return_value={}, + ): + main() - printed = mock_print.call_args[0][0] - response = json.loads(printed) - assert response["decision"] == "allow" + response = json.loads(mock_print.call_args[0][0]) + assert response["hookSpecificOutput"]["permissionDecision"] == "allow" def test_enforcement_mode_progression(self, complex_code): - """Test progression through enforcement modes.""" + """Strict/warn/permissive modes map to deny/ask/allow.""" from code_quality_guard import main hook_input = { @@ -298,33 +294,40 @@ def complex_func(a, b, c): }, } - modes_and_decisions = [ + scenarios = [ ("strict", "deny"), ("warn", "ask"), ("permissive", "allow"), ] - for mode, expected_decision in modes_and_decisions: + for mode, expected in scenarios: os.environ["QUALITY_ENFORCEMENT"] = mode os.environ["QUALITY_COMPLEXITY_THRESHOLD"] = "10" - - with patch("sys.stdin"): - with patch("builtins.print") as mock_print: - with patch("json.load", return_value=hook_input): - if expected_decision in ["deny", "ask"]: - # Expect SystemExit with code 2 for deny/ask decisions + try: + with patch("sys.stdin"), patch("builtins.print") as mock_print: + with patch("json.load", return_value=hook_input), patch( + "code_quality_guard.analyze_code_quality", + return_value={ + "complexity": { + "summary": {"average_cyclomatic_complexity": 25}, + "distribution": {"High": 1}, + }, + }, + ): + if expected in {"deny", "ask"}: try: main() - msg = f"Expected SystemExit for {mode} mode" - raise AssertionError(msg) - except SystemExit as e: - assert e.code == 2, ( # noqa: PT017 - f"Expected exit code 2 for {mode} mode" - ) + raise AssertionError("Expected SystemExit") + except SystemExit as exc: + assert exc.code == 2 else: - # Permissive mode should not exit main() - printed = mock_print.call_args[0][0] - response = json.loads(printed) - assert response["decision"] == expected_decision + response = json.loads(mock_print.call_args[0][0]) + assert ( + response["hookSpecificOutput"]["permissionDecision"] + == expected + ) + finally: + os.environ.pop("QUALITY_ENFORCEMENT", None) + os.environ.pop("QUALITY_COMPLEXITY_THRESHOLD", None) diff --git a/tests/hooks/test_posttooluse.py b/tests/hooks/test_posttooluse.py index 8597c01..864d9a3 100644 --- a/tests/hooks/test_posttooluse.py +++ b/tests/hooks/test_posttooluse.py @@ -18,7 +18,8 @@ class TestPostToolUseHook: } result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert result["hookSpecificOutput"]["hookEventName"] == "PostToolUse" + assert "decision" not in result def test_file_path_extraction_dict(self): """Test file path extraction from dict output.""" @@ -34,7 +35,7 @@ class TestPostToolUseHook: with patch("pathlib.Path.exists", return_value=True): with patch("pathlib.Path.read_text", return_value="def test(): pass"): result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert "decision" not in result # Test with path key hook_data["tool_output"] = {"path": test_file} @@ -42,7 +43,7 @@ class TestPostToolUseHook: with patch("pathlib.Path.exists", return_value=True): with patch("pathlib.Path.read_text", return_value="def test(): pass"): result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert "decision" not in result def test_file_path_extraction_string(self): """Test file path extraction from string output.""" @@ -55,7 +56,7 @@ class TestPostToolUseHook: with patch("pathlib.Path.exists", return_value=True): with patch("pathlib.Path.read_text", return_value="def test(): pass"): result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert "decision" not in result def test_non_python_file_skipped(self): """Test that non-Python files are skipped.""" @@ -66,7 +67,7 @@ class TestPostToolUseHook: } result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert "decision" not in result def test_nonexistent_file_skipped(self): """Test that nonexistent files are skipped.""" @@ -78,7 +79,7 @@ class TestPostToolUseHook: with patch("pathlib.Path.exists", return_value=False): result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert "decision" not in result def test_state_tracking_degradation(self): """Test state tracking detects quality degradation.""" @@ -97,9 +98,10 @@ class TestPostToolUseHook: ] result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" - assert "Post-write quality notes" in result["message"] - assert "Reduced functions" in result["message"] + assert result["decision"] == "block" + reason_text = result["reason"].lower() + assert "post-write quality notes" in reason_text + assert "reduced functions" in reason_text def test_cross_file_duplicates(self): """Test cross-file duplicate detection.""" @@ -117,8 +119,8 @@ class TestPostToolUseHook: mock_check.return_value = ["⚠️ Cross-file duplication detected"] result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" - assert "Cross-file duplication" in result["message"] + assert result["decision"] == "block" + assert "cross-file duplication" in result["reason"].lower() def test_naming_convention_violations(self, non_pep8_code): """Test naming convention verification.""" @@ -131,9 +133,9 @@ class TestPostToolUseHook: with patch("pathlib.Path.exists", return_value=True): with patch("pathlib.Path.read_text", return_value=non_pep8_code): result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" - assert "Non-PEP8 function names" in result["message"] - assert "Non-PEP8 class names" in result["message"] + assert result["decision"] == "block" + assert "non-pep8 function names" in result["reason"].lower() + assert "non-pep8 class names" in result["reason"].lower() def test_show_success_message(self, clean_code): """Test success message when enabled.""" @@ -146,8 +148,8 @@ class TestPostToolUseHook: with patch("pathlib.Path.exists", return_value=True): with patch("pathlib.Path.read_text", return_value=clean_code): result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" - assert "passed post-write verification" in result["message"] + assert result["decision"] == "approve" + assert "passed post-write verification" in result["systemMessage"].lower() def test_no_message_when_success_disabled(self, clean_code): """Test no message when show_success is disabled.""" @@ -160,8 +162,8 @@ class TestPostToolUseHook: with patch("pathlib.Path.exists", return_value=True): with patch("pathlib.Path.read_text", return_value=clean_code): result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" - assert "message" not in result + assert "decision" not in result + assert "systemMessage" not in result def test_all_features_combined(self): """Test all PostToolUse features combined.""" @@ -190,10 +192,11 @@ class TestPostToolUseHook: mock_naming.return_value = ["⚠️ Issue 3"] result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" - assert "Issue 1" in result["message"] - assert "Issue 2" in result["message"] - assert "Issue 3" in result["message"] + assert result["decision"] == "block" + reason_text = result["reason"].lower() + assert "issue 1" in reason_text + assert "issue 2" in reason_text + assert "issue 3" in reason_text def test_edit_tool_output(self): """Test Edit tool output handling.""" @@ -209,7 +212,7 @@ class TestPostToolUseHook: with patch("pathlib.Path.exists", return_value=True): with patch("pathlib.Path.read_text", return_value="def test(): pass"): result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert "decision" not in result def test_multiedit_tool_output(self): """Test MultiEdit tool output handling.""" @@ -225,7 +228,7 @@ class TestPostToolUseHook: with patch("pathlib.Path.exists", return_value=True): with patch("pathlib.Path.read_text", return_value="def test(): pass"): result = posttooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert "decision" not in result def test_features_disabled(self): """Test with all features disabled.""" @@ -257,5 +260,5 @@ class TestPostToolUseHook: mock_cross.assert_not_called() mock_naming.assert_not_called() - assert result["decision"] == "allow" - assert "message" not in result + assert "decision" not in result + assert "systemMessage" not in result diff --git a/tests/hooks/test_pretooluse.py b/tests/hooks/test_pretooluse.py index 9efea81..9fec21d 100644 --- a/tests/hooks/test_pretooluse.py +++ b/tests/hooks/test_pretooluse.py @@ -17,7 +17,7 @@ class TestPreToolUseHook: } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert result["permissionDecision"] == "allow" def test_non_python_file_allowed(self): """Test that non-Python files are always allowed.""" @@ -31,7 +31,7 @@ class TestPreToolUseHook: } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert result["permissionDecision"] == "allow" def test_test_file_skipped(self): """Test that test files are skipped when configured.""" @@ -45,7 +45,7 @@ class TestPreToolUseHook: } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert result["permissionDecision"] == "allow" def test_clean_code_allowed(self, clean_code): """Test that clean code is allowed.""" @@ -61,7 +61,7 @@ class TestPreToolUseHook: with patch("code_quality_guard.analyze_code_quality") as mock_analyze: mock_analyze.return_value = {} result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert result["permissionDecision"] == "allow" def test_complex_code_denied_strict(self, complex_code): """Test that complex code is denied in strict mode.""" @@ -83,8 +83,8 @@ class TestPreToolUseHook: } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "deny" - assert "quality check failed" in result["message"].lower() + assert result["permissionDecision"] == "deny" + assert "quality check failed" in result["reason"].lower() def test_complex_code_ask_warn_mode(self, complex_code): """Test that complex code triggers ask in warn mode.""" @@ -106,7 +106,7 @@ class TestPreToolUseHook: } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "ask" + assert result["permissionDecision"] == "ask" def test_complex_code_allowed_permissive(self, complex_code): """Test that complex code is allowed with warning in permissive mode.""" @@ -128,8 +128,8 @@ class TestPreToolUseHook: } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" - assert "warning" in result["message"].lower() + assert result["permissionDecision"] == "allow" + assert "warning" in result.get("reason", "").lower() def test_duplicate_code_detection(self, duplicate_code): """Test internal duplicate detection.""" @@ -162,8 +162,8 @@ class TestPreToolUseHook: } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "deny" - assert "duplication" in result["message"].lower() + assert result["permissionDecision"] == "deny" + assert "duplication" in result["reason"].lower() def test_edit_tool_handling(self): """Test Edit tool content extraction.""" @@ -180,7 +180,7 @@ class TestPreToolUseHook: with patch("code_quality_guard.analyze_code_quality") as mock_analyze: mock_analyze.return_value = {} result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert result["permissionDecision"] == "allow" # Verify new_string was analyzed call_args = mock_analyze.call_args[0] @@ -203,7 +203,7 @@ class TestPreToolUseHook: with patch("code_quality_guard.analyze_code_quality") as mock_analyze: mock_analyze.return_value = {} result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert result["permissionDecision"] == "allow" # Verify concatenated content was analyzed call_args = mock_analyze.call_args[0] @@ -245,8 +245,8 @@ class TestPreToolUseHook: mock_analyze.side_effect = Exception("Analysis failed") result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" - assert "error" in result["message"].lower() + assert result["permissionDecision"] == "allow" + assert "error" in result.get("reason", "").lower() def test_custom_skip_patterns(self): """Test custom skip patterns.""" @@ -261,12 +261,12 @@ class TestPreToolUseHook: }, } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert result["permissionDecision"] == "allow" # Test path pattern match hook_data["tool_input"]["file_path"] = "/ignored/file.py" result = pretooluse_hook(hook_data, config) - assert result["decision"] == "allow" + assert result["permissionDecision"] == "allow" def test_modernization_issues(self, old_style_code): """Test modernization issue detection.""" @@ -292,8 +292,8 @@ class TestPreToolUseHook: } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "deny" - assert "modernization" in result["message"].lower() + assert result["permissionDecision"] == "deny" + assert "modernization" in result["reason"].lower() def test_type_hint_threshold(self): """Test type hint issue threshold.""" @@ -320,5 +320,51 @@ class TestPreToolUseHook: } result = pretooluse_hook(hook_data, config) - assert result["decision"] == "deny" - assert "type hints" in result["message"].lower() + assert result["permissionDecision"] == "deny" + assert "type hints" in result["reason"].lower() + + def test_any_usage_denied(self): + """Test that typing.Any usage triggers a denial.""" + config = QualityConfig(enforcement_mode="strict") + hook_data = { + "tool_name": "Write", + "tool_input": { + "file_path": "example.py", + "content": "from typing import Any\n\n" + "def example(value: Any) -> None:\n pass\n", + }, + } + + with patch("code_quality_guard.analyze_code_quality") as mock_analyze: + mock_analyze.return_value = {} + + result = pretooluse_hook(hook_data, config) + assert result["permissionDecision"] == "ask" + assert "any" in result["reason"].lower() + + def test_any_usage_detected_in_multiedit(self): + """Test that MultiEdit content is scanned for typing.Any usage.""" + config = QualityConfig() + hook_data = { + "tool_name": "MultiEdit", + "tool_input": { + "file_path": "example.py", + "edits": [ + { + "old_string": "pass", + "new_string": "from typing import Any\n", + }, + { + "old_string": "pass", + "new_string": "def handler(arg: Any) -> str:\n return str(arg)\n", + }, + ], + }, + } + + with patch("code_quality_guard.analyze_code_quality") as mock_analyze: + mock_analyze.return_value = {} + + result = pretooluse_hook(hook_data, config) + assert result["permissionDecision"] == "ask" + assert "any" in result["reason"].lower()