Files
claude-scripts/hooks/bash_command_guard.py
Travis Vasceannie bfb7773096 feat: enhance bash command guard with file locking and hook chaining
- Introduced file-based locking in `bash_command_guard.py` to prevent concurrent execution issues.
- Added configuration for lock timeout and polling interval in `bash_guard_constants.py`.
- Implemented a new `hook_chain.py` to unify hook execution, allowing for sequential processing of guards.
- Updated `claude-code-settings.json` to support the new hook chaining mechanism.
- Refactored subprocess lock handling to improve reliability and prevent deadlocks.

This update improves the robustness of the hook system by ensuring that bash commands are executed in a controlled manner, reducing the risk of concurrency-related errors.
2025-10-26 02:14:12 +00:00

576 lines
17 KiB
Python

"""Shell command guard for Claude Code PreToolUse/PostToolUse hooks.
Prevents circumvention of type safety rules via shell commands that could inject
'Any' types or type ignore comments into Python files.
"""
import fcntl
import json
import os
import re
import subprocess
import sys
import tempfile
import time
from contextlib import contextmanager
from pathlib import Path
from shutil import which
from typing import TypedDict
# Handle both relative imports (when run as module) and direct imports (when run as script)
try:
from .bash_guard_constants import (
DANGEROUS_SHELL_PATTERNS,
FORBIDDEN_PATTERNS,
LOCK_POLL_INTERVAL_SECONDS,
LOCK_TIMEOUT_SECONDS,
PYTHON_FILE_PATTERNS,
)
except ImportError:
import bash_guard_constants
DANGEROUS_SHELL_PATTERNS = bash_guard_constants.DANGEROUS_SHELL_PATTERNS
FORBIDDEN_PATTERNS = bash_guard_constants.FORBIDDEN_PATTERNS
PYTHON_FILE_PATTERNS = bash_guard_constants.PYTHON_FILE_PATTERNS
LOCK_TIMEOUT_SECONDS = bash_guard_constants.LOCK_TIMEOUT_SECONDS
LOCK_POLL_INTERVAL_SECONDS = bash_guard_constants.LOCK_POLL_INTERVAL_SECONDS
class JsonObject(TypedDict, total=False):
"""Type for JSON-like objects."""
hookEventName: str
permissionDecision: str
permissionDecisionReason: str
decision: str
reason: str
systemMessage: str
hookSpecificOutput: dict[str, object]
# File-based lock for inter-process synchronization
def _get_lock_file() -> Path:
"""Get path to lock file for subprocess serialization."""
lock_dir = Path(tempfile.gettempdir()) / ".claude_hooks"
lock_dir.mkdir(exist_ok=True, mode=0o700)
return lock_dir / "subprocess.lock"
@contextmanager
def _subprocess_lock(timeout: float = LOCK_TIMEOUT_SECONDS):
"""Context manager for file-based subprocess locking.
Args:
timeout: Maximum time in seconds to wait for the lock. Non-positive
values attempt a single non-blocking acquisition.
Yields:
True if lock was acquired, False if timeout occurred.
"""
lock_file = _get_lock_file()
deadline = (
time.monotonic() + timeout if timeout and timeout > 0 else None
)
acquired = False
# Open or create lock file
with open(lock_file, "a") as f:
try:
while True:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
acquired = True
break
except (IOError, OSError):
if deadline is None:
break
remaining = deadline - time.monotonic()
if remaining <= 0:
break
time.sleep(min(LOCK_POLL_INTERVAL_SECONDS, remaining))
yield acquired
finally:
if acquired:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except (IOError, OSError):
pass
def _contains_forbidden_pattern(text: str) -> tuple[bool, str | None]:
"""Check if text contains any forbidden patterns.
Args:
text: The text to check for forbidden patterns.
Returns:
Tuple of (has_violation, matched_pattern_description)
"""
for pattern in FORBIDDEN_PATTERNS:
if re.search(pattern, text, re.IGNORECASE):
if "Any" in pattern:
return True, "typing.Any usage"
if "type.*ignore" in pattern:
return True, "type suppression comment"
return False, None
def _is_dangerous_shell_command(command: str) -> tuple[bool, str | None]:
"""Check if shell command uses dangerous patterns.
Args:
command: The shell command to analyze.
Returns:
Tuple of (is_dangerous, reason)
"""
# Check if command targets Python files
targets_python = any(
re.search(pattern, command) for pattern in PYTHON_FILE_PATTERNS
)
if not targets_python:
return False, None
# Allow operations on temporary files (they're not project files)
temp_dirs = [r"/tmp/", r"/var/tmp/", r"\.tmp/", r"tempfile"]
if any(re.search(temp_pattern, command) for temp_pattern in temp_dirs):
return False, None
# Check for dangerous shell patterns
for pattern in DANGEROUS_SHELL_PATTERNS:
if re.search(pattern, command):
tool_match = re.search(
r"\b(sed|awk|perl|ed|echo|printf|cat|tee|find|xargs|python|vim|nano|emacs)\b",
pattern,
)
tool_name = tool_match[1] if tool_match else "shell utility"
return True, f"Use of {tool_name} to modify Python files"
return False, None
def _command_contains_forbidden_injection(command: str) -> tuple[bool, str | None]:
"""Check if command attempts to inject forbidden patterns.
Args:
command: The shell command to analyze.
Returns:
Tuple of (has_injection, violation_description)
"""
# Check if the command itself contains forbidden patterns
has_violation, violation_type = _contains_forbidden_pattern(command)
if has_violation:
return True, violation_type
# Check for encoded or escaped patterns
# Handle common escape sequences
decoded_cmd = command.replace("\\n", "\n").replace("\\t", "\t")
decoded_cmd = re.sub(r"\\\s", " ", decoded_cmd)
has_violation, violation_type = _contains_forbidden_pattern(decoded_cmd)
if has_violation:
return True, f"{violation_type} (escaped)"
return False, None
def _analyze_bash_command(command: str) -> tuple[bool, list[str]]:
"""Analyze bash command for safety violations.
Args:
command: The bash command to analyze.
Returns:
Tuple of (should_block, list_of_violations)
"""
violations: list[str] = []
# Check for forbidden pattern injection
has_injection, injection_type = _command_contains_forbidden_injection(command)
if has_injection:
violations.append(f"⛔ Shell command attempts to inject {injection_type}")
# Check for dangerous shell patterns on Python files
is_dangerous, danger_reason = _is_dangerous_shell_command(command)
if is_dangerous:
violations.append(
f"{danger_reason} is forbidden - use Edit/Write tools instead",
)
return len(violations) > 0, violations
def _create_hook_response(
event_name: str,
permission: str = "",
reason: str = "",
system_message: str = "",
*,
decision: str | None = None,
) -> JsonObject:
"""Create standardized hook response.
Args:
event_name: Name of the hook event (PreToolUse, PostToolUse, Stop).
permission: Permission decision (allow, deny, ask).
reason: Reason for the decision.
system_message: System message to display.
decision: Decision for PostToolUse/Stop hooks (approve, block).
Returns:
JSON response object for the hook.
"""
hook_output: dict[str, object] = {
"hookEventName": event_name,
}
if permission:
hook_output["permissionDecision"] = permission
if reason:
hook_output["permissionDecisionReason"] = reason
response: JsonObject = {
"hookSpecificOutput": hook_output,
}
if permission:
response["permissionDecision"] = permission
if decision:
response["decision"] = decision
if reason:
response["reason"] = reason
if system_message:
response["systemMessage"] = system_message
return response
def pretooluse_bash_hook(hook_data: dict[str, object]) -> JsonObject:
"""Handle PreToolUse hook for Bash commands.
Args:
hook_data: Hook input data containing tool_name and tool_input.
Returns:
Hook response with permission decision.
"""
tool_name = str(hook_data.get("tool_name", ""))
# Only analyze Bash commands
if tool_name != "Bash":
return _create_hook_response("PreToolUse", "allow")
tool_input_raw = hook_data.get("tool_input", {})
if not isinstance(tool_input_raw, dict):
return _create_hook_response("PreToolUse", "allow")
tool_input: dict[str, object] = dict(tool_input_raw)
command = str(tool_input.get("command", ""))
if not command:
return _create_hook_response("PreToolUse", "allow")
# Analyze command for violations
should_block, violations = _analyze_bash_command(command)
if not should_block:
return _create_hook_response("PreToolUse", "allow")
# Build denial message
violation_text = "\n".join(f" {v}" for v in violations)
message = (
f"🚫 Shell Command Blocked\n\n"
f"Violations:\n{violation_text}\n\n"
f"Command: {command[:200]}{'...' if len(command) > 200 else ''}\n\n"
f"Use Edit/Write tools to modify Python files with proper type safety."
)
return _create_hook_response(
"PreToolUse",
"deny",
message,
message,
)
def posttooluse_bash_hook(hook_data: dict[str, object]) -> JsonObject:
"""Handle PostToolUse hook for Bash commands.
Args:
hook_data: Hook output data containing tool_response.
Returns:
Hook response with decision.
"""
tool_name = str(hook_data.get("tool_name", ""))
# Only analyze Bash commands
if tool_name != "Bash":
return _create_hook_response("PostToolUse")
# Extract command from hook data
tool_input_raw = hook_data.get("tool_input", {})
if not isinstance(tool_input_raw, dict):
return _create_hook_response("PostToolUse")
tool_input: dict[str, object] = dict(tool_input_raw)
command = str(tool_input.get("command", ""))
# Check if command modified any Python files
# Look for file paths in the command
python_files: list[str] = []
for match in re.finditer(r"([^\s]+\.pyi?)\b", command):
file_path = match.group(1)
if Path(file_path).exists():
python_files.append(file_path)
if not python_files:
return _create_hook_response("PostToolUse")
# Scan modified files for violations
violations: list[str] = []
for file_path in python_files:
try:
with open(file_path, encoding="utf-8") as file_handle:
content = file_handle.read()
has_violation, violation_type = _contains_forbidden_pattern(content)
if has_violation:
violations.append(
f"⛔ File '{Path(file_path).name}' contains {violation_type}",
)
except (OSError, UnicodeDecodeError):
# If we can't read the file, skip it
continue
if violations:
violation_text = "\n".join(f" {v}" for v in violations)
message = (
f"🚫 Post-Execution Violation Detected\n\n"
f"Violations:\n{violation_text}\n\n"
f"Shell command introduced forbidden patterns. "
f"Please revert changes and use proper typing."
)
return _create_hook_response(
"PostToolUse",
"",
message,
message,
decision="block",
)
return _create_hook_response("PostToolUse")
def _get_staged_python_files() -> list[str]:
"""Get list of staged Python files from git.
Returns:
List of file paths that are staged and end with .py or .pyi
"""
git_path = which("git")
if git_path is None:
return []
try:
# Acquire file-based lock to prevent subprocess concurrency issues
with _subprocess_lock(timeout=LOCK_TIMEOUT_SECONDS) as acquired:
if not acquired:
return []
# Safe: invokes git with fixed arguments, no user input interpolation.
result = subprocess.run( # noqa: S603
[git_path, "diff", "--name-only", "--cached"],
capture_output=True,
text=True,
check=False,
timeout=10,
)
if result.returncode != 0:
return []
return [
file_name.strip()
for file_name in result.stdout.split("\n")
if file_name.strip() and file_name.strip().endswith((".py", ".pyi"))
]
except (OSError, subprocess.SubprocessError, TimeoutError):
return []
def _check_files_for_violations(file_paths: list[str]) -> list[str]:
"""Scan files for forbidden patterns.
Args:
file_paths: List of file paths to check.
Returns:
List of violation messages.
"""
violations: list[str] = []
for file_path in file_paths:
if not Path(file_path).exists():
continue
try:
with open(file_path, encoding="utf-8") as file_handle:
content = file_handle.read()
has_violation, violation_type = _contains_forbidden_pattern(content)
if has_violation:
violations.append(f"{file_path}: {violation_type}")
except (OSError, UnicodeDecodeError):
continue
return violations
def stop_hook(_hook_data: dict[str, object]) -> JsonObject:
"""Handle Stop hook - final validation before completion.
Args:
_hook_data: Stop hook data (unused).
Returns:
Hook response with decision.
"""
# Get list of changed files from git
try:
changed_files = _get_staged_python_files()
if not changed_files:
return _create_hook_response("Stop", decision="approve")
if violations := _check_files_for_violations(changed_files):
violation_text = "\n".join(f" {v}" for v in violations)
message = (
f"🚫 Final Validation Failed\n\n"
f"Violations:\n{violation_text}\n\n"
f"Please remove forbidden patterns before completing."
)
return _create_hook_response(
"Stop",
"",
message,
message,
decision="block",
)
return _create_hook_response("Stop", decision="approve")
except (OSError, subprocess.SubprocessError, TimeoutError) as exc:
# If validation fails, allow but warn
return _create_hook_response(
"Stop",
"",
f"Warning: Final validation error: {exc}",
f"Warning: Final validation error: {exc}",
decision="approve",
)
def _handle_hook_exit_code(response: JsonObject) -> None:
"""Handle exit codes based on hook response.
Args:
response: Hook response object.
"""
hook_output_raw = response.get("hookSpecificOutput", {})
if not hook_output_raw or not isinstance(hook_output_raw, dict):
return
hook_output: dict[str, object] = hook_output_raw
permission_decision = hook_output.get("permissionDecision")
if permission_decision == "deny":
# Exit code 2: Blocking error
reason = str(
hook_output.get("permissionDecisionReason", "Permission denied"),
)
sys.stderr.write(reason)
sys.stderr.flush()
sys.exit(2)
if permission_decision == "ask":
# Exit code 2 for ask decisions
reason = str(
hook_output.get("permissionDecisionReason", "Permission request"),
)
sys.stderr.write(reason)
sys.stderr.flush()
sys.exit(2)
# Check for Stop hook block decision
if response.get("decision") == "block":
reason = str(response.get("reason", "Validation failed"))
sys.stderr.write(reason)
sys.stderr.flush()
sys.exit(2)
def _detect_hook_type(hook_data: dict[str, object]) -> JsonObject:
"""Detect hook type and route to appropriate handler.
Args:
hook_data: Hook input data.
Returns:
Hook response object.
"""
if "tool_response" in hook_data or "tool_output" in hook_data:
return posttooluse_bash_hook(hook_data)
if hook_data.get("hookEventName") == "Stop":
return stop_hook(hook_data)
return pretooluse_bash_hook(hook_data)
def main() -> None:
"""Main hook entry point."""
try:
# Read hook input from stdin
try:
hook_data: dict[str, object] = json.load(sys.stdin)
except json.JSONDecodeError:
fallback_response: JsonObject = {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "allow",
},
}
sys.stdout.write(json.dumps(fallback_response))
sys.stdout.write("\n")
sys.stdout.flush()
return
# Detect hook type and get response
response = _detect_hook_type(hook_data)
# Write response to stdout with explicit flush
sys.stdout.write(json.dumps(response))
sys.stdout.write("\n")
sys.stdout.flush()
# Handle exit codes
_handle_hook_exit_code(response)
except (OSError, ValueError, subprocess.SubprocessError, TimeoutError) as exc:
# Unexpected error - use exit code 1 (non-blocking error)
sys.stderr.write(f"Hook error: {exc}")
sys.stderr.flush()
sys.exit(1)
if __name__ == "__main__":
main()