feat(helexa-acp): model picker + session/set_model handler
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 41s
CI / Clippy (push) Successful in 2m32s
build-prerelease / Build cortex binary (push) Successful in 4m45s
CI / Test (push) Successful in 5m52s
build-prerelease / Build neuron-blackwell (push) Successful in 5m59s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-ampere (push) Successful in 7m21s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ada (push) Successful in 4m54s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m58s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m48s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m3s

Stage 4. Zed's model dropdown now lists every model from every
configured endpoint, and switching it routes the next prompt to a
new endpoint+model.

- Enable `unstable_session_model` on the agent-client-protocol dep
  so SessionModelState / SetSessionModelRequest / ModelInfo are
  available.
- Agent::new becomes async and calls Provider::list_models on every
  provider at startup; per-endpoint failures warn-and-skip instead
  of aborting the agent.
- With a single endpoint configured, model ids appear bare; with
  multiple endpoints every id carries the `endpoint:` prefix so the
  picker is unambiguous and parse_model_selector routes correctly.
- NewSessionResponse and LoadSessionResponse attach SessionModelState
  with the session's current model id + the aggregated catalogue.
- session/set_model: validates the requested model id against
  resolve_provider, mutates session.model_id, and persists so the
  on-disk transcript reflects the new model.

Three new aggregate_models tests cover the prefixing rule (bare vs
multi-endpoint) and warn-and-skip on a failing endpoint.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-29 09:10:16 +03:00
parent 537a0fe7f2
commit adbc52bfcd
3 changed files with 253 additions and 10 deletions

View File

@@ -16,7 +16,12 @@ to cortex (helexa's reverse-proxy / fleet gateway).
# a painless migration to a dedicated GitHub repo in the future if the # a painless migration to a dedicated GitHub repo in the future if the
# project grows beyond helexa's needs. All deps are crates.io. # project grows beyond helexa's needs. All deps are crates.io.
[dependencies] [dependencies]
agent-client-protocol = "0.12" # `unstable_session_model` flips on the SessionModelState type and the
# session/set_model RPC the model-picker dropdown in Zed needs. The
# feature is upstream-marked unstable; we accept that risk because the
# model picker is core UX and the alternative (rolling our own
# extension method) drifts further from spec each time it moves.
agent-client-protocol = { version = "0.12", features = ["unstable_session_model"] }
tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "io-util", "process", "signal"] } tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "io-util", "process", "signal"] }
reqwest = { version = "0.12", features = ["json", "stream", "rustls-tls"], default-features = false } reqwest = { version = "0.12", features = ["json", "stream", "rustls-tls"], default-features = false }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }

View File

@@ -9,9 +9,10 @@
//! | `session/prompt` | tool-call loop: stream → dispatch tools → re-enter, repeat | //! | `session/prompt` | tool-call loop: stream → dispatch tools → re-enter, repeat |
//! | `session/cancel` | fire the session's cancellation token | //! | `session/cancel` | fire the session's cancellation token |
//! | `session/set_mode` | mutate the session's mode (gated vs. bypass-permissions) | //! | `session/set_mode` | mutate the session's mode (gated vs. bypass-permissions) |
//! | `session/set_model` | switch the session's active model (endpoint:model selector) |
//! | (anything else) | "not implemented yet" error | //! | (anything else) | "not implemented yet" error |
//! //!
//! Stage 4 wires `session/set_model`; Stage 5 flips on image content. //! Stage 5 flips on image content.
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
@@ -19,11 +20,12 @@ use std::sync::atomic::{AtomicU64, Ordering};
use agent_client_protocol::schema::{ use agent_client_protocol::schema::{
AgentCapabilities, CancelNotification, ContentBlock, InitializeRequest, InitializeResponse, AgentCapabilities, CancelNotification, ContentBlock, InitializeRequest, InitializeResponse,
ListSessionsRequest, ListSessionsResponse, LoadSessionRequest, LoadSessionResponse, ListSessionsRequest, ListSessionsResponse, LoadSessionRequest, LoadSessionResponse, ModelId,
NewSessionRequest, NewSessionResponse, PromptCapabilities, PromptRequest, PromptResponse, ModelInfo as AcpModelInfo, NewSessionRequest, NewSessionResponse, PromptCapabilities,
SessionCapabilities, SessionId, SessionInfo, SessionListCapabilities, SessionMode, PromptRequest, PromptResponse, SessionCapabilities, SessionId, SessionInfo,
SessionModeId, SessionModeState, SessionNotification, SessionUpdate, SetSessionModeRequest, SessionListCapabilities, SessionMode, SessionModeId, SessionModeState, SessionModelState,
SetSessionModeResponse, StopReason, TextContent, SessionNotification, SessionUpdate, SetSessionModeRequest, SetSessionModeResponse,
SetSessionModelRequest, SetSessionModelResponse, StopReason, TextContent,
}; };
use agent_client_protocol::{Agent as AgentRole, Client, ConnectionTo, Dispatch, Stdio}; use agent_client_protocol::{Agent as AgentRole, Client, ConnectionTo, Dispatch, Stdio};
use futures::StreamExt; use futures::StreamExt;
@@ -73,6 +75,14 @@ struct AgentInner {
/// fits inside `context_window - max_tokens - safety` tokens. /// fits inside `context_window - max_tokens - safety` tokens.
/// Absent entry → no compaction (legacy behaviour). /// Absent entry → no compaction (legacy behaviour).
context_window: std::collections::HashMap<String, usize>, context_window: std::collections::HashMap<String, usize>,
/// Aggregated list of selectable models across every configured
/// endpoint, computed once at startup. With a single endpoint
/// the model ids appear bare; with multiple endpoints every id
/// carries the `endpoint:` prefix so the picker is unambiguous.
/// Empty when every provider's `list_models` failed at startup —
/// the dropdown then shows nothing and the session keeps using
/// the configured `default_model`.
available_models: Vec<AcpModelInfo>,
sessions: SessionStore, sessions: SessionStore,
system_prompt_path: Option<PathBuf>, system_prompt_path: Option<PathBuf>,
/// Monotonic counter for minting session ids. The wire format is /// Monotonic counter for minting session ids. The wire format is
@@ -85,7 +95,13 @@ struct AgentInner {
impl Agent { impl Agent {
/// Construct an agent from a validated [`Config`] and the providers /// Construct an agent from a validated [`Config`] and the providers
/// that were successfully built for each endpoint. /// that were successfully built for each endpoint.
pub fn new(cfg: &Config, providers: Vec<Arc<dyn Provider>>) -> anyhow::Result<Self> { ///
/// `async` because we call `Provider::list_models` on every
/// provider up-front so the model-picker dropdown is populated
/// from the very first `session/new`. Per-endpoint failure
/// warns and skips rather than aborting startup — a single
/// unreachable endpoint shouldn't take down the agent.
pub async fn new(cfg: &Config, providers: Vec<Arc<dyn Provider>>) -> anyhow::Result<Self> {
if providers.is_empty() { if providers.is_empty() {
anyhow::bail!("no usable providers"); anyhow::bail!("no usable providers");
} }
@@ -110,6 +126,12 @@ impl Agent {
.iter() .iter()
.filter_map(|ep| ep.context_window.map(|w| (ep.name.clone(), w))) .filter_map(|ep| ep.context_window.map(|w| (ep.name.clone(), w)))
.collect(); .collect();
let available_models = aggregate_models(&providers).await;
tracing::info!(
models = available_models.len(),
endpoints = providers.len(),
"model catalogue assembled"
);
Ok(Self { Ok(Self {
inner: Arc::new(AgentInner { inner: Arc::new(AgentInner {
providers, providers,
@@ -117,6 +139,7 @@ impl Agent {
default_model: default.default_model.clone(), default_model: default.default_model.clone(),
max_tokens, max_tokens,
context_window, context_window,
available_models,
sessions: session::new_store(), sessions: session::new_store(),
system_prompt_path: cfg.system_prompt_path.clone(), system_prompt_path: cfg.system_prompt_path.clone(),
next_session_id: AtomicU64::new(1), next_session_id: AtomicU64::new(1),
@@ -207,6 +230,18 @@ impl Agent {
}, },
agent_client_protocol::on_receive_request!(), agent_client_protocol::on_receive_request!(),
) )
.on_receive_request(
{
let inner = inner.clone();
async move |req: SetSessionModelRequest, responder, _cx| {
match handle_set_session_model(&inner, req).await {
Ok(()) => responder.respond(SetSessionModelResponse::new()),
Err(e) => responder.respond_with_internal_error(format!("{e:#}")),
}
}
},
agent_client_protocol::on_receive_request!(),
)
.on_receive_notification( .on_receive_notification(
{ {
let inner = inner.clone(); let inner = inner.clone();
@@ -300,7 +335,12 @@ async fn handle_new_session(
cwd = %cwd_display, cwd = %cwd_display,
"session created" "session created"
); );
Ok(NewSessionResponse::new(session_id).modes(default_mode_state())) let resp = NewSessionResponse::new(session_id).modes(default_mode_state());
let resp = match session_model_state(inner, &log_model) {
Some(models) => resp.models(Some(models)),
None => resp,
};
Ok(resp)
} }
/// Rehydrate a session from disk. /// Rehydrate a session from disk.
@@ -353,7 +393,12 @@ async fn handle_load_session(
SessionModeId::new(mode_id), SessionModeId::new(mode_id),
default_mode_state().available_modes, default_mode_state().available_modes,
); );
Ok((LoadSessionResponse::new().modes(modes), history_for_replay)) let resp = LoadSessionResponse::new().modes(modes);
let resp = match session_model_state(inner, &model_id) {
Some(models) => resp.models(Some(models)),
None => resp,
};
Ok((resp, history_for_replay))
} }
/// Re-emit a session's persisted history as `session/update` /// Re-emit a session's persisted history as `session/update`
@@ -559,6 +604,108 @@ fn derive_session_title(history: &[Message]) -> Option<String> {
.filter(|s| !s.is_empty()) .filter(|s| !s.is_empty())
} }
/// Build the model catalogue advertised in `NewSessionResponse.models`
/// (and the resume equivalent). Walks every provider, calls
/// `list_models`, and prefixes ids with `endpoint:` when the user has
/// more than one endpoint configured. A failing endpoint logs and
/// contributes nothing — losing one endpoint must not blank the whole
/// dropdown.
async fn aggregate_models(providers: &[Arc<dyn Provider>]) -> Vec<AcpModelInfo> {
let multi_endpoint = providers.len() > 1;
let mut out: Vec<AcpModelInfo> = Vec::new();
for provider in providers {
let endpoint = provider.name().to_string();
match provider.list_models().await {
Ok(models) => {
tracing::info!(
endpoint = %endpoint,
count = models.len(),
"fetched models from endpoint"
);
for m in models {
let id = if multi_endpoint {
format!("{endpoint}:{}", m.id)
} else {
m.id.clone()
};
let display = m.display_name.unwrap_or_else(|| m.id.clone());
let info = AcpModelInfo::new(ModelId::new(id), display)
.description(Some(format!("endpoint: {endpoint}")));
out.push(info);
}
}
Err(e) => {
tracing::warn!(
endpoint = %endpoint,
error = %format!("{e:#}"),
"list_models failed; this endpoint's models won't appear in the picker"
);
}
}
}
out
}
/// Build the `SessionModelState` that Zed renders as the
/// model-picker dropdown. The current model id is exactly what
/// the session is using right now (already in `endpoint:model`
/// form if it was set that way). Returns `None` when the
/// catalogue is empty — no point showing an empty dropdown.
fn session_model_state(inner: &AgentInner, current: &str) -> Option<SessionModelState> {
if inner.available_models.is_empty() {
return None;
}
Some(SessionModelState::new(
ModelId::new(current.to_string()),
inner.available_models.clone(),
))
}
async fn handle_set_session_model(
inner: &AgentInner,
req: SetSessionModelRequest,
) -> anyhow::Result<()> {
let Some(state) = session::get(&inner.sessions, &req.session_id).await else {
anyhow::bail!("unknown session id {}", req.session_id.0);
};
let target = req.model_id.0.as_ref().to_string();
// Validate the requested model id resolves to a configured
// provider. We don't require it to appear in `available_models`
// because the catalogue may be stale (endpoint added a model
// after startup) and rejecting unknown ids would be too rigid.
// Provider lookup is the actual source of truth.
let (_, _) = resolve_provider(&inner.providers, &inner.default_endpoint_name, &target)
.map_err(|e| anyhow::anyhow!("set_session_model: {e:#}"))?;
// Persist the new model id on the session under the mutex,
// then snapshot for disk persistence outside the lock.
let snapshot = {
let mut s = state.lock().await;
s.model_id = target.clone();
PersistedSession {
session_id: req.session_id.0.as_ref().to_string(),
cwd: s.cwd.clone(),
model_id: s.model_id.clone(),
mode_id: s.mode_id.0.as_ref().to_string(),
history: s.history.clone(),
created_at: store::now_secs(),
updated_at: store::now_secs(),
}
};
if let Err(e) = store::save(&snapshot) {
tracing::warn!(
session_id = %req.session_id.0,
error = %format!("{e:#}"),
"session persist after set_model failed; on-disk model id stays stale"
);
}
tracing::info!(
session_id = %req.session_id.0,
model_id = %target,
"session model changed"
);
Ok(())
}
/// The three modes every Stage 3 session advertises: /// The three modes every Stage 3 session advertises:
/// ///
/// - **Default** — writes / bash prompt the user. /// - **Default** — writes / bash prompt the user.
@@ -1375,6 +1522,55 @@ mod tests {
} }
} }
/// Provider stub whose `list_models` returns canned results.
/// Used by the `aggregate_models` tests.
struct ModelProvider {
name: &'static str,
models: anyhow::Result<Vec<crate::provider::ModelInfo>>,
}
impl ModelProvider {
fn ok(name: &'static str, ids: &[&str]) -> Arc<dyn Provider> {
let models = ids
.iter()
.map(|id| crate::provider::ModelInfo {
id: (*id).to_string(),
display_name: None,
})
.collect();
Arc::new(Self {
name,
models: Ok(models),
})
}
fn err(name: &'static str, msg: &'static str) -> Arc<dyn Provider> {
Arc::new(Self {
name,
models: Err(anyhow::anyhow!(msg)),
})
}
}
#[async_trait]
impl Provider for ModelProvider {
fn name(&self) -> &str {
self.name
}
async fn list_models(&self) -> anyhow::Result<Vec<crate::provider::ModelInfo>> {
match &self.models {
Ok(v) => Ok(v.clone()),
Err(e) => Err(anyhow::anyhow!("{e:#}")),
}
}
async fn complete(
&self,
_request: CompletionRequest,
_cancel: CancellationToken,
) -> anyhow::Result<BoxStream<'static, anyhow::Result<CompletionEvent>>> {
unimplemented!()
}
}
fn providers() -> Vec<Arc<dyn Provider>> { fn providers() -> Vec<Arc<dyn Provider>> {
vec![ vec![
Arc::new(StubProvider("helexa")), Arc::new(StubProvider("helexa")),
@@ -1431,6 +1627,47 @@ mod tests {
assert_eq!(prompt_budget(1_000, Some(8_192)), 0); assert_eq!(prompt_budget(1_000, Some(8_192)), 0);
} }
// ── aggregate_models ────────────────────────────────────────────
#[tokio::test]
async fn aggregate_models_single_endpoint_has_bare_ids() {
let providers = vec![ModelProvider::ok(
"helexa",
&["helexa/large", "helexa/small"],
)];
let models = aggregate_models(&providers).await;
let ids: Vec<&str> = models.iter().map(|m| m.model_id.0.as_ref()).collect();
assert_eq!(ids, vec!["helexa/large", "helexa/small"]);
}
#[tokio::test]
async fn aggregate_models_multi_endpoint_prefixes_every_id() {
let providers = vec![
ModelProvider::ok("helexa", &["helexa/large"]),
ModelProvider::ok("openrouter", &["anthropic/claude-opus-4"]),
];
let models = aggregate_models(&providers).await;
let ids: Vec<&str> = models.iter().map(|m| m.model_id.0.as_ref()).collect();
assert_eq!(
ids,
vec!["helexa:helexa/large", "openrouter:anthropic/claude-opus-4"]
);
}
#[tokio::test]
async fn aggregate_models_skips_failing_endpoint() {
let providers = vec![
ModelProvider::err("flaky", "boom"),
ModelProvider::ok("openrouter", &["gpt-9"]),
];
let models = aggregate_models(&providers).await;
let ids: Vec<&str> = models.iter().map(|m| m.model_id.0.as_ref()).collect();
// Multi-endpoint case → prefix survives even when one
// endpoint dropped out. flaky's models are absent, not
// null-filled.
assert_eq!(ids, vec!["openrouter:gpt-9"]);
}
#[test] #[test]
fn maps_known_finish_reasons() { fn maps_known_finish_reasons() {
assert!(matches!( assert!(matches!(

View File

@@ -142,6 +142,7 @@ async fn main() -> Result<()> {
} }
let agent = Agent::new(&cfg, providers) let agent = Agent::new(&cfg, providers)
.await
.map_err(|e| agent_client_protocol::util::internal_error(format!("agent: {e:#}")))?; .map_err(|e| agent_client_protocol::util::internal_error(format!("agent: {e:#}")))?;
agent.serve(Stdio::new()).await agent.serve(Stdio::new()).await
} }