Source code for yoker.events.handlers

"""Event handlers for the Yoker event system."""

from collections.abc import Coroutine
from pathlib import Path
from typing import Any, Protocol, runtime_checkable

from rich.console import Console
from rich.style import Style

from yoker.events.spinner import LiveDisplay
from yoker.events.types import (
  CommandEvent,
  ContentChunkEvent,
  ContentEndEvent,
  ContentStartEvent,
  ErrorEvent,
  Event,
  EventType,
  SessionEndEvent,
  SessionStartEvent,
  ThinkingChunkEvent,
  ThinkingEndEvent,
  ThinkingStartEvent,
  ToolCallEvent,
  ToolContentEvent,
  ToolResultEvent,
  TurnEndEvent,
  TurnStartEvent,
)

# Styles for console output
THINKING_STYLE = Style(color="bright_black", dim=True)
ERROR_STYLE = Style(color="red", bold=True)
TOOL_STYLE = Style(color="cyan")


[docs] @runtime_checkable class EventHandler(Protocol): """Protocol for event handlers. Event handlers receive all events emitted during agent processing. Both sync and async handlers are supported: - Sync handlers: def __call__(self, event: Event) -> None - Async handlers: async def __call__(self, event: Event) -> None Handlers should complete quickly (ideally <100ms) to avoid blocking the event loop. For I/O-bound operations, use async handlers. Security Note: Handlers can access potentially sensitive data (tool results, file contents). Only register handlers from trusted sources. """ def __call__(self, event: Event) -> None | Coroutine[None, None, None]: """Handle an event. Args: event: The event to handle. Returns: None for sync handlers, coroutine for async handlers. """ ...
[docs] class ConsoleEventHandler: """Handles events by rendering to Rich console.""" def __init__( self, console: Console | None = None, show_thinking: bool = True, show_tool_calls: bool = True, wrap_width: int | None = None, version: str = "0.1.0", ) -> None: """Initialize the console handler. Args: console: Rich console (default: new Console). show_thinking: Whether to display thinking output. show_tool_calls: Whether to display tool call info. wrap_width: Optional width for wrapping streaming output. version: Version string to display in session start. """ self.console = console if console is not None else Console() self.show_thinking = show_thinking self.show_tool_calls = show_tool_calls self.wrap_width = wrap_width self.version = version # Live display - managed internally, not passed from outside self._live_display: LiveDisplay | None = None self._spinner_active = False # Track if spinner is currently active # State for wrapping and thinking tracking self._column = 0 self._thinking_shown = False # Track if thinking was displayed self._content_shown = False # Track if any content was shown this turn async def __call__(self, event: Event) -> None: """Handle an event by dispatching to the appropriate handler method. This is an async handler that dispatches to synchronous _handle_* methods. Rich console operations are thread-safe and fast (<10ms), so we keep the handler methods synchronous for simplicity. Args: event: The event to handle. """ match event.type: case EventType.SESSION_START: self._handle_session_start(event) # type: ignore[arg-type] case EventType.SESSION_END: self._handle_session_end(event) # type: ignore[arg-type] case EventType.TURN_START: self._handle_turn_start(event) # type: ignore[arg-type] case EventType.TURN_END: self._handle_turn_end(event) # type: ignore[arg-type] case EventType.THINKING_START: self._handle_thinking_start(event) # type: ignore[arg-type] case EventType.THINKING_CHUNK: self._handle_thinking_chunk(event) # type: ignore[arg-type] case EventType.THINKING_END: self._handle_thinking_end(event) # type: ignore[arg-type] case EventType.CONTENT_START: self._handle_content_start(event) # type: ignore[arg-type] case EventType.CONTENT_CHUNK: self._handle_content_chunk(event) # type: ignore[arg-type] case EventType.CONTENT_END: self._handle_content_end(event) # type: ignore[arg-type] case EventType.TOOL_CALL: self._handle_tool_call(event) # type: ignore[arg-type] case EventType.TOOL_RESULT: self._handle_tool_result(event) # type: ignore[arg-type] case EventType.TOOL_CONTENT: self._handle_tool_content(event) # type: ignore[arg-type] case EventType.ERROR: self._handle_error(event) # type: ignore[arg-type] case EventType.COMMAND: self._handle_command(event) # type: ignore[arg-type] def _handle_session_start(self, event: SessionStartEvent) -> None: """Handle session start event.""" self.console.print(f"Yoker v{self.version} - Using model: {event.model}") thinking_status = "enabled" if event.thinking_enabled else "disabled" self.console.print(f"Thinking mode: {thinking_status} (use /think on|off to toggle)") self.console.print("Type /help for available commands.") self.console.print("Press Ctrl+D (or Ctrl+Z on Windows) to quit.\n") def _handle_session_end(self, event: SessionEndEvent) -> None: """Handle session end event.""" self.console.print("\nGoodbye!") def _handle_turn_start(self, event: TurnStartEvent) -> None: """Handle turn start event.""" # Reset flags for new turn self._thinking_shown = False self._content_shown = False # Note: LiveDisplay is created lazily in _handle_thinking_start or # _handle_content_start when content actually starts. This prevents # showing a "Processing..." spinner during replay mode where events # come immediately without real-time delays. def _handle_turn_end(self, event: TurnEndEvent) -> None: """Handle turn end event.""" # Show stats and exit LiveDisplay if active if self._live_display: # Stop the spinner and exit Live display self._live_display.stop_spinner() self._live_display.__exit__(None, None, None) self._live_display = None self._spinner_active = False # Print stats directly to console (outside Live, ensures SVG capture) # Only show timing if we have actual timing data (non-zero duration or tokens) duration_s = event.total_duration_ms / 1000.0 total_tokens = event.prompt_eval_count + event.eval_count if event.total_duration_ms > 0 or total_tokens > 0: if total_tokens > 0: tokens_per_sec = total_tokens / duration_s if duration_s > 0 else 0 stats = f"⏱ {duration_s:.1f}s | {event.prompt_eval_count}+{event.eval_count}={total_tokens} tokens | {tokens_per_sec:.0f} tok/s" else: stats = f"⏱ {duration_s:.1f}s" self.console.print(stats, style="dim") else: # No timing data available - just print blank line for spacing self.console.print() else: # Without live display, add blank line after response self.console.print() def _handle_thinking_start(self, event: ThinkingStartEvent) -> None: """Handle thinking start event.""" if self.show_thinking: # Add separator if there was previous content (e.g., tool calls) # Check BEFORE setting the flag needs_separator = self._content_shown and self._live_display is None self._thinking_shown = True self._content_shown = True # Content was shown this turn # Create LiveDisplay if not already active if self._live_display is None: if needs_separator: self.console.print() self._live_display = LiveDisplay(console=self.console) self._live_display.__enter__() self._live_display.start_spinner() # Show spinner while streaming # Without live display, add newline before thinking else: self._print_wrapped("\n", style=THINKING_STYLE) def _handle_thinking_chunk(self, event: ThinkingChunkEvent) -> None: """Handle thinking chunk event.""" if self.show_thinking: if self._live_display: self._live_display.append_thinking(event.text) else: self._print_wrapped(event.text, style=THINKING_STYLE) def _handle_thinking_end(self, event: ThinkingEndEvent) -> None: """Handle thinking end event.""" if self.show_thinking: if not self._live_display: # Without live display, add newlines after thinking self._print_wrapped("\n\n") def _handle_content_start(self, event: ContentStartEvent) -> None: """Handle content start event.""" # Add separator before response if there was previous content # This includes thinking shown, tool calls, or spinner activity if self._thinking_shown: # Thinking was shown - add separator in LiveDisplay if self._live_display: self._live_display.append_response("\n") else: self.console.print() self._thinking_shown = False # Reset for next turn elif self._content_shown or self._spinner_active: # Tool calls were shown or spinner was active - add separator if self._live_display: self._live_display.append_response("\n") else: self.console.print() # Create new LiveDisplay if not active (e.g., after tool calls) if self._live_display is None: self._live_display = LiveDisplay(console=self.console) self._live_display.__enter__() self._live_display.start_spinner() # Show spinner while streaming def _handle_content_chunk(self, event: ContentChunkEvent) -> None: """Handle content chunk event.""" if self._live_display: self._live_display.append_response(event.text) else: self._print_wrapped(event.text) def _handle_content_end(self, event: ContentEndEvent) -> None: """Handle content end event.""" # Final newline - only needed without live display if not self._live_display: self.console.print() @staticmethod def _extract_filename(arguments: dict[str, Any]) -> str: """Extract filename from tool arguments. Args: arguments: Tool arguments dictionary. Returns: Filename (basename) of the path argument, or first arg value if no path. """ # Special case: git tool shows operation, not path if "operation" in arguments: return str(arguments["operation"]) # Look for common path argument names for key in ("file_path", "path", "filepath"): if key in arguments: return Path(arguments[key]).name # Fallback: use first argument value if arguments: first_value = next(iter(arguments.values())) return str(first_value) return "" @staticmethod def _capitalize(name: str) -> str: """Capitalize first letter of name for display. Args: name: Tool name to capitalize. Returns: Name with first letter capitalized. """ if name: return name[0].upper() + name[1:] return name def _handle_tool_call(self, event: ToolCallEvent) -> None: """Handle tool call event.""" if self.show_tool_calls: # Exit current LiveDisplay to freeze buffered content # The content is already visible on screen, we just stop live updating if self._live_display: self._live_display.stop_spinner() # Remove spinner before exiting self._live_display.__exit__(None, None, None) self._live_display = None self._spinner_active = False # Spinner is no longer active self._content_shown = True # Content was shown this turn tool_name = self._capitalize(event.tool_name) details = self._format_tool_details(event.tool_name, event.arguments) # Print tool call with newline separator from previous segment self.console.print(f"\n{tool_name} tool: {details}") def _format_tool_details(self, tool_name: str, arguments: dict[str, Any]) -> str: """Format tool arguments for display. Args: tool_name: Name of the tool. arguments: Tool arguments dictionary. Returns: Formatted string showing relevant arguments. """ # Special formatting for git tool: show operation, path, and args if tool_name == "git": operation = arguments.get("operation", "") path = arguments.get("path", "") args = arguments.get("args", {}) # Build details string parts = [operation] if path: parts.append(f"on {path}") if args: # Show key args (first 2 to keep it concise) args_str = ", ".join(f"{k}={v}" for k, v in list(args.items())[:2]) if len(args) > 2: args_str += ", ..." parts.append(f"({args_str})") return " ".join(parts) if parts else str(arguments) # Special formatting for web_search: show query if tool_name == "web_search": query = arguments.get("query", "") if query: return str(query) return str(arguments) # For other tools: show filename/path return self._extract_filename(arguments) def _handle_tool_result(self, event: ToolResultEvent) -> None: """Handle tool result event.""" if self.show_tool_calls: # Show success/failure indicator (outside Live context) # LiveDisplay was already exited in _handle_tool_call if event.success: self.console.print(" ✓ Success") else: # Show first 50 chars of result (error message) error_msg = event.result[:50] if event.result else "Failed" self.console.print(f" ✗ {error_msg}") # Create LiveDisplay with spinner for subsequent processing # This ensures spinner is visible between tool calls and next segment if self._live_display is None: self._live_display = LiveDisplay(console=self.console) self._live_display.__enter__() self._live_display.start_spinner() # Show spinner while streaming def _handle_tool_content(self, event: ToolContentEvent) -> None: """Handle tool content event. Displays content based on content_type: - 'full': Show full content with line numbers - 'diff': Show unified diff with colors - 'summary': Show operation summary only """ if not self.show_tool_calls: return # Exit LiveDisplay (created in ToolResult) to print content if self._live_display: self._live_display.stop_spinner() self._live_display.__exit__(None, None, None) self._live_display = None self._spinner_active = False # Tool content is printed outside Live context # (Live was exited in _handle_tool_call or above) # Get operation details _ = event.operation # Available for future use filename = Path(event.path).name # Dispatch based on content_type if event.content_type == "summary": self._show_summary(event, filename) elif event.content_type == "diff": self._show_diff_content(event, filename) else: # content_type == "full" self._show_full_content(event, filename) # Create LiveDisplay with spinner for subsequent processing if self._live_display is None: self._live_display = LiveDisplay(console=self.console) self._live_display.__enter__() self._live_display.start_spinner() # Show spinner while streaming def _show_summary(self, event: ToolContentEvent, filename: str) -> None: """Show operation summary. Args: event: ToolContentEvent with summary metadata. filename: Basename of file. """ operation = event.operation metadata = event.metadata if operation == "write": lines = metadata.get("lines", 0) is_new_file = metadata.get("is_new_file", False) is_binary = metadata.get("is_binary", False) if is_binary: byte_size = metadata.get("bytes", 0) self.console.print(f" {filename} ({byte_size // 1024} KB binary)") elif lines == 0: self.console.print(f" {filename} (0 lines, empty)") elif is_new_file: self.console.print(f" Creating new file {filename} ({lines} lines)") else: self.console.print(f" Overwriting {filename} ({lines} lines)") elif operation in ("insert_before", "insert_after"): line_number = metadata.get("line_number", 0) inserted_lines = metadata.get("inserted_lines", 1) self.console.print(f" Insert at line {line_number} in {filename}: {inserted_lines} line(s)") elif operation == "replace": self.console.print(f" Replace in {filename}") elif operation == "delete": line_number = metadata.get("line_number") if line_number: self.console.print(f" Delete line {line_number} in {filename}") else: self.console.print(f" Delete in {filename}") def _show_full_content(self, event: ToolContentEvent, filename: str) -> None: """Show full content with line numbers. Args: event: ToolContentEvent with content. filename: Basename of file. """ content = event.content metadata = event.metadata if content is None: # Fall back to summary self._show_summary(event, filename) return # Show header _ = event.operation # Available for future use self.console.print(f"\n {filename}") # Show content with line numbers lines = content.splitlines() for i, line in enumerate(lines, start=1): # Escape brackets in user content to prevent Rich markup escaped_line = line.replace("[", "\\[").replace("]", "\\]") self.console.print(f" {i:4d}{escaped_line}") # Show truncation indicator if needed if metadata.get("truncated"): original_lines = metadata.get("original_line_count", 0) remaining = original_lines - len(lines) self.console.print(f" ... ({remaining} more lines)") def _show_diff_content(self, event: ToolContentEvent, filename: str) -> None: """Show unified diff with colors. Args: event: ToolContentEvent with diff content. filename: Basename of file. """ content = event.content metadata = event.metadata if content is None: # Fall back to summary self._show_summary(event, filename) return # Show header self.console.print(f" {filename}") # Show diff with colors (using ANSI codes) lines = content.splitlines() for line in lines: # Skip file header lines if line.startswith("--- ") or line.startswith("+++ "): continue # Skip diff header if line.startswith("diff --"): continue # Escape brackets in user content escaped_line = line.replace("[", "\\[").replace("]", "\\]") # Color based on prefix (using Rich styles) if line.startswith("@@"): self.console.print(f" [cyan]{escaped_line}[/]") # Cyan elif line.startswith("-"): self.console.print(f" [red]{escaped_line}[/]") # Red elif line.startswith("+"): self.console.print(f" [green]{escaped_line}[/]") # Green else: self.console.print(f" {escaped_line}") # Show truncation indicator if needed if metadata.get("truncated"): original_lines = metadata.get("original_diff_lines", 0) remaining = original_lines - len(lines) self.console.print(f" ... ({remaining} more lines)") def _handle_error(self, event: ErrorEvent) -> None: """Handle error event.""" self.console.print( f"\n[Error] {event.error_type}: {event.message}", style=ERROR_STYLE, ) def _handle_command(self, event: CommandEvent) -> None: """Handle command event.""" if event.result: self.console.print(f"{event.result}\n") def _print_wrapped( self, text: str, style: Style | None = None, end: str = "", ) -> None: """Print text with optional wrapping at wrap_width. Uses word-aware wrapping - breaks at word boundaries when possible, only breaking mid-word when a single word exceeds the wrap width. Args: text: Text to print. style: Optional Rich style. end: String to append at the end (default: ""). """ if self.wrap_width is None: # No wrapping, use standard print self.console.print(text, style=style, end=end) return # Track current position in line current_line: list[str] = [] def flush_line() -> None: """Print current line and reset.""" nonlocal current_line if current_line: self.console.print("".join(current_line), style=style, end="") current_line = [] for char in text: if char == "\n": flush_line() self.console.print(style=style) self._column = 0 elif char == "\r": flush_line() self._column = 0 elif char == " ": # Space: check if adding it would exceed width if self._column + 1 > self.wrap_width: # Line break at word boundary flush_line() self.console.print(style=style) self._column = 0 else: current_line.append(char) self._column += 1 else: # Regular character if self._column >= self.wrap_width: # Break before this character if line is full flush_line() self.console.print(style=style) self._column = 0 current_line.append(char) self._column += 1 # Flush remaining content flush_line() if end: self.console.print(end, style=style, end="")