- Updated test files to improve coverage for Any usage, type: ignore, and old typing patterns, ensuring that these patterns are properly blocked. - Refactored test structure for better clarity and maintainability, including the introduction of fixtures and improved assertions. - Enhanced error handling and messaging in the hook system to provide clearer feedback on violations. - Improved integration tests to validate hook behavior across various scenarios, ensuring robustness and reliability.
451 lines
14 KiB
Python
451 lines
14 KiB
Python
"""Shell command guard for Claude Code PreToolUse/PostToolUse hooks.
|
|
|
|
Prevents circumvention of type safety rules via shell commands that could inject
|
|
'Any' types or type ignore comments into Python files.
|
|
"""
|
|
|
|
import json
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from shutil import which
|
|
from typing import TypedDict
|
|
|
|
from .bash_guard_constants import (
|
|
DANGEROUS_SHELL_PATTERNS,
|
|
FORBIDDEN_PATTERNS,
|
|
PYTHON_FILE_PATTERNS,
|
|
)
|
|
|
|
|
|
class JsonObject(TypedDict, total=False):
|
|
"""Type for JSON-like objects."""
|
|
|
|
hookEventName: str
|
|
permissionDecision: str
|
|
permissionDecisionReason: str
|
|
decision: str
|
|
reason: str
|
|
systemMessage: str
|
|
hookSpecificOutput: dict[str, object]
|
|
|
|
|
|
def _contains_forbidden_pattern(text: str) -> tuple[bool, str | None]:
|
|
"""Check if text contains any forbidden patterns.
|
|
|
|
Args:
|
|
text: The text to check for forbidden patterns.
|
|
|
|
Returns:
|
|
Tuple of (has_violation, matched_pattern_description)
|
|
"""
|
|
for pattern in FORBIDDEN_PATTERNS:
|
|
if re.search(pattern, text, re.IGNORECASE):
|
|
if "Any" in pattern:
|
|
return True, "typing.Any usage"
|
|
if "type.*ignore" in pattern:
|
|
return True, "type suppression comment"
|
|
return False, None
|
|
|
|
|
|
def _is_dangerous_shell_command(command: str) -> tuple[bool, str | None]:
|
|
"""Check if shell command uses dangerous patterns.
|
|
|
|
Args:
|
|
command: The shell command to analyze.
|
|
|
|
Returns:
|
|
Tuple of (is_dangerous, reason)
|
|
"""
|
|
# Check if command targets Python files
|
|
targets_python = any(
|
|
re.search(pattern, command) for pattern in PYTHON_FILE_PATTERNS
|
|
)
|
|
|
|
if not targets_python:
|
|
return False, None
|
|
|
|
# Check for dangerous shell patterns
|
|
for pattern in DANGEROUS_SHELL_PATTERNS:
|
|
if re.search(pattern, command):
|
|
tool_match = re.search(
|
|
r"\b(sed|awk|perl|ed|echo|printf|cat|tee|find|xargs|python|vim|nano|emacs)\b",
|
|
pattern,
|
|
)
|
|
tool_name = tool_match.group(1) if tool_match else "shell utility"
|
|
return True, f"Use of {tool_name} to modify Python files"
|
|
|
|
return False, None
|
|
|
|
|
|
def _command_contains_forbidden_injection(command: str) -> tuple[bool, str | None]:
|
|
"""Check if command attempts to inject forbidden patterns.
|
|
|
|
Args:
|
|
command: The shell command to analyze.
|
|
|
|
Returns:
|
|
Tuple of (has_injection, violation_description)
|
|
"""
|
|
# Check if the command itself contains forbidden patterns
|
|
has_violation, violation_type = _contains_forbidden_pattern(command)
|
|
|
|
if has_violation:
|
|
return True, violation_type
|
|
|
|
# Check for encoded or escaped patterns
|
|
# Handle common escape sequences
|
|
decoded_cmd = command.replace("\\n", "\n").replace("\\t", "\t")
|
|
decoded_cmd = re.sub(r"\\\s", " ", decoded_cmd)
|
|
|
|
has_violation, violation_type = _contains_forbidden_pattern(decoded_cmd)
|
|
if has_violation:
|
|
return True, f"{violation_type} (escaped)"
|
|
|
|
return False, None
|
|
|
|
|
|
def _analyze_bash_command(command: str) -> tuple[bool, list[str]]:
|
|
"""Analyze bash command for safety violations.
|
|
|
|
Args:
|
|
command: The bash command to analyze.
|
|
|
|
Returns:
|
|
Tuple of (should_block, list_of_violations)
|
|
"""
|
|
violations: list[str] = []
|
|
|
|
# Check for forbidden pattern injection
|
|
has_injection, injection_type = _command_contains_forbidden_injection(command)
|
|
if has_injection:
|
|
violations.append(f"⛔ Shell command attempts to inject {injection_type}")
|
|
|
|
# Check for dangerous shell patterns on Python files
|
|
is_dangerous, danger_reason = _is_dangerous_shell_command(command)
|
|
if is_dangerous:
|
|
violations.append(
|
|
f"⛔ {danger_reason} is forbidden - use Edit/Write tools instead",
|
|
)
|
|
|
|
return len(violations) > 0, violations
|
|
|
|
|
|
def _create_hook_response(
|
|
event_name: str,
|
|
permission: str = "",
|
|
reason: str = "",
|
|
system_message: str = "",
|
|
*,
|
|
decision: str | None = None,
|
|
) -> JsonObject:
|
|
"""Create standardized hook response.
|
|
|
|
Args:
|
|
event_name: Name of the hook event (PreToolUse, PostToolUse, Stop).
|
|
permission: Permission decision (allow, deny, ask).
|
|
reason: Reason for the decision.
|
|
system_message: System message to display.
|
|
decision: Decision for PostToolUse/Stop hooks (approve, block).
|
|
|
|
Returns:
|
|
JSON response object for the hook.
|
|
"""
|
|
hook_output: dict[str, object] = {
|
|
"hookEventName": event_name,
|
|
}
|
|
|
|
if permission:
|
|
hook_output["permissionDecision"] = permission
|
|
if reason:
|
|
hook_output["permissionDecisionReason"] = reason
|
|
|
|
response: JsonObject = {
|
|
"hookSpecificOutput": hook_output,
|
|
}
|
|
|
|
if permission:
|
|
response["permissionDecision"] = permission
|
|
|
|
if decision:
|
|
response["decision"] = decision
|
|
|
|
if reason:
|
|
response["reason"] = reason
|
|
|
|
if system_message:
|
|
response["systemMessage"] = system_message
|
|
|
|
return response
|
|
|
|
|
|
def pretooluse_bash_hook(hook_data: dict[str, object]) -> JsonObject:
|
|
"""Handle PreToolUse hook for Bash commands.
|
|
|
|
Args:
|
|
hook_data: Hook input data containing tool_name and tool_input.
|
|
|
|
Returns:
|
|
Hook response with permission decision.
|
|
"""
|
|
tool_name = str(hook_data.get("tool_name", ""))
|
|
|
|
# Only analyze Bash commands
|
|
if tool_name != "Bash":
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
tool_input_raw = hook_data.get("tool_input", {})
|
|
if not isinstance(tool_input_raw, dict):
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
tool_input: dict[str, object] = dict(tool_input_raw)
|
|
command = str(tool_input.get("command", ""))
|
|
|
|
if not command:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
# Analyze command for violations
|
|
should_block, violations = _analyze_bash_command(command)
|
|
|
|
if not should_block:
|
|
return _create_hook_response("PreToolUse", "allow")
|
|
|
|
# Build denial message
|
|
violation_text = "\n".join(f" {v}" for v in violations)
|
|
message = (
|
|
f"🚫 Shell Command Blocked\n\n"
|
|
f"Violations:\n{violation_text}\n\n"
|
|
f"Command: {command[:200]}{'...' if len(command) > 200 else ''}\n\n"
|
|
f"Use Edit/Write tools to modify Python files with proper type safety."
|
|
)
|
|
|
|
return _create_hook_response(
|
|
"PreToolUse",
|
|
"deny",
|
|
message,
|
|
message,
|
|
)
|
|
|
|
|
|
def posttooluse_bash_hook(hook_data: dict[str, object]) -> JsonObject:
|
|
"""Handle PostToolUse hook for Bash commands.
|
|
|
|
Args:
|
|
hook_data: Hook output data containing tool_response.
|
|
|
|
Returns:
|
|
Hook response with decision.
|
|
"""
|
|
tool_name = str(hook_data.get("tool_name", ""))
|
|
|
|
# Only analyze Bash commands
|
|
if tool_name != "Bash":
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
# Extract command from hook data
|
|
tool_input_raw = hook_data.get("tool_input", {})
|
|
if not isinstance(tool_input_raw, dict):
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
tool_input: dict[str, object] = dict(tool_input_raw)
|
|
command = str(tool_input.get("command", ""))
|
|
|
|
# Check if command modified any Python files
|
|
# Look for file paths in the command
|
|
python_files: list[str] = []
|
|
for match in re.finditer(r"([^\s]+\.pyi?)\b", command):
|
|
file_path = match.group(1)
|
|
if Path(file_path).exists():
|
|
python_files.append(file_path)
|
|
|
|
if not python_files:
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
# Scan modified files for violations
|
|
violations: list[str] = []
|
|
for file_path in python_files:
|
|
try:
|
|
with open(file_path, encoding="utf-8") as file_handle:
|
|
content = file_handle.read()
|
|
|
|
has_violation, violation_type = _contains_forbidden_pattern(content)
|
|
if has_violation:
|
|
violations.append(
|
|
f"⛔ File '{Path(file_path).name}' contains {violation_type}",
|
|
)
|
|
except (OSError, UnicodeDecodeError):
|
|
# If we can't read the file, skip it
|
|
continue
|
|
|
|
if violations:
|
|
violation_text = "\n".join(f" {v}" for v in violations)
|
|
message = (
|
|
f"🚫 Post-Execution Violation Detected\n\n"
|
|
f"Violations:\n{violation_text}\n\n"
|
|
f"Shell command introduced forbidden patterns. "
|
|
f"Please revert changes and use proper typing."
|
|
)
|
|
|
|
return _create_hook_response(
|
|
"PostToolUse",
|
|
"",
|
|
message,
|
|
message,
|
|
decision="block",
|
|
)
|
|
|
|
return _create_hook_response("PostToolUse")
|
|
|
|
|
|
def stop_hook(_hook_data: dict[str, object]) -> JsonObject:
|
|
"""Handle Stop hook - final validation before completion.
|
|
|
|
Args:
|
|
_hook_data: Stop hook data (unused).
|
|
|
|
Returns:
|
|
Hook response with decision.
|
|
"""
|
|
# Get list of changed files from git
|
|
try:
|
|
git_path = which("git")
|
|
if git_path is None:
|
|
return _create_hook_response("Stop", decision="approve")
|
|
|
|
# Safe: invokes git with fixed arguments, no user input interpolation.
|
|
result = subprocess.run( # noqa: S603
|
|
[git_path, "diff", "--name-only", "--cached"],
|
|
capture_output=True,
|
|
text=True,
|
|
check=False,
|
|
timeout=10,
|
|
)
|
|
|
|
if result.returncode != 0:
|
|
# No git repo or no staged changes
|
|
return _create_hook_response("Stop", decision="approve")
|
|
|
|
changed_files = [
|
|
file_name.strip()
|
|
for file_name in result.stdout.split("\n")
|
|
if file_name.strip() and file_name.strip().endswith((".py", ".pyi"))
|
|
]
|
|
|
|
if not changed_files:
|
|
return _create_hook_response("Stop", decision="approve")
|
|
|
|
# Scan all changed Python files
|
|
violations: list[str] = []
|
|
for file_path in changed_files:
|
|
if not Path(file_path).exists():
|
|
continue
|
|
|
|
try:
|
|
with open(file_path, encoding="utf-8") as file_handle:
|
|
content = file_handle.read()
|
|
|
|
has_violation, violation_type = _contains_forbidden_pattern(content)
|
|
if has_violation:
|
|
violations.append(f"⛔ {file_path}: {violation_type}")
|
|
except (OSError, UnicodeDecodeError):
|
|
continue
|
|
|
|
if violations:
|
|
violation_text = "\n".join(f" {v}" for v in violations)
|
|
message = (
|
|
f"🚫 Final Validation Failed\n\n"
|
|
f"Violations:\n{violation_text}\n\n"
|
|
f"Please remove forbidden patterns before completing."
|
|
)
|
|
|
|
return _create_hook_response(
|
|
"Stop",
|
|
"",
|
|
message,
|
|
message,
|
|
decision="block",
|
|
)
|
|
|
|
return _create_hook_response("Stop", decision="approve")
|
|
|
|
except (OSError, subprocess.SubprocessError, TimeoutError) as exc:
|
|
# If validation fails, allow but warn
|
|
return _create_hook_response(
|
|
"Stop",
|
|
"",
|
|
f"Warning: Final validation error: {exc}",
|
|
f"Warning: Final validation error: {exc}",
|
|
decision="approve",
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
"""Main hook entry point."""
|
|
try:
|
|
# Read hook input from stdin
|
|
try:
|
|
hook_data: dict[str, object] = json.load(sys.stdin)
|
|
except json.JSONDecodeError:
|
|
fallback_response: JsonObject = {
|
|
"hookSpecificOutput": {
|
|
"hookEventName": "PreToolUse",
|
|
"permissionDecision": "allow",
|
|
},
|
|
}
|
|
sys.stdout.write(json.dumps(fallback_response))
|
|
return
|
|
|
|
# Detect hook type
|
|
response: JsonObject
|
|
|
|
if "tool_response" in hook_data or "tool_output" in hook_data:
|
|
# PostToolUse hook
|
|
response = posttooluse_bash_hook(hook_data)
|
|
elif hook_data.get("hookEventName") == "Stop":
|
|
# Stop hook
|
|
response = stop_hook(hook_data)
|
|
else:
|
|
# PreToolUse hook
|
|
response = pretooluse_bash_hook(hook_data)
|
|
|
|
sys.stdout.write(json.dumps(response))
|
|
|
|
# Handle exit codes based on hook output
|
|
hook_output_raw = response.get("hookSpecificOutput", {})
|
|
if not hook_output_raw or not isinstance(hook_output_raw, dict):
|
|
return
|
|
|
|
hook_output: dict[str, object] = hook_output_raw
|
|
permission_decision = hook_output.get("permissionDecision")
|
|
|
|
if permission_decision == "deny":
|
|
# Exit code 2: Blocking error
|
|
reason = str(
|
|
hook_output.get("permissionDecisionReason", "Permission denied"),
|
|
)
|
|
sys.stderr.write(reason)
|
|
sys.exit(2)
|
|
elif permission_decision == "ask":
|
|
# Exit code 2 for ask decisions
|
|
reason = str(
|
|
hook_output.get("permissionDecisionReason", "Permission request"),
|
|
)
|
|
sys.stderr.write(reason)
|
|
sys.exit(2)
|
|
|
|
# Check for Stop hook block decision
|
|
if response.get("decision") == "block":
|
|
reason = str(response.get("reason", "Validation failed"))
|
|
sys.stderr.write(reason)
|
|
sys.exit(2)
|
|
|
|
except (OSError, ValueError, subprocess.SubprocessError, TimeoutError) as exc:
|
|
# Unexpected error - use exit code 1 (non-blocking error)
|
|
sys.stderr.write(f"Hook error: {exc}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|