feat(helexa-acp): model picker + session/set_model handler
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 41s
CI / Clippy (push) Successful in 2m32s
build-prerelease / Build cortex binary (push) Successful in 4m45s
CI / Test (push) Successful in 5m52s
build-prerelease / Build neuron-blackwell (push) Successful in 5m59s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-ampere (push) Successful in 7m21s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ada (push) Successful in 4m54s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m58s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m48s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m3s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 41s
CI / Clippy (push) Successful in 2m32s
build-prerelease / Build cortex binary (push) Successful in 4m45s
CI / Test (push) Successful in 5m52s
build-prerelease / Build neuron-blackwell (push) Successful in 5m59s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-ampere (push) Successful in 7m21s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ada (push) Successful in 4m54s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m58s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m48s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m3s
Stage 4. Zed's model dropdown now lists every model from every configured endpoint, and switching it routes the next prompt to a new endpoint+model. - Enable `unstable_session_model` on the agent-client-protocol dep so SessionModelState / SetSessionModelRequest / ModelInfo are available. - Agent::new becomes async and calls Provider::list_models on every provider at startup; per-endpoint failures warn-and-skip instead of aborting the agent. - With a single endpoint configured, model ids appear bare; with multiple endpoints every id carries the `endpoint:` prefix so the picker is unambiguous and parse_model_selector routes correctly. - NewSessionResponse and LoadSessionResponse attach SessionModelState with the session's current model id + the aggregated catalogue. - session/set_model: validates the requested model id against resolve_provider, mutates session.model_id, and persists so the on-disk transcript reflects the new model. Three new aggregate_models tests cover the prefixing rule (bare vs multi-endpoint) and warn-and-skip on a failing endpoint. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -16,7 +16,12 @@ to cortex (helexa's reverse-proxy / fleet gateway).
|
||||
# a painless migration to a dedicated GitHub repo in the future if the
|
||||
# project grows beyond helexa's needs. All deps are crates.io.
|
||||
[dependencies]
|
||||
agent-client-protocol = "0.12"
|
||||
# `unstable_session_model` flips on the SessionModelState type and the
|
||||
# session/set_model RPC the model-picker dropdown in Zed needs. The
|
||||
# feature is upstream-marked unstable; we accept that risk because the
|
||||
# model picker is core UX and the alternative (rolling our own
|
||||
# extension method) drifts further from spec each time it moves.
|
||||
agent-client-protocol = { version = "0.12", features = ["unstable_session_model"] }
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "io-util", "process", "signal"] }
|
||||
reqwest = { version = "0.12", features = ["json", "stream", "rustls-tls"], default-features = false }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
|
||||
@@ -9,9 +9,10 @@
|
||||
//! | `session/prompt` | tool-call loop: stream → dispatch tools → re-enter, repeat |
|
||||
//! | `session/cancel` | fire the session's cancellation token |
|
||||
//! | `session/set_mode` | mutate the session's mode (gated vs. bypass-permissions) |
|
||||
//! | `session/set_model` | switch the session's active model (endpoint:model selector) |
|
||||
//! | (anything else) | "not implemented yet" error |
|
||||
//!
|
||||
//! Stage 4 wires `session/set_model`; Stage 5 flips on image content.
|
||||
//! Stage 5 flips on image content.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
@@ -19,11 +20,12 @@ use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use agent_client_protocol::schema::{
|
||||
AgentCapabilities, CancelNotification, ContentBlock, InitializeRequest, InitializeResponse,
|
||||
ListSessionsRequest, ListSessionsResponse, LoadSessionRequest, LoadSessionResponse,
|
||||
NewSessionRequest, NewSessionResponse, PromptCapabilities, PromptRequest, PromptResponse,
|
||||
SessionCapabilities, SessionId, SessionInfo, SessionListCapabilities, SessionMode,
|
||||
SessionModeId, SessionModeState, SessionNotification, SessionUpdate, SetSessionModeRequest,
|
||||
SetSessionModeResponse, StopReason, TextContent,
|
||||
ListSessionsRequest, ListSessionsResponse, LoadSessionRequest, LoadSessionResponse, ModelId,
|
||||
ModelInfo as AcpModelInfo, NewSessionRequest, NewSessionResponse, PromptCapabilities,
|
||||
PromptRequest, PromptResponse, SessionCapabilities, SessionId, SessionInfo,
|
||||
SessionListCapabilities, SessionMode, SessionModeId, SessionModeState, SessionModelState,
|
||||
SessionNotification, SessionUpdate, SetSessionModeRequest, SetSessionModeResponse,
|
||||
SetSessionModelRequest, SetSessionModelResponse, StopReason, TextContent,
|
||||
};
|
||||
use agent_client_protocol::{Agent as AgentRole, Client, ConnectionTo, Dispatch, Stdio};
|
||||
use futures::StreamExt;
|
||||
@@ -73,6 +75,14 @@ struct AgentInner {
|
||||
/// fits inside `context_window - max_tokens - safety` tokens.
|
||||
/// Absent entry → no compaction (legacy behaviour).
|
||||
context_window: std::collections::HashMap<String, usize>,
|
||||
/// Aggregated list of selectable models across every configured
|
||||
/// endpoint, computed once at startup. With a single endpoint
|
||||
/// the model ids appear bare; with multiple endpoints every id
|
||||
/// carries the `endpoint:` prefix so the picker is unambiguous.
|
||||
/// Empty when every provider's `list_models` failed at startup —
|
||||
/// the dropdown then shows nothing and the session keeps using
|
||||
/// the configured `default_model`.
|
||||
available_models: Vec<AcpModelInfo>,
|
||||
sessions: SessionStore,
|
||||
system_prompt_path: Option<PathBuf>,
|
||||
/// Monotonic counter for minting session ids. The wire format is
|
||||
@@ -85,7 +95,13 @@ struct AgentInner {
|
||||
impl Agent {
|
||||
/// Construct an agent from a validated [`Config`] and the providers
|
||||
/// that were successfully built for each endpoint.
|
||||
pub fn new(cfg: &Config, providers: Vec<Arc<dyn Provider>>) -> anyhow::Result<Self> {
|
||||
///
|
||||
/// `async` because we call `Provider::list_models` on every
|
||||
/// provider up-front so the model-picker dropdown is populated
|
||||
/// from the very first `session/new`. Per-endpoint failure
|
||||
/// warns and skips rather than aborting startup — a single
|
||||
/// unreachable endpoint shouldn't take down the agent.
|
||||
pub async fn new(cfg: &Config, providers: Vec<Arc<dyn Provider>>) -> anyhow::Result<Self> {
|
||||
if providers.is_empty() {
|
||||
anyhow::bail!("no usable providers");
|
||||
}
|
||||
@@ -110,6 +126,12 @@ impl Agent {
|
||||
.iter()
|
||||
.filter_map(|ep| ep.context_window.map(|w| (ep.name.clone(), w)))
|
||||
.collect();
|
||||
let available_models = aggregate_models(&providers).await;
|
||||
tracing::info!(
|
||||
models = available_models.len(),
|
||||
endpoints = providers.len(),
|
||||
"model catalogue assembled"
|
||||
);
|
||||
Ok(Self {
|
||||
inner: Arc::new(AgentInner {
|
||||
providers,
|
||||
@@ -117,6 +139,7 @@ impl Agent {
|
||||
default_model: default.default_model.clone(),
|
||||
max_tokens,
|
||||
context_window,
|
||||
available_models,
|
||||
sessions: session::new_store(),
|
||||
system_prompt_path: cfg.system_prompt_path.clone(),
|
||||
next_session_id: AtomicU64::new(1),
|
||||
@@ -207,6 +230,18 @@ impl Agent {
|
||||
},
|
||||
agent_client_protocol::on_receive_request!(),
|
||||
)
|
||||
.on_receive_request(
|
||||
{
|
||||
let inner = inner.clone();
|
||||
async move |req: SetSessionModelRequest, responder, _cx| {
|
||||
match handle_set_session_model(&inner, req).await {
|
||||
Ok(()) => responder.respond(SetSessionModelResponse::new()),
|
||||
Err(e) => responder.respond_with_internal_error(format!("{e:#}")),
|
||||
}
|
||||
}
|
||||
},
|
||||
agent_client_protocol::on_receive_request!(),
|
||||
)
|
||||
.on_receive_notification(
|
||||
{
|
||||
let inner = inner.clone();
|
||||
@@ -300,7 +335,12 @@ async fn handle_new_session(
|
||||
cwd = %cwd_display,
|
||||
"session created"
|
||||
);
|
||||
Ok(NewSessionResponse::new(session_id).modes(default_mode_state()))
|
||||
let resp = NewSessionResponse::new(session_id).modes(default_mode_state());
|
||||
let resp = match session_model_state(inner, &log_model) {
|
||||
Some(models) => resp.models(Some(models)),
|
||||
None => resp,
|
||||
};
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
/// Rehydrate a session from disk.
|
||||
@@ -353,7 +393,12 @@ async fn handle_load_session(
|
||||
SessionModeId::new(mode_id),
|
||||
default_mode_state().available_modes,
|
||||
);
|
||||
Ok((LoadSessionResponse::new().modes(modes), history_for_replay))
|
||||
let resp = LoadSessionResponse::new().modes(modes);
|
||||
let resp = match session_model_state(inner, &model_id) {
|
||||
Some(models) => resp.models(Some(models)),
|
||||
None => resp,
|
||||
};
|
||||
Ok((resp, history_for_replay))
|
||||
}
|
||||
|
||||
/// Re-emit a session's persisted history as `session/update`
|
||||
@@ -559,6 +604,108 @@ fn derive_session_title(history: &[Message]) -> Option<String> {
|
||||
.filter(|s| !s.is_empty())
|
||||
}
|
||||
|
||||
/// Build the model catalogue advertised in `NewSessionResponse.models`
|
||||
/// (and the resume equivalent). Walks every provider, calls
|
||||
/// `list_models`, and prefixes ids with `endpoint:` when the user has
|
||||
/// more than one endpoint configured. A failing endpoint logs and
|
||||
/// contributes nothing — losing one endpoint must not blank the whole
|
||||
/// dropdown.
|
||||
async fn aggregate_models(providers: &[Arc<dyn Provider>]) -> Vec<AcpModelInfo> {
|
||||
let multi_endpoint = providers.len() > 1;
|
||||
let mut out: Vec<AcpModelInfo> = Vec::new();
|
||||
for provider in providers {
|
||||
let endpoint = provider.name().to_string();
|
||||
match provider.list_models().await {
|
||||
Ok(models) => {
|
||||
tracing::info!(
|
||||
endpoint = %endpoint,
|
||||
count = models.len(),
|
||||
"fetched models from endpoint"
|
||||
);
|
||||
for m in models {
|
||||
let id = if multi_endpoint {
|
||||
format!("{endpoint}:{}", m.id)
|
||||
} else {
|
||||
m.id.clone()
|
||||
};
|
||||
let display = m.display_name.unwrap_or_else(|| m.id.clone());
|
||||
let info = AcpModelInfo::new(ModelId::new(id), display)
|
||||
.description(Some(format!("endpoint: {endpoint}")));
|
||||
out.push(info);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
endpoint = %endpoint,
|
||||
error = %format!("{e:#}"),
|
||||
"list_models failed; this endpoint's models won't appear in the picker"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Build the `SessionModelState` that Zed renders as the
|
||||
/// model-picker dropdown. The current model id is exactly what
|
||||
/// the session is using right now (already in `endpoint:model`
|
||||
/// form if it was set that way). Returns `None` when the
|
||||
/// catalogue is empty — no point showing an empty dropdown.
|
||||
fn session_model_state(inner: &AgentInner, current: &str) -> Option<SessionModelState> {
|
||||
if inner.available_models.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(SessionModelState::new(
|
||||
ModelId::new(current.to_string()),
|
||||
inner.available_models.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn handle_set_session_model(
|
||||
inner: &AgentInner,
|
||||
req: SetSessionModelRequest,
|
||||
) -> anyhow::Result<()> {
|
||||
let Some(state) = session::get(&inner.sessions, &req.session_id).await else {
|
||||
anyhow::bail!("unknown session id {}", req.session_id.0);
|
||||
};
|
||||
let target = req.model_id.0.as_ref().to_string();
|
||||
// Validate the requested model id resolves to a configured
|
||||
// provider. We don't require it to appear in `available_models`
|
||||
// because the catalogue may be stale (endpoint added a model
|
||||
// after startup) and rejecting unknown ids would be too rigid.
|
||||
// Provider lookup is the actual source of truth.
|
||||
let (_, _) = resolve_provider(&inner.providers, &inner.default_endpoint_name, &target)
|
||||
.map_err(|e| anyhow::anyhow!("set_session_model: {e:#}"))?;
|
||||
// Persist the new model id on the session under the mutex,
|
||||
// then snapshot for disk persistence outside the lock.
|
||||
let snapshot = {
|
||||
let mut s = state.lock().await;
|
||||
s.model_id = target.clone();
|
||||
PersistedSession {
|
||||
session_id: req.session_id.0.as_ref().to_string(),
|
||||
cwd: s.cwd.clone(),
|
||||
model_id: s.model_id.clone(),
|
||||
mode_id: s.mode_id.0.as_ref().to_string(),
|
||||
history: s.history.clone(),
|
||||
created_at: store::now_secs(),
|
||||
updated_at: store::now_secs(),
|
||||
}
|
||||
};
|
||||
if let Err(e) = store::save(&snapshot) {
|
||||
tracing::warn!(
|
||||
session_id = %req.session_id.0,
|
||||
error = %format!("{e:#}"),
|
||||
"session persist after set_model failed; on-disk model id stays stale"
|
||||
);
|
||||
}
|
||||
tracing::info!(
|
||||
session_id = %req.session_id.0,
|
||||
model_id = %target,
|
||||
"session model changed"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The three modes every Stage 3 session advertises:
|
||||
///
|
||||
/// - **Default** — writes / bash prompt the user.
|
||||
@@ -1375,6 +1522,55 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
/// Provider stub whose `list_models` returns canned results.
|
||||
/// Used by the `aggregate_models` tests.
|
||||
struct ModelProvider {
|
||||
name: &'static str,
|
||||
models: anyhow::Result<Vec<crate::provider::ModelInfo>>,
|
||||
}
|
||||
|
||||
impl ModelProvider {
|
||||
fn ok(name: &'static str, ids: &[&str]) -> Arc<dyn Provider> {
|
||||
let models = ids
|
||||
.iter()
|
||||
.map(|id| crate::provider::ModelInfo {
|
||||
id: (*id).to_string(),
|
||||
display_name: None,
|
||||
})
|
||||
.collect();
|
||||
Arc::new(Self {
|
||||
name,
|
||||
models: Ok(models),
|
||||
})
|
||||
}
|
||||
fn err(name: &'static str, msg: &'static str) -> Arc<dyn Provider> {
|
||||
Arc::new(Self {
|
||||
name,
|
||||
models: Err(anyhow::anyhow!(msg)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for ModelProvider {
|
||||
fn name(&self) -> &str {
|
||||
self.name
|
||||
}
|
||||
async fn list_models(&self) -> anyhow::Result<Vec<crate::provider::ModelInfo>> {
|
||||
match &self.models {
|
||||
Ok(v) => Ok(v.clone()),
|
||||
Err(e) => Err(anyhow::anyhow!("{e:#}")),
|
||||
}
|
||||
}
|
||||
async fn complete(
|
||||
&self,
|
||||
_request: CompletionRequest,
|
||||
_cancel: CancellationToken,
|
||||
) -> anyhow::Result<BoxStream<'static, anyhow::Result<CompletionEvent>>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
fn providers() -> Vec<Arc<dyn Provider>> {
|
||||
vec![
|
||||
Arc::new(StubProvider("helexa")),
|
||||
@@ -1431,6 +1627,47 @@ mod tests {
|
||||
assert_eq!(prompt_budget(1_000, Some(8_192)), 0);
|
||||
}
|
||||
|
||||
// ── aggregate_models ────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn aggregate_models_single_endpoint_has_bare_ids() {
|
||||
let providers = vec![ModelProvider::ok(
|
||||
"helexa",
|
||||
&["helexa/large", "helexa/small"],
|
||||
)];
|
||||
let models = aggregate_models(&providers).await;
|
||||
let ids: Vec<&str> = models.iter().map(|m| m.model_id.0.as_ref()).collect();
|
||||
assert_eq!(ids, vec!["helexa/large", "helexa/small"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn aggregate_models_multi_endpoint_prefixes_every_id() {
|
||||
let providers = vec![
|
||||
ModelProvider::ok("helexa", &["helexa/large"]),
|
||||
ModelProvider::ok("openrouter", &["anthropic/claude-opus-4"]),
|
||||
];
|
||||
let models = aggregate_models(&providers).await;
|
||||
let ids: Vec<&str> = models.iter().map(|m| m.model_id.0.as_ref()).collect();
|
||||
assert_eq!(
|
||||
ids,
|
||||
vec!["helexa:helexa/large", "openrouter:anthropic/claude-opus-4"]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn aggregate_models_skips_failing_endpoint() {
|
||||
let providers = vec![
|
||||
ModelProvider::err("flaky", "boom"),
|
||||
ModelProvider::ok("openrouter", &["gpt-9"]),
|
||||
];
|
||||
let models = aggregate_models(&providers).await;
|
||||
let ids: Vec<&str> = models.iter().map(|m| m.model_id.0.as_ref()).collect();
|
||||
// Multi-endpoint case → prefix survives even when one
|
||||
// endpoint dropped out. flaky's models are absent, not
|
||||
// null-filled.
|
||||
assert_eq!(ids, vec!["openrouter:gpt-9"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn maps_known_finish_reasons() {
|
||||
assert!(matches!(
|
||||
|
||||
@@ -142,6 +142,7 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
let agent = Agent::new(&cfg, providers)
|
||||
.await
|
||||
.map_err(|e| agent_client_protocol::util::internal_error(format!("agent: {e:#}")))?;
|
||||
agent.serve(Stdio::new()).await
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user