diff --git a/crates/cortex-core/src/catalogue.rs b/crates/cortex-core/src/catalogue.rs index 1656e4f..ac0e028 100644 --- a/crates/cortex-core/src/catalogue.rs +++ b/crates/cortex-core/src/catalogue.rs @@ -2,6 +2,7 @@ use crate::discovery::DeviceInfo; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::path::Path; /// A model serving profile loaded from models.toml. @@ -34,6 +35,14 @@ fn default_min_devices() -> u32 { pub struct ModelCatalogue { #[serde(default)] pub models: Vec, + /// Tier aliases — clients can send a request with `model: "helexa/small"` + /// and the gateway transparently rewrites + routes to the concrete + /// model id this maps to. Lets operators define latency/quality + /// tiers (`small`/`balanced`/`large`, `fast`/`thinking`, etc.) + /// without imposing knowledge of specific model ids on clients. + /// Loaded from the `[aliases]` table in models.toml. + #[serde(default)] + pub aliases: HashMap, } impl ModelCatalogue { @@ -70,6 +79,13 @@ impl ModelCatalogue { pub fn get(&self, model_id: &str) -> Option<&ModelProfile> { self.models.iter().find(|p| p.id == model_id) } + + /// Resolve an alias to its concrete model id. Returns `id` verbatim + /// when it isn't an alias. Aliases never chain — operator config + /// is treated as flat — so this is a single lookup. + pub fn resolve_alias<'a>(&'a self, id: &'a str) -> &'a str { + self.aliases.get(id).map(String::as_str).unwrap_or(id) + } } impl ModelProfile { @@ -164,4 +180,32 @@ mod tests { let devices = [device(0, 1_000), device(1, 1_000)]; assert!(p.is_feasible_on("anywhere", &devices)); } + + #[test] + fn resolve_alias_returns_target_when_alias_present() { + let mut cat = ModelCatalogue::default(); + cat.aliases + .insert("helexa/small".into(), "Qwen/Qwen3-1.7B".into()); + assert_eq!(cat.resolve_alias("helexa/small"), "Qwen/Qwen3-1.7B"); + } + + #[test] + fn resolve_alias_passes_through_when_not_an_alias() { + let mut cat = ModelCatalogue::default(); + cat.aliases + .insert("helexa/small".into(), "Qwen/Qwen3-1.7B".into()); + assert_eq!(cat.resolve_alias("Qwen/Qwen3-8B"), "Qwen/Qwen3-8B"); + } + + #[test] + fn aliases_table_round_trips_through_toml() { + let src = r#" +[aliases] +"helexa/small" = "Qwen/Qwen3-1.7B" +"helexa/large" = "Qwen/Qwen3.6-27B" +"#; + let cat: ModelCatalogue = toml::from_str(src).expect("parse aliases table"); + assert_eq!(cat.resolve_alias("helexa/small"), "Qwen/Qwen3-1.7B"); + assert_eq!(cat.resolve_alias("helexa/large"), "Qwen/Qwen3.6-27B"); + } } diff --git a/crates/cortex-gateway/src/handlers.rs b/crates/cortex-gateway/src/handlers.rs index daf799f..a6df990 100644 --- a/crates/cortex-gateway/src/handlers.rs +++ b/crates/cortex-gateway/src/handlers.rs @@ -60,15 +60,16 @@ async fn chat_completions( } }; - touch_model(&fleet, &route.node_name, &model_id).await; + touch_model(&fleet, &route.node_name, &route.resolved_model_id).await; + let body = rewrite_model_in_body(body, &route.resolved_model_id); proxy_with_metrics( &fleet, &route, "/v1/chat/completions", headers, body, - &model_id, + &route.resolved_model_id, ) .await } @@ -107,9 +108,18 @@ async fn completions( } }; - touch_model(&fleet, &route.node_name, &model_id).await; + touch_model(&fleet, &route.node_name, &route.resolved_model_id).await; - proxy_with_metrics(&fleet, &route, "/v1/completions", headers, body, &model_id).await + let body = rewrite_model_in_body(body, &route.resolved_model_id); + proxy_with_metrics( + &fleet, + &route, + "/v1/completions", + headers, + body, + &route.resolved_model_id, + ) + .await } /// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back. @@ -166,10 +176,15 @@ async fn anthropic_messages( } }; - touch_model(&fleet, &route.node_name, &model_id).await; + touch_model(&fleet, &route.node_name, &route.resolved_model_id).await; + + // Swap the alias for the concrete id in the translated body so + // neuron's harness sees a model name that matches what it has + // loaded. + let openai_body = rewrite_model_in_body(openai_body, &route.resolved_model_id); let labels = [ - ("model", model_id.clone()), + ("model", route.resolved_model_id.clone()), ("node", route.node_name.clone()), ]; metrics::counter!("cortex_requests_total", &labels).increment(1); @@ -434,6 +449,35 @@ async fn list_models(State(fleet): State>) -> Json { } } + // Pass 4: surface aliases as their own entries pointing at the + // same locations as the target id, so a client browsing /v1/models + // sees "helexa/small" / "helexa/balanced" / "helexa/large" (or + // whatever the operator defined) and can request inference + // against them directly. Aliases that point at unknown targets + // are skipped — surfacing a dead alias would be misleading. + for (alias, target) in &catalogue.aliases { + let Some(target_entry) = entries.get(target).cloned() else { + tracing::warn!( + alias = alias, + target = target, + "alias points at a model not present in catalogue or fleet; skipping" + ); + continue; + }; + entries.insert( + alias.clone(), + CortexModelEntry { + id: alias.clone(), + object: "model".into(), + created: now, + owned_by: "helexa".into(), + loaded: target_entry.loaded, + feasible_on: target_entry.feasible_on, + locations: target_entry.locations, + }, + ); + } + let data: Vec = entries.values().map(|e| json!(e)).collect(); Json(json!({ "object": "list", @@ -512,6 +556,38 @@ fn extract_model(body: &[u8]) -> Option { v.get("model")?.as_str().map(|s| s.to_string()) } +/// Rewrite the `model` field of an OpenAI-style JSON request body to +/// the resolved concrete id. Returns the original bytes if `new_model` +/// matches what's already there or the body fails to parse — the +/// caller has already extracted `model` via `extract_model`, so a +/// parse failure here would only happen on a body the client crafted +/// to defeat us, and we'd rather proxy it unchanged than 500. +/// +/// Needed because neuron rejects requests whose `model` field doesn't +/// match a loaded model, so a client that sends `model: "helexa/small"` +/// would hit a 404 at the harness unless we swap it for the concrete +/// id the alias resolved to. +fn rewrite_model_in_body(body: Bytes, new_model: &str) -> Bytes { + let Ok(mut v) = serde_json::from_slice::(&body) else { + return body; + }; + let needs_rewrite = v + .get("model") + .and_then(|m| m.as_str()) + .map(|m| m != new_model) + .unwrap_or(false); + if !needs_rewrite { + return body; + } + if let Value::Object(obj) = &mut v { + obj.insert("model".into(), Value::String(new_model.to_string())); + } + match serde_json::to_vec(&v) { + Ok(bytes) => Bytes::from(bytes), + Err(_) => body, + } +} + fn error_response(status: u16, message: &str) -> Response { let code = axum::http::StatusCode::from_u16(status) .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR); diff --git a/crates/cortex-gateway/src/router.rs b/crates/cortex-gateway/src/router.rs index 5d4db72..5e59757 100644 --- a/crates/cortex-gateway/src/router.rs +++ b/crates/cortex-gateway/src/router.rs @@ -29,6 +29,13 @@ pub struct RouteDecision { /// when we just triggered an explicit cold-load via the catalogue /// path. pub cold_start: bool, + /// The concrete model id we actually routed to. Equal to the + /// caller's requested id unless an alias was resolved (e.g. caller + /// asked for `helexa/small`, this carries `Qwen/Qwen3-1.7B`). The + /// handler uses this to rewrite the request body's `model` field + /// before proxying — neurons reject requests where the body's + /// model name doesn't match a loaded model. + pub resolved_model_id: String, } #[derive(Debug, thiserror::Error)] @@ -55,8 +62,20 @@ pub enum RouteError { /// Asks the neuron for the inference endpoint after selecting a node. pub async fn resolve( fleet: &Arc, - model_id: &str, + requested_model_id: &str, ) -> Result { + // Alias resolution first — swap `helexa/small` (etc.) for the + // concrete id before any node lookups so the rest of routing, + // loading, and metrics deal in concrete ids only. `resolve_alias` + // returns the input verbatim when it isn't an alias. + let model_id = fleet.catalogue.resolve_alias(requested_model_id); + if model_id != requested_model_id { + tracing::debug!( + requested = requested_model_id, + resolved = model_id, + "alias resolved" + ); + } // Snapshot loaded / unloaded state from the poller cache. let (loaded_route, unloaded_route, any_healthy) = { let nodes = fleet.nodes.read().await; @@ -326,6 +345,7 @@ async fn finish( node_name: node_name.to_string(), endpoint, cold_start, + resolved_model_id: model_id.to_string(), }) } diff --git a/crates/cortex-gateway/tests/aliases.rs b/crates/cortex-gateway/tests/aliases.rs new file mode 100644 index 0000000..4bf5a40 --- /dev/null +++ b/crates/cortex-gateway/tests/aliases.rs @@ -0,0 +1,265 @@ +//! Alias resolution: a client request with `model: "helexa/small"` +//! routes to the concrete model id (e.g. `Qwen/Qwen3-1.7B`), with the +//! proxied request body rewritten so the upstream neuron sees a model +//! name that matches its loaded handle. + +mod common; + +use cortex_core::config::{ + EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint, +}; +use cortex_core::node::{ModelEntry, ModelStatus}; +use cortex_gateway::state::CortexState; +use serde_json::json; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::net::TcpListener; + +/// Write a `models.toml` with one alias to a unique temp path. Returns +/// the path; the file persists for the test process and gets reaped by +/// the OS at exit. Using $XDG_RUNTIME_DIR fallback for the temp dir +/// keeps the file off shared /tmp on CI without pulling in tempfile. +fn write_models_toml(alias: &str, target: &str) -> PathBuf { + let contents = format!( + r#" +[aliases] +"{alias}" = "{target}" +"# + ); + let mut path = std::env::temp_dir(); + let pid = std::process::id(); + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + path.push(format!("cortex-test-models-{pid}-{now}.toml")); + std::fs::write(&path, contents).expect("write temp models.toml"); + path +} + +#[tokio::test] +async fn test_alias_resolves_in_chat_completions() { + let mock_url = common::spawn_mock_neuron().await; + let models_path = write_models_toml("helexa/small", "test-model"); + + let config = GatewayConfig { + gateway: GatewaySettings { + listen: "127.0.0.1:0".into(), + metrics_listen: "127.0.0.1:0".into(), + }, + eviction: EvictionSettings { + strategy: EvictionStrategy::Lru, + defrag_after_cycles: 0, + }, + neurons: vec![NeuronEndpoint { + name: "mock-node".into(), + endpoint: mock_url, + }], + models_config: models_path.to_string_lossy().to_string(), + }; + + let fleet = Arc::new(CortexState::from_config(&config)); + + // Seed the node as healthy with the concrete model loaded under + // the target id. The poller doesn't run in this test; we just + // populate state manually. + { + let mut nodes = fleet.nodes.write().await; + let node = nodes.get_mut("mock-node").expect("node must exist"); + node.healthy = true; + node.models.insert( + "test-model".into(), + ModelEntry { + id: "test-model".into(), + status: ModelStatus::Loaded, + last_accessed: None, + vram_estimate_mb: None, + }, + ); + } + + // Sanity: the catalogue actually picked up the alias. + assert_eq!( + fleet.catalogue.resolve_alias("helexa/small"), + "test-model", + "alias should resolve to target id" + ); + + // Spawn the gateway against this fleet. + let app = cortex_gateway::build_app(Arc::clone(&fleet)); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let gateway_addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + let gateway_url = format!("http://{gateway_addr}"); + + // Send a chat completion against the alias. The mock backend + // echoes back the `model` field it received — so a body whose + // model wasn't rewritten would come back as "helexa/small", and a + // properly-rewritten one as "test-model". + let client = reqwest::Client::new(); + let resp = client + .post(format!("{gateway_url}/v1/chat/completions")) + .json(&json!({ + "model": "helexa/small", + "messages": [{"role": "user", "content": "hi"}], + })) + .send() + .await + .expect("gateway should respond"); + + assert!(resp.status().is_success(), "gateway returned non-2xx"); + let body: serde_json::Value = resp.json().await.expect("response is JSON"); + assert_eq!( + body.get("model").and_then(|m| m.as_str()), + Some("test-model"), + "mock backend should have seen the resolved model id, not the alias" + ); +} + +#[tokio::test] +async fn test_aliases_surface_in_v1_models() { + let mock_url = common::spawn_mock_neuron().await; + let models_path = write_models_toml("helexa/small", "test-model"); + + let config = GatewayConfig { + gateway: GatewaySettings { + listen: "127.0.0.1:0".into(), + metrics_listen: "127.0.0.1:0".into(), + }, + eviction: EvictionSettings { + strategy: EvictionStrategy::Lru, + defrag_after_cycles: 0, + }, + neurons: vec![NeuronEndpoint { + name: "mock-node".into(), + endpoint: mock_url, + }], + models_config: models_path.to_string_lossy().to_string(), + }; + + let fleet = Arc::new(CortexState::from_config(&config)); + + // Seed the target as loaded so the alias's mirrored entry shows + // loaded=true. + { + let mut nodes = fleet.nodes.write().await; + let node = nodes.get_mut("mock-node").expect("node must exist"); + node.healthy = true; + node.models.insert( + "test-model".into(), + ModelEntry { + id: "test-model".into(), + status: ModelStatus::Loaded, + last_accessed: None, + vram_estimate_mb: Some(2000), + }, + ); + } + + let app = cortex_gateway::build_app(Arc::clone(&fleet)); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let gateway_addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + let gateway_url = format!("http://{gateway_addr}"); + + let resp = reqwest::get(format!("{gateway_url}/v1/models")) + .await + .expect("gateway should respond"); + let body: serde_json::Value = resp.json().await.unwrap(); + let entries = body + .get("data") + .and_then(|d| d.as_array()) + .expect("data array"); + + // Both the alias and the target should be present. + let ids: Vec<&str> = entries + .iter() + .filter_map(|e| e.get("id").and_then(|v| v.as_str())) + .collect(); + assert!(ids.contains(&"test-model"), "target should be listed"); + assert!(ids.contains(&"helexa/small"), "alias should be listed"); + + // The alias's `loaded` flag and locations should mirror the target. + let alias_entry = entries + .iter() + .find(|e| e.get("id").and_then(|v| v.as_str()) == Some("helexa/small")) + .expect("alias entry"); + assert_eq!(alias_entry.get("loaded"), Some(&json!(true))); + let locations = alias_entry + .get("locations") + .and_then(|l| l.as_array()) + .expect("locations array"); + assert_eq!(locations.len(), 1); + assert_eq!( + locations[0].get("node").and_then(|n| n.as_str()), + Some("mock-node") + ); +} + +#[tokio::test] +async fn test_alias_falls_through_for_unmapped_model() { + // Catalogue has an alias for some-other-thing but the request + // model "test-model" isn't an alias; resolution should be a no-op. + let mock_url = common::spawn_mock_neuron().await; + let models_path = write_models_toml("helexa/large", "definitely-not-loaded"); + + let config = GatewayConfig { + gateway: GatewaySettings { + listen: "127.0.0.1:0".into(), + metrics_listen: "127.0.0.1:0".into(), + }, + eviction: EvictionSettings { + strategy: EvictionStrategy::Lru, + defrag_after_cycles: 0, + }, + neurons: vec![NeuronEndpoint { + name: "mock-node".into(), + endpoint: mock_url, + }], + models_config: models_path.to_string_lossy().to_string(), + }; + + let fleet = Arc::new(CortexState::from_config(&config)); + { + let mut nodes = fleet.nodes.write().await; + let node = nodes.get_mut("mock-node").expect("node must exist"); + node.healthy = true; + node.models.insert( + "test-model".into(), + ModelEntry { + id: "test-model".into(), + status: ModelStatus::Loaded, + last_accessed: None, + vram_estimate_mb: None, + }, + ); + } + + let app = cortex_gateway::build_app(Arc::clone(&fleet)); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let gateway_addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + let gateway_url = format!("http://{gateway_addr}"); + + let resp = reqwest::Client::new() + .post(format!("{gateway_url}/v1/chat/completions")) + .json(&json!({ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + })) + .send() + .await + .unwrap(); + assert!(resp.status().is_success()); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!( + body.get("model").and_then(|m| m.as_str()), + Some("test-model") + ); +} diff --git a/models.example.toml b/models.example.toml index 5f73c1f..12da8a5 100644 --- a/models.example.toml +++ b/models.example.toml @@ -48,3 +48,17 @@ quant = "Q4_K_M" vram_mb = 500 min_devices = 1 min_device_vram_mb = 4000 + +# -- Tier aliases ------------------------------------------------------------ +# Optional. Clients can request inference against an alias (e.g. +# `model: "helexa/small"` in /v1/chat/completions) and cortex +# transparently routes to the concrete model id below — including +# rewriting the body's model field so neuron sees a name that matches +# its loaded handle. Both the alias and the target appear in +# /v1/models so clients can discover either. Operators can swap +# targets here without changing client code. +# +# [aliases] +# "helexa/small" = "Qwen/Qwen3-1.7B" +# "helexa/balanced" = "Qwen/Qwen3-8B" +# "helexa/large" = "Qwen/Qwen3.6-27B"