From 0609f1ac5d1c4a24286b8d7fd793106e98b5cb7e Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Thu, 28 May 2026 10:01:32 +0300 Subject: [PATCH] feat(helexa-acp): add tools, session modes, and permission gating Stage 3 introduces five tools (read_file, write_file, edit_file, list_dir, bash) backed by ACP fs/* and terminal/* calls, a ClientOps trait so the runner is mock-testable, two session modes (default + bypassPermissions) with session/set_mode honouring them, and a tool-call loop in the agent that streams the model, dispatches each call, feeds results back into history, and re-enters until the model finishes or MAX_TOOL_ROUNDS is hit. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/helexa-acp/src/agent.rs | 354 ++++++--- crates/helexa-acp/src/main.rs | 2 + crates/helexa-acp/src/session.rs | 16 +- crates/helexa-acp/src/tool_runner.rs | 1049 ++++++++++++++++++++++++++ crates/helexa-acp/src/tools.rs | 179 +++++ 5 files changed, 1514 insertions(+), 86 deletions(-) create mode 100644 crates/helexa-acp/src/tool_runner.rs create mode 100644 crates/helexa-acp/src/tools.rs diff --git a/crates/helexa-acp/src/agent.rs b/crates/helexa-acp/src/agent.rs index c4e7af0..4e7b970 100644 --- a/crates/helexa-acp/src/agent.rs +++ b/crates/helexa-acp/src/agent.rs @@ -1,19 +1,17 @@ -//! ACP agent loop — text-only (Stage 2). +//! ACP agent loop with tools and session modes (Stage 3). //! //! Handlers: //! -//! | ACP method | Behaviour | -//! |-------------------|------------------------------------------------------------| -//! | `initialize` | echo client's protocol version, advertise capabilities | -//! | `session/new` | mint a session id, register state, return it | -//! | `session/prompt` | flatten user blocks → history, stream provider → updates | -//! | `session/cancel` | fire the session's cancellation token | -//! | (anything else) | "not implemented yet" error | +//! | ACP method | Behaviour | +//! |-----------------------|-------------------------------------------------------------| +//! | `initialize` | echo protocol version, advertise capabilities | +//! | `session/new` | mint id, register state, advertise [Default, Bypass] modes | +//! | `session/prompt` | tool-call loop: stream → dispatch tools → re-enter, repeat | +//! | `session/cancel` | fire the session's cancellation token | +//! | `session/set_mode` | mutate the session's mode (gated vs. bypass-permissions) | +//! | (anything else) | "not implemented yet" error | //! -//! Stage 3 adds tool calls; Stage 4 wires `session/set_model`; Stage 5 -//! flips on image content. Stage 2 deliberately answers the model-picker -//! and session-modes fields with `None` so editors render a single model -//! / single mode UI. +//! Stage 4 wires `session/set_model`; Stage 5 flips on image content. use std::path::PathBuf; use std::sync::Arc; @@ -22,18 +20,28 @@ use std::sync::atomic::{AtomicU64, Ordering}; use agent_client_protocol::schema::{ AgentCapabilities, CancelNotification, ContentBlock, InitializeRequest, InitializeResponse, NewSessionRequest, NewSessionResponse, PromptCapabilities, PromptRequest, PromptResponse, - SessionId, SessionNotification, SessionUpdate, StopReason, TextContent, + SessionId, SessionMode, SessionModeId, SessionModeState, SessionNotification, SessionUpdate, + SetSessionModeRequest, SetSessionModeResponse, StopReason, TextContent, }; use agent_client_protocol::{Agent as AgentRole, Client, ConnectionTo, Dispatch, Stdio}; use futures::StreamExt; +use std::collections::BTreeMap; use tokio_util::sync::CancellationToken; use crate::config::{Config, parse_model_selector}; use crate::prompt::build_system_prompt; use crate::provider::{ - CompletionEvent, CompletionRequest, Message, MessageContent, Provider, Role, + CompletionEvent, CompletionRequest, Message, MessageContent, Provider, Role, ToolCall, }; -use crate::session::{self, SessionState, SessionStore}; +use crate::session::{self, MODE_BYPASS, MODE_DEFAULT, SessionState, SessionStore}; +use crate::tool_runner::{AcpClientOps, ToolCallEvent, dispatch_tool_call}; +use crate::tools; + +/// Maximum number of provider→tool→provider round-trips per +/// `session/prompt` request. Bound exists to keep a runaway model +/// from looping forever; the spec maps this to +/// [`StopReason::MaxTurnRequests`]. +const MAX_TOOL_ROUNDS: usize = 25; /// Public entry point. Wraps an `Arc` so handlers can clone /// it cheaply into every closure. @@ -126,6 +134,18 @@ impl Agent { }, agent_client_protocol::on_receive_request!(), ) + .on_receive_request( + { + let inner = inner.clone(); + async move |req: SetSessionModeRequest, responder, _cx| { + match handle_set_session_mode(&inner, req).await { + Ok(()) => responder.respond(SetSessionModeResponse::new()), + Err(e) => responder.respond_with_internal_error(format!("{e:#}")), + } + } + }, + agent_client_protocol::on_receive_request!(), + ) .on_receive_notification( { let inner = inner.clone(); @@ -187,7 +207,48 @@ async fn handle_new_session( cwd = %cwd_display, "session created" ); - Ok(NewSessionResponse::new(session_id)) + Ok(NewSessionResponse::new(session_id).modes(default_mode_state())) +} + +/// The two modes every Stage 3 session advertises. Stage 7 may grow +/// this list (e.g. "plan" for plan-only output, "ask" for read-only), +/// but Default + Bypass cover the two operationally distinct +/// permission policies. +fn default_mode_state() -> SessionModeState { + SessionModeState::new( + SessionModeId::new(MODE_DEFAULT), + vec![ + SessionMode::new(SessionModeId::new(MODE_DEFAULT), "Default") + .description("Prompt for permission before writes or shell commands."), + SessionMode::new(SessionModeId::new(MODE_BYPASS), "Bypass Permissions") + .description("Auto-allow all tool calls. Use with care."), + ], + ) +} + +async fn handle_set_session_mode( + inner: &AgentInner, + req: SetSessionModeRequest, +) -> anyhow::Result<()> { + let Some(state) = session::get(&inner.sessions, &req.session_id).await else { + anyhow::bail!("unknown session id {}", req.session_id.0); + }; + let accepted = req.mode_id.0.as_ref() == MODE_DEFAULT || req.mode_id.0.as_ref() == MODE_BYPASS; + if !accepted { + anyhow::bail!( + "unknown mode '{}' — must be one of: {}, {}", + req.mode_id.0, + MODE_DEFAULT, + MODE_BYPASS + ); + } + state.lock().await.mode_id = req.mode_id.clone(); + tracing::info!( + session_id = %req.session_id.0, + mode = %req.mode_id.0, + "session mode changed" + ); + Ok(()) } async fn handle_cancel(inner: &AgentInner, notif: CancelNotification) { @@ -239,11 +300,11 @@ async fn drive_prompt( return Ok(()); }; - // Snapshot the inputs to the upstream call under the session - // lock, then drop the lock before any `await` that touches the - // network. We *also* install a fresh cancellation token so - // `session/cancel` can fire only this prompt. - let (mut history, model_id, cwd, cancel) = { + // Snapshot the inputs under the session lock, then drop the lock + // before any `await` that touches the network. `mode_id` is + // refreshed between tool rounds (the user can toggle modes + // mid-turn). + let (existing_history, model_id, cwd, cancel, mut mode_id) = { let mut state = session_arc.lock().await; let cancel = CancellationToken::new(); state.cancel = cancel.clone(); @@ -257,6 +318,7 @@ async fn drive_prompt( state.model_id.clone(), state.cwd.clone(), cancel, + state.mode_id.clone(), ) }; @@ -276,98 +338,220 @@ async fn drive_prompt( session_id = %session_id.0, endpoint = %provider.name(), model = %local_model, - history_turns = history.len(), + mode = %mode_id.0, + history_turns = existing_history.len(), "sending prompt upstream" ); - let mut messages = Vec::with_capacity(history.len() + 1); + let ops = AcpClientOps::new(cx.clone()); + + // `messages` is the rolling conversation we send to the provider + // each round. We seed it with the system prompt + the snapshot + // (which includes the new user turn) and grow it with each + // round's assistant turn + tool-result turns. + let mut messages: Vec = Vec::with_capacity(existing_history.len() + 1); messages.push(Message { role: Role::System, content: MessageContent::Text(system_prompt), }); - messages.append(&mut history); + messages.extend(existing_history); - let completion_req = CompletionRequest { - model: local_model, - messages, - tools: vec![], - temperature: None, - top_p: None, - max_tokens: None, - }; + // Whatever new turns this prompt generates beyond the user's + // input — we persist these to session.history at the end so + // future prompts see them. + let mut new_turns: Vec = Vec::new(); - let stream_result = provider.complete(completion_req, cancel.clone()).await; - let mut stream = match stream_result { - Ok(s) => s, - Err(e) => { - let _ = responder - .respond_with_internal_error(format!("{} complete: {e:#}", provider.name())); - return Ok(()); - } - }; - - let mut assistant_text = String::new(); + let tool_specs = tools::all_tools(); let mut stop_reason = StopReason::EndTurn; - while let Some(event) = stream.next().await { - let event = match event { - Ok(e) => e, + for round in 0..MAX_TOOL_ROUNDS { + if cancel.is_cancelled() { + stop_reason = StopReason::Cancelled; + break; + } + + let completion_req = CompletionRequest { + model: local_model.clone(), + messages: messages.clone(), + tools: tool_specs.clone(), + temperature: None, + top_p: None, + max_tokens: None, + }; + + let mut stream = match provider.complete(completion_req, cancel.clone()).await { + Ok(s) => s, Err(e) => { - tracing::warn!(error = %format!("{e:#}"), "stream error; ending turn"); - break; + let _ = responder + .respond_with_internal_error(format!("{} complete: {e:#}", provider.name())); + return Ok(()); } }; - match event { - CompletionEvent::TextDelta(t) => { - assistant_text.push_str(&t); - send_chunk( - &cx, - &session_id, - SessionUpdate::AgentMessageChunk(text_chunk(t)), - ); + + let mut assistant_text = String::new(); + let mut finish_reason: Option = None; + // `BTreeMap` keyed by the provider's tool-call index keeps + // insertion order while allowing arg deltas to mutate any + // bucket — `ToolCallStart` may arrive interleaved with + // `ToolCallArgsDelta` for different indices. + let mut tool_buckets: BTreeMap = BTreeMap::new(); + + while let Some(event) = stream.next().await { + let event = match event { + Ok(e) => e, + Err(e) => { + tracing::warn!(error = %format!("{e:#}"), "stream error; ending round"); + break; + } + }; + match event { + CompletionEvent::TextDelta(t) => { + assistant_text.push_str(&t); + send_chunk( + &cx, + &session_id, + SessionUpdate::AgentMessageChunk(text_chunk(t)), + ); + } + CompletionEvent::ReasoningDelta(t) => { + send_chunk( + &cx, + &session_id, + SessionUpdate::AgentThoughtChunk(text_chunk(t)), + ); + } + CompletionEvent::ToolCallStart { index, id, name } => { + tool_buckets.insert( + index, + ToolCallBucket { + id, + name, + arguments: String::new(), + }, + ); + } + CompletionEvent::ToolCallArgsDelta { index, args_delta } => { + tool_buckets + .entry(index) + .or_default() + .arguments + .push_str(&args_delta); + } + CompletionEvent::Finish { reason } => finish_reason = reason, + CompletionEvent::Usage(_) => {} } - CompletionEvent::ReasoningDelta(t) => { - send_chunk( - &cx, - &session_id, - SessionUpdate::AgentThoughtChunk(text_chunk(t)), - ); + } + + if cancel.is_cancelled() { + stop_reason = StopReason::Cancelled; + // Persist any partial text so the next turn has context. + if !assistant_text.is_empty() { + new_turns.push(Message { + role: Role::Assistant, + content: MessageContent::Text(assistant_text), + }); } - CompletionEvent::Finish { reason } => { - stop_reason = map_finish_reason(reason.as_deref()); + break; + } + + let has_tool_calls = !tool_buckets.is_empty(); + + if !has_tool_calls { + // Terminal turn: just text. Save and finish. + if !assistant_text.is_empty() { + new_turns.push(Message { + role: Role::Assistant, + content: MessageContent::Text(assistant_text), + }); } - // Stage 2 ignores tool calls and usage. Tool calls land in - // Stage 3; usage telemetry isn't in the (non-unstable) - // PromptResponse, so there's nothing to attach it to today. - CompletionEvent::ToolCallStart { .. } - | CompletionEvent::ToolCallArgsDelta { .. } - | CompletionEvent::Usage(_) => {} + stop_reason = map_finish_reason(finish_reason.as_deref()); + break; + } + + // Assistant turn carrying the tool calls. + let calls: Vec = tool_buckets + .values() + .map(|b| ToolCall { + id: b.id.clone(), + name: b.name.clone(), + arguments: b.arguments.clone(), + }) + .collect(); + let assistant_turn = Message { + role: Role::Assistant, + content: MessageContent::ToolCalls { + text: (!assistant_text.is_empty()).then_some(assistant_text), + calls, + }, + }; + new_turns.push(assistant_turn.clone()); + messages.push(assistant_turn); + + // Refresh the mode in case the user toggled it during the + // streaming above (cheap — one mutex acquisition). + mode_id = session_arc.lock().await.mode_id.clone(); + + // Dispatch every tool call sequentially. Parallelism is + // tempting but would require Zed to handle interleaved + // permission prompts; serial is friendlier. + for bucket in tool_buckets.into_values() { + if cancel.is_cancelled() { + stop_reason = StopReason::Cancelled; + break; + } + let event = ToolCallEvent { + id: bucket.id, + name: bucket.name, + arguments: bucket.arguments, + }; + let result = + dispatch_tool_call(&ops, &session_id, &mode_id, &cwd, event, &cancel).await; + let result_turn = Message { + role: Role::Tool, + content: MessageContent::ToolResult { + tool_call_id: result.tool_call_id, + content: result.content, + }, + }; + new_turns.push(result_turn.clone()); + messages.push(result_turn); + } + + if cancel.is_cancelled() { + stop_reason = StopReason::Cancelled; + break; + } + + if round + 1 == MAX_TOOL_ROUNDS { + tracing::warn!( + session_id = %session_id.0, + rounds = MAX_TOOL_ROUNDS, + "hit MAX_TOOL_ROUNDS, returning MaxTurnRequests" + ); + stop_reason = StopReason::MaxTurnRequests; } } - // If cancellation fired, override whatever finish reason we got - // (or didn't get). Per spec: a `session/cancel` MUST result in - // `StopReason::Cancelled`, regardless of partial output. - if cancel.is_cancelled() { - stop_reason = StopReason::Cancelled; - } - - // Re-acquire the lock just long enough to persist the assistant - // turn (even partial output, so future turns have the context). { let mut state = session_arc.lock().await; - if !assistant_text.is_empty() { - state.history.push(Message { - role: Role::Assistant, - content: MessageContent::Text(assistant_text), - }); - } + state.history.extend(new_turns); } let _ = responder.respond(PromptResponse::new(stop_reason)); Ok(()) } +/// Accumulator for one streamed tool call: the OpenAI wire format +/// sends `id` + `name` once (in the first chunk for that index) and +/// then argument bytes piecemeal. We gather them all before +/// dispatching. +#[derive(Debug, Default)] +struct ToolCallBucket { + id: String, + name: String, + arguments: String, +} + fn send_chunk(cx: &ConnectionTo, session_id: &SessionId, update: SessionUpdate) { let notif = SessionNotification::new(session_id.clone(), update); if let Err(e) = cx.send_notification(notif) { diff --git a/crates/helexa-acp/src/main.rs b/crates/helexa-acp/src/main.rs index abe40ed..9ca9523 100644 --- a/crates/helexa-acp/src/main.rs +++ b/crates/helexa-acp/src/main.rs @@ -20,6 +20,8 @@ mod config; mod prompt; mod provider; mod session; +mod tool_runner; +mod tools; use agent::Agent; use config::{Config, EndpointConfig, WireApi}; diff --git a/crates/helexa-acp/src/session.rs b/crates/helexa-acp/src/session.rs index 18c5954..d13de34 100644 --- a/crates/helexa-acp/src/session.rs +++ b/crates/helexa-acp/src/session.rs @@ -18,12 +18,20 @@ use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; -use agent_client_protocol::schema::SessionId; +use agent_client_protocol::schema::{SessionId, SessionModeId}; use tokio::sync::{Mutex, RwLock}; use tokio_util::sync::CancellationToken; use crate::provider::Message; +/// Mode id advertised as the gated default. Writes / bash prompt for +/// permission via `session/request_permission`. +pub const MODE_DEFAULT: &str = "default"; + +/// Mode id advertised as "auto-allow everything". Matches the +/// favorite name (`bypassPermissions`) Zed clients tend to reference. +pub const MODE_BYPASS: &str = "bypassPermissions"; + /// State carried for a single ACP session. /// /// Mutated under `Mutex`; never share a clone across @@ -50,6 +58,11 @@ pub struct SessionState { /// token is "spent" — firing it does nothing — which is fine, /// `session/cancel` is a no-op when there's nothing to cancel. pub cancel: CancellationToken, + /// Permission gating mode. Stage 3 advertises two ids in + /// `NewSessionResponse.modes`: [`MODE_DEFAULT`] (writes / bash + /// prompt the user) and [`MODE_BYPASS`] (auto-allow). Mutated by + /// `session/set_mode`. + pub mode_id: SessionModeId, } impl SessionState { @@ -59,6 +72,7 @@ impl SessionState { cwd, model_id, cancel: CancellationToken::new(), + mode_id: SessionModeId::new(MODE_DEFAULT), } } } diff --git a/crates/helexa-acp/src/tool_runner.rs b/crates/helexa-acp/src/tool_runner.rs new file mode 100644 index 0000000..2441061 --- /dev/null +++ b/crates/helexa-acp/src/tool_runner.rs @@ -0,0 +1,1049 @@ +//! Execute the LLM's tool calls against the editor client. +//! +//! Each tool call goes through `dispatch_tool_call`, which: +//! +//! 1. Emits `SessionUpdate::ToolCall { status: Pending }` so the +//! editor can show "agent is about to do X". +//! 2. If the tool is gated (write / edit / bash) and the session mode +//! is the default, asks the user via `session/request_permission`. +//! Bypass mode skips this step. +//! 3. Executes the tool by calling the appropriate ACP client method +//! (`fs/read_text_file`, `fs/write_text_file`, `terminal/create` + +//! `terminal/wait_for_exit` + `terminal/output` + `terminal/release`) +//! or, for `list_dir`, a local `std::fs` call. +//! 4. Emits `SessionUpdate::ToolCallUpdate { status: Completed | Failed }` +//! with the result content (text, diff, or error). +//! 5. Returns a [`ToolResult`] string that the agent loop folds back +//! into the model's conversation history. +//! +//! Client-side ACP calls are abstracted behind the [`ClientOps`] +//! trait. Production wires it to a real `ConnectionTo`; tests +//! pass in a recording fake. + +use std::path::{Path, PathBuf}; + +use agent_client_protocol::schema::{ + ContentBlock, CreateTerminalRequest, Diff, KillTerminalRequest, PermissionOption, + PermissionOptionId, PermissionOptionKind, ReadTextFileRequest, ReleaseTerminalRequest, + RequestPermissionOutcome, RequestPermissionRequest, SessionId, SessionModeId, + SessionNotification, SessionUpdate, TerminalExitStatus, TerminalId, TerminalOutputRequest, + TerminalOutputResponse, TextContent, ToolCall, ToolCallContent, ToolCallId, ToolCallStatus, + ToolCallUpdate, ToolCallUpdateFields, ToolKind, WaitForTerminalExitRequest, + WriteTextFileRequest, +}; +use agent_client_protocol::{Client, ConnectionTo, util::internal_error}; +use async_trait::async_trait; +use serde::Deserialize; +use serde_json::json; +use tokio_util::sync::CancellationToken; + +use crate::session::{MODE_BYPASS, MODE_DEFAULT}; +use crate::tools::{BASH, EDIT_FILE, LIST_DIR, READ_FILE, WRITE_FILE}; + +/// Accumulated state of a single tool call streamed from the +/// provider. The agent loop gathers `ToolCallStart` + N +/// `ToolCallArgsDelta` events into one of these before dispatch. +#[derive(Debug, Clone)] +pub struct ToolCallEvent { + /// Provider-assigned id (e.g. OpenAI's `call_…`). Used as the + /// `tool_call_id` in both the assistant turn and the tool-result + /// turn we feed back to the model. + pub id: String, + pub name: String, + /// Concatenated JSON argument bytes. Parsed lazily by the runner. + pub arguments: String, +} + +/// What the runner sends back to the agent loop after a dispatch. +#[derive(Debug, Clone)] +pub struct ToolResult { + /// Echoes [`ToolCallEvent::id`] so the agent can build the + /// `MessageContent::ToolResult { tool_call_id, … }` history entry. + pub tool_call_id: String, + /// Human/agent-readable result text. Always non-empty: errors are + /// stringified so the model can react to them. + pub content: String, + /// True for failures (so the agent can decide whether to stop on + /// repeated tool errors — currently unused but worth surfacing). + #[allow(dead_code)] + pub is_error: bool, +} + +/// Client-side ACP RPCs the runner needs. Real wiring lives in +/// [`AcpClientOps`]; tests use a recording fake. +#[async_trait] +pub trait ClientOps: Send + Sync { + async fn read_text_file( + &self, + session: &SessionId, + path: PathBuf, + line: Option, + limit: Option, + ) -> anyhow::Result; + + async fn write_text_file( + &self, + session: &SessionId, + path: PathBuf, + content: String, + ) -> anyhow::Result<()>; + + async fn request_permission( + &self, + session: &SessionId, + tool_call: ToolCallUpdate, + options: Vec, + ) -> anyhow::Result; + + async fn create_terminal( + &self, + session: &SessionId, + command: String, + args: Vec, + cwd: Option, + ) -> anyhow::Result; + + async fn wait_for_terminal_exit( + &self, + session: &SessionId, + terminal: &TerminalId, + ) -> anyhow::Result; + + async fn terminal_output( + &self, + session: &SessionId, + terminal: &TerminalId, + ) -> anyhow::Result; + + async fn kill_terminal(&self, session: &SessionId, terminal: &TerminalId) + -> anyhow::Result<()>; + + async fn release_terminal( + &self, + session: &SessionId, + terminal: &TerminalId, + ) -> anyhow::Result<()>; + + /// Fire-and-forget. Failures are logged inside the impl, not + /// propagated — losing a `session/update` is non-fatal. + fn send_session_update(&self, session: &SessionId, update: SessionUpdate); +} + +/// Production wrapper around a live ACP connection. +pub struct AcpClientOps { + cx: ConnectionTo, +} + +impl AcpClientOps { + pub fn new(cx: ConnectionTo) -> Self { + Self { cx } + } +} + +#[async_trait] +impl ClientOps for AcpClientOps { + async fn read_text_file( + &self, + session: &SessionId, + path: PathBuf, + line: Option, + limit: Option, + ) -> anyhow::Result { + let mut req = ReadTextFileRequest::new(session.clone(), path); + req = req.line(line).limit(limit); + let resp = self + .cx + .send_request(req) + .block_task() + .await + .map_err(|e| anyhow::anyhow!("fs/read_text_file: {e}"))?; + Ok(resp.content) + } + + async fn write_text_file( + &self, + session: &SessionId, + path: PathBuf, + content: String, + ) -> anyhow::Result<()> { + let req = WriteTextFileRequest::new(session.clone(), path, content); + self.cx + .send_request(req) + .block_task() + .await + .map_err(|e| anyhow::anyhow!("fs/write_text_file: {e}"))?; + Ok(()) + } + + async fn request_permission( + &self, + session: &SessionId, + tool_call: ToolCallUpdate, + options: Vec, + ) -> anyhow::Result { + let req = RequestPermissionRequest::new(session.clone(), tool_call, options); + let resp = self + .cx + .send_request(req) + .block_task() + .await + .map_err(|e| anyhow::anyhow!("session/request_permission: {e}"))?; + Ok(resp.outcome) + } + + async fn create_terminal( + &self, + session: &SessionId, + command: String, + args: Vec, + cwd: Option, + ) -> anyhow::Result { + let mut req = CreateTerminalRequest::new(session.clone(), command).args(args); + req = req.cwd(cwd); + let resp = self + .cx + .send_request(req) + .block_task() + .await + .map_err(|e| anyhow::anyhow!("terminal/create: {e}"))?; + Ok(resp.terminal_id) + } + + async fn wait_for_terminal_exit( + &self, + session: &SessionId, + terminal: &TerminalId, + ) -> anyhow::Result { + let req = WaitForTerminalExitRequest::new(session.clone(), terminal.clone()); + let resp = self + .cx + .send_request(req) + .block_task() + .await + .map_err(|e| anyhow::anyhow!("terminal/wait_for_exit: {e}"))?; + Ok(resp.exit_status) + } + + async fn terminal_output( + &self, + session: &SessionId, + terminal: &TerminalId, + ) -> anyhow::Result { + let req = TerminalOutputRequest::new(session.clone(), terminal.clone()); + let resp = self + .cx + .send_request(req) + .block_task() + .await + .map_err(|e| anyhow::anyhow!("terminal/output: {e}"))?; + Ok(resp) + } + + async fn kill_terminal( + &self, + session: &SessionId, + terminal: &TerminalId, + ) -> anyhow::Result<()> { + let req = KillTerminalRequest::new(session.clone(), terminal.clone()); + self.cx + .send_request(req) + .block_task() + .await + .map_err(|e| anyhow::anyhow!("terminal/kill: {e}"))?; + Ok(()) + } + + async fn release_terminal( + &self, + session: &SessionId, + terminal: &TerminalId, + ) -> anyhow::Result<()> { + let req = ReleaseTerminalRequest::new(session.clone(), terminal.clone()); + self.cx + .send_request(req) + .block_task() + .await + .map_err(|e| anyhow::anyhow!("terminal/release: {e}"))?; + Ok(()) + } + + fn send_session_update(&self, session: &SessionId, update: SessionUpdate) { + let notif = SessionNotification::new(session.clone(), update); + if let Err(e) = self.cx.send_notification(notif) { + tracing::warn!(error = %internal_error(format!("{e}")), "session/update notification dropped"); + } + } +} + +// ── Tool argument shapes ───────────────────────────────────────────── + +#[derive(Debug, Deserialize)] +struct ReadFileArgs { + path: PathBuf, + #[serde(default)] + line: Option, + #[serde(default)] + limit: Option, +} + +#[derive(Debug, Deserialize)] +struct WriteFileArgs { + path: PathBuf, + content: String, +} + +#[derive(Debug, Deserialize)] +struct EditFileArgs { + path: PathBuf, + old_text: String, + new_text: String, +} + +#[derive(Debug, Deserialize)] +struct ListDirArgs { + path: PathBuf, +} + +#[derive(Debug, Deserialize)] +struct BashArgs { + command: String, + #[serde(default)] + cwd: Option, +} + +// ── Dispatch ───────────────────────────────────────────────────────── + +/// Tools whose default-mode behaviour is to ask the user first. +pub fn is_gated(tool_name: &str) -> bool { + matches!(tool_name, WRITE_FILE | EDIT_FILE | BASH) +} + +/// Map a tool name to the [`ToolKind`] icon hint Zed uses. +fn tool_kind(name: &str) -> ToolKind { + match name { + READ_FILE | LIST_DIR => ToolKind::Read, + WRITE_FILE | EDIT_FILE => ToolKind::Edit, + BASH => ToolKind::Execute, + _ => ToolKind::Other, + } +} + +/// Human-readable one-line title shown next to the tool-call card. +fn tool_title(name: &str, args_value: &serde_json::Value) -> String { + fn path(args: &serde_json::Value) -> &str { + args.get("path").and_then(|v| v.as_str()).unwrap_or("?") + } + match name { + READ_FILE => format!("Read {}", path(args_value)), + WRITE_FILE => format!("Write {}", path(args_value)), + EDIT_FILE => format!("Edit {}", path(args_value)), + LIST_DIR => format!("List {}", path(args_value)), + BASH => { + let cmd = args_value + .get("command") + .and_then(|v| v.as_str()) + .unwrap_or("?"); + let snippet = if cmd.len() > 60 { + format!("{}…", &cmd[..60]) + } else { + cmd.to_string() + }; + format!("Run: {snippet}") + } + other => format!("Tool: {other}"), + } +} + +/// Run a single tool call. Always returns a [`ToolResult`] — failures +/// are reported as `is_error = true` strings, not Err. +pub async fn dispatch_tool_call( + ops: &dyn ClientOps, + session_id: &SessionId, + mode: &SessionModeId, + session_cwd: &Path, + call: ToolCallEvent, + cancel: &CancellationToken, +) -> ToolResult { + let tool_call_id = ToolCallId::new(call.id.clone()); + + // Parse args once, up front. If the model produced invalid JSON + // we surface that to it so it can retry rather than to the user. + let args_value: serde_json::Value = match serde_json::from_str(&call.arguments) { + Ok(v) => v, + Err(e) => { + let msg = format!("tool '{}' had invalid JSON arguments: {e}", call.name); + let init = ToolCall::new(tool_call_id.clone(), tool_title(&call.name, &json!({}))) + .kind(tool_kind(&call.name)) + .status(ToolCallStatus::Failed) + .content(vec![ToolCallContent::Content( + agent_client_protocol::schema::Content::new(ContentBlock::Text( + TextContent::new(msg.clone()), + )), + )]) + .raw_input(serde_json::Value::String(call.arguments.clone())); + ops.send_session_update(session_id, SessionUpdate::ToolCall(init)); + return ToolResult { + tool_call_id: call.id, + content: msg, + is_error: true, + }; + } + }; + + let title = tool_title(&call.name, &args_value); + let kind = tool_kind(&call.name); + let initial = ToolCall::new(tool_call_id.clone(), title) + .kind(kind) + .status(ToolCallStatus::Pending) + .raw_input(args_value.clone()); + ops.send_session_update(session_id, SessionUpdate::ToolCall(initial)); + + if cancel.is_cancelled() { + return finish_failed( + ops, + session_id, + &tool_call_id, + &call.id, + "cancelled before tool ran", + ); + } + + // ── Permission gate ────────────────────────────────────────────── + if is_gated(&call.name) && mode.0.as_ref() != MODE_BYPASS { + // Default mode (or any non-bypass id): always ask. The user's + // "Allow" decision is per-call here; we don't carry over an + // "Allow always" across calls — that's a Stage 7 polish item + // (persisted permission grants). + let _ = mode.0.as_ref() == MODE_DEFAULT; // explicit acknowledgement that's our intent + let options = vec![ + PermissionOption::new( + PermissionOptionId::new("allow_once"), + "Allow", + PermissionOptionKind::AllowOnce, + ), + PermissionOption::new( + PermissionOptionId::new("reject_once"), + "Reject", + PermissionOptionKind::RejectOnce, + ), + ]; + let permission_call = + ToolCallUpdate::new(tool_call_id.clone(), ToolCallUpdateFields::new()); + match ops + .request_permission(session_id, permission_call, options) + .await + { + Ok(RequestPermissionOutcome::Selected(sel)) + if sel.option_id.0.as_ref().starts_with("allow") => {} + Ok(RequestPermissionOutcome::Selected(_)) => { + return finish_failed( + ops, + session_id, + &tool_call_id, + &call.id, + "user rejected the action", + ); + } + Ok(RequestPermissionOutcome::Cancelled) => { + return finish_failed( + ops, + session_id, + &tool_call_id, + &call.id, + "permission request cancelled", + ); + } + // `RequestPermissionOutcome` is `#[non_exhaustive]`. If a + // future protocol version adds a new variant, treat it + // conservatively as "did not explicitly allow" rather + // than letting the call through. + Ok(_) => { + return finish_failed( + ops, + session_id, + &tool_call_id, + &call.id, + "unknown permission outcome", + ); + } + Err(e) => { + return finish_failed( + ops, + session_id, + &tool_call_id, + &call.id, + &format!("permission request failed: {e:#}"), + ); + } + } + } + + // ── In-progress update ─────────────────────────────────────────── + ops.send_session_update( + session_id, + SessionUpdate::ToolCallUpdate(ToolCallUpdate::new( + tool_call_id.clone(), + ToolCallUpdateFields::new().status(ToolCallStatus::InProgress), + )), + ); + + // ── Execute ────────────────────────────────────────────────────── + let outcome: Result<(String, Vec), String> = match call.name.as_str() { + READ_FILE => exec_read_file(ops, session_id, &args_value).await, + WRITE_FILE => exec_write_file(ops, session_id, &args_value).await, + EDIT_FILE => exec_edit_file(ops, session_id, &args_value).await, + LIST_DIR => exec_list_dir(&args_value), + BASH => exec_bash(ops, session_id, session_cwd, &args_value, cancel).await, + other => Err(format!("unknown tool '{other}'")), + }; + + match outcome { + Ok((result_text, content)) => { + ops.send_session_update( + session_id, + SessionUpdate::ToolCallUpdate(ToolCallUpdate::new( + tool_call_id.clone(), + ToolCallUpdateFields::new() + .status(ToolCallStatus::Completed) + .content(content), + )), + ); + ToolResult { + tool_call_id: call.id, + content: result_text, + is_error: false, + } + } + Err(msg) => finish_failed(ops, session_id, &tool_call_id, &call.id, &msg), + } +} + +fn finish_failed( + ops: &dyn ClientOps, + session_id: &SessionId, + tool_call_id: &ToolCallId, + raw_id: &str, + message: &str, +) -> ToolResult { + ops.send_session_update( + session_id, + SessionUpdate::ToolCallUpdate(ToolCallUpdate::new( + tool_call_id.clone(), + ToolCallUpdateFields::new() + .status(ToolCallStatus::Failed) + .content(vec![ToolCallContent::Content( + agent_client_protocol::schema::Content::new(ContentBlock::Text( + TextContent::new(message.to_string()), + )), + )]), + )), + ); + ToolResult { + tool_call_id: raw_id.to_string(), + content: format!("ERROR: {message}"), + is_error: true, + } +} + +// ── Per-tool executors ────────────────────────────────────────────── + +async fn exec_read_file( + ops: &dyn ClientOps, + session_id: &SessionId, + args_value: &serde_json::Value, +) -> Result<(String, Vec), String> { + let args: ReadFileArgs = + serde_json::from_value(args_value.clone()).map_err(|e| format!("read_file: {e}"))?; + let content = ops + .read_text_file(session_id, args.path, args.line, args.limit) + .await + .map_err(|e| format!("read_file: {e:#}"))?; + let blocks = vec![ToolCallContent::Content( + agent_client_protocol::schema::Content::new(ContentBlock::Text(TextContent::new( + content.clone(), + ))), + )]; + Ok((content, blocks)) +} + +async fn exec_write_file( + ops: &dyn ClientOps, + session_id: &SessionId, + args_value: &serde_json::Value, +) -> Result<(String, Vec), String> { + let args: WriteFileArgs = + serde_json::from_value(args_value.clone()).map_err(|e| format!("write_file: {e}"))?; + // Best-effort read of the existing file so Zed can render a diff. + // Failure here just means we render the write as an additive diff + // — not a fatal error, the actual write below still runs. + let old_text = ops + .read_text_file(session_id, args.path.clone(), None, None) + .await + .ok(); + ops.write_text_file(session_id, args.path.clone(), args.content.clone()) + .await + .map_err(|e| format!("write_file: {e:#}"))?; + let mut diff = Diff::new(args.path.clone(), args.content.clone()); + if let Some(old) = old_text { + diff = diff.old_text(old); + } + let summary = format!( + "wrote {} ({} bytes)", + args.path.display(), + args.content.len() + ); + Ok((summary, vec![ToolCallContent::Diff(diff)])) +} + +async fn exec_edit_file( + ops: &dyn ClientOps, + session_id: &SessionId, + args_value: &serde_json::Value, +) -> Result<(String, Vec), String> { + let args: EditFileArgs = + serde_json::from_value(args_value.clone()).map_err(|e| format!("edit_file: {e}"))?; + let original = ops + .read_text_file(session_id, args.path.clone(), None, None) + .await + .map_err(|e| format!("edit_file: read {}: {e:#}", args.path.display()))?; + let occurrences = original.matches(args.old_text.as_str()).count(); + if occurrences == 0 { + return Err(format!( + "edit_file: old_text not found in {}", + args.path.display() + )); + } + if occurrences > 1 { + return Err(format!( + "edit_file: old_text appears {occurrences} times in {} — make it unique", + args.path.display() + )); + } + let new_content = original.replacen(args.old_text.as_str(), args.new_text.as_str(), 1); + ops.write_text_file(session_id, args.path.clone(), new_content.clone()) + .await + .map_err(|e| format!("edit_file: write {}: {e:#}", args.path.display()))?; + let diff = Diff::new(args.path.clone(), new_content.clone()).old_text(original); + let summary = format!( + "edited {} ({} bytes)", + args.path.display(), + new_content.len() + ); + Ok((summary, vec![ToolCallContent::Diff(diff)])) +} + +fn exec_list_dir(args_value: &serde_json::Value) -> Result<(String, Vec), String> { + let args: ListDirArgs = + serde_json::from_value(args_value.clone()).map_err(|e| format!("list_dir: {e}"))?; + let entries = std::fs::read_dir(&args.path) + .map_err(|e| format!("list_dir: read {}: {e}", args.path.display()))?; + let mut lines: Vec = Vec::new(); + for entry in entries.flatten() { + let name = entry.file_name().to_string_lossy().into_owned(); + let kind = match entry.file_type() { + Ok(t) if t.is_dir() => 'd', + Ok(t) if t.is_symlink() => 'l', + Ok(_) => 'f', + Err(_) => '?', + }; + lines.push(format!("{kind} {name}")); + } + lines.sort(); + let body = lines.join("\n"); + let blocks = vec![ToolCallContent::Content( + agent_client_protocol::schema::Content::new(ContentBlock::Text(TextContent::new( + body.clone(), + ))), + )]; + Ok((body, blocks)) +} + +async fn exec_bash( + ops: &dyn ClientOps, + session_id: &SessionId, + session_cwd: &Path, + args_value: &serde_json::Value, + cancel: &CancellationToken, +) -> Result<(String, Vec), String> { + let args: BashArgs = + serde_json::from_value(args_value.clone()).map_err(|e| format!("bash: {e}"))?; + let cwd = args.cwd.unwrap_or_else(|| session_cwd.to_path_buf()); + + let terminal = ops + .create_terminal( + session_id, + "sh".to_string(), + vec!["-c".to_string(), args.command.clone()], + Some(cwd), + ) + .await + .map_err(|e| format!("bash: terminal/create: {e:#}"))?; + + // Wait for completion. If cancelled, ask the client to kill the + // process. We still try to release the terminal afterwards. + let exit = tokio::select! { + biased; + _ = cancel.cancelled() => { + let _ = ops.kill_terminal(session_id, &terminal).await; + let _ = ops.release_terminal(session_id, &terminal).await; + return Err("bash: cancelled".to_string()); + } + res = ops.wait_for_terminal_exit(session_id, &terminal) => { + res.map_err(|e| format!("bash: terminal/wait_for_exit: {e:#}"))? + } + }; + + let output_resp = ops + .terminal_output(session_id, &terminal) + .await + .map_err(|e| format!("bash: terminal/output: {e:#}"))?; + let _ = ops.release_terminal(session_id, &terminal).await; + + let summary = render_bash_result(&exit, &output_resp); + let blocks = vec![ToolCallContent::Content( + agent_client_protocol::schema::Content::new(ContentBlock::Text(TextContent::new( + summary.clone(), + ))), + )]; + Ok((summary, blocks)) +} + +fn render_bash_result(exit: &TerminalExitStatus, output: &TerminalOutputResponse) -> String { + let mut out = String::new(); + match (exit.exit_code, exit.signal.as_deref()) { + (Some(0), _) => out.push_str("exit 0\n"), + (Some(code), _) => out.push_str(&format!("exit {code}\n")), + (None, Some(sig)) => out.push_str(&format!("terminated by signal {sig}\n")), + (None, None) => out.push_str("exit ?\n"), + } + if output.truncated { + out.push_str("(output truncated)\n"); + } + out.push_str(&output.output); + out +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + + /// Recording fake. Captures every outbound op so tests can + /// assert on what the runner did. + #[derive(Default)] + struct FakeClient { + events: Mutex>, + /// Canned response for read_text_file. + read_responses: Mutex>>, + /// Canned response for request_permission. + permission: Mutex>, + } + + // Fields are read only through the `{:?}` formatter in + // `events()`; clippy's dead-code pass doesn't notice that, so + // we suppress the warning at the enum level. The payloads stay + // typed (vs. `String`-everything) so a test failure surfaces + // useful detail in the Debug output. + #[allow(dead_code)] + #[derive(Debug)] + enum FakeEvent { + Read(PathBuf), + Write(PathBuf, String), + RequestPermission, + CreateTerminal(String, Vec), + WaitForExit, + TerminalOutput, + KillTerminal, + ReleaseTerminal, + Update(String), + } + + impl FakeClient { + fn set_read(&self, path: PathBuf, body: anyhow::Result) { + self.read_responses.lock().unwrap().insert(path, body); + } + fn set_permission(&self, outcome: RequestPermissionOutcome) { + *self.permission.lock().unwrap() = Some(outcome); + } + fn events(&self) -> Vec { + self.events + .lock() + .unwrap() + .iter() + .map(|e| format!("{e:?}")) + .collect() + } + } + + #[async_trait] + impl ClientOps for FakeClient { + async fn read_text_file( + &self, + _session: &SessionId, + path: PathBuf, + _line: Option, + _limit: Option, + ) -> anyhow::Result { + self.events + .lock() + .unwrap() + .push(FakeEvent::Read(path.clone())); + self.read_responses + .lock() + .unwrap() + .remove(&path) + .unwrap_or_else(|| Err(anyhow::anyhow!("no canned read for {}", path.display()))) + } + async fn write_text_file( + &self, + _session: &SessionId, + path: PathBuf, + content: String, + ) -> anyhow::Result<()> { + self.events + .lock() + .unwrap() + .push(FakeEvent::Write(path, content)); + Ok(()) + } + async fn request_permission( + &self, + _session: &SessionId, + _tc: ToolCallUpdate, + _options: Vec, + ) -> anyhow::Result { + self.events + .lock() + .unwrap() + .push(FakeEvent::RequestPermission); + self.permission + .lock() + .unwrap() + .clone() + .ok_or_else(|| anyhow::anyhow!("no canned permission outcome")) + } + async fn create_terminal( + &self, + _session: &SessionId, + command: String, + args: Vec, + _cwd: Option, + ) -> anyhow::Result { + self.events + .lock() + .unwrap() + .push(FakeEvent::CreateTerminal(command, args)); + Ok(TerminalId::new("t1")) + } + async fn wait_for_terminal_exit( + &self, + _session: &SessionId, + _terminal: &TerminalId, + ) -> anyhow::Result { + self.events.lock().unwrap().push(FakeEvent::WaitForExit); + Ok(TerminalExitStatus::new().exit_code(0u32)) + } + async fn terminal_output( + &self, + _session: &SessionId, + _terminal: &TerminalId, + ) -> anyhow::Result { + self.events.lock().unwrap().push(FakeEvent::TerminalOutput); + Ok(TerminalOutputResponse::new("ok\n", false)) + } + async fn kill_terminal( + &self, + _session: &SessionId, + _terminal: &TerminalId, + ) -> anyhow::Result<()> { + self.events.lock().unwrap().push(FakeEvent::KillTerminal); + Ok(()) + } + async fn release_terminal( + &self, + _session: &SessionId, + _terminal: &TerminalId, + ) -> anyhow::Result<()> { + self.events.lock().unwrap().push(FakeEvent::ReleaseTerminal); + Ok(()) + } + fn send_session_update(&self, _session: &SessionId, update: SessionUpdate) { + let tag = match update { + SessionUpdate::ToolCall(_) => "tool_call".to_string(), + SessionUpdate::ToolCallUpdate(u) => format!( + "tool_call_update:{:?}", + u.fields.status.unwrap_or(ToolCallStatus::Pending) + ), + _ => "other".to_string(), + }; + self.events.lock().unwrap().push(FakeEvent::Update(tag)); + } + } + + fn sid() -> SessionId { + SessionId::new("s1") + } + fn mode_default() -> SessionModeId { + SessionModeId::new(MODE_DEFAULT) + } + fn mode_bypass() -> SessionModeId { + SessionModeId::new(MODE_BYPASS) + } + + fn make_call(name: &str, args: serde_json::Value) -> ToolCallEvent { + ToolCallEvent { + id: "call_1".to_string(), + name: name.to_string(), + arguments: args.to_string(), + } + } + + #[tokio::test] + async fn read_file_is_not_gated_in_default_mode() { + let fake = FakeClient::default(); + fake.set_read(PathBuf::from("/tmp/x"), Ok("hello".to_string())); + let res = dispatch_tool_call( + &fake, + &sid(), + &mode_default(), + Path::new("/tmp"), + make_call(READ_FILE, json!({"path": "/tmp/x"})), + &CancellationToken::new(), + ) + .await; + assert!(!res.is_error, "result: {}", res.content); + assert_eq!(res.content, "hello"); + let events = fake.events(); + // Pending ToolCall → Read → InProgress update → Completed update + assert!(!events.iter().any(|e| e == "RequestPermission")); + assert!(events.iter().any(|e| e.starts_with("Read"))); + } + + #[tokio::test] + async fn write_file_gated_in_default_mode_and_asks_permission() { + let fake = FakeClient::default(); + fake.set_permission(RequestPermissionOutcome::Selected( + agent_client_protocol::schema::SelectedPermissionOutcome::new("allow_once"), + )); + // The pre-write read fails; we tolerate that. + let res = dispatch_tool_call( + &fake, + &sid(), + &mode_default(), + Path::new("/tmp"), + make_call(WRITE_FILE, json!({"path": "/tmp/y", "content": "hi"})), + &CancellationToken::new(), + ) + .await; + assert!(!res.is_error, "result: {}", res.content); + let events = fake.events(); + assert!(events.iter().any(|e| e == "RequestPermission")); + assert!(events.iter().any(|e| e.starts_with("Write"))); + } + + #[tokio::test] + async fn bypass_mode_skips_permission_prompt() { + let fake = FakeClient::default(); + let res = dispatch_tool_call( + &fake, + &sid(), + &mode_bypass(), + Path::new("/tmp"), + make_call(WRITE_FILE, json!({"path": "/tmp/y", "content": "hi"})), + &CancellationToken::new(), + ) + .await; + assert!(!res.is_error, "result: {}", res.content); + let events = fake.events(); + assert!( + !events.iter().any(|e| e == "RequestPermission"), + "bypass mode must not prompt: {events:?}" + ); + assert!(events.iter().any(|e| e.starts_with("Write"))); + } + + #[tokio::test] + async fn rejected_permission_returns_error() { + let fake = FakeClient::default(); + fake.set_permission(RequestPermissionOutcome::Selected( + agent_client_protocol::schema::SelectedPermissionOutcome::new("reject_once"), + )); + let res = dispatch_tool_call( + &fake, + &sid(), + &mode_default(), + Path::new("/tmp"), + make_call(BASH, json!({"command": "rm -rf /"})), + &CancellationToken::new(), + ) + .await; + assert!(res.is_error, "expected error: {}", res.content); + assert!(res.content.contains("reject")); + } + + #[tokio::test] + async fn bash_runs_through_terminal_lifecycle() { + let fake = FakeClient::default(); + let res = dispatch_tool_call( + &fake, + &sid(), + &mode_bypass(), + Path::new("/tmp"), + make_call(BASH, json!({"command": "echo ok"})), + &CancellationToken::new(), + ) + .await; + assert!(!res.is_error, "result: {}", res.content); + assert!(res.content.contains("exit 0")); + assert!(res.content.contains("ok")); + let events = fake.events(); + let sequence: Vec<&str> = events.iter().map(|s| s.as_str()).collect(); + // create → wait_for_exit → output → release + let create = sequence + .iter() + .position(|e| e.starts_with("CreateTerminal")) + .expect("CreateTerminal event"); + let wait = sequence + .iter() + .position(|e| e == &"WaitForExit") + .expect("WaitForExit event"); + let out = sequence + .iter() + .position(|e| e == &"TerminalOutput") + .expect("TerminalOutput event"); + let release = sequence + .iter() + .position(|e| e == &"ReleaseTerminal") + .expect("ReleaseTerminal event"); + assert!(create < wait && wait < out && out < release); + } + + #[tokio::test] + async fn edit_file_rejects_ambiguous_match() { + let fake = FakeClient::default(); + fake.set_read(PathBuf::from("/tmp/dup"), Ok("foo bar foo".to_string())); + let res = dispatch_tool_call( + &fake, + &sid(), + &mode_bypass(), + Path::new("/tmp"), + make_call( + EDIT_FILE, + json!({"path": "/tmp/dup", "old_text": "foo", "new_text": "baz"}), + ), + &CancellationToken::new(), + ) + .await; + assert!(res.is_error, "expected error, got {}", res.content); + assert!(res.content.contains("2 times") || res.content.contains("appears")); + } + + #[test] + fn gated_set_matches_spec() { + assert!(!is_gated(READ_FILE)); + assert!(!is_gated(LIST_DIR)); + assert!(is_gated(WRITE_FILE)); + assert!(is_gated(EDIT_FILE)); + assert!(is_gated(BASH)); + } +} diff --git a/crates/helexa-acp/src/tools.rs b/crates/helexa-acp/src/tools.rs new file mode 100644 index 0000000..eb93768 --- /dev/null +++ b/crates/helexa-acp/src/tools.rs @@ -0,0 +1,179 @@ +//! Tool schemas sent to the upstream model on every completion. +//! +//! These are the OpenAI-function-style declarations the LLM sees in +//! `CompletionRequest.tools`; the runtime dispatch happens in +//! [`crate::tool_runner`]. Keeping declarations and execution in +//! separate modules makes it easy to add a tool without touching the +//! runner, and vice versa. +//! +//! Stage 3 ships five: filesystem read / write / edit, directory +//! listing, and `bash`. Image generation, web fetch, MCP-derived +//! tools, etc. are out of scope here. + +use serde_json::json; + +use crate::provider::ToolSpec; + +pub const READ_FILE: &str = "read_file"; +pub const WRITE_FILE: &str = "write_file"; +pub const EDIT_FILE: &str = "edit_file"; +pub const LIST_DIR: &str = "list_dir"; +pub const BASH: &str = "bash"; + +/// Build the static tool list passed to the model on every prompt. +/// Cheap — the JSON Schema fragments are constructed each call but +/// the bodies are small constants. If this ever shows up in a +/// profile we can `OnceLock` the Vec. +pub fn all_tools() -> Vec { + vec![ + ToolSpec { + name: READ_FILE.to_string(), + description: "Read the contents of a text file. Returns the file's text.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Absolute path to the file." + }, + "line": { + "type": "integer", + "description": "Optional 1-based line number to start reading from.", + "minimum": 1 + }, + "limit": { + "type": "integer", + "description": "Optional maximum number of lines to read.", + "minimum": 1 + } + }, + "required": ["path"], + "additionalProperties": false + }), + }, + ToolSpec { + name: WRITE_FILE.to_string(), + description: "Write text content to a file, replacing any existing contents. \ + Creates the file (and parent directories) if needed." + .to_string(), + parameters: json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Absolute path to the file." + }, + "content": { + "type": "string", + "description": "Full new contents of the file." + } + }, + "required": ["path", "content"], + "additionalProperties": false + }), + }, + ToolSpec { + name: EDIT_FILE.to_string(), + description: "Replace one exact substring in a file with another. \ + Fails if `old_text` does not appear in the file, or appears more than once. \ + Use multiple edit_file calls for multiple edits." + .to_string(), + parameters: json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Absolute path to the file." + }, + "old_text": { + "type": "string", + "description": "Exact text fragment to replace. Must be unique within the file." + }, + "new_text": { + "type": "string", + "description": "Replacement text." + } + }, + "required": ["path", "old_text", "new_text"], + "additionalProperties": false + }), + }, + ToolSpec { + name: LIST_DIR.to_string(), + description: + "List the entries of a directory. Returns names and a (f|d|l) kind per entry." + .to_string(), + parameters: json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Absolute path to the directory." + } + }, + "required": ["path"], + "additionalProperties": false + }), + }, + ToolSpec { + name: BASH.to_string(), + description: "Run a shell command via `sh -c`. \ + Returns combined stdout+stderr and the exit status. \ + The command runs in the session's working directory unless `cwd` is given." + .to_string(), + parameters: json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Shell command line, evaluated by `sh -c`." + }, + "cwd": { + "type": "string", + "description": "Optional absolute path to run the command from." + } + }, + "required": ["command"], + "additionalProperties": false + }), + }, + ] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn all_tools_has_five_named_entries() { + let tools = all_tools(); + let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect(); + assert_eq!( + names, + vec![READ_FILE, WRITE_FILE, EDIT_FILE, LIST_DIR, BASH] + ); + } + + #[test] + fn every_tool_has_an_object_parameter_schema() { + for tool in all_tools() { + let ty = tool.parameters.get("type").and_then(|v| v.as_str()); + assert_eq!( + ty, + Some("object"), + "tool {} parameters.type must be \"object\"", + tool.name + ); + assert!( + tool.parameters.get("properties").is_some(), + "tool {} missing properties", + tool.name + ); + assert!( + tool.parameters.get("required").is_some(), + "tool {} missing required list", + tool.name + ); + } + } +}