Source code for yoker.tools.write

"""Write tool implementation for Yoker.

Provides the WriteTool for writing file contents with guardrail validation,
overwrite protection, and explicit parent directory handling.
"""

import os
from pathlib import Path
from typing import TYPE_CHECKING, Any

from yoker.config.schema import Config
from yoker.logging import get_logger
from yoker.tools.base import Tool, ToolResult

if TYPE_CHECKING:
  from yoker.tools.guardrails import Guardrail

log = get_logger(__name__)


def _is_binary(content: str) -> bool:
  """Check if content appears to be binary.

  Checks for null bytes in the first 8KB of content.

  Args:
    content: Content to check.

  Returns:
    True if content appears to be binary, False otherwise.
  """
  # Check first 8KB for null bytes
  check_size = min(len(content), 8192)
  return "\x00" in content[:check_size]


def _truncate_content(
  content: str,
  max_lines: int,
  max_bytes: int,
) -> tuple[str, bool, int, int]:
  """Truncate content based on max lines and max bytes.

  Args:
    content: Content to truncate.
    max_lines: Maximum lines to include.
    max_bytes: Maximum bytes to include.

  Returns:
    Tuple of (truncated_content, was_truncated, original_lines, original_bytes).
  """
  original_bytes = len(content.encode("utf-8"))
  lines = content.splitlines(keepends=True)
  original_lines_count = len(lines)

  # Truncate by lines first
  if len(lines) > max_lines:
    lines = lines[:max_lines]
    was_truncated = True
  else:
    was_truncated = False

  # Truncate by bytes
  truncated_content = "".join(lines)
  truncated_bytes = len(truncated_content.encode("utf-8"))

  if truncated_bytes > max_bytes:
    # Truncate to max_bytes
    truncated_content = truncated_content[:max_bytes]
    was_truncated = True

  return truncated_content, was_truncated, original_lines_count, original_bytes


[docs] class WriteTool(Tool): """Tool for writing file contents. Writes content to a file with defense-in-depth validation. When a guardrail is provided, validates parameters before writing. Resolves paths with realpath, rejects symlinks, and supports overwrite protection and parent directory creation. Error messages returned to the LLM are sanitized to avoid leaking filesystem structure. Full paths are logged internally for debugging. """ def __init__( self, guardrail: "Guardrail | None" = None, config: Config | None = None, ) -> None: """Initialize WriteTool with optional guardrail and config. Args: guardrail: Optional guardrail for parameter validation. config: Optional config for overwrite protection and size limits. If not provided, defaults to Config() (allow_overwrite=False). """ super().__init__(guardrail=guardrail) self._config = config or Config() @property def name(self) -> str: return "write" @property def description(self) -> str: return "Write content to a file"
[docs] def get_schema(self) -> dict[str, Any]: """Return Ollama-compatible schema for the write tool. Returns: Dict with 'type': 'function' and function metadata. """ return { "type": "function", "function": { "name": self.name, "description": self.description, "parameters": { "type": "object", "properties": { "path": { "type": "string", "description": "Path to the file to write", }, "content": { "type": "string", "description": "Content to write to the file", }, "create_parents": { "type": "boolean", "description": ("If true, create missing parent directories. Defaults to false."), }, }, "required": ["path", "content"], }, }, }
[docs] async def execute(self, **kwargs: Any) -> ToolResult: """Write content to a file. Steps: 1. Validate parameters via guardrail if provided. 2. Extract and validate path and content parameters. 3. Resolve the path with os.path.realpath(). 4. Reject symlinks unless explicitly allowed. 5. Check overwrite protection (config-based). 6. Create parent directories if requested. 7. Write with UTF-8 encoding. 8. Log write for audit trail. 9. Populate content_metadata for content display. Args: **kwargs: Must contain 'path' and 'content' keys. May contain 'create_parents' (default False). Returns: ToolResult with success status and output or error message. """ path_str = kwargs.get("path", "") content = kwargs.get("content", "") create_parents = bool(kwargs.get("create_parents", False)) # Defense-in-depth: validate via guardrail if provided if self._guardrail is not None: validation = self._guardrail.validate(self.name, kwargs) if not validation.valid: log.info( "write_guardrail_blocked", path=path_str, reason=validation.reason, ) return ToolResult( success=False, result="", error=validation.reason, ) # Ensure path is a non-empty string if not isinstance(path_str, str) or not path_str.strip(): log.warning("write_invalid_path_type", path_type=type(path_str).__name__) return ToolResult(success=False, result="", error="Invalid path parameter") # Ensure content is a string (empty content is valid) if not isinstance(content, str): log.warning( "write_invalid_content_type", content_type=type(content).__name__, ) return ToolResult(success=False, result="", error="Invalid content parameter") # Reject symlinks before resolving to prevent traversal via symlinks original_path = Path(path_str) if original_path.is_symlink(): log.warning("write_symlink_rejected", path=path_str) return ToolResult( success=False, result="", error="Writing to symlinks is not permitted", ) # Resolve the path to prevent traversal and normalize try: resolved = Path(os.path.realpath(path_str)) except (OSError, ValueError): log.warning("write_invalid_path", path=path_str) return ToolResult(success=False, result="", error="Invalid path") # Check overwrite protection is_overwrite = resolved.exists() if is_overwrite: allow_overwrite = self._config.tools.write.allow_overwrite if not allow_overwrite: log.info( "write_overwrite_blocked", path=str(resolved), ) return ToolResult( success=False, result="", error="File already exists and overwrite is not permitted", ) # Check parent directory parent = resolved.parent if not parent.exists(): if create_parents: try: parent.mkdir(parents=True, exist_ok=True) log.info( "write_created_parents", path=str(parent), ) except OSError as e: log.error( "write_create_parents_failed", path=str(parent), error=str(e), ) return ToolResult( success=False, result="", error="Failed to create parent directories", ) else: log.info("write_parent_missing", path=str(resolved)) return ToolResult( success=False, result="", error="Parent directory does not exist", ) # Write the file with explicit encoding try: resolved.write_text(content, encoding="utf-8") log.info( "write_success", path=str(resolved), bytes=len(content.encode("utf-8")), ) # Build content_metadata for content display content_metadata = self._build_content_metadata( content=content, resolved_path=resolved, is_overwrite=is_overwrite, ) return ToolResult( success=True, result="File written successfully", content_metadata=content_metadata, ) except PermissionError: log.warning("write_permission_denied", path=str(resolved)) return ToolResult(success=False, result="", error="Permission denied") except OSError as e: log.error("write_os_error", path=str(resolved), error=str(e)) return ToolResult(success=False, result="", error="Error writing file")
def _build_content_metadata( self, content: str, resolved_path: Path, is_overwrite: bool, ) -> dict[str, Any] | None: """Build content_metadata for ToolResult. Args: content: Content that was written. resolved_path: Resolved file path. is_overwrite: Whether this was an overwrite operation. Returns: Content metadata dict, or None if verbosity is 'silent'. """ content_display = self._config.tools.content_display # Check verbosity - return None for silent mode if content_display.verbosity == "silent": return None # Detect binary content is_binary = _is_binary(content) if is_binary: # Binary content: return summary only byte_size = len(content.encode("utf-8")) return { "operation": "write", "path": str(resolved_path), "content_type": "summary", "content": None, "metadata": { "lines": 0, "bytes": byte_size, "is_new_file": not is_overwrite, "is_overwrite": is_overwrite, "is_binary": True, }, } # Count lines and bytes lines = content.splitlines() line_count = len(lines) byte_size = len(content.encode("utf-8")) # Check if empty is_empty = line_count == 0 # Determine content type based on verbosity if content_display.verbosity == "summary": # Summary mode: return line count only return { "operation": "write", "path": str(resolved_path), "content_type": "summary", "content": None, "metadata": { "lines": line_count, "bytes": byte_size, "is_new_file": not is_overwrite, "is_overwrite": is_overwrite, "is_empty": is_empty, }, } # Content mode: return full content (possibly truncated) truncated_content, was_truncated, _, _ = _truncate_content( content, content_display.max_content_lines, content_display.max_content_bytes, ) metadata: dict[str, Any] = { "lines": line_count, "bytes": byte_size, "is_new_file": not is_overwrite, "is_overwrite": is_overwrite, "is_empty": is_empty, } if was_truncated: metadata["truncated"] = True metadata["original_line_count"] = line_count return { "operation": "write", "path": str(resolved_path), "content_type": "full", "content": truncated_content, "metadata": metadata, }