refactor: cortex talks to neurons instead of mistral.rs directly
Replace NodeConfig (static vram_mb, pinned) with NeuronEndpoint.
Hardware discovery and model pinning now come from neuron API and
models.toml catalogue respectively.
- config.rs: nodes -> neurons, add models_config path
- catalogue.rs: ModelProfile with pinned_on, ModelCatalogue
- poller.rs: poll neuron GET /models (ModelInfo format)
- router.rs: resolve inference endpoint via neuron GET /models/{id}/endpoint
- evictor.rs: call neuron POST /models/unload
- node.rs: remove vram_mb, pinned fields (come from discovery/catalogue)
- All 22 gateway tests updated to mock neuron API
- Remove MistralModelsResponse, ModelLifecycleRequest (no longer needed)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,29 +1,19 @@
|
||||
//! Model eviction logic.
|
||||
//!
|
||||
//! The evictor runs as a background task. When the router determines that a
|
||||
//! model needs to be loaded on a node but VRAM is tight, it can request
|
||||
//! eviction via a channel. The evictor then:
|
||||
//! 1. Identifies the LRU model on that node (excluding pinned models)
|
||||
//! 2. Calls `POST /v1/models/unload` on the node
|
||||
//! 3. Increments the lifecycle cycle counter (for defrag tracking)
|
||||
//! The evictor identifies the LRU model on a node (excluding pinned models),
|
||||
//! calls neuron's `POST /models/unload` to free the model, and updates
|
||||
//! local state.
|
||||
|
||||
use crate::state::CortexState;
|
||||
use cortex_core::node::{ModelLifecycleRequest, ModelStatus};
|
||||
use cortex_core::node::ModelStatus;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Runs forever. Currently a placeholder that periodically checks for
|
||||
/// eviction opportunities. In the future, this will be driven by a
|
||||
/// channel from the router when VRAM pressure is detected.
|
||||
/// Runs forever. Placeholder for future channel-driven eviction.
|
||||
pub async fn eviction_loop(fleet: Arc<CortexState>) {
|
||||
// TODO: Replace this polling approach with a channel-driven design
|
||||
// where the router sends eviction requests when it detects that a
|
||||
// model load would exceed available VRAM.
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||
// Placeholder: the actual eviction logic is in `evict_lru_on_node`,
|
||||
// called on demand by the router.
|
||||
let _ = &fleet; // suppress unused warning
|
||||
let _ = &fleet;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,18 +23,19 @@ pub async fn evict_lru_on_node(
|
||||
fleet: &CortexState,
|
||||
node_name: &str,
|
||||
) -> anyhow::Result<Option<String>> {
|
||||
let (endpoint, candidate) = {
|
||||
let (neuron_endpoint, candidate) = {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let Some(node) = nodes.get(node_name) else {
|
||||
anyhow::bail!("node '{node_name}' not found");
|
||||
};
|
||||
|
||||
// Find the loaded model with the oldest last_accessed, excluding pinned.
|
||||
// Find the loaded model with the oldest last_accessed,
|
||||
// excluding models pinned on this neuron (from catalogue).
|
||||
let candidate = node
|
||||
.models
|
||||
.values()
|
||||
.filter(|m| m.status == ModelStatus::Loaded)
|
||||
.filter(|m| !node.pinned.contains(&m.id))
|
||||
.filter(|m| !fleet.catalogue.is_pinned(&m.id, node_name))
|
||||
.min_by_key(|m| m.last_accessed)
|
||||
.map(|m| m.id.clone());
|
||||
|
||||
@@ -58,18 +49,16 @@ pub async fn evict_lru_on_node(
|
||||
|
||||
tracing::info!(node = node_name, model = %model_id, "evicting model");
|
||||
|
||||
let url = format!("{endpoint}/v1/models/unload");
|
||||
// Call neuron's unload endpoint.
|
||||
let url = format!("{neuron_endpoint}/models/unload");
|
||||
let resp = fleet
|
||||
.http_client
|
||||
.post(&url)
|
||||
.json(&ModelLifecycleRequest {
|
||||
model_id: model_id.clone(),
|
||||
})
|
||||
.json(&serde_json::json!({ "model_id": model_id }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if resp.status().is_success() {
|
||||
// Update local state.
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
if let Some(node) = nodes.get_mut(node_name) {
|
||||
if let Some(entry) = node.models.get_mut(&model_id) {
|
||||
@@ -77,14 +66,13 @@ pub async fn evict_lru_on_node(
|
||||
}
|
||||
node.lifecycle_cycles += 1;
|
||||
|
||||
// Check if we should flag for defrag.
|
||||
if fleet.eviction.defrag_after_cycles > 0
|
||||
&& node.lifecycle_cycles >= fleet.eviction.defrag_after_cycles
|
||||
{
|
||||
tracing::warn!(
|
||||
node = node_name,
|
||||
cycles = node.lifecycle_cycles,
|
||||
"VRAM fragmentation threshold reached — consider restarting mistralrs"
|
||||
"VRAM fragmentation threshold reached — consider restarting harness"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
//! Background poller that periodically queries each node's `/v1/models`
|
||||
//! endpoint to refresh the fleet state.
|
||||
//! Background poller that periodically queries each neuron's API
|
||||
//! to refresh the fleet state.
|
||||
|
||||
use crate::state::CortexState;
|
||||
use chrono::Utc;
|
||||
use cortex_core::node::{MistralModelsResponse, ModelEntry, ModelStatus};
|
||||
use cortex_core::harness::ModelInfo;
|
||||
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
const POLL_INTERVAL: Duration = Duration::from_secs(10);
|
||||
|
||||
/// Runs forever, polling all nodes on a fixed interval.
|
||||
/// Runs forever, polling all neurons on a fixed interval.
|
||||
pub async fn poll_loop(fleet: Arc<CortexState>) {
|
||||
loop {
|
||||
poll_once(&fleet).await;
|
||||
@@ -17,15 +18,15 @@ pub async fn poll_loop(fleet: Arc<CortexState>) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll all nodes once. Used by `poll_loop` and available for testing.
|
||||
/// Poll all neurons once. Used by `poll_loop` and available for testing.
|
||||
pub async fn poll_once(fleet: &CortexState) {
|
||||
for nc in &fleet.node_configs {
|
||||
poll_node(fleet, &nc.name, &nc.endpoint).await;
|
||||
for nc in &fleet.neuron_configs {
|
||||
poll_neuron(fleet, &nc.name, &nc.endpoint).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
let url = format!("{endpoint}/v1/models");
|
||||
async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
let url = format!("{endpoint}/models");
|
||||
|
||||
let result = fleet
|
||||
.http_client
|
||||
@@ -41,38 +42,36 @@ async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
|
||||
match result {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
match resp.json::<MistralModelsResponse>().await {
|
||||
Ok(models_resp) => {
|
||||
// Merge upstream model list into our state, preserving
|
||||
// our local metadata (last_accessed, vram_estimate).
|
||||
match resp.json::<Vec<ModelInfo>>().await {
|
||||
Ok(models) => {
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
for upstream in &models_resp.data {
|
||||
for upstream in &models {
|
||||
seen.insert(upstream.id.clone());
|
||||
let status = parse_status(upstream.status.as_deref());
|
||||
let status = parse_status(&upstream.status);
|
||||
|
||||
node.models
|
||||
.entry(upstream.id.clone())
|
||||
.and_modify(|e| {
|
||||
e.status = status;
|
||||
e.vram_estimate_mb = upstream.vram_used_mb;
|
||||
})
|
||||
.or_insert_with(|| ModelEntry {
|
||||
id: upstream.id.clone(),
|
||||
status,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: None,
|
||||
vram_estimate_mb: upstream.vram_used_mb,
|
||||
});
|
||||
}
|
||||
|
||||
// Remove models that are no longer reported by the node
|
||||
// (e.g. after a config change / restart).
|
||||
// Remove models no longer reported by the neuron.
|
||||
node.models.retain(|id, _| seen.contains(id));
|
||||
|
||||
node.healthy = true;
|
||||
node.last_poll = Some(Utc::now());
|
||||
tracing::debug!(node = name, models = models_resp.data.len(), "poll ok");
|
||||
tracing::debug!(node = name, models = models.len(), "poll ok");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(node = name, error = %e, "failed to parse /v1/models response");
|
||||
tracing::warn!(node = name, error = %e, "failed to parse /models response");
|
||||
node.healthy = false;
|
||||
}
|
||||
}
|
||||
@@ -81,24 +80,22 @@ async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
tracing::warn!(
|
||||
node = name,
|
||||
status = %resp.status(),
|
||||
"node returned non-success status"
|
||||
"neuron returned non-success status"
|
||||
);
|
||||
node.healthy = false;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(node = name, error = %e, "failed to reach node");
|
||||
tracing::warn!(node = name, error = %e, "failed to reach neuron");
|
||||
node.healthy = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_status(s: Option<&str>) -> ModelStatus {
|
||||
fn parse_status(s: &str) -> ModelStatus {
|
||||
match s {
|
||||
Some("loaded") => ModelStatus::Loaded,
|
||||
Some("unloaded") => ModelStatus::Unloaded,
|
||||
Some("reloading") => ModelStatus::Reloading,
|
||||
// If the status field is absent, assume loaded (older mistral.rs versions
|
||||
// may not include it).
|
||||
"loaded" => ModelStatus::Loaded,
|
||||
"unloaded" => ModelStatus::Unloaded,
|
||||
"reloading" => ModelStatus::Reloading,
|
||||
_ => ModelStatus::Loaded,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ use std::sync::Arc;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RouteDecision {
|
||||
pub node_name: String,
|
||||
/// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint).
|
||||
pub endpoint: String,
|
||||
/// Whether the model will need to load (cold start).
|
||||
pub cold_start: bool,
|
||||
@@ -25,51 +26,76 @@ pub enum RouteError {
|
||||
ModelNotFound(String),
|
||||
#[error("no healthy nodes available")]
|
||||
NoHealthyNodes,
|
||||
#[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")]
|
||||
EndpointResolveFailed(String, String),
|
||||
}
|
||||
|
||||
/// Resolve which node should serve a request for the given model.
|
||||
/// Asks the neuron for the inference endpoint after selecting a node.
|
||||
pub async fn resolve(
|
||||
fleet: &Arc<CortexState>,
|
||||
model_id: &str,
|
||||
) -> Result<RouteDecision, RouteError> {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let (node_name, neuron_endpoint, cold_start) = {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
|
||||
// Pass 1: find a node where the model is already loaded.
|
||||
let mut loaded_candidate = None;
|
||||
let mut unloaded_candidate = None;
|
||||
let mut loaded_candidate = None;
|
||||
let mut unloaded_candidate = None;
|
||||
|
||||
for node in nodes.values() {
|
||||
if !node.healthy {
|
||||
continue;
|
||||
}
|
||||
if let Some(entry) = node.models.get(model_id) {
|
||||
match entry.status {
|
||||
ModelStatus::Loaded | ModelStatus::Reloading => {
|
||||
loaded_candidate = Some(RouteDecision {
|
||||
node_name: node.name.clone(),
|
||||
endpoint: node.endpoint.clone(),
|
||||
cold_start: false,
|
||||
});
|
||||
break; // loaded is best, stop searching
|
||||
}
|
||||
ModelStatus::Unloaded => {
|
||||
if unloaded_candidate.is_none() {
|
||||
unloaded_candidate = Some(RouteDecision {
|
||||
node_name: node.name.clone(),
|
||||
endpoint: node.endpoint.clone(),
|
||||
cold_start: true,
|
||||
});
|
||||
for node in nodes.values() {
|
||||
if !node.healthy {
|
||||
continue;
|
||||
}
|
||||
if let Some(entry) = node.models.get(model_id) {
|
||||
match entry.status {
|
||||
ModelStatus::Loaded | ModelStatus::Reloading => {
|
||||
loaded_candidate = Some((node.name.clone(), node.endpoint.clone(), false));
|
||||
break;
|
||||
}
|
||||
ModelStatus::Unloaded => {
|
||||
if unloaded_candidate.is_none() {
|
||||
unloaded_candidate =
|
||||
Some((node.name.clone(), node.endpoint.clone(), true));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loaded_candidate.or(unloaded_candidate).ok_or_else(|| {
|
||||
if nodes.values().any(|n| n.healthy) {
|
||||
RouteError::ModelNotFound(model_id.to_string())
|
||||
} else {
|
||||
RouteError::NoHealthyNodes
|
||||
}
|
||||
loaded_candidate.or(unloaded_candidate).ok_or_else(|| {
|
||||
if nodes.values().any(|n| n.healthy) {
|
||||
RouteError::ModelNotFound(model_id.to_string())
|
||||
} else {
|
||||
RouteError::NoHealthyNodes
|
||||
}
|
||||
})?
|
||||
};
|
||||
|
||||
// Ask the neuron for the inference endpoint for this model.
|
||||
let endpoint_url = format!(
|
||||
"{}/models/{}/endpoint",
|
||||
neuron_endpoint,
|
||||
urlencoding::encode(model_id)
|
||||
);
|
||||
|
||||
let inference_endpoint = match fleet.http_client.get(&endpoint_url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => match resp.json::<serde_json::Value>().await {
|
||||
Ok(body) => body
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string()),
|
||||
Err(_) => None,
|
||||
},
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let endpoint = inference_endpoint.ok_or_else(|| {
|
||||
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.clone())
|
||||
})?;
|
||||
|
||||
Ok(RouteDecision {
|
||||
node_name,
|
||||
endpoint,
|
||||
cold_start,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use cortex_core::config::{EvictionSettings, GatewayConfig, NodeConfig};
|
||||
use cortex_core::catalogue::ModelCatalogue;
|
||||
use cortex_core::config::{EvictionSettings, GatewayConfig, NeuronEndpoint};
|
||||
use cortex_core::node::NodeState;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::RwLock;
|
||||
@@ -6,23 +7,22 @@ use tokio::sync::RwLock;
|
||||
/// Shared fleet state, protected by a RwLock for concurrent reader access.
|
||||
pub struct CortexState {
|
||||
pub nodes: RwLock<HashMap<String, NodeState>>,
|
||||
pub node_configs: Vec<NodeConfig>,
|
||||
pub neuron_configs: Vec<NeuronEndpoint>,
|
||||
pub eviction: EvictionSettings,
|
||||
pub catalogue: ModelCatalogue,
|
||||
pub http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl CortexState {
|
||||
pub fn from_config(config: &GatewayConfig) -> Self {
|
||||
let mut nodes = HashMap::new();
|
||||
for nc in &config.nodes {
|
||||
for nc in &config.neurons {
|
||||
nodes.insert(
|
||||
nc.name.clone(),
|
||||
NodeState {
|
||||
name: nc.name.clone(),
|
||||
endpoint: nc.endpoint.clone(),
|
||||
vram_mb: nc.vram_mb,
|
||||
pinned: nc.pinned.clone(),
|
||||
healthy: false, // will be set by first poll
|
||||
healthy: false,
|
||||
models: HashMap::new(),
|
||||
lifecycle_cycles: 0,
|
||||
last_poll: None,
|
||||
@@ -30,10 +30,13 @@ impl CortexState {
|
||||
);
|
||||
}
|
||||
|
||||
let catalogue = ModelCatalogue::load(&config.models_config);
|
||||
|
||||
Self {
|
||||
nodes: RwLock::new(nodes),
|
||||
node_configs: config.nodes.clone(),
|
||||
neuron_configs: config.neurons.clone(),
|
||||
eviction: config.eviction.clone(),
|
||||
catalogue,
|
||||
http_client: reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.build()
|
||||
|
||||
Reference in New Issue
Block a user