From adbc52bfcd582cbfc9282c55c61870b74b0b3a9c Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Fri, 29 May 2026 09:10:16 +0300 Subject: [PATCH] feat(helexa-acp): model picker + session/set_model handler Stage 4. Zed's model dropdown now lists every model from every configured endpoint, and switching it routes the next prompt to a new endpoint+model. - Enable `unstable_session_model` on the agent-client-protocol dep so SessionModelState / SetSessionModelRequest / ModelInfo are available. - Agent::new becomes async and calls Provider::list_models on every provider at startup; per-endpoint failures warn-and-skip instead of aborting the agent. - With a single endpoint configured, model ids appear bare; with multiple endpoints every id carries the `endpoint:` prefix so the picker is unambiguous and parse_model_selector routes correctly. - NewSessionResponse and LoadSessionResponse attach SessionModelState with the session's current model id + the aggregated catalogue. - session/set_model: validates the requested model id against resolve_provider, mutates session.model_id, and persists so the on-disk transcript reflects the new model. Three new aggregate_models tests cover the prefixing rule (bare vs multi-endpoint) and warn-and-skip on a failing endpoint. Co-Authored-By: Claude Opus 4.7 --- crates/helexa-acp/Cargo.toml | 7 +- crates/helexa-acp/src/agent.rs | 255 +++++++++++++++++++++++++++++++-- crates/helexa-acp/src/main.rs | 1 + 3 files changed, 253 insertions(+), 10 deletions(-) diff --git a/crates/helexa-acp/Cargo.toml b/crates/helexa-acp/Cargo.toml index 25e41f7..f48f021 100644 --- a/crates/helexa-acp/Cargo.toml +++ b/crates/helexa-acp/Cargo.toml @@ -16,7 +16,12 @@ to cortex (helexa's reverse-proxy / fleet gateway). # a painless migration to a dedicated GitHub repo in the future if the # project grows beyond helexa's needs. All deps are crates.io. [dependencies] -agent-client-protocol = "0.12" +# `unstable_session_model` flips on the SessionModelState type and the +# session/set_model RPC the model-picker dropdown in Zed needs. The +# feature is upstream-marked unstable; we accept that risk because the +# model picker is core UX and the alternative (rolling our own +# extension method) drifts further from spec each time it moves. +agent-client-protocol = { version = "0.12", features = ["unstable_session_model"] } tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "io-util", "process", "signal"] } reqwest = { version = "0.12", features = ["json", "stream", "rustls-tls"], default-features = false } serde = { version = "1", features = ["derive"] } diff --git a/crates/helexa-acp/src/agent.rs b/crates/helexa-acp/src/agent.rs index 9d242cb..63d3546 100644 --- a/crates/helexa-acp/src/agent.rs +++ b/crates/helexa-acp/src/agent.rs @@ -9,9 +9,10 @@ //! | `session/prompt` | tool-call loop: stream → dispatch tools → re-enter, repeat | //! | `session/cancel` | fire the session's cancellation token | //! | `session/set_mode` | mutate the session's mode (gated vs. bypass-permissions) | +//! | `session/set_model` | switch the session's active model (endpoint:model selector) | //! | (anything else) | "not implemented yet" error | //! -//! Stage 4 wires `session/set_model`; Stage 5 flips on image content. +//! Stage 5 flips on image content. use std::path::PathBuf; use std::sync::Arc; @@ -19,11 +20,12 @@ use std::sync::atomic::{AtomicU64, Ordering}; use agent_client_protocol::schema::{ AgentCapabilities, CancelNotification, ContentBlock, InitializeRequest, InitializeResponse, - ListSessionsRequest, ListSessionsResponse, LoadSessionRequest, LoadSessionResponse, - NewSessionRequest, NewSessionResponse, PromptCapabilities, PromptRequest, PromptResponse, - SessionCapabilities, SessionId, SessionInfo, SessionListCapabilities, SessionMode, - SessionModeId, SessionModeState, SessionNotification, SessionUpdate, SetSessionModeRequest, - SetSessionModeResponse, StopReason, TextContent, + ListSessionsRequest, ListSessionsResponse, LoadSessionRequest, LoadSessionResponse, ModelId, + ModelInfo as AcpModelInfo, NewSessionRequest, NewSessionResponse, PromptCapabilities, + PromptRequest, PromptResponse, SessionCapabilities, SessionId, SessionInfo, + SessionListCapabilities, SessionMode, SessionModeId, SessionModeState, SessionModelState, + SessionNotification, SessionUpdate, SetSessionModeRequest, SetSessionModeResponse, + SetSessionModelRequest, SetSessionModelResponse, StopReason, TextContent, }; use agent_client_protocol::{Agent as AgentRole, Client, ConnectionTo, Dispatch, Stdio}; use futures::StreamExt; @@ -73,6 +75,14 @@ struct AgentInner { /// fits inside `context_window - max_tokens - safety` tokens. /// Absent entry → no compaction (legacy behaviour). context_window: std::collections::HashMap, + /// Aggregated list of selectable models across every configured + /// endpoint, computed once at startup. With a single endpoint + /// the model ids appear bare; with multiple endpoints every id + /// carries the `endpoint:` prefix so the picker is unambiguous. + /// Empty when every provider's `list_models` failed at startup — + /// the dropdown then shows nothing and the session keeps using + /// the configured `default_model`. + available_models: Vec, sessions: SessionStore, system_prompt_path: Option, /// Monotonic counter for minting session ids. The wire format is @@ -85,7 +95,13 @@ struct AgentInner { impl Agent { /// Construct an agent from a validated [`Config`] and the providers /// that were successfully built for each endpoint. - pub fn new(cfg: &Config, providers: Vec>) -> anyhow::Result { + /// + /// `async` because we call `Provider::list_models` on every + /// provider up-front so the model-picker dropdown is populated + /// from the very first `session/new`. Per-endpoint failure + /// warns and skips rather than aborting startup — a single + /// unreachable endpoint shouldn't take down the agent. + pub async fn new(cfg: &Config, providers: Vec>) -> anyhow::Result { if providers.is_empty() { anyhow::bail!("no usable providers"); } @@ -110,6 +126,12 @@ impl Agent { .iter() .filter_map(|ep| ep.context_window.map(|w| (ep.name.clone(), w))) .collect(); + let available_models = aggregate_models(&providers).await; + tracing::info!( + models = available_models.len(), + endpoints = providers.len(), + "model catalogue assembled" + ); Ok(Self { inner: Arc::new(AgentInner { providers, @@ -117,6 +139,7 @@ impl Agent { default_model: default.default_model.clone(), max_tokens, context_window, + available_models, sessions: session::new_store(), system_prompt_path: cfg.system_prompt_path.clone(), next_session_id: AtomicU64::new(1), @@ -207,6 +230,18 @@ impl Agent { }, agent_client_protocol::on_receive_request!(), ) + .on_receive_request( + { + let inner = inner.clone(); + async move |req: SetSessionModelRequest, responder, _cx| { + match handle_set_session_model(&inner, req).await { + Ok(()) => responder.respond(SetSessionModelResponse::new()), + Err(e) => responder.respond_with_internal_error(format!("{e:#}")), + } + } + }, + agent_client_protocol::on_receive_request!(), + ) .on_receive_notification( { let inner = inner.clone(); @@ -300,7 +335,12 @@ async fn handle_new_session( cwd = %cwd_display, "session created" ); - Ok(NewSessionResponse::new(session_id).modes(default_mode_state())) + let resp = NewSessionResponse::new(session_id).modes(default_mode_state()); + let resp = match session_model_state(inner, &log_model) { + Some(models) => resp.models(Some(models)), + None => resp, + }; + Ok(resp) } /// Rehydrate a session from disk. @@ -353,7 +393,12 @@ async fn handle_load_session( SessionModeId::new(mode_id), default_mode_state().available_modes, ); - Ok((LoadSessionResponse::new().modes(modes), history_for_replay)) + let resp = LoadSessionResponse::new().modes(modes); + let resp = match session_model_state(inner, &model_id) { + Some(models) => resp.models(Some(models)), + None => resp, + }; + Ok((resp, history_for_replay)) } /// Re-emit a session's persisted history as `session/update` @@ -559,6 +604,108 @@ fn derive_session_title(history: &[Message]) -> Option { .filter(|s| !s.is_empty()) } +/// Build the model catalogue advertised in `NewSessionResponse.models` +/// (and the resume equivalent). Walks every provider, calls +/// `list_models`, and prefixes ids with `endpoint:` when the user has +/// more than one endpoint configured. A failing endpoint logs and +/// contributes nothing — losing one endpoint must not blank the whole +/// dropdown. +async fn aggregate_models(providers: &[Arc]) -> Vec { + let multi_endpoint = providers.len() > 1; + let mut out: Vec = Vec::new(); + for provider in providers { + let endpoint = provider.name().to_string(); + match provider.list_models().await { + Ok(models) => { + tracing::info!( + endpoint = %endpoint, + count = models.len(), + "fetched models from endpoint" + ); + for m in models { + let id = if multi_endpoint { + format!("{endpoint}:{}", m.id) + } else { + m.id.clone() + }; + let display = m.display_name.unwrap_or_else(|| m.id.clone()); + let info = AcpModelInfo::new(ModelId::new(id), display) + .description(Some(format!("endpoint: {endpoint}"))); + out.push(info); + } + } + Err(e) => { + tracing::warn!( + endpoint = %endpoint, + error = %format!("{e:#}"), + "list_models failed; this endpoint's models won't appear in the picker" + ); + } + } + } + out +} + +/// Build the `SessionModelState` that Zed renders as the +/// model-picker dropdown. The current model id is exactly what +/// the session is using right now (already in `endpoint:model` +/// form if it was set that way). Returns `None` when the +/// catalogue is empty — no point showing an empty dropdown. +fn session_model_state(inner: &AgentInner, current: &str) -> Option { + if inner.available_models.is_empty() { + return None; + } + Some(SessionModelState::new( + ModelId::new(current.to_string()), + inner.available_models.clone(), + )) +} + +async fn handle_set_session_model( + inner: &AgentInner, + req: SetSessionModelRequest, +) -> anyhow::Result<()> { + let Some(state) = session::get(&inner.sessions, &req.session_id).await else { + anyhow::bail!("unknown session id {}", req.session_id.0); + }; + let target = req.model_id.0.as_ref().to_string(); + // Validate the requested model id resolves to a configured + // provider. We don't require it to appear in `available_models` + // because the catalogue may be stale (endpoint added a model + // after startup) and rejecting unknown ids would be too rigid. + // Provider lookup is the actual source of truth. + let (_, _) = resolve_provider(&inner.providers, &inner.default_endpoint_name, &target) + .map_err(|e| anyhow::anyhow!("set_session_model: {e:#}"))?; + // Persist the new model id on the session under the mutex, + // then snapshot for disk persistence outside the lock. + let snapshot = { + let mut s = state.lock().await; + s.model_id = target.clone(); + PersistedSession { + session_id: req.session_id.0.as_ref().to_string(), + cwd: s.cwd.clone(), + model_id: s.model_id.clone(), + mode_id: s.mode_id.0.as_ref().to_string(), + history: s.history.clone(), + created_at: store::now_secs(), + updated_at: store::now_secs(), + } + }; + if let Err(e) = store::save(&snapshot) { + tracing::warn!( + session_id = %req.session_id.0, + error = %format!("{e:#}"), + "session persist after set_model failed; on-disk model id stays stale" + ); + } + tracing::info!( + session_id = %req.session_id.0, + model_id = %target, + "session model changed" + ); + Ok(()) +} + /// The three modes every Stage 3 session advertises: /// /// - **Default** — writes / bash prompt the user. @@ -1375,6 +1522,55 @@ mod tests { } } + /// Provider stub whose `list_models` returns canned results. + /// Used by the `aggregate_models` tests. + struct ModelProvider { + name: &'static str, + models: anyhow::Result>, + } + + impl ModelProvider { + fn ok(name: &'static str, ids: &[&str]) -> Arc { + let models = ids + .iter() + .map(|id| crate::provider::ModelInfo { + id: (*id).to_string(), + display_name: None, + }) + .collect(); + Arc::new(Self { + name, + models: Ok(models), + }) + } + fn err(name: &'static str, msg: &'static str) -> Arc { + Arc::new(Self { + name, + models: Err(anyhow::anyhow!(msg)), + }) + } + } + + #[async_trait] + impl Provider for ModelProvider { + fn name(&self) -> &str { + self.name + } + async fn list_models(&self) -> anyhow::Result> { + match &self.models { + Ok(v) => Ok(v.clone()), + Err(e) => Err(anyhow::anyhow!("{e:#}")), + } + } + async fn complete( + &self, + _request: CompletionRequest, + _cancel: CancellationToken, + ) -> anyhow::Result>> { + unimplemented!() + } + } + fn providers() -> Vec> { vec![ Arc::new(StubProvider("helexa")), @@ -1431,6 +1627,47 @@ mod tests { assert_eq!(prompt_budget(1_000, Some(8_192)), 0); } + // ── aggregate_models ──────────────────────────────────────────── + + #[tokio::test] + async fn aggregate_models_single_endpoint_has_bare_ids() { + let providers = vec![ModelProvider::ok( + "helexa", + &["helexa/large", "helexa/small"], + )]; + let models = aggregate_models(&providers).await; + let ids: Vec<&str> = models.iter().map(|m| m.model_id.0.as_ref()).collect(); + assert_eq!(ids, vec!["helexa/large", "helexa/small"]); + } + + #[tokio::test] + async fn aggregate_models_multi_endpoint_prefixes_every_id() { + let providers = vec![ + ModelProvider::ok("helexa", &["helexa/large"]), + ModelProvider::ok("openrouter", &["anthropic/claude-opus-4"]), + ]; + let models = aggregate_models(&providers).await; + let ids: Vec<&str> = models.iter().map(|m| m.model_id.0.as_ref()).collect(); + assert_eq!( + ids, + vec!["helexa:helexa/large", "openrouter:anthropic/claude-opus-4"] + ); + } + + #[tokio::test] + async fn aggregate_models_skips_failing_endpoint() { + let providers = vec![ + ModelProvider::err("flaky", "boom"), + ModelProvider::ok("openrouter", &["gpt-9"]), + ]; + let models = aggregate_models(&providers).await; + let ids: Vec<&str> = models.iter().map(|m| m.model_id.0.as_ref()).collect(); + // Multi-endpoint case → prefix survives even when one + // endpoint dropped out. flaky's models are absent, not + // null-filled. + assert_eq!(ids, vec!["openrouter:gpt-9"]); + } + #[test] fn maps_known_finish_reasons() { assert!(matches!( diff --git a/crates/helexa-acp/src/main.rs b/crates/helexa-acp/src/main.rs index 7e6def5..91ae108 100644 --- a/crates/helexa-acp/src/main.rs +++ b/crates/helexa-acp/src/main.rs @@ -142,6 +142,7 @@ async fn main() -> Result<()> { } let agent = Agent::new(&cfg, providers) + .await .map_err(|e| agent_client_protocol::util::internal_error(format!("agent: {e:#}")))?; agent.serve(Stdio::new()).await }