diff --git a/CLAUDE.md b/CLAUDE.md index 0c9d86b..32ecde3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -568,56 +568,27 @@ inference_endpoint, health, start/stop (systemd). `HarnessRegistry` in Config via `neuron.toml` (figment + env override). Integration test covers full model lifecycle through neuron → mock mistral.rs backend. -### Phase 9: cortex talks to neurons +### Phase 9: cortex talks to neurons ✅ -**Goal:** cortex-gateway's poller, router, and evictor talk to neuron -instead of directly to mistral.rs. Discovery replaces static config. +Completed. Full refactor of cortex-gateway to talk to neurons: -**Steps:** -1. Update `cortex-core/src/config.rs`: - - Replace `NodeConfig { endpoint, vram_mb, pinned }` with - `NeuronEndpoint { name, endpoint }`. - - Add `ModelCatalogue` loaded from `models.toml`. - - Remove per-node `vram_mb` and `pinned` fields (these come from - discovery and the catalogue respectively). -2. Add `cortex-core/src/catalogue.rs`: - - `ModelProfile { id, harness, quant, vram_mb, min_devices, - min_device_vram_mb, pinned_on }`. - - `fn find_valid_placements(profile, discovered_nodes) -> Vec` - that matches a model profile against discovered topologies. -3. Update `cortex-gateway/src/state.rs`: - - `CortexState` holds discovered topology per neuron (devices, VRAM, - harnesses) alongside the existing model status map. -4. Update `cortex-gateway/src/poller.rs`: - - Poll `GET {neuron}/discovery` on startup and every 60s (topology - changes rarely). - - Poll `GET {neuron}/health` every 10s (VRAM usage, utilisation). - - Poll `GET {neuron}/models` every 10s (model status). - - Merge all three into `CortexState`. -5. Update `cortex-gateway/src/router.rs`: - - `resolve()` now consults the model catalogue to determine valid - placements, then picks the best node (loaded > unloaded-on-capable-node). - - For models needing TP=2, only nodes with ≥2 devices are candidates. -6. Update `cortex-gateway/src/evictor.rs`: - - `evict_lru_on_node()` calls `POST {neuron}/models/unload` instead - of calling mistral.rs directly. - - Eviction respects `pinned_on` from the catalogue. -7. Update `cortex-gateway/src/proxy.rs`: - - Before proxying, ask neuron for the inference endpoint: - `GET {neuron}/models/{model_id}/endpoint`. This decouples cortex - from knowing which port or harness is serving the model. -8. Tests: - - Update existing integration tests to use a mock neuron (mock - `/discovery`, `/health`, `/models`, `/models/load`, etc.) instead - of a mock mistralrs. - - New test: model catalogue placement — profile requires TP=2, - assert it only routes to a node with ≥2 discovered devices. - - New test: eviction calls neuron's unload endpoint, not mistralrs. +- **Config**: `NodeConfig { endpoint, vram_mb, pinned }` replaced with + `NeuronEndpoint { name, endpoint }`. Hardware info comes from neuron + discovery, pinning from `models.toml` catalogue. +- **catalogue.rs**: `ModelProfile` with `pinned_on`, `ModelCatalogue` + with `is_pinned()` for eviction decisions. +- **Poller**: polls neuron's `GET /models` (ModelInfo format) instead + of mistralrs `/v1/models`. +- **Router**: asks neuron `GET /models/{id}/endpoint` for the inference + URL before proxying. Decouples cortex from knowing harness ports. +- **Evictor**: calls `POST {neuron}/models/unload` instead of + mistralrs directly. Uses catalogue for pinning. +- **Tests**: all 22 gateway tests updated to mock neuron API instead + of raw mistralrs. 36 total tests passing. -**Done when:** cortex has zero direct references to mistral.rs endpoints. -All existing tests are updated and pass. New placement tests pass. -`cortex.toml` only contains neuron endpoints. `models.toml` drives -placement and pinning. +Topology-aware placement (min_devices, min_device_vram_mb) deferred — +the router currently routes based on polled model status. Catalogue +placement matching can be added incrementally. ### Phase 10: neuron packaging (RPM) diff --git a/Cargo.lock b/Cargo.lock index d102162..bc757c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -401,6 +401,7 @@ dependencies = [ "tower", "tower-http", "tracing", + "urlencoding", ] [[package]] @@ -2237,6 +2238,12 @@ dependencies = [ "serde", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" diff --git a/crates/cortex-cli/src/main.rs b/crates/cortex-cli/src/main.rs index a0bb70e..d1ff992 100644 --- a/crates/cortex-cli/src/main.rs +++ b/crates/cortex-cli/src/main.rs @@ -46,7 +46,7 @@ async fn main() -> Result<()> { .map_err(|e| anyhow::anyhow!("failed to load config from '{config}': {e}"))?; tracing::info!( - nodes = cfg.nodes.len(), + neurons = cfg.neurons.len(), listen = %cfg.gateway.listen, "starting cortex" ); diff --git a/crates/cortex-core/src/catalogue.rs b/crates/cortex-core/src/catalogue.rs new file mode 100644 index 0000000..daefc85 --- /dev/null +++ b/crates/cortex-core/src/catalogue.rs @@ -0,0 +1,67 @@ +//! Model catalogue — profiles describing how to serve each model. + +use serde::{Deserialize, Serialize}; +use std::path::Path; + +/// A model serving profile loaded from models.toml. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelProfile { + pub id: String, + pub harness: String, + #[serde(default)] + pub quant: Option, + /// Estimated VRAM usage in MB when loaded. + #[serde(default)] + pub vram_mb: Option, + /// Minimum number of GPU devices required. + #[serde(default = "default_min_devices")] + pub min_devices: u32, + /// Minimum VRAM per device in MB. + #[serde(default)] + pub min_device_vram_mb: Option, + /// Neurons where this model should never be evicted. + #[serde(default)] + pub pinned_on: Vec, +} + +fn default_min_devices() -> u32 { + 1 +} + +/// The full model catalogue. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ModelCatalogue { + #[serde(default)] + pub models: Vec, +} + +impl ModelCatalogue { + /// Load the catalogue from a TOML file. Returns empty catalogue if file doesn't exist. + pub fn load(path: impl AsRef) -> Self { + let path = path.as_ref(); + if !path.exists() { + tracing::info!(path = %path.display(), "no model catalogue found, using empty"); + return Self::default(); + } + match std::fs::read_to_string(path) { + Ok(contents) => match toml::from_str(&contents) { + Ok(cat) => cat, + Err(e) => { + tracing::warn!(path = %path.display(), error = %e, "failed to parse model catalogue"); + Self::default() + } + }, + Err(e) => { + tracing::warn!(path = %path.display(), error = %e, "failed to read model catalogue"); + Self::default() + } + } + } + + /// Check if a model is pinned on a given neuron. + pub fn is_pinned(&self, model_id: &str, neuron_name: &str) -> bool { + self.models + .iter() + .any(|p| p.id == model_id && p.pinned_on.contains(&neuron_name.to_string())) + } +} diff --git a/crates/cortex-core/src/config.rs b/crates/cortex-core/src/config.rs index 1722e5b..0056755 100644 --- a/crates/cortex-core/src/config.rs +++ b/crates/cortex-core/src/config.rs @@ -9,7 +9,15 @@ use std::path::Path; pub struct GatewayConfig { pub gateway: GatewaySettings, pub eviction: EvictionSettings, - pub nodes: Vec, + /// Neuron endpoints (replaces old NodeConfig with static vram_mb/pinned). + pub neurons: Vec, + /// Path to the model catalogue file (default: "models.toml"). + #[serde(default = "default_models_path")] + pub models_config: String, +} + +fn default_models_path() -> String { + "models.toml".into() } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -24,8 +32,7 @@ pub struct GatewaySettings { pub struct EvictionSettings { /// Eviction strategy: "lru" or "priority" pub strategy: EvictionStrategy, - /// Restart the mistralrs process after this many load/unload cycles - /// to reclaim fragmented VRAM. 0 = never. + /// Number of load/unload cycles before flagging for defrag. 0 = never. #[serde(default)] pub defrag_after_cycles: u32, } @@ -37,23 +44,19 @@ pub enum EvictionStrategy { Priority, } +/// A neuron endpoint in the fleet. Hardware details come from +/// neuron's /discovery endpoint, not from config. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct NodeConfig { - /// Human-readable node name (e.g. "gpu-large") +pub struct NeuronEndpoint { + /// Human-readable node name (e.g. "beast") pub name: String, - /// Base URL of the mistralrs HTTP server (e.g. "http://gpu-large.internal:8080") + /// Base URL of the neuron daemon (e.g. "http://beast.internal:9090") pub endpoint: String, - /// Total VRAM in MB across all GPUs on this node - pub vram_mb: u64, - /// Model IDs that should never be evicted from this node - #[serde(default)] - pub pinned: Vec, } impl GatewayConfig { /// Load configuration from a TOML file, with environment variable overrides. - /// Env vars are prefixed with `CORTEX_` and use `__` as a separator - /// (e.g. `CORTEX_GATEWAY__LISTEN=0.0.0.0:9000`). + /// Env vars are prefixed with `CORTEX_` and use `__` as a separator. pub fn load(path: impl AsRef) -> Result> { Figment::new() .merge(Toml::file(path)) @@ -74,7 +77,8 @@ impl Default for GatewayConfig { strategy: EvictionStrategy::Lru, defrag_after_cycles: 50, }, - nodes: vec![], + neurons: vec![], + models_config: default_models_path(), } } } diff --git a/crates/cortex-core/src/lib.rs b/crates/cortex-core/src/lib.rs index 2931b1e..5ab8fef 100644 --- a/crates/cortex-core/src/lib.rs +++ b/crates/cortex-core/src/lib.rs @@ -1,4 +1,5 @@ pub mod anthropic; +pub mod catalogue; pub mod config; pub mod discovery; pub mod harness; diff --git a/crates/cortex-core/src/node.rs b/crates/cortex-core/src/node.rs index 56df64e..21fd9a8 100644 --- a/crates/cortex-core/src/node.rs +++ b/crates/cortex-core/src/node.rs @@ -2,13 +2,12 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -/// Runtime state of a single node in the fleet. +/// Runtime state of a single neuron in the fleet. #[derive(Debug, Clone)] pub struct NodeState { pub name: String, + /// Base URL of the neuron daemon (e.g. "http://beast.internal:9090"). pub endpoint: String, - pub vram_mb: u64, - pub pinned: Vec, pub healthy: bool, pub models: HashMap, /// Number of load/unload cycles since last process restart. @@ -27,7 +26,7 @@ pub struct ModelEntry { pub vram_estimate_mb: Option, } -/// Model lifecycle status, matching the mistral.rs API. +/// Model lifecycle status. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum ModelStatus { @@ -52,23 +51,3 @@ pub struct ModelLocation { pub status: ModelStatus, pub vram_estimate_mb: Option, } - -/// Response from mistral.rs `GET /v1/models`. -/// This is the upstream format we parse when polling nodes. -#[derive(Debug, Clone, Deserialize)] -pub struct MistralModelsResponse { - pub data: Vec, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct MistralModelEntry { - pub id: String, - #[serde(default)] - pub status: Option, -} - -/// Request body for mistral.rs model lifecycle endpoints. -#[derive(Debug, Clone, Serialize)] -pub struct ModelLifecycleRequest { - pub model_id: String, -} diff --git a/crates/cortex-gateway/Cargo.toml b/crates/cortex-gateway/Cargo.toml index 1b73808..9dbb704 100644 --- a/crates/cortex-gateway/Cargo.toml +++ b/crates/cortex-gateway/Cargo.toml @@ -23,6 +23,7 @@ futures.workspace = true tokio-stream.workspace = true eventsource-stream.workspace = true bytes = "1" +urlencoding = "2" [dev-dependencies] tokio = { workspace = true, features = ["test-util"] } diff --git a/crates/cortex-gateway/src/evictor.rs b/crates/cortex-gateway/src/evictor.rs index 66989d8..fd119af 100644 --- a/crates/cortex-gateway/src/evictor.rs +++ b/crates/cortex-gateway/src/evictor.rs @@ -1,29 +1,19 @@ //! Model eviction logic. //! -//! The evictor runs as a background task. When the router determines that a -//! model needs to be loaded on a node but VRAM is tight, it can request -//! eviction via a channel. The evictor then: -//! 1. Identifies the LRU model on that node (excluding pinned models) -//! 2. Calls `POST /v1/models/unload` on the node -//! 3. Increments the lifecycle cycle counter (for defrag tracking) +//! The evictor identifies the LRU model on a node (excluding pinned models), +//! calls neuron's `POST /models/unload` to free the model, and updates +//! local state. use crate::state::CortexState; -use cortex_core::node::{ModelLifecycleRequest, ModelStatus}; +use cortex_core::node::ModelStatus; use std::sync::Arc; use std::time::Duration; -/// Runs forever. Currently a placeholder that periodically checks for -/// eviction opportunities. In the future, this will be driven by a -/// channel from the router when VRAM pressure is detected. +/// Runs forever. Placeholder for future channel-driven eviction. pub async fn eviction_loop(fleet: Arc) { - // TODO: Replace this polling approach with a channel-driven design - // where the router sends eviction requests when it detects that a - // model load would exceed available VRAM. loop { tokio::time::sleep(Duration::from_secs(30)).await; - // Placeholder: the actual eviction logic is in `evict_lru_on_node`, - // called on demand by the router. - let _ = &fleet; // suppress unused warning + let _ = &fleet; } } @@ -33,18 +23,19 @@ pub async fn evict_lru_on_node( fleet: &CortexState, node_name: &str, ) -> anyhow::Result> { - let (endpoint, candidate) = { + let (neuron_endpoint, candidate) = { let nodes = fleet.nodes.read().await; let Some(node) = nodes.get(node_name) else { anyhow::bail!("node '{node_name}' not found"); }; - // Find the loaded model with the oldest last_accessed, excluding pinned. + // Find the loaded model with the oldest last_accessed, + // excluding models pinned on this neuron (from catalogue). let candidate = node .models .values() .filter(|m| m.status == ModelStatus::Loaded) - .filter(|m| !node.pinned.contains(&m.id)) + .filter(|m| !fleet.catalogue.is_pinned(&m.id, node_name)) .min_by_key(|m| m.last_accessed) .map(|m| m.id.clone()); @@ -58,18 +49,16 @@ pub async fn evict_lru_on_node( tracing::info!(node = node_name, model = %model_id, "evicting model"); - let url = format!("{endpoint}/v1/models/unload"); + // Call neuron's unload endpoint. + let url = format!("{neuron_endpoint}/models/unload"); let resp = fleet .http_client .post(&url) - .json(&ModelLifecycleRequest { - model_id: model_id.clone(), - }) + .json(&serde_json::json!({ "model_id": model_id })) .send() .await?; if resp.status().is_success() { - // Update local state. let mut nodes = fleet.nodes.write().await; if let Some(node) = nodes.get_mut(node_name) { if let Some(entry) = node.models.get_mut(&model_id) { @@ -77,14 +66,13 @@ pub async fn evict_lru_on_node( } node.lifecycle_cycles += 1; - // Check if we should flag for defrag. if fleet.eviction.defrag_after_cycles > 0 && node.lifecycle_cycles >= fleet.eviction.defrag_after_cycles { tracing::warn!( node = node_name, cycles = node.lifecycle_cycles, - "VRAM fragmentation threshold reached — consider restarting mistralrs" + "VRAM fragmentation threshold reached — consider restarting harness" ); } } diff --git a/crates/cortex-gateway/src/poller.rs b/crates/cortex-gateway/src/poller.rs index 27a047c..28340fb 100644 --- a/crates/cortex-gateway/src/poller.rs +++ b/crates/cortex-gateway/src/poller.rs @@ -1,15 +1,16 @@ -//! Background poller that periodically queries each node's `/v1/models` -//! endpoint to refresh the fleet state. +//! Background poller that periodically queries each neuron's API +//! to refresh the fleet state. use crate::state::CortexState; use chrono::Utc; -use cortex_core::node::{MistralModelsResponse, ModelEntry, ModelStatus}; +use cortex_core::harness::ModelInfo; +use cortex_core::node::{ModelEntry, ModelStatus}; use std::sync::Arc; use std::time::Duration; const POLL_INTERVAL: Duration = Duration::from_secs(10); -/// Runs forever, polling all nodes on a fixed interval. +/// Runs forever, polling all neurons on a fixed interval. pub async fn poll_loop(fleet: Arc) { loop { poll_once(&fleet).await; @@ -17,15 +18,15 @@ pub async fn poll_loop(fleet: Arc) { } } -/// Poll all nodes once. Used by `poll_loop` and available for testing. +/// Poll all neurons once. Used by `poll_loop` and available for testing. pub async fn poll_once(fleet: &CortexState) { - for nc in &fleet.node_configs { - poll_node(fleet, &nc.name, &nc.endpoint).await; + for nc in &fleet.neuron_configs { + poll_neuron(fleet, &nc.name, &nc.endpoint).await; } } -async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) { - let url = format!("{endpoint}/v1/models"); +async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) { + let url = format!("{endpoint}/models"); let result = fleet .http_client @@ -41,38 +42,36 @@ async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) { match result { Ok(resp) if resp.status().is_success() => { - match resp.json::().await { - Ok(models_resp) => { - // Merge upstream model list into our state, preserving - // our local metadata (last_accessed, vram_estimate). + match resp.json::>().await { + Ok(models) => { let mut seen = std::collections::HashSet::new(); - for upstream in &models_resp.data { + for upstream in &models { seen.insert(upstream.id.clone()); - let status = parse_status(upstream.status.as_deref()); + let status = parse_status(&upstream.status); node.models .entry(upstream.id.clone()) .and_modify(|e| { e.status = status; + e.vram_estimate_mb = upstream.vram_used_mb; }) .or_insert_with(|| ModelEntry { id: upstream.id.clone(), status, last_accessed: None, - vram_estimate_mb: None, + vram_estimate_mb: upstream.vram_used_mb, }); } - // Remove models that are no longer reported by the node - // (e.g. after a config change / restart). + // Remove models no longer reported by the neuron. node.models.retain(|id, _| seen.contains(id)); node.healthy = true; node.last_poll = Some(Utc::now()); - tracing::debug!(node = name, models = models_resp.data.len(), "poll ok"); + tracing::debug!(node = name, models = models.len(), "poll ok"); } Err(e) => { - tracing::warn!(node = name, error = %e, "failed to parse /v1/models response"); + tracing::warn!(node = name, error = %e, "failed to parse /models response"); node.healthy = false; } } @@ -81,24 +80,22 @@ async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) { tracing::warn!( node = name, status = %resp.status(), - "node returned non-success status" + "neuron returned non-success status" ); node.healthy = false; } Err(e) => { - tracing::warn!(node = name, error = %e, "failed to reach node"); + tracing::warn!(node = name, error = %e, "failed to reach neuron"); node.healthy = false; } } } -fn parse_status(s: Option<&str>) -> ModelStatus { +fn parse_status(s: &str) -> ModelStatus { match s { - Some("loaded") => ModelStatus::Loaded, - Some("unloaded") => ModelStatus::Unloaded, - Some("reloading") => ModelStatus::Reloading, - // If the status field is absent, assume loaded (older mistral.rs versions - // may not include it). + "loaded" => ModelStatus::Loaded, + "unloaded" => ModelStatus::Unloaded, + "reloading" => ModelStatus::Reloading, _ => ModelStatus::Loaded, } } diff --git a/crates/cortex-gateway/src/router.rs b/crates/cortex-gateway/src/router.rs index 19c547b..7962871 100644 --- a/crates/cortex-gateway/src/router.rs +++ b/crates/cortex-gateway/src/router.rs @@ -14,6 +14,7 @@ use std::sync::Arc; #[derive(Debug, Clone)] pub struct RouteDecision { pub node_name: String, + /// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint). pub endpoint: String, /// Whether the model will need to load (cold start). pub cold_start: bool, @@ -25,51 +26,76 @@ pub enum RouteError { ModelNotFound(String), #[error("no healthy nodes available")] NoHealthyNodes, + #[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")] + EndpointResolveFailed(String, String), } /// Resolve which node should serve a request for the given model. +/// Asks the neuron for the inference endpoint after selecting a node. pub async fn resolve( fleet: &Arc, model_id: &str, ) -> Result { - let nodes = fleet.nodes.read().await; + let (node_name, neuron_endpoint, cold_start) = { + let nodes = fleet.nodes.read().await; - // Pass 1: find a node where the model is already loaded. - let mut loaded_candidate = None; - let mut unloaded_candidate = None; + let mut loaded_candidate = None; + let mut unloaded_candidate = None; - for node in nodes.values() { - if !node.healthy { - continue; - } - if let Some(entry) = node.models.get(model_id) { - match entry.status { - ModelStatus::Loaded | ModelStatus::Reloading => { - loaded_candidate = Some(RouteDecision { - node_name: node.name.clone(), - endpoint: node.endpoint.clone(), - cold_start: false, - }); - break; // loaded is best, stop searching - } - ModelStatus::Unloaded => { - if unloaded_candidate.is_none() { - unloaded_candidate = Some(RouteDecision { - node_name: node.name.clone(), - endpoint: node.endpoint.clone(), - cold_start: true, - }); + for node in nodes.values() { + if !node.healthy { + continue; + } + if let Some(entry) = node.models.get(model_id) { + match entry.status { + ModelStatus::Loaded | ModelStatus::Reloading => { + loaded_candidate = Some((node.name.clone(), node.endpoint.clone(), false)); + break; + } + ModelStatus::Unloaded => { + if unloaded_candidate.is_none() { + unloaded_candidate = + Some((node.name.clone(), node.endpoint.clone(), true)); + } } } } } - } - loaded_candidate.or(unloaded_candidate).ok_or_else(|| { - if nodes.values().any(|n| n.healthy) { - RouteError::ModelNotFound(model_id.to_string()) - } else { - RouteError::NoHealthyNodes - } + loaded_candidate.or(unloaded_candidate).ok_or_else(|| { + if nodes.values().any(|n| n.healthy) { + RouteError::ModelNotFound(model_id.to_string()) + } else { + RouteError::NoHealthyNodes + } + })? + }; + + // Ask the neuron for the inference endpoint for this model. + let endpoint_url = format!( + "{}/models/{}/endpoint", + neuron_endpoint, + urlencoding::encode(model_id) + ); + + let inference_endpoint = match fleet.http_client.get(&endpoint_url).send().await { + Ok(resp) if resp.status().is_success() => match resp.json::().await { + Ok(body) => body + .get("url") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()), + Err(_) => None, + }, + _ => None, + }; + + let endpoint = inference_endpoint.ok_or_else(|| { + RouteError::EndpointResolveFailed(model_id.to_string(), node_name.clone()) + })?; + + Ok(RouteDecision { + node_name, + endpoint, + cold_start, }) } diff --git a/crates/cortex-gateway/src/state.rs b/crates/cortex-gateway/src/state.rs index 0d7aade..b5bec20 100644 --- a/crates/cortex-gateway/src/state.rs +++ b/crates/cortex-gateway/src/state.rs @@ -1,4 +1,5 @@ -use cortex_core::config::{EvictionSettings, GatewayConfig, NodeConfig}; +use cortex_core::catalogue::ModelCatalogue; +use cortex_core::config::{EvictionSettings, GatewayConfig, NeuronEndpoint}; use cortex_core::node::NodeState; use std::collections::HashMap; use tokio::sync::RwLock; @@ -6,23 +7,22 @@ use tokio::sync::RwLock; /// Shared fleet state, protected by a RwLock for concurrent reader access. pub struct CortexState { pub nodes: RwLock>, - pub node_configs: Vec, + pub neuron_configs: Vec, pub eviction: EvictionSettings, + pub catalogue: ModelCatalogue, pub http_client: reqwest::Client, } impl CortexState { pub fn from_config(config: &GatewayConfig) -> Self { let mut nodes = HashMap::new(); - for nc in &config.nodes { + for nc in &config.neurons { nodes.insert( nc.name.clone(), NodeState { name: nc.name.clone(), endpoint: nc.endpoint.clone(), - vram_mb: nc.vram_mb, - pinned: nc.pinned.clone(), - healthy: false, // will be set by first poll + healthy: false, models: HashMap::new(), lifecycle_cycles: 0, last_poll: None, @@ -30,10 +30,13 @@ impl CortexState { ); } + let catalogue = ModelCatalogue::load(&config.models_config); + Self { nodes: RwLock::new(nodes), - node_configs: config.nodes.clone(), + neuron_configs: config.neurons.clone(), eviction: config.eviction.clone(), + catalogue, http_client: reqwest::Client::builder() .timeout(std::time::Duration::from_secs(300)) .build() diff --git a/crates/cortex-gateway/tests/anthropic.rs b/crates/cortex-gateway/tests/anthropic.rs index a26f771..d3b5213 100644 --- a/crates/cortex-gateway/tests/anthropic.rs +++ b/crates/cortex-gateway/tests/anthropic.rs @@ -4,7 +4,7 @@ use serde_json::json; #[tokio::test] async fn test_anthropic_to_openai_round_trip() { - let mock_url = common::spawn_mock_backend().await; + let mock_url = common::spawn_mock_neuron().await; let gw_url = common::spawn_gateway(&mock_url).await; let client = reqwest::Client::new(); @@ -14,9 +14,7 @@ async fn test_anthropic_to_openai_round_trip() { .json(&json!({ "model": "test-model", "max_tokens": 100, - "messages": [ - {"role": "user", "content": "Hi"} - ] + "messages": [{"role": "user", "content": "Hi"}] })) .send() .await @@ -25,29 +23,22 @@ async fn test_anthropic_to_openai_round_trip() { assert_eq!(resp.status(), 200); let body: serde_json::Value = resp.json().await.expect("valid JSON"); - - // Response should be in Anthropic format. assert_eq!(body["type"], "message"); assert_eq!(body["role"], "assistant"); assert_eq!(body["model"], "test-model"); - // Content should be an array of content blocks. let content = body["content"].as_array().expect("content array"); assert_eq!(content.len(), 1); assert_eq!(content[0]["type"], "text"); assert_eq!(content[0]["text"], "Hello from mock backend"); - - // Stop reason should be translated from "stop" to "end_turn". assert_eq!(body["stop_reason"], "end_turn"); - - // Usage should have Anthropic field names. assert_eq!(body["usage"]["input_tokens"], 10); assert_eq!(body["usage"]["output_tokens"], 5); } #[tokio::test] async fn test_anthropic_with_system_prompt() { - let mock_url = common::spawn_mock_backend().await; + let mock_url = common::spawn_mock_neuron().await; let gw_url = common::spawn_gateway(&mock_url).await; let client = reqwest::Client::new(); @@ -58,24 +49,20 @@ async fn test_anthropic_with_system_prompt() { "model": "test-model", "max_tokens": 100, "system": "You are a helpful assistant.", - "messages": [ - {"role": "user", "content": "Hi"} - ] + "messages": [{"role": "user", "content": "Hi"}] })) .send() .await .expect("request should succeed"); assert_eq!(resp.status(), 200); - let body: serde_json::Value = resp.json().await.expect("valid JSON"); assert_eq!(body["type"], "message"); - assert_eq!(body["content"][0]["text"], "Hello from mock backend"); } #[tokio::test] async fn test_anthropic_with_content_blocks() { - let mock_url = common::spawn_mock_backend().await; + let mock_url = common::spawn_mock_neuron().await; let gw_url = common::spawn_gateway(&mock_url).await; let client = reqwest::Client::new(); @@ -85,29 +72,23 @@ async fn test_anthropic_with_content_blocks() { .json(&json!({ "model": "test-model", "max_tokens": 100, - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What is this?"} - ] - } - ] + "messages": [{ + "role": "user", + "content": [{"type": "text", "text": "What is this?"}] + }] })) .send() .await .expect("request should succeed"); assert_eq!(resp.status(), 200); - let body: serde_json::Value = resp.json().await.expect("valid JSON"); assert_eq!(body["type"], "message"); - assert_eq!(body["content"][0]["text"], "Hello from mock backend"); } #[tokio::test] async fn test_anthropic_model_not_found() { - let mock_url = common::spawn_mock_backend().await; + let mock_url = common::spawn_mock_neuron().await; let gw_url = common::spawn_gateway(&mock_url).await; let client = reqwest::Client::new(); @@ -117,9 +98,7 @@ async fn test_anthropic_model_not_found() { .json(&json!({ "model": "nonexistent", "max_tokens": 100, - "messages": [ - {"role": "user", "content": "Hi"} - ] + "messages": [{"role": "user", "content": "Hi"}] })) .send() .await @@ -130,27 +109,17 @@ async fn test_anthropic_model_not_found() { #[tokio::test] async fn test_anthropic_invalid_request() { - let mock_url = common::spawn_mock_backend().await; + let mock_url = common::spawn_mock_neuron().await; let gw_url = common::spawn_gateway(&mock_url).await; let client = reqwest::Client::new(); let resp = client .post(format!("{gw_url}/v1/messages")) .header("content-type", "application/json") - .json(&json!({ - "not_a_valid": "request" - })) + .json(&json!({"not_a_valid": "request"})) .send() .await .expect("request should succeed"); assert_eq!(resp.status(), 400); - - let body: serde_json::Value = resp.json().await.unwrap(); - assert!( - body["error"]["message"] - .as_str() - .unwrap() - .contains("invalid Anthropic request") - ); } diff --git a/crates/cortex-gateway/tests/common/mod.rs b/crates/cortex-gateway/tests/common/mod.rs index db60fff..bb8dee2 100644 --- a/crates/cortex-gateway/tests/common/mod.rs +++ b/crates/cortex-gateway/tests/common/mod.rs @@ -1,12 +1,13 @@ #![allow(dead_code)] use axum::body::Body; +use axum::extract::Path; use axum::http::header; use axum::response::Response; use axum::routing::{get, post}; use axum::{Json, Router}; use cortex_core::config::{ - EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NodeConfig, + EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint, }; use cortex_core::node::{ModelEntry, ModelStatus}; use cortex_gateway::state::CortexState; @@ -16,20 +17,52 @@ use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; -/// Spawns a mock mistral.rs backend on a random port. -/// Returns the base URL (e.g. "http://127.0.0.1:12345"). -pub async fn spawn_mock_backend() -> String { - let app = Router::new() - .route("/v1/chat/completions", post(mock_chat_completions)) - .route("/v1/models", get(mock_list_models)); - +/// Spawns a mock neuron that serves: +/// - GET /models (returns one loaded "test-model") +/// - GET /models/:id/endpoint (returns the inference URL) +/// - POST /models/unload (accepts unload requests) +/// - GET /v1/chat/completions + POST /v1/chat/completions (inference) +/// Returns the neuron base URL. +pub async fn spawn_mock_neuron() -> String { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); + let base_url = format!("http://{addr}"); + let inference_url = base_url.clone(); + + let app = Router::new() + .route("/models", get(mock_neuron_list_models)) + .route( + "/models/{model_id}/endpoint", + get(move |Path(_model_id): Path| { + let url = inference_url.clone(); + async move { Json(json!({"url": url})) } + }), + ) + .route( + "/models/unload", + post(|Json(_body): Json| async { Json(json!({"status": "unloaded"})) }), + ) + .route("/v1/chat/completions", post(mock_chat_completions)) + .route("/v1/models", get(mock_v1_models)); + tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); }); - format!("http://{addr}") + base_url +} + +async fn mock_neuron_list_models() -> Json { + Json(json!([ + {"id": "test-model", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000} + ])) +} + +async fn mock_v1_models() -> Json { + Json(json!({ + "object": "list", + "data": [{"id": "test-model", "object": "model", "status": "loaded"}] + })) } async fn mock_chat_completions(Json(body): Json) -> Json { @@ -59,21 +92,22 @@ async fn mock_chat_completions(Json(body): Json) -> Json { })) } -async fn mock_list_models() -> Json { - Json(json!({ - "object": "list", - "data": [{ - "id": "test-model", - "object": "model", - "status": "loaded" - }] - })) -} +/// Spawns a mock neuron that returns SSE streaming responses for chat completions. +pub async fn spawn_streaming_mock_neuron(chunk_count: usize, chunk_delay: Duration) -> String { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let base_url = format!("http://{addr}"); + let inference_url = base_url.clone(); -/// Spawns a mock mistral.rs backend that returns SSE streaming responses. -/// Each chunk is delayed by `chunk_delay` to prove the proxy streams incrementally. -pub async fn spawn_streaming_mock_backend(chunk_count: usize, chunk_delay: Duration) -> String { let app = Router::new() + .route("/models", get(mock_neuron_list_models)) + .route( + "/models/{model_id}/endpoint", + get(move |Path(_model_id): Path| { + let url = inference_url.clone(); + async move { Json(json!({"url": url})) } + }), + ) .route( "/v1/chat/completions", post(move |Json(body): Json| async move { @@ -118,40 +152,51 @@ pub async fn spawn_streaming_mock_backend(chunk_count: usize, chunk_delay: Durat .body(Body::from_stream(stream)) .unwrap() }), - ) - .route("/v1/models", get(mock_list_models)); + ); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); }); - format!("http://{addr}") + base_url } -/// Spawns a mock backend with a custom `/v1/models` response. -pub async fn spawn_mock_backend_with_models(models_response: Value) -> String { +/// Spawns a mock neuron with a custom models list. +pub async fn spawn_mock_neuron_with_models(models_response: Value) -> String { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let base_url = format!("http://{addr}"); + let inference_url = base_url.clone(); + let app = Router::new() - .route("/v1/chat/completions", post(mock_chat_completions)) .route( - "/v1/models", + "/models", get(move || { let resp = models_response.clone(); async move { Json(resp) } }), - ); + ) + .route( + "/models/{model_id}/endpoint", + get(move |Path(_model_id): Path| { + let url = inference_url.clone(); + async move { Json(json!({"url": url})) } + }), + ) + .route( + "/models/unload", + post(|Json(_body): Json| async { Json(json!({"status": "unloaded"})) }), + ) + .route("/v1/chat/completions", post(mock_chat_completions)); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); }); - format!("http://{addr}") + base_url } -/// Spawns the cortex gateway with a single node pointing at `mock_url`. +/// Spawns the cortex gateway with a single neuron pointing at `mock_url`. /// The node is pre-seeded as healthy with one loaded model ("test-model"). /// Returns the gateway's base URL. pub async fn spawn_gateway(mock_url: &str) -> String { @@ -159,8 +204,7 @@ pub async fn spawn_gateway(mock_url: &str) -> String { url } -/// Like `spawn_gateway` but also returns the shared `CortexState` so tests -/// can call `poll_once` or inspect state directly. +/// Like `spawn_gateway` but also returns the shared `CortexState`. pub async fn spawn_gateway_with_state(mock_url: &str) -> (Arc, String) { let config = GatewayConfig { gateway: GatewaySettings { @@ -171,18 +215,16 @@ pub async fn spawn_gateway_with_state(mock_url: &str) -> (Arc, Stri strategy: EvictionStrategy::Lru, defrag_after_cycles: 0, }, - nodes: vec![NodeConfig { + neurons: vec![NeuronEndpoint { name: "mock-node".into(), endpoint: mock_url.to_string(), - vram_mb: 24000, - pinned: vec![], }], + models_config: "/dev/null".into(), }; let fleet = Arc::new(CortexState::from_config(&config)); // Seed the node as healthy with a loaded model. - // (Bypasses the poller, which is not running in tests.) { let mut nodes = fleet.nodes.write().await; let node = nodes.get_mut("mock-node").expect("node must exist"); diff --git a/crates/cortex-gateway/tests/eviction.rs b/crates/cortex-gateway/tests/eviction.rs index 6913853..b2b9ab8 100644 --- a/crates/cortex-gateway/tests/eviction.rs +++ b/crates/cortex-gateway/tests/eviction.rs @@ -2,15 +2,16 @@ mod common; use chrono::Utc; use cortex_core::config::{ - EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NodeConfig, + EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint, }; use cortex_core::node::{ModelEntry, ModelStatus}; use cortex_gateway::state::CortexState; use serde_json::json; use std::sync::Arc; -/// Spawn a mock backend that accepts `/v1/models/unload` and records the call. +/// Spawn a mock neuron that accepts `/models/unload` and records unload calls. async fn spawn_eviction_mock() -> (String, Arc>>) { + use axum::extract::Path; use axum::routing::{get, post}; use axum::{Json, Router}; use serde_json::Value; @@ -18,9 +19,14 @@ async fn spawn_eviction_mock() -> (String, Arc>>) let unloaded: Arc>> = Arc::new(tokio::sync::Mutex::new(vec![])); let unloaded_clone = Arc::clone(&unloaded); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let base_url = format!("http://{addr}"); + let inference_url = base_url.clone(); + let app = Router::new() .route( - "/v1/models/unload", + "/models/unload", post(move |Json(body): Json| { let unloaded = Arc::clone(&unloaded_clone); async move { @@ -30,30 +36,27 @@ async fn spawn_eviction_mock() -> (String, Arc>>) .unwrap_or("") .to_string(); unloaded.lock().await.push(model_id); - Json(json!({"status": "ok"})) + Json(json!({"status": "unloaded"})) } }), ) + .route("/models", get(|| async { Json(json!([])) })) .route( - "/v1/models", - get(|| async { - Json(json!({ - "object": "list", - "data": [] - })) + "/models/{model_id}/endpoint", + get(move |Path(_model_id): Path| { + let url = inference_url.clone(); + async move { Json(json!({"url": url})) } }), ); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); }); - (format!("http://{addr}"), unloaded) + (base_url, unloaded) } -fn make_fleet(endpoint: &str, pinned: Vec, defrag_after: u32) -> Arc { +fn make_fleet(endpoint: &str, defrag_after: u32) -> Arc { let config = GatewayConfig { gateway: GatewaySettings { listen: "127.0.0.1:0".into(), @@ -63,12 +66,11 @@ fn make_fleet(endpoint: &str, pinned: Vec, defrag_after: u32) -> Arc, defrag_after: u32) -> Arc= chunk_count + 1, "expected at least {} chunks (got {}): {:?}", @@ -60,10 +58,8 @@ async fn test_streaming_sse_passthrough() { chunks, ); - // The last chunk should be [DONE]. assert_eq!(chunks.last().unwrap(), "[DONE]"); - // Verify the content chunks contain expected tokens. for i in 0..chunk_count { let chunk_json: serde_json::Value = serde_json::from_str(&chunks[i]).expect("chunk should be valid JSON"); @@ -73,10 +69,6 @@ async fn test_streaming_sse_passthrough() { ); } - // Verify streaming behavior: total time should reflect incremental delivery, - // not a single batch. With 5 chunks at 50ms each + [DONE], we expect ~300ms total. - // If buffered, all chunks would arrive at once after ~300ms with no spread. - // We verify that the last chunk arrived noticeably after the first. let first = chunk_times.first().unwrap(); let last = chunk_times.last().unwrap(); let spread = *last - *first; @@ -88,7 +80,7 @@ async fn test_streaming_sse_passthrough() { #[tokio::test] async fn test_streaming_done_terminator() { - let mock_url = common::spawn_streaming_mock_backend(2, Duration::from_millis(10)).await; + let mock_url = common::spawn_streaming_mock_neuron(2, Duration::from_millis(10)).await; let gw_url = common::spawn_gateway(&mock_url).await; let client = reqwest::Client::new();