Source code for yoker.tools.search

"""Search tool implementation for Yoker.

Provides the SearchTool for searching files and their contents with
regex patterns (content search) and glob patterns (filename search).
"""

import fnmatch
import os
import re
import time
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING, Any

from yoker.logging import get_logger
from yoker.tools.base import Tool, ToolResult

if TYPE_CHECKING:
  from yoker.tools.guardrails import Guardrail

log = get_logger(__name__)


[docs] class SearchTool(Tool): """Tool for searching files and their contents. Supports two search modes: - 'content': Search within file contents (grep-like, using regex) - 'filename': Search file names (glob-like, using fnmatch) All searches respect allowed paths guardrails and enforce limits to prevent resource exhaustion. ReDoS Prevention: - Pattern length limit (500 characters) - Forbidden pattern detection (nested quantifiers) - Compile-time regex validation - File size filtering (skip large files) """ # Operational limits DEFAULT_MAX_RESULTS: int = 100 ABSOLUTE_MAX_RESULTS: int = 1000 MAX_FILE_SIZE_KB: int = 500 MAX_PATTERN_LENGTH: int = 500 # Timeout limits (in milliseconds) DEFAULT_TIMEOUT_MS: int = 5000 ABSOLUTE_TIMEOUT_MS: int = 30000 # Forbidden regex patterns that cause ReDoS # These patterns match dangerous constructs in user-provided regex FORBIDDEN_PATTERNS: tuple[str, ...] = ( # Nested quantifiers: (a+)+, (a*)*, (a+)*, etc. r"\([^)]*[+*][^)]*\)[+*]", # Alternation with nested quantifiers r"\([^)]*\|[^)]*\)[+*]", ) # Directories to skip during search SKIP_DIRS: frozenset[str] = frozenset( { ".git", "__pycache__", "node_modules", ".venv", "venv", "build", "dist", ".mypy_cache", ".pytest_cache", "htmlcov", ".tox", ".eggs", "*.egg-info", } ) def __init__(self, guardrail: "Guardrail | None" = None) -> None: """Initialize SearchTool with optional guardrail. Args: guardrail: Optional guardrail for parameter validation. """ super().__init__(guardrail=guardrail) @property def name(self) -> str: return "search" @property def description(self) -> str: return ( "Search for patterns in files. " "Use type='content' for grep-like regex search in file contents. " "Use type='filename' for find-like glob pattern matching." )
[docs] def get_schema(self) -> dict[str, Any]: """Return Ollama-compatible schema for the search tool. Returns: Dict with 'type': 'function' and function metadata. """ return { "type": "function", "function": { "name": self.name, "description": self.description, "parameters": { "type": "object", "properties": { "path": { "type": "string", "description": "Directory to search in", }, "pattern": { "type": "string", "description": ( "Search pattern. For 'content' type: regex pattern. " "For 'filename' type: glob pattern (e.g., '*.py')" ), }, "type": { "type": "string", "enum": ["content", "filename"], "description": ( "Type of search: 'content' searches within files, " "'filename' searches file names. Defaults to 'content'." ), }, "max_results": { "type": "integer", "description": ( f"Maximum results to return. Defaults to {self.DEFAULT_MAX_RESULTS}." ), "minimum": 1, "maximum": self.ABSOLUTE_MAX_RESULTS, }, "timeout_ms": { "type": "integer", "description": ( f"Maximum search time in milliseconds. Defaults to {self.DEFAULT_TIMEOUT_MS}." ), "minimum": 100, "maximum": self.ABSOLUTE_TIMEOUT_MS, }, }, "required": ["path"], }, }, }
[docs] async def execute(self, **kwargs: Any) -> ToolResult: """Search files with pattern matching and limits. Args: **kwargs: Must contain 'path'. May contain 'pattern', 'type', and 'max_results'. Returns: ToolResult with search results or error message. """ path_str = kwargs.get("path", "") if not path_str: return ToolResult( success=False, result="", error="Missing required parameter: path", ) # Defense-in-depth: validate via guardrail if provided if self._guardrail is not None: validation = self._guardrail.validate(self.name, kwargs) if not validation.valid: log.info( "search_guardrail_blocked", path=path_str, reason=validation.reason, ) return ToolResult( success=False, result="", error=validation.reason, ) # Parse and clamp parameters try: max_results = self._clamp( int(kwargs.get("max_results", self.DEFAULT_MAX_RESULTS)), 1, self.ABSOLUTE_MAX_RESULTS, ) except (ValueError, TypeError): return ToolResult( success=False, result="", error="Invalid numeric parameter: max_results", ) try: timeout_ms = self._clamp( int(kwargs.get("timeout_ms", self.DEFAULT_TIMEOUT_MS)), 100, self.ABSOLUTE_TIMEOUT_MS, ) except (ValueError, TypeError): return ToolResult( success=False, result="", error="Invalid numeric parameter: timeout_ms", ) search_type = kwargs.get("type", "content") if search_type not in ("content", "filename"): return ToolResult( success=False, result="", error=f"Invalid type: {search_type}. Must be 'content' or 'filename'", ) pattern = kwargs.get("pattern", "") if not pattern: pattern = "*" if search_type == "filename" else ".*" # Validate regex for content search if search_type == "content": is_valid, error = self._validate_regex(pattern) if not is_valid: return ToolResult(success=False, result="", error=error) # Validate path try: path = Path(path_str) if not path.exists(): return ToolResult( success=False, result="", error=f"Path not found: {path_str}", ) if not path.is_dir(): return ToolResult( success=False, result="", error=f"Path is not a directory: {path_str}", ) except PermissionError: return ToolResult(success=False, result="", error=f"Permission denied: {path_str}") except Exception as e: return ToolResult(success=False, result="", error=f"Invalid path: {e}") # Execute search try: if search_type == "content": matches, total, truncated, files_searched = self._search_content( path, pattern, max_results, timeout_ms ) result = { "success": True, "matches": matches, "total_matches": total, "truncated": truncated, "files_searched": files_searched, } else: matches, total, truncated = self._search_filename(path, pattern, max_results) result = { "success": True, "matches": matches, "total_matches": total, "truncated": truncated, } log.info( "search_success", path=str(path), type=search_type, pattern=pattern, total_matches=total, files_searched=result.get("files_searched"), ) return ToolResult(success=True, result=result) except PermissionError: return ToolResult(success=False, result="", error=f"Permission denied: {path_str}") except Exception as e: log.error("search_error", error=str(e)) return ToolResult(success=False, result="", error=f"Error searching: {e}")
def _clamp(self, value: int, minimum: int, maximum: int) -> int: """Clamp a value to a range. Args: value: Value to clamp. minimum: Minimum value. maximum: Maximum value. Returns: Clamped value. """ return max(minimum, min(value, maximum)) def _validate_regex(self, pattern: str) -> tuple[bool, str]: """Validate regex pattern for safety. Checks for: - Maximum pattern length - Dangerous constructs (nested quantifiers) - Compile-time regex validity Args: pattern: Regex pattern to validate. Returns: Tuple of (is_valid, error_message). """ # Length check if len(pattern) > self.MAX_PATTERN_LENGTH: return False, f"Pattern too long: max {self.MAX_PATTERN_LENGTH} characters" # Check for forbidden patterns (ReDoS vectors) for forbidden in self.FORBIDDEN_PATTERNS: if re.search(forbidden, pattern): return ( False, "Pattern rejected: potential ReDoS vulnerability (nested quantifiers)", ) # Try to compile the pattern try: re.compile(pattern) return True, "" except re.error as e: return False, f"Invalid regex pattern: {e}" def _walk_files(self, root: Path) -> Iterator[Path]: """Walk directory tree, yielding files. Skips: - Hidden files/directories (starting with .) - Common binary directories (.git, __pycache__, node_modules, etc.) - Symlinks Args: root: Root directory to walk. Yields: Path objects for each file found. """ for dirpath, dirnames, filenames in os.walk(root): # Skip hidden and binary directories dirnames[:] = [d for d in dirnames if not d.startswith(".") and d not in self.SKIP_DIRS] for filename in filenames: # Skip hidden files if filename.startswith("."): continue file_path = Path(dirpath) / filename # Note: symlink check moved to _search_content to handle PermissionError yield file_path def _search_content( self, root: Path, pattern: str, max_results: int, timeout_ms: int, ) -> tuple[list[dict[str, Any]], int, bool, int]: """Search file contents using regex. Args: root: Root directory to search. pattern: Regex pattern to match. max_results: Maximum results to return. timeout_ms: Maximum search time in milliseconds. Returns: Tuple of (matches, total_count, truncated, files_searched). """ matches: list[dict[str, Any]] = [] total_count = 0 truncated = False files_searched = 0 # Compile regex (already validated) regex = re.compile(pattern) max_size = self.MAX_FILE_SIZE_KB * 1024 # Track timeout start_time = time.monotonic() timeout_seconds = timeout_ms / 1000.0 for file_path in self._walk_files(root): # Check timeout periodically (on each file iteration) if time.monotonic() - start_time > timeout_seconds: truncated = True break files_searched += 1 try: # Skip symlinks (check here to catch PermissionError) if file_path.is_symlink(): continue # Skip large files if file_path.stat().st_size > max_size: continue # Read file content content = file_path.read_text(encoding="utf-8", errors="replace") # Search each line for line_num, line in enumerate(content.splitlines(), 1): if regex.search(line): total_count += 1 if len(matches) < max_results: matches.append( { "file": str(file_path), "line": line_num, "content": line.strip(), } ) except (UnicodeDecodeError, PermissionError, OSError): # Skip binary files and permission-denied files continue if len(matches) < total_count: truncated = True return matches, total_count, truncated, files_searched def _search_filename( self, root: Path, pattern: str, max_results: int, ) -> tuple[list[dict[str, Any]], int, bool]: """Search file names using glob pattern. Args: root: Root directory to search. pattern: Glob pattern to match. max_results: Maximum results to return. Returns: Tuple of (matches, total_count, truncated). """ matches: list[dict[str, Any]] = [] total_count = 0 truncated = False for file_path in self._walk_files(root): if fnmatch.fnmatch(file_path.name, pattern): total_count += 1 if len(matches) < max_results: matches.append({"file": str(file_path)}) if len(matches) < total_count: truncated = True return matches, total_count, truncated