Source code for yoker.tools.web_guardrail

"""Web guardrail for security enforcement.

Provides comprehensive security validation for web tools including:
- SSRF (Server-Side Request Forgery) protection
- Domain whitelist/blacklist
- Query sanitization
- Rate limiting
"""

import ipaddress
import logging
import re
import socket
import time
from collections import defaultdict
from dataclasses import dataclass, field
from threading import Lock
from typing import TYPE_CHECKING, Any
from urllib.parse import unquote, urlparse

from .base import ValidationResult
from .guardrails import Guardrail

if TYPE_CHECKING:
  pass

logger = logging.getLogger(__name__)

# IP ranges for SSRF protection
PRIVATE_CIDRS = [
  ipaddress.ip_network("10.0.0.0/8"),  # RFC 1918
  ipaddress.ip_network("172.16.0.0/12"),  # RFC 1918
  ipaddress.ip_network("192.168.0.0/16"),  # RFC 1918
  ipaddress.ip_network("127.0.0.0/8"),  # Loopback
  ipaddress.ip_network("169.254.0.0/16"),  # Link-local (cloud metadata)
  ipaddress.ip_network("::1/128"),  # IPv6 loopback
  ipaddress.ip_network("fe80::/10"),  # IPv6 link-local
  ipaddress.ip_network("fc00::/7"),  # IPv6 ULA
]

# Cloud metadata endpoints
METADATA_IPS = [
  "169.254.169.254",  # AWS/GCP/Azure metadata
]

# Sensitive patterns to block in queries
# Note: These block patterns that indicate actual secret exposure attempts,
# not documentation/tutorial searches. A web search for ".env tutorial" is fine.
SENSITIVE_PATTERNS = [
  r"password\s*=\s*['\"]?\S+",  # password='actual_password'
  r"api[_-]?key\s*=\s*['\"]?\S+",  # api_key='actual_key'
  r"apikey\s*=\s*['\"]?\S+",  # apikey='actual_key'
  r"secret\s*=\s*['\"]?\S+",  # secret='actual_secret'
  r"token\s*=\s*['\"]?\S+",  # token='actual_token'
  r"credentials\s*=\s*['\"]?\S+",  # credentials='actual_creds'
  r"private[_-]?key\s*=\s*['\"]?\S+",  # private_key='actual_key'
  r"bearer\s+['\"]?[a-zA-Z0-9_-]+",  # bearer actual_token
]

# Unicode characters to strip
INVISIBLE_UNICODE = [
  "​",  # Zero-width space
  "‌",  # Zero-width non-joiner
  "‍",  # Zero-width joiner
  "",  # BOM
]


@dataclass
class RateLimitState:
  """Rate limiting state for a user/session.

  Attributes:
    requests_per_minute: Request timestamps for minute window.
    requests_per_hour: Request timestamps for hour window.
    concurrent_requests: Current concurrent request count.
    last_reset: Timestamp of last rate limit reset.
  """

  requests_per_minute: list[float] = field(default_factory=list)
  requests_per_hour: list[float] = field(default_factory=list)
  concurrent_requests: int = 0
  last_reset: float = field(default_factory=time.time)


[docs] @dataclass class WebGuardrailConfig: """Configuration for WebGuardrail. Attributes: max_query_length: Maximum query string length (default 500). domain_allowlist: Domains to allow (empty = all allowed). domain_blocklist: Domains to block (empty = none blocked). requests_per_minute: Maximum requests per minute (0 = unlimited). requests_per_hour: Maximum requests per hour (0 = unlimited). max_concurrent_requests: Maximum concurrent requests (0 = unlimited). block_private_cidrs: Whether to block private IP ranges. timeout_seconds: Search timeout in seconds. require_https: Whether to require HTTPS URLs (block HTTP). """ max_query_length: int = 500 domain_allowlist: tuple[str, ...] = () domain_blocklist: tuple[str, ...] = () requests_per_minute: int = 60 requests_per_hour: int = 1000 max_concurrent_requests: int = 0 block_private_cidrs: bool = True timeout_seconds: int = 30 require_https: bool = True
[docs] class WebGuardrail(Guardrail): """Guardrail for web tool validation. Validates: - Query length (prevents excessive queries) - Domain allowlist (optional, for restricted searches) - Domain blocklist (optional, for blocked domains) - SSRF protection (blocks private IPs, cloud metadata) - Query sanitization (blocks sensitive patterns) - Rate limiting (requests per minute, per hour, concurrent) Note: Domain filtering is client-side validation only. Ollama backend may still access blocked domains. For full control, use LocalWebSearchBackend. """ def __init__(self, config: "WebGuardrailConfig | None" = None) -> None: """Initialize guardrail with configuration. Args: config: WebGuardrailConfig with validation settings. """ self._config = config or WebGuardrailConfig() self._rate_limit_lock = Lock() self._rate_limits: dict[str, RateLimitState] = defaultdict(RateLimitState)
[docs] def validate(self, tool_name: str, params: dict[str, Any]) -> ValidationResult: """Validate web search parameters. Steps: 1. Validate query is present and non-empty. 2. Validate query length <= max_query_length. 3. Check for SSRF attempts (private IPs, cloud metadata). 4. Check domain allowlist if configured. 5. Check domain blocklist if configured. 6. Check for sensitive patterns in query. 7. Check rate limits. Args: tool_name: Name of tool being validated. params: Tool parameters from LLM. Returns: ValidationResult with success/failure and reason. """ # Step 1: Validate query parameter query = params.get("query", "") if not query: return ValidationResult(valid=False, reason="Query is required") # Strip whitespace and check stripped_query = query.strip() if not stripped_query: return ValidationResult(valid=False, reason="Query cannot be empty or whitespace") # Step 2: Validate query length if len(stripped_query) > self._config.max_query_length: return ValidationResult( valid=False, reason=f"Query exceeds maximum length: {len(stripped_query)} > {self._config.max_query_length}", ) # Step 3: Strip invisible Unicode and validate cleaned_query = self._strip_invisible_unicode(stripped_query) # Step 4: Check for SSRF attempts ssrf_error = self._check_ssrf(cleaned_query) if ssrf_error: return ValidationResult(valid=False, reason=ssrf_error) # Step 5: Check domain allowlist/blocklist if self._config.domain_allowlist: allow_error = self._check_domain_allowlist(cleaned_query) if allow_error: return ValidationResult(valid=False, reason=allow_error) if self._config.domain_blocklist: block_error = self._check_domain_blocklist(cleaned_query) if block_error: return ValidationResult(valid=False, reason=block_error) # Step 6: Check for sensitive patterns sensitive_error = self._check_sensitive_patterns(cleaned_query) if sensitive_error: return ValidationResult(valid=False, reason=sensitive_error) # Step 7: Check rate limits user_id = params.get("_user_id", "default") rate_error = self._check_rate_limit(user_id) if rate_error: return ValidationResult(valid=False, reason=rate_error) return ValidationResult(valid=True)
def _strip_invisible_unicode(self, text: str) -> str: """Strip invisible Unicode characters from text. Args: text: Input text. Returns: Text with invisible characters removed. """ result = text for char in INVISIBLE_UNICODE: result = result.replace(char, "") # Strip Unicode Tag characters (U+E0000-U+E007F) # These are invisible characters that can be used for prompt injection result = re.sub(r"[\U000e0000-\U000e007f]", "", result) return result def _check_ssrf(self, query: str) -> str | None: """Check for SSRF attempts in query. Checks for: - Private IP addresses (10.x, 172.16-31.x, 192.168.x) - Cloud metadata IPs (169.254.169.254) - Localhost (127.0.0.1, localhost) - URL-encoded IPs - Hex-encoded IPs - Decimal-encoded IPs - IPv6 private ranges Args: query: Search query string. Returns: Error message if SSRF detected, None if safe. """ if not self._config.block_private_cidrs: return None # Extract potential IPs/URLs from query # Check for IP patterns ip_pattern = r"\b(?:(?:\d{1,3}\.){3}\d{1,3})\b" # IPv6 pattern (supports compressed notation like fe80::1, ::1, etc.) ipv6_pattern = r"\b(?:[0-9a-fA-F]{1,4}(?::[0-9a-fA-F]{1,4})*|::(?:[0-9a-fA-F]{1,4}(?::[0-9a-fA-F]{1,4})*)?)\b" # Find all potential IPs potential_ips = re.findall(ip_pattern, query) potential_ips.extend(re.findall(ipv6_pattern, query)) # Also check for URL-encoded IPs try: decoded_query = unquote(query) potential_ips.extend(re.findall(ip_pattern, decoded_query)) except Exception: pass # Check hex-encoded IPs (0xa9fea9fe = 169.254.169.254) hex_pattern = r"0x([0-9a-fA-F]+)" hex_matches = re.findall(hex_pattern, query) for hex_val in hex_matches: try: decimal_val = int(hex_val, 16) if decimal_val < 2**32: # Convert to IP (big-endian: most significant byte first) ip_str = ".".join(str((decimal_val >> (8 * (3 - i))) & 0xFF) for i in range(4)) potential_ips.append(ip_str) except ValueError: pass # Check decimal-encoded IPs decimal_pattern = r"\b(\d{8,12})\b" decimal_matches = re.findall(decimal_pattern, query) for dec_val in decimal_matches: try: decimal_val = int(dec_val) if decimal_val < 2**32: # Convert to IP (big-endian: most significant byte first) ip_str = ".".join(str((decimal_val >> (8 * (3 - i))) & 0xFF) for i in range(4)) potential_ips.append(ip_str) except ValueError: pass # Validate each potential IP for ip_str in potential_ips: try: # Check for IPv4-mapped IPv6 if ip_str.lower().startswith("::ffff:"): ip_str = ip_str[7:] # Remove ::ffff: prefix ip = ipaddress.ip_address(ip_str) # Check against private CIDRs for cidr in PRIVATE_CIDRS: if ip in cidr: return f"SSRF blocked: private IP address detected ({ip_str})" # Check for cloud metadata IP if str(ip) in METADATA_IPS: return "SSRF blocked: cloud metadata endpoint detected" except ValueError: # Not a valid IP, might be a domain pass # Check for localhost if re.search(r"\blocalhost\b", query, re.IGNORECASE): return "SSRF blocked: localhost detected" # Check for domains that might resolve to private IPs # Extract domains from query (must contain at least one dot) url_with_scheme_pattern = r"https?://([a-zA-Z0-9.-]+)" bare_domain_pattern = r"(?:^|\s)([a-zA-Z0-9-]+\.[a-zA-Z0-9.-]+)(?:/|\s|$)" site_pattern = r"\bsite:\s*([a-zA-Z0-9.-]+)" domains = re.findall(url_with_scheme_pattern, query) domains.extend(re.findall(bare_domain_pattern, query)) domains.extend(re.findall(site_pattern, query)) for domain in domains: if not self._is_safe_domain(domain): return f"SSRF blocked: domain may resolve to private IP ({domain})" return None def _is_safe_domain(self, domain: str) -> bool: """Check if a domain is safe (doesn't resolve to private IP). Note: This performs DNS resolution which can be slow. We cache results in production implementations. Args: domain: Domain name to check. Returns: True if domain is safe, False if it resolves to private IP. """ try: # Resolve domain to IPs infos = socket.getaddrinfo(domain, None) for info in infos: ip_str = info[4][0] try: ip = ipaddress.ip_address(ip_str) # Check against private CIDRs for cidr in PRIVATE_CIDRS: if ip in cidr: return False # Check metadata IP if str(ip) in METADATA_IPS: return False except ValueError: continue return True except socket.gaierror: # Cannot resolve - allow but log logger.debug(f"Could not resolve domain: {domain}") return True def _check_domain_allowlist(self, query: str) -> str | None: """Check if query violates domain allowlist. Note: This is a heuristic check. It looks for domain patterns in the query string. For full control, use LocalWebSearchBackend which can filter results by domain. Args: query: Search query string. Returns: Error message if blocked, None if allowed. """ if not self._config.domain_allowlist: return None # Extract domains from query # Pattern 1: URL pattern with scheme (http://, https://) url_with_scheme_pattern = r"https?://([a-zA-Z0-9.-]+)" # Pattern 2: Bare domain (must contain at least one dot to be a domain) bare_domain_pattern = r"(?:^|\s)([a-zA-Z0-9-]+\.[a-zA-Z0-9.-]+)(?:/|\s|$)" # Pattern 3: site: directive site_pattern = r"\bsite:\s*([a-zA-Z0-9.-]+)" domains = re.findall(url_with_scheme_pattern, query.lower()) domains.extend(re.findall(bare_domain_pattern, query.lower())) domains.extend(re.findall(site_pattern, query.lower())) # If no domains found, allow (no domain restriction in query) if not domains: return None # Check each domain against allowlist for domain in domains: if not self._domain_matches_list(domain, self._config.domain_allowlist): return f"Query contains non-whitelisted domain: {domain}" return None def _check_domain_blocklist(self, query: str) -> str | None: """Check if query matches blocked domains. Args: query: Search query string. Returns: Error message if blocked, None if allowed. """ if not self._config.domain_blocklist: return None # Extract domains from query # Pattern 1: URL pattern with scheme (http://, https://) url_with_scheme_pattern = r"https?://([a-zA-Z0-9.-]+)" # Pattern 2: Bare domain (must contain at least one dot to be a domain) bare_domain_pattern = r"(?:^|\s)([a-zA-Z0-9-]+\.[a-zA-Z0-9.-]+)(?:/|\s|$)" # Pattern 3: site: directive site_pattern = r"\bsite:\s*([a-zA-Z0-9.-]+)" domains = re.findall(url_with_scheme_pattern, query.lower()) domains.extend(re.findall(bare_domain_pattern, query.lower())) domains.extend(re.findall(site_pattern, query.lower())) # Check each domain against blocklist for domain in domains: if self._domain_matches_list(domain, self._config.domain_blocklist): return f"Query contains blocked domain: {domain}" return None def _domain_matches_list(self, domain: str, patterns: tuple[str, ...]) -> bool: """Check if domain matches any pattern in list. Supports wildcard matching: - "*.example.com" matches "api.example.com", "v1.api.example.com" - "example.com" matches only "example.com" Args: domain: Domain to check (lowercase). patterns: List of patterns (may include wildcards). Returns: True if domain matches any pattern. """ domain_lower = domain.lower() for pattern in patterns: pattern_lower = pattern.lower() if pattern_lower.startswith("*."): # Wildcard pattern - match any subdomain suffix = pattern_lower[1:] # Remove "*" if domain_lower.endswith(suffix) or domain_lower == pattern_lower[2:]: return True else: # Exact match if domain_lower == pattern_lower: return True return False def _check_sensitive_patterns(self, query: str) -> str | None: """Check for sensitive patterns in query. Args: query: Search query string. Returns: Error message if sensitive pattern found, None if safe. """ for pattern in SENSITIVE_PATTERNS: if re.search(pattern, query, re.IGNORECASE): return "Query contains sensitive pattern" return None def _check_rate_limit(self, user_id: str) -> str | None: """Check rate limits for user. Args: user_id: User/session identifier. Returns: Error message if rate limited, None if allowed. """ if ( self._config.requests_per_minute == 0 and self._config.requests_per_hour == 0 and self._config.max_concurrent_requests == 0 ): return None with self._rate_limit_lock: state = self._rate_limits[user_id] current_time = time.time() # Clean old timestamps minute_ago = current_time - 60 hour_ago = current_time - 3600 state.requests_per_minute = [t for t in state.requests_per_minute if t > minute_ago] state.requests_per_hour = [t for t in state.requests_per_hour if t > hour_ago] # Check per-minute limit if ( self._config.requests_per_minute > 0 and len(state.requests_per_minute) >= self._config.requests_per_minute ): return f"Rate limit exceeded: {self._config.requests_per_minute} requests per minute" # Check per-hour limit if ( self._config.requests_per_hour > 0 and len(state.requests_per_hour) >= self._config.requests_per_hour ): return f"Rate limit exceeded: {self._config.requests_per_hour} requests per hour" # Check concurrent limit if ( self._config.max_concurrent_requests > 0 and state.concurrent_requests >= self._config.max_concurrent_requests ): return f"Rate limit exceeded: {self._config.max_concurrent_requests} concurrent requests" # Record this request state.requests_per_minute.append(current_time) state.requests_per_hour.append(current_time) state.concurrent_requests += 1 return None
[docs] def release_concurrent(self, user_id: str = "default") -> None: """Release a concurrent request slot. Call this after search completes to decrement concurrent count. Args: user_id: User/session identifier. """ with self._rate_limit_lock: if user_id in self._rate_limits: state = self._rate_limits[user_id] if state.concurrent_requests > 0: state.concurrent_requests -= 1
[docs] def validate_url(self, url: str) -> ValidationResult: """Validate a URL for web fetch. Steps: 1. Validate URL format (scheme, host, etc.). 2. Check for SSRF attempts (private IPs, metadata endpoints). 3. Check domain allowlist if configured. 4. Check domain blocklist if configured. 5. Check HTTPS requirement if configured. Args: url: URL string to validate. Returns: ValidationResult with success/failure and reason. """ # Parse URL try: parsed = urlparse(url) except Exception: return ValidationResult(valid=False, reason="Invalid URL format") # Check scheme if self._config.require_https and parsed.scheme != "https": return ValidationResult(valid=False, reason="Only HTTPS URLs are allowed") # Extract host host = parsed.hostname if not host: return ValidationResult(valid=False, reason="URL must have a host") # Check SSRF (private IPs, metadata endpoints) if self._config.block_private_cidrs: ssrf_error = self._check_ssrf_for_host(host) if ssrf_error: return ValidationResult(valid=False, reason=ssrf_error) # Check domain allowlist if self._config.domain_allowlist: if not self._domain_matches_list(host, self._config.domain_allowlist): return ValidationResult(valid=False, reason=f"Domain not in allowlist: {host}") # Check domain blocklist if self._config.domain_blocklist: if self._domain_matches_list(host, self._config.domain_blocklist): return ValidationResult(valid=False, reason=f"Domain is blocked: {host}") return ValidationResult(valid=True)
def _check_ssrf_for_host(self, host: str) -> str | None: """Check if a host resolves to a private IP. Resolves the hostname and checks against private CIDRs. Also checks for metadata IP addresses. Args: host: Hostname or IP address. Returns: Error message if SSRF detected, None if safe. """ # Check for localhost if host.lower() in ("localhost", "localhost.localdomain"): return "SSRF blocked: localhost detected" # Check for IP address patterns try: ip = ipaddress.ip_address(host) # Check against private CIDRs for cidr in PRIVATE_CIDRS: if ip in cidr: return f"SSRF blocked: private IP address detected ({host})" # Check metadata IP if str(ip) in METADATA_IPS: return "SSRF blocked: cloud metadata endpoint detected" except ValueError: # Not an IP, resolve hostname if not self._is_safe_domain(host): return f"SSRF blocked: domain may resolve to private IP ({host})" return None
__all__ = [ "WebGuardrail", "WebGuardrailConfig", "RateLimitState", ]