Files
claude-scripts/hooks/bash_command_guard.py
Travis Vasceannie f3832bdf3d feat: enhance test coverage and improve code quality checks
- Updated test files to improve coverage for Any usage, type: ignore, and old typing patterns, ensuring that these patterns are properly blocked.
- Refactored test structure for better clarity and maintainability, including the introduction of fixtures and improved assertions.
- Enhanced error handling and messaging in the hook system to provide clearer feedback on violations.
- Improved integration tests to validate hook behavior across various scenarios, ensuring robustness and reliability.
2025-10-08 09:10:32 +00:00

451 lines
14 KiB
Python

"""Shell command guard for Claude Code PreToolUse/PostToolUse hooks.
Prevents circumvention of type safety rules via shell commands that could inject
'Any' types or type ignore comments into Python files.
"""
import json
import re
import subprocess
import sys
from pathlib import Path
from shutil import which
from typing import TypedDict
from .bash_guard_constants import (
DANGEROUS_SHELL_PATTERNS,
FORBIDDEN_PATTERNS,
PYTHON_FILE_PATTERNS,
)
class JsonObject(TypedDict, total=False):
"""Type for JSON-like objects."""
hookEventName: str
permissionDecision: str
permissionDecisionReason: str
decision: str
reason: str
systemMessage: str
hookSpecificOutput: dict[str, object]
def _contains_forbidden_pattern(text: str) -> tuple[bool, str | None]:
"""Check if text contains any forbidden patterns.
Args:
text: The text to check for forbidden patterns.
Returns:
Tuple of (has_violation, matched_pattern_description)
"""
for pattern in FORBIDDEN_PATTERNS:
if re.search(pattern, text, re.IGNORECASE):
if "Any" in pattern:
return True, "typing.Any usage"
if "type.*ignore" in pattern:
return True, "type suppression comment"
return False, None
def _is_dangerous_shell_command(command: str) -> tuple[bool, str | None]:
"""Check if shell command uses dangerous patterns.
Args:
command: The shell command to analyze.
Returns:
Tuple of (is_dangerous, reason)
"""
# Check if command targets Python files
targets_python = any(
re.search(pattern, command) for pattern in PYTHON_FILE_PATTERNS
)
if not targets_python:
return False, None
# Check for dangerous shell patterns
for pattern in DANGEROUS_SHELL_PATTERNS:
if re.search(pattern, command):
tool_match = re.search(
r"\b(sed|awk|perl|ed|echo|printf|cat|tee|find|xargs|python|vim|nano|emacs)\b",
pattern,
)
tool_name = tool_match.group(1) if tool_match else "shell utility"
return True, f"Use of {tool_name} to modify Python files"
return False, None
def _command_contains_forbidden_injection(command: str) -> tuple[bool, str | None]:
"""Check if command attempts to inject forbidden patterns.
Args:
command: The shell command to analyze.
Returns:
Tuple of (has_injection, violation_description)
"""
# Check if the command itself contains forbidden patterns
has_violation, violation_type = _contains_forbidden_pattern(command)
if has_violation:
return True, violation_type
# Check for encoded or escaped patterns
# Handle common escape sequences
decoded_cmd = command.replace("\\n", "\n").replace("\\t", "\t")
decoded_cmd = re.sub(r"\\\s", " ", decoded_cmd)
has_violation, violation_type = _contains_forbidden_pattern(decoded_cmd)
if has_violation:
return True, f"{violation_type} (escaped)"
return False, None
def _analyze_bash_command(command: str) -> tuple[bool, list[str]]:
"""Analyze bash command for safety violations.
Args:
command: The bash command to analyze.
Returns:
Tuple of (should_block, list_of_violations)
"""
violations: list[str] = []
# Check for forbidden pattern injection
has_injection, injection_type = _command_contains_forbidden_injection(command)
if has_injection:
violations.append(f"⛔ Shell command attempts to inject {injection_type}")
# Check for dangerous shell patterns on Python files
is_dangerous, danger_reason = _is_dangerous_shell_command(command)
if is_dangerous:
violations.append(
f"{danger_reason} is forbidden - use Edit/Write tools instead",
)
return len(violations) > 0, violations
def _create_hook_response(
event_name: str,
permission: str = "",
reason: str = "",
system_message: str = "",
*,
decision: str | None = None,
) -> JsonObject:
"""Create standardized hook response.
Args:
event_name: Name of the hook event (PreToolUse, PostToolUse, Stop).
permission: Permission decision (allow, deny, ask).
reason: Reason for the decision.
system_message: System message to display.
decision: Decision for PostToolUse/Stop hooks (approve, block).
Returns:
JSON response object for the hook.
"""
hook_output: dict[str, object] = {
"hookEventName": event_name,
}
if permission:
hook_output["permissionDecision"] = permission
if reason:
hook_output["permissionDecisionReason"] = reason
response: JsonObject = {
"hookSpecificOutput": hook_output,
}
if permission:
response["permissionDecision"] = permission
if decision:
response["decision"] = decision
if reason:
response["reason"] = reason
if system_message:
response["systemMessage"] = system_message
return response
def pretooluse_bash_hook(hook_data: dict[str, object]) -> JsonObject:
"""Handle PreToolUse hook for Bash commands.
Args:
hook_data: Hook input data containing tool_name and tool_input.
Returns:
Hook response with permission decision.
"""
tool_name = str(hook_data.get("tool_name", ""))
# Only analyze Bash commands
if tool_name != "Bash":
return _create_hook_response("PreToolUse", "allow")
tool_input_raw = hook_data.get("tool_input", {})
if not isinstance(tool_input_raw, dict):
return _create_hook_response("PreToolUse", "allow")
tool_input: dict[str, object] = dict(tool_input_raw)
command = str(tool_input.get("command", ""))
if not command:
return _create_hook_response("PreToolUse", "allow")
# Analyze command for violations
should_block, violations = _analyze_bash_command(command)
if not should_block:
return _create_hook_response("PreToolUse", "allow")
# Build denial message
violation_text = "\n".join(f" {v}" for v in violations)
message = (
f"🚫 Shell Command Blocked\n\n"
f"Violations:\n{violation_text}\n\n"
f"Command: {command[:200]}{'...' if len(command) > 200 else ''}\n\n"
f"Use Edit/Write tools to modify Python files with proper type safety."
)
return _create_hook_response(
"PreToolUse",
"deny",
message,
message,
)
def posttooluse_bash_hook(hook_data: dict[str, object]) -> JsonObject:
"""Handle PostToolUse hook for Bash commands.
Args:
hook_data: Hook output data containing tool_response.
Returns:
Hook response with decision.
"""
tool_name = str(hook_data.get("tool_name", ""))
# Only analyze Bash commands
if tool_name != "Bash":
return _create_hook_response("PostToolUse")
# Extract command from hook data
tool_input_raw = hook_data.get("tool_input", {})
if not isinstance(tool_input_raw, dict):
return _create_hook_response("PostToolUse")
tool_input: dict[str, object] = dict(tool_input_raw)
command = str(tool_input.get("command", ""))
# Check if command modified any Python files
# Look for file paths in the command
python_files: list[str] = []
for match in re.finditer(r"([^\s]+\.pyi?)\b", command):
file_path = match.group(1)
if Path(file_path).exists():
python_files.append(file_path)
if not python_files:
return _create_hook_response("PostToolUse")
# Scan modified files for violations
violations: list[str] = []
for file_path in python_files:
try:
with open(file_path, encoding="utf-8") as file_handle:
content = file_handle.read()
has_violation, violation_type = _contains_forbidden_pattern(content)
if has_violation:
violations.append(
f"⛔ File '{Path(file_path).name}' contains {violation_type}",
)
except (OSError, UnicodeDecodeError):
# If we can't read the file, skip it
continue
if violations:
violation_text = "\n".join(f" {v}" for v in violations)
message = (
f"🚫 Post-Execution Violation Detected\n\n"
f"Violations:\n{violation_text}\n\n"
f"Shell command introduced forbidden patterns. "
f"Please revert changes and use proper typing."
)
return _create_hook_response(
"PostToolUse",
"",
message,
message,
decision="block",
)
return _create_hook_response("PostToolUse")
def stop_hook(_hook_data: dict[str, object]) -> JsonObject:
"""Handle Stop hook - final validation before completion.
Args:
_hook_data: Stop hook data (unused).
Returns:
Hook response with decision.
"""
# Get list of changed files from git
try:
git_path = which("git")
if git_path is None:
return _create_hook_response("Stop", decision="approve")
# Safe: invokes git with fixed arguments, no user input interpolation.
result = subprocess.run( # noqa: S603
[git_path, "diff", "--name-only", "--cached"],
capture_output=True,
text=True,
check=False,
timeout=10,
)
if result.returncode != 0:
# No git repo or no staged changes
return _create_hook_response("Stop", decision="approve")
changed_files = [
file_name.strip()
for file_name in result.stdout.split("\n")
if file_name.strip() and file_name.strip().endswith((".py", ".pyi"))
]
if not changed_files:
return _create_hook_response("Stop", decision="approve")
# Scan all changed Python files
violations: list[str] = []
for file_path in changed_files:
if not Path(file_path).exists():
continue
try:
with open(file_path, encoding="utf-8") as file_handle:
content = file_handle.read()
has_violation, violation_type = _contains_forbidden_pattern(content)
if has_violation:
violations.append(f"{file_path}: {violation_type}")
except (OSError, UnicodeDecodeError):
continue
if violations:
violation_text = "\n".join(f" {v}" for v in violations)
message = (
f"🚫 Final Validation Failed\n\n"
f"Violations:\n{violation_text}\n\n"
f"Please remove forbidden patterns before completing."
)
return _create_hook_response(
"Stop",
"",
message,
message,
decision="block",
)
return _create_hook_response("Stop", decision="approve")
except (OSError, subprocess.SubprocessError, TimeoutError) as exc:
# If validation fails, allow but warn
return _create_hook_response(
"Stop",
"",
f"Warning: Final validation error: {exc}",
f"Warning: Final validation error: {exc}",
decision="approve",
)
def main() -> None:
"""Main hook entry point."""
try:
# Read hook input from stdin
try:
hook_data: dict[str, object] = json.load(sys.stdin)
except json.JSONDecodeError:
fallback_response: JsonObject = {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow",
},
}
sys.stdout.write(json.dumps(fallback_response))
return
# Detect hook type
response: JsonObject
if "tool_response" in hook_data or "tool_output" in hook_data:
# PostToolUse hook
response = posttooluse_bash_hook(hook_data)
elif hook_data.get("hookEventName") == "Stop":
# Stop hook
response = stop_hook(hook_data)
else:
# PreToolUse hook
response = pretooluse_bash_hook(hook_data)
sys.stdout.write(json.dumps(response))
# Handle exit codes based on hook output
hook_output_raw = response.get("hookSpecificOutput", {})
if not hook_output_raw or not isinstance(hook_output_raw, dict):
return
hook_output: dict[str, object] = hook_output_raw
permission_decision = hook_output.get("permissionDecision")
if permission_decision == "deny":
# Exit code 2: Blocking error
reason = str(
hook_output.get("permissionDecisionReason", "Permission denied"),
)
sys.stderr.write(reason)
sys.exit(2)
elif permission_decision == "ask":
# Exit code 2 for ask decisions
reason = str(
hook_output.get("permissionDecisionReason", "Permission request"),
)
sys.stderr.write(reason)
sys.exit(2)
# Check for Stop hook block decision
if response.get("decision") == "block":
reason = str(response.get("reason", "Validation failed"))
sys.stderr.write(reason)
sys.exit(2)
except (OSError, ValueError, subprocess.SubprocessError, TimeoutError) as exc:
# Unexpected error - use exit code 1 (non-blocking error)
sys.stderr.write(f"Hook error: {exc}")
sys.exit(1)
if __name__ == "__main__":
main()