refactor: cortex talks to neurons instead of mistral.rs directly
All checks were successful
CI / Format, lint, build, test (push) Successful in 2m46s
CI / Build SRPM (push) Has been skipped
CI / Publish to COPR (push) Has been skipped

Replace NodeConfig (static vram_mb, pinned) with NeuronEndpoint.
Hardware discovery and model pinning now come from neuron API and
models.toml catalogue respectively.

- config.rs: nodes -> neurons, add models_config path
- catalogue.rs: ModelProfile with pinned_on, ModelCatalogue
- poller.rs: poll neuron GET /models (ModelInfo format)
- router.rs: resolve inference endpoint via neuron GET /models/{id}/endpoint
- evictor.rs: call neuron POST /models/unload
- node.rs: remove vram_mb, pinned fields (come from discovery/catalogue)
- All 22 gateway tests updated to mock neuron API
- Remove MistralModelsResponse, ModelLifecycleRequest (no longer needed)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-15 14:42:52 +03:00
parent 26e5e7ead8
commit e42e8ee81f
19 changed files with 385 additions and 437 deletions

View File

@@ -1,29 +1,19 @@
//! Model eviction logic.
//!
//! The evictor runs as a background task. When the router determines that a
//! model needs to be loaded on a node but VRAM is tight, it can request
//! eviction via a channel. The evictor then:
//! 1. Identifies the LRU model on that node (excluding pinned models)
//! 2. Calls `POST /v1/models/unload` on the node
//! 3. Increments the lifecycle cycle counter (for defrag tracking)
//! The evictor identifies the LRU model on a node (excluding pinned models),
//! calls neuron's `POST /models/unload` to free the model, and updates
//! local state.
use crate::state::CortexState;
use cortex_core::node::{ModelLifecycleRequest, ModelStatus};
use cortex_core::node::ModelStatus;
use std::sync::Arc;
use std::time::Duration;
/// Runs forever. Currently a placeholder that periodically checks for
/// eviction opportunities. In the future, this will be driven by a
/// channel from the router when VRAM pressure is detected.
/// Runs forever. Placeholder for future channel-driven eviction.
pub async fn eviction_loop(fleet: Arc<CortexState>) {
// TODO: Replace this polling approach with a channel-driven design
// where the router sends eviction requests when it detects that a
// model load would exceed available VRAM.
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
// Placeholder: the actual eviction logic is in `evict_lru_on_node`,
// called on demand by the router.
let _ = &fleet; // suppress unused warning
let _ = &fleet;
}
}
@@ -33,18 +23,19 @@ pub async fn evict_lru_on_node(
fleet: &CortexState,
node_name: &str,
) -> anyhow::Result<Option<String>> {
let (endpoint, candidate) = {
let (neuron_endpoint, candidate) = {
let nodes = fleet.nodes.read().await;
let Some(node) = nodes.get(node_name) else {
anyhow::bail!("node '{node_name}' not found");
};
// Find the loaded model with the oldest last_accessed, excluding pinned.
// Find the loaded model with the oldest last_accessed,
// excluding models pinned on this neuron (from catalogue).
let candidate = node
.models
.values()
.filter(|m| m.status == ModelStatus::Loaded)
.filter(|m| !node.pinned.contains(&m.id))
.filter(|m| !fleet.catalogue.is_pinned(&m.id, node_name))
.min_by_key(|m| m.last_accessed)
.map(|m| m.id.clone());
@@ -58,18 +49,16 @@ pub async fn evict_lru_on_node(
tracing::info!(node = node_name, model = %model_id, "evicting model");
let url = format!("{endpoint}/v1/models/unload");
// Call neuron's unload endpoint.
let url = format!("{neuron_endpoint}/models/unload");
let resp = fleet
.http_client
.post(&url)
.json(&ModelLifecycleRequest {
model_id: model_id.clone(),
})
.json(&serde_json::json!({ "model_id": model_id }))
.send()
.await?;
if resp.status().is_success() {
// Update local state.
let mut nodes = fleet.nodes.write().await;
if let Some(node) = nodes.get_mut(node_name) {
if let Some(entry) = node.models.get_mut(&model_id) {
@@ -77,14 +66,13 @@ pub async fn evict_lru_on_node(
}
node.lifecycle_cycles += 1;
// Check if we should flag for defrag.
if fleet.eviction.defrag_after_cycles > 0
&& node.lifecycle_cycles >= fleet.eviction.defrag_after_cycles
{
tracing::warn!(
node = node_name,
cycles = node.lifecycle_cycles,
"VRAM fragmentation threshold reached — consider restarting mistralrs"
"VRAM fragmentation threshold reached — consider restarting harness"
);
}
}

View File

@@ -1,15 +1,16 @@
//! Background poller that periodically queries each node's `/v1/models`
//! endpoint to refresh the fleet state.
//! Background poller that periodically queries each neuron's API
//! to refresh the fleet state.
use crate::state::CortexState;
use chrono::Utc;
use cortex_core::node::{MistralModelsResponse, ModelEntry, ModelStatus};
use cortex_core::harness::ModelInfo;
use cortex_core::node::{ModelEntry, ModelStatus};
use std::sync::Arc;
use std::time::Duration;
const POLL_INTERVAL: Duration = Duration::from_secs(10);
/// Runs forever, polling all nodes on a fixed interval.
/// Runs forever, polling all neurons on a fixed interval.
pub async fn poll_loop(fleet: Arc<CortexState>) {
loop {
poll_once(&fleet).await;
@@ -17,15 +18,15 @@ pub async fn poll_loop(fleet: Arc<CortexState>) {
}
}
/// Poll all nodes once. Used by `poll_loop` and available for testing.
/// Poll all neurons once. Used by `poll_loop` and available for testing.
pub async fn poll_once(fleet: &CortexState) {
for nc in &fleet.node_configs {
poll_node(fleet, &nc.name, &nc.endpoint).await;
for nc in &fleet.neuron_configs {
poll_neuron(fleet, &nc.name, &nc.endpoint).await;
}
}
async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
let url = format!("{endpoint}/v1/models");
async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
let url = format!("{endpoint}/models");
let result = fleet
.http_client
@@ -41,38 +42,36 @@ async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
match result {
Ok(resp) if resp.status().is_success() => {
match resp.json::<MistralModelsResponse>().await {
Ok(models_resp) => {
// Merge upstream model list into our state, preserving
// our local metadata (last_accessed, vram_estimate).
match resp.json::<Vec<ModelInfo>>().await {
Ok(models) => {
let mut seen = std::collections::HashSet::new();
for upstream in &models_resp.data {
for upstream in &models {
seen.insert(upstream.id.clone());
let status = parse_status(upstream.status.as_deref());
let status = parse_status(&upstream.status);
node.models
.entry(upstream.id.clone())
.and_modify(|e| {
e.status = status;
e.vram_estimate_mb = upstream.vram_used_mb;
})
.or_insert_with(|| ModelEntry {
id: upstream.id.clone(),
status,
last_accessed: None,
vram_estimate_mb: None,
vram_estimate_mb: upstream.vram_used_mb,
});
}
// Remove models that are no longer reported by the node
// (e.g. after a config change / restart).
// Remove models no longer reported by the neuron.
node.models.retain(|id, _| seen.contains(id));
node.healthy = true;
node.last_poll = Some(Utc::now());
tracing::debug!(node = name, models = models_resp.data.len(), "poll ok");
tracing::debug!(node = name, models = models.len(), "poll ok");
}
Err(e) => {
tracing::warn!(node = name, error = %e, "failed to parse /v1/models response");
tracing::warn!(node = name, error = %e, "failed to parse /models response");
node.healthy = false;
}
}
@@ -81,24 +80,22 @@ async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
tracing::warn!(
node = name,
status = %resp.status(),
"node returned non-success status"
"neuron returned non-success status"
);
node.healthy = false;
}
Err(e) => {
tracing::warn!(node = name, error = %e, "failed to reach node");
tracing::warn!(node = name, error = %e, "failed to reach neuron");
node.healthy = false;
}
}
}
fn parse_status(s: Option<&str>) -> ModelStatus {
fn parse_status(s: &str) -> ModelStatus {
match s {
Some("loaded") => ModelStatus::Loaded,
Some("unloaded") => ModelStatus::Unloaded,
Some("reloading") => ModelStatus::Reloading,
// If the status field is absent, assume loaded (older mistral.rs versions
// may not include it).
"loaded" => ModelStatus::Loaded,
"unloaded" => ModelStatus::Unloaded,
"reloading" => ModelStatus::Reloading,
_ => ModelStatus::Loaded,
}
}

View File

@@ -14,6 +14,7 @@ use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct RouteDecision {
pub node_name: String,
/// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint).
pub endpoint: String,
/// Whether the model will need to load (cold start).
pub cold_start: bool,
@@ -25,51 +26,76 @@ pub enum RouteError {
ModelNotFound(String),
#[error("no healthy nodes available")]
NoHealthyNodes,
#[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")]
EndpointResolveFailed(String, String),
}
/// Resolve which node should serve a request for the given model.
/// Asks the neuron for the inference endpoint after selecting a node.
pub async fn resolve(
fleet: &Arc<CortexState>,
model_id: &str,
) -> Result<RouteDecision, RouteError> {
let nodes = fleet.nodes.read().await;
let (node_name, neuron_endpoint, cold_start) = {
let nodes = fleet.nodes.read().await;
// Pass 1: find a node where the model is already loaded.
let mut loaded_candidate = None;
let mut unloaded_candidate = None;
let mut loaded_candidate = None;
let mut unloaded_candidate = None;
for node in nodes.values() {
if !node.healthy {
continue;
}
if let Some(entry) = node.models.get(model_id) {
match entry.status {
ModelStatus::Loaded | ModelStatus::Reloading => {
loaded_candidate = Some(RouteDecision {
node_name: node.name.clone(),
endpoint: node.endpoint.clone(),
cold_start: false,
});
break; // loaded is best, stop searching
}
ModelStatus::Unloaded => {
if unloaded_candidate.is_none() {
unloaded_candidate = Some(RouteDecision {
node_name: node.name.clone(),
endpoint: node.endpoint.clone(),
cold_start: true,
});
for node in nodes.values() {
if !node.healthy {
continue;
}
if let Some(entry) = node.models.get(model_id) {
match entry.status {
ModelStatus::Loaded | ModelStatus::Reloading => {
loaded_candidate = Some((node.name.clone(), node.endpoint.clone(), false));
break;
}
ModelStatus::Unloaded => {
if unloaded_candidate.is_none() {
unloaded_candidate =
Some((node.name.clone(), node.endpoint.clone(), true));
}
}
}
}
}
}
loaded_candidate.or(unloaded_candidate).ok_or_else(|| {
if nodes.values().any(|n| n.healthy) {
RouteError::ModelNotFound(model_id.to_string())
} else {
RouteError::NoHealthyNodes
}
loaded_candidate.or(unloaded_candidate).ok_or_else(|| {
if nodes.values().any(|n| n.healthy) {
RouteError::ModelNotFound(model_id.to_string())
} else {
RouteError::NoHealthyNodes
}
})?
};
// Ask the neuron for the inference endpoint for this model.
let endpoint_url = format!(
"{}/models/{}/endpoint",
neuron_endpoint,
urlencoding::encode(model_id)
);
let inference_endpoint = match fleet.http_client.get(&endpoint_url).send().await {
Ok(resp) if resp.status().is_success() => match resp.json::<serde_json::Value>().await {
Ok(body) => body
.get("url")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
Err(_) => None,
},
_ => None,
};
let endpoint = inference_endpoint.ok_or_else(|| {
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.clone())
})?;
Ok(RouteDecision {
node_name,
endpoint,
cold_start,
})
}

View File

@@ -1,4 +1,5 @@
use cortex_core::config::{EvictionSettings, GatewayConfig, NodeConfig};
use cortex_core::catalogue::ModelCatalogue;
use cortex_core::config::{EvictionSettings, GatewayConfig, NeuronEndpoint};
use cortex_core::node::NodeState;
use std::collections::HashMap;
use tokio::sync::RwLock;
@@ -6,23 +7,22 @@ use tokio::sync::RwLock;
/// Shared fleet state, protected by a RwLock for concurrent reader access.
pub struct CortexState {
pub nodes: RwLock<HashMap<String, NodeState>>,
pub node_configs: Vec<NodeConfig>,
pub neuron_configs: Vec<NeuronEndpoint>,
pub eviction: EvictionSettings,
pub catalogue: ModelCatalogue,
pub http_client: reqwest::Client,
}
impl CortexState {
pub fn from_config(config: &GatewayConfig) -> Self {
let mut nodes = HashMap::new();
for nc in &config.nodes {
for nc in &config.neurons {
nodes.insert(
nc.name.clone(),
NodeState {
name: nc.name.clone(),
endpoint: nc.endpoint.clone(),
vram_mb: nc.vram_mb,
pinned: nc.pinned.clone(),
healthy: false, // will be set by first poll
healthy: false,
models: HashMap::new(),
lifecycle_cycles: 0,
last_poll: None,
@@ -30,10 +30,13 @@ impl CortexState {
);
}
let catalogue = ModelCatalogue::load(&config.models_config);
Self {
nodes: RwLock::new(nodes),
node_configs: config.nodes.clone(),
neuron_configs: config.neurons.clone(),
eviction: config.eviction.clone(),
catalogue,
http_client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()