feat: scaffold cortex workspace

Rust reverse-proxy for multi-node mistral.rs inference clusters.
Includes crate structure (cortex-core, cortex-gateway, cortex-agent,
cortex-cli), config loading, OpenAI/Anthropic translation stubs,
model routing, eviction, polling, and streaming proxy scaffolding.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-14 18:13:30 +03:00
commit 0da68833af
28 changed files with 4659 additions and 0 deletions

View File

@@ -0,0 +1,14 @@
[package]
name = "cortex-agent"
version.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
cortex-core.workspace = true
tokio.workspace = true
serde.workspace = true
serde_json.workspace = true
reqwest.workspace = true
tracing.workspace = true
anyhow.workspace = true

View File

@@ -0,0 +1,72 @@
//! Per-node agent sidecar.
//!
//! This is a future component that runs on each GPU node alongside mistralrs.
//! It handles:
//! - VRAM defragmentation (restarting the mistralrs systemd unit when the
//! gateway signals that lifecycle_cycles has exceeded the threshold)
//! - Local nvidia-smi polling for actual VRAM usage reporting
//! - Systemd unit management for mistralrs process restarts
//!
//! For now this is a stub. The gateway's poller + evictor handle the critical
//! path (model lifecycle via the mistralrs HTTP API). The agent adds
//! operational niceties that can be built incrementally.
/// Placeholder for agent configuration.
#[derive(Debug, Clone)]
pub struct AgentConfig {
/// The local mistralrs endpoint to monitor.
pub mistralrs_endpoint: String,
/// The systemd unit name for mistralrs (e.g. "mistralrs.service").
pub systemd_unit: String,
}
/// Restart the local mistralrs process via systemd.
/// This is the nuclear option for VRAM defragmentation.
pub async fn restart_mistralrs(config: &AgentConfig) -> anyhow::Result<()> {
tracing::warn!(
unit = %config.systemd_unit,
"restarting mistralrs for VRAM defragmentation"
);
let output = tokio::process::Command::new("systemctl")
.args(["restart", &config.systemd_unit])
.output()
.await?;
if output.status.success() {
tracing::info!(unit = %config.systemd_unit, "mistralrs restarted successfully");
Ok(())
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("systemctl restart failed: {stderr}");
}
}
/// Query nvidia-smi for current VRAM usage on this node.
/// Returns (used_mb, total_mb) for each GPU.
pub async fn query_vram() -> anyhow::Result<Vec<(u64, u64)>> {
let output = tokio::process::Command::new("nvidia-smi")
.args([
"--query-gpu=memory.used,memory.total",
"--format=csv,noheader,nounits",
])
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("nvidia-smi failed: {stderr}");
}
let stdout = String::from_utf8_lossy(&output.stdout);
let mut gpus = Vec::new();
for line in stdout.lines() {
let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
if parts.len() == 2 {
let used: u64 = parts[0].parse().unwrap_or(0);
let total: u64 = parts[1].parse().unwrap_or(0);
gpus.push((used, total));
}
}
Ok(gpus)
}

View File

@@ -0,0 +1 @@
pub mod agent;

View File

@@ -0,0 +1,20 @@
[package]
name = "cortex-cli"
version.workspace = true
edition.workspace = true
license.workspace = true
[[bin]]
name = "cortex"
path = "src/main.rs"
[dependencies]
cortex-core.workspace = true
cortex-gateway.workspace = true
tokio.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
anyhow.workspace = true
reqwest.workspace = true
serde_json.workspace = true
clap = { version = "4", features = ["derive"] }

View File

@@ -0,0 +1,112 @@
use anyhow::Result;
use clap::{Parser, Subcommand};
use cortex_core::config::GatewayConfig;
use tracing_subscriber::EnvFilter;
#[derive(Parser)]
#[command(name = "cortex")]
#[command(about = "Unified inference gateway for multi-node mistral.rs clusters")]
#[command(version)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
/// Start the gateway server.
Serve {
/// Path to the gateway config file.
#[arg(short, long, default_value = "cortex.toml")]
config: String,
},
/// Print the fleet status (models, nodes, health).
Status {
/// Gateway API endpoint to query.
#[arg(short, long, default_value = "http://localhost:8000")]
endpoint: String,
},
}
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing with env filter (e.g. RUST_LOG=cortex_gateway=debug).
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info,cortex_gateway=debug")),
)
.init();
let cli = Cli::parse();
match cli.command {
Commands::Serve { config } => {
let cfg = GatewayConfig::load(&config).map_err(|e| {
anyhow::anyhow!("failed to load config from '{config}': {e}")
})?;
tracing::info!(
nodes = cfg.nodes.len(),
listen = %cfg.gateway.listen,
"starting cortex"
);
// Install Prometheus metrics exporter on a separate port.
cortex_gateway::metrics::install(&cfg.gateway.metrics_listen)?;
cortex_gateway::run(cfg).await?;
}
Commands::Status { endpoint } => {
print_status(&endpoint).await?;
}
}
Ok(())
}
async fn print_status(endpoint: &str) -> Result<()> {
let client = reqwest::Client::new();
// Fetch health.
let health: serde_json::Value = client
.get(format!("{endpoint}/health"))
.send()
.await?
.json()
.await?;
println!("Fleet health: {}", serde_json::to_string_pretty(&health)?);
// Fetch models.
let models: serde_json::Value = client
.get(format!("{endpoint}/v1/models"))
.send()
.await?
.json()
.await?;
println!("\nModels:");
if let Some(data) = models.get("data").and_then(|d| d.as_array()) {
for model in data {
let id = model.get("id").and_then(|v| v.as_str()).unwrap_or("?");
let locations = model
.get("locations")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|l| {
let node = l.get("node")?.as_str()?;
let status = l.get("status")?.as_str()?;
Some(format!("{node}({status})"))
})
.collect::<Vec<_>>()
.join(", ")
})
.unwrap_or_default();
println!(" {id:40} {locations}");
}
}
Ok(())
}

View File

@@ -0,0 +1,15 @@
[package]
name = "cortex-core"
version.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
serde.workspace = true
serde_json.workspace = true
toml.workspace = true
figment.workspace = true
chrono.workspace = true
anyhow.workspace = true
thiserror.workspace = true
tracing.workspace = true

View File

@@ -0,0 +1,87 @@
//! Anthropic Messages API request and response types.
//!
//! These mirror the `/v1/messages` format used by the Anthropic API.
//! The gateway accepts these, translates to OpenAI format, proxies to
//! mistral.rs, then translates the response back.
use serde::{Deserialize, Serialize};
use serde_json::Value;
// ── Messages request ─────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessagesRequest {
pub model: String,
pub messages: Vec<AnthropicMessage>,
pub max_tokens: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<SystemPrompt>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum SystemPrompt {
Text(String),
Blocks(Vec<Value>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnthropicMessage {
pub role: String,
pub content: AnthropicContent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum AnthropicContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContentBlock {
#[serde(rename = "type")]
pub block_type: String,
#[serde(flatten)]
pub data: Value,
}
// ── Messages response ────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessagesResponse {
pub id: String,
#[serde(rename = "type")]
pub response_type: String,
pub role: String,
pub content: Vec<ContentBlock>,
pub model: String,
pub stop_reason: Option<String>,
pub usage: AnthropicUsage,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnthropicUsage {
pub input_tokens: u64,
pub output_tokens: u64,
}
// ── Streaming events ─────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamEvent {
#[serde(rename = "type")]
pub event_type: String,
#[serde(flatten)]
pub data: Value,
}

View File

@@ -0,0 +1,79 @@
use figment::{
Figment,
providers::{Env, Format, Toml},
};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewayConfig {
pub gateway: GatewaySettings,
pub eviction: EvictionSettings,
pub nodes: Vec<NodeConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewaySettings {
/// Address to listen on for API requests (e.g. "0.0.0.0:8000")
pub listen: String,
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:9100")
pub metrics_listen: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvictionSettings {
/// Eviction strategy: "lru" or "priority"
pub strategy: EvictionStrategy,
/// Restart the mistralrs process after this many load/unload cycles
/// to reclaim fragmented VRAM. 0 = never.
#[serde(default)]
pub defrag_after_cycles: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EvictionStrategy {
Lru,
Priority,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeConfig {
/// Human-readable node name (e.g. "gpu-large")
pub name: String,
/// Base URL of the mistralrs HTTP server (e.g. "http://gpu-large.internal:8080")
pub endpoint: String,
/// Total VRAM in MB across all GPUs on this node
pub vram_mb: u64,
/// Model IDs that should never be evicted from this node
#[serde(default)]
pub pinned: Vec<String>,
}
impl GatewayConfig {
/// Load configuration from a TOML file, with environment variable overrides.
/// Env vars are prefixed with `CORTEX_` and use `__` as a separator
/// (e.g. `CORTEX_GATEWAY__LISTEN=0.0.0.0:9000`).
pub fn load(path: impl AsRef<Path>) -> Result<Self, figment::Error> {
Figment::new()
.merge(Toml::file(path))
.merge(Env::prefixed("CORTEX_").split("__"))
.extract()
}
}
impl Default for GatewayConfig {
fn default() -> Self {
Self {
gateway: GatewaySettings {
listen: "0.0.0.0:8000".into(),
metrics_listen: "0.0.0.0:9100".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,
defrag_after_cycles: 50,
},
nodes: vec![],
}
}
}

View File

@@ -0,0 +1,6 @@
pub mod anthropic;
pub mod config;
pub mod metrics;
pub mod node;
pub mod openai;
pub mod translate;

View File

@@ -0,0 +1,23 @@
//! Request-level metrics captured by the gateway proxy layer.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
/// Metrics captured for a single proxied request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestMetrics {
pub timestamp: DateTime<Utc>,
pub model: String,
pub node: String,
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
/// Tokens per second for the generation phase.
pub tok_per_sec: f64,
/// Time from request start to first SSE chunk (streaming) or full response.
pub time_to_first_token_ms: u64,
/// Total request latency including proxy overhead.
pub total_latency_ms: u64,
/// Whether this request triggered a model load (cold start).
pub cold_start: bool,
}

View File

@@ -0,0 +1,74 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Runtime state of a single node in the fleet.
#[derive(Debug, Clone)]
pub struct NodeState {
pub name: String,
pub endpoint: String,
pub vram_mb: u64,
pub pinned: Vec<String>,
pub healthy: bool,
pub models: HashMap<String, ModelEntry>,
/// Number of load/unload cycles since last process restart.
pub lifecycle_cycles: u32,
pub last_poll: Option<DateTime<Utc>>,
}
/// A model registered on a node, with its runtime status.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelEntry {
pub id: String,
pub status: ModelStatus,
/// When this model was last used (for LRU eviction).
pub last_accessed: Option<DateTime<Utc>>,
/// Estimated VRAM usage in MB when loaded.
pub vram_estimate_mb: Option<u64>,
}
/// Model lifecycle status, matching the mistral.rs API.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ModelStatus {
Loaded,
Unloaded,
Reloading,
}
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
/// Includes which node(s) host this model and their status.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CortexModelEntry {
pub id: String,
pub object: String,
/// Which nodes have this model (and their status).
pub locations: Vec<ModelLocation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelLocation {
pub node: String,
pub status: ModelStatus,
pub vram_estimate_mb: Option<u64>,
}
/// Response from mistral.rs `GET /v1/models`.
/// This is the upstream format we parse when polling nodes.
#[derive(Debug, Clone, Deserialize)]
pub struct MistralModelsResponse {
pub data: Vec<MistralModelEntry>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MistralModelEntry {
pub id: String,
#[serde(default)]
pub status: Option<String>,
}
/// Request body for mistral.rs model lifecycle endpoints.
#[derive(Debug, Clone, Serialize)]
pub struct ModelLifecycleRequest {
pub model_id: String,
}

View File

@@ -0,0 +1,122 @@
//! OpenAI-compatible request and response types.
//!
//! These are a subset sufficient for chat completions (streaming + non-streaming).
//! Fields not relevant to proxying are captured as `serde_json::Value` via
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
//! extension field mistral.rs supports.
use serde::{Deserialize, Serialize};
use serde_json::Value;
// ── Chat completion request ──────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
/// All other fields (tools, response_format, mistral.rs extensions, etc.)
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: MessageContent,
#[serde(flatten)]
pub extra: Value,
}
/// Content can be a simple string or an array of content parts (for vision).
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<Value>),
}
// ── Chat completion response (non-streaming) ─────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatCompletionChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
#[serde(flatten)]
pub extra: Value,
}
// ── Streaming chunk ──────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChunkChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkChoice {
pub index: u32,
pub delta: Value,
pub finish_reason: Option<String>,
#[serde(flatten)]
pub extra: Value,
}
// ── Usage ────────────────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
}
// ── Models list response ─────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelObject>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelObject {
pub id: String,
pub object: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub owned_by: Option<String>,
/// Gateway extensions: which node(s) host this model.
#[serde(skip_serializing_if = "Option::is_none")]
pub locations: Option<Vec<super::node::ModelLocation>>,
#[serde(flatten)]
pub extra: Value,
}

View File

@@ -0,0 +1,114 @@
//! Translation between OpenAI and Anthropic request/response envelopes.
//!
//! This is a stateless transformation — no context is carried between requests.
use crate::anthropic::{
AnthropicContent, AnthropicMessage, AnthropicUsage, ContentBlock, MessagesRequest,
MessagesResponse, SystemPrompt,
};
use crate::openai::{
ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, Usage,
MessageContent,
};
use serde_json::{json, Value};
/// Convert an Anthropic Messages request into an OpenAI ChatCompletion request.
pub fn anthropic_to_openai(req: MessagesRequest) -> ChatCompletionRequest {
let mut messages = Vec::new();
// Anthropic `system` field becomes a system message.
if let Some(system) = req.system {
let content = match system {
SystemPrompt::Text(t) => t,
SystemPrompt::Blocks(blocks) => serde_json::to_string(&blocks).unwrap_or_default(),
};
messages.push(ChatMessage {
role: "system".into(),
content: MessageContent::Text(content),
extra: Value::Null,
});
}
// Convert message roles and content.
for msg in req.messages {
let content = match msg.content {
AnthropicContent::Text(t) => MessageContent::Text(t),
AnthropicContent::Blocks(blocks) => {
// For simple text-only blocks, extract the text.
// For mixed content (images, etc.), pass as parts.
if blocks.len() == 1 && blocks[0].block_type == "text" {
let text = blocks[0]
.data
.get("text")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
MessageContent::Text(text)
} else {
MessageContent::Parts(
blocks.into_iter().map(|b| json!(b)).collect(),
)
}
}
};
messages.push(ChatMessage {
role: msg.role,
content,
extra: Value::Null,
});
}
ChatCompletionRequest {
model: req.model,
messages,
temperature: req.temperature,
top_p: req.top_p,
max_tokens: Some(req.max_tokens),
stream: req.stream,
extra: req.extra,
}
}
/// Convert an OpenAI ChatCompletion response into an Anthropic Messages response.
pub fn openai_to_anthropic(resp: ChatCompletionResponse) -> MessagesResponse {
let choice = resp.choices.into_iter().next();
let (content_text, stop_reason) = match choice {
Some(c) => {
let text = match c.message.content {
MessageContent::Text(t) => t,
MessageContent::Parts(parts) => serde_json::to_string(&parts).unwrap_or_default(),
};
let stop = c.finish_reason.map(|r| match r.as_str() {
"stop" => "end_turn".to_string(),
"length" => "max_tokens".to_string(),
other => other.to_string(),
});
(text, stop)
}
None => (String::new(), None),
};
let usage = resp.usage.unwrap_or(Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
});
MessagesResponse {
id: resp.id,
response_type: "message".into(),
role: "assistant".into(),
content: vec![ContentBlock {
block_type: "text".into(),
data: json!({ "text": content_text }),
}],
model: resp.model,
stop_reason,
usage: AnthropicUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
},
extra: Value::Null,
}
}

View File

@@ -0,0 +1,25 @@
[package]
name = "cortex-gateway"
version.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
cortex-core.workspace = true
tokio.workspace = true
axum.workspace = true
tower.workspace = true
tower-http.workspace = true
serde.workspace = true
serde_json.workspace = true
reqwest.workspace = true
tracing.workspace = true
metrics.workspace = true
metrics-exporter-prometheus.workspace = true
chrono.workspace = true
anyhow.workspace = true
thiserror.workspace = true
futures.workspace = true
tokio-stream.workspace = true
eventsource-stream.workspace = true
bytes = "1"

View File

@@ -0,0 +1,106 @@
//! Model eviction logic.
//!
//! The evictor runs as a background task. When the router determines that a
//! model needs to be loaded on a node but VRAM is tight, it can request
//! eviction via a channel. The evictor then:
//! 1. Identifies the LRU model on that node (excluding pinned models)
//! 2. Calls `POST /v1/models/unload` on the node
//! 3. Increments the lifecycle cycle counter (for defrag tracking)
use crate::state::CortexState;
use cortex_core::node::{ModelLifecycleRequest, ModelStatus};
use std::sync::Arc;
use std::time::Duration;
/// Runs forever. Currently a placeholder that periodically checks for
/// eviction opportunities. In the future, this will be driven by a
/// channel from the router when VRAM pressure is detected.
pub async fn eviction_loop(fleet: Arc<CortexState>) {
// TODO: Replace this polling approach with a channel-driven design
// where the router sends eviction requests when it detects that a
// model load would exceed available VRAM.
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
// Placeholder: the actual eviction logic is in `evict_lru_on_node`,
// called on demand by the router.
let _ = &fleet; // suppress unused warning
}
}
/// Evict the least-recently-used model on a given node.
/// Returns the model ID that was evicted, or None if nothing could be evicted.
pub async fn evict_lru_on_node(
fleet: &CortexState,
node_name: &str,
) -> anyhow::Result<Option<String>> {
let (endpoint, candidate) = {
let nodes = fleet.nodes.read().await;
let Some(node) = nodes.get(node_name) else {
anyhow::bail!("node '{node_name}' not found");
};
// Find the loaded model with the oldest last_accessed, excluding pinned.
let candidate = node
.models
.values()
.filter(|m| m.status == ModelStatus::Loaded)
.filter(|m| !node.pinned.contains(&m.id))
.min_by_key(|m| m.last_accessed)
.map(|m| m.id.clone());
(node.endpoint.clone(), candidate)
};
let Some(model_id) = candidate else {
tracing::info!(node = node_name, "no evictable models found");
return Ok(None);
};
tracing::info!(node = node_name, model = %model_id, "evicting model");
let url = format!("{endpoint}/v1/models/unload");
let resp = fleet
.http_client
.post(&url)
.json(&ModelLifecycleRequest {
model_id: model_id.clone(),
})
.send()
.await?;
if resp.status().is_success() {
// Update local state.
let mut nodes = fleet.nodes.write().await;
if let Some(node) = nodes.get_mut(node_name) {
if let Some(entry) = node.models.get_mut(&model_id) {
entry.status = ModelStatus::Unloaded;
}
node.lifecycle_cycles += 1;
// Check if we should flag for defrag.
if fleet.eviction.defrag_after_cycles > 0
&& node.lifecycle_cycles >= fleet.eviction.defrag_after_cycles
{
tracing::warn!(
node = node_name,
cycles = node.lifecycle_cycles,
"VRAM fragmentation threshold reached — consider restarting mistralrs"
);
}
}
tracing::info!(node = node_name, model = %model_id, "model evicted");
Ok(Some(model_id))
} else {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
tracing::error!(
node = node_name,
model = %model_id,
status = %status,
body = %body,
"failed to evict model"
);
anyhow::bail!("eviction failed: {status} {body}");
}
}

View File

@@ -0,0 +1,207 @@
//! Axum HTTP handlers for the gateway API surface.
use crate::proxy;
use crate::router;
use crate::state::CortexState;
use axum::body::Bytes;
use axum::extract::State;
use axum::http::HeaderMap;
use axum::response::{IntoResponse, Json, Response};
use axum::routing::{get, post};
use axum::Router;
use cortex_core::node::{CortexModelEntry, ModelLocation};
use cortex_core::openai::ChatCompletionRequest;
use serde_json::{json, Value};
use std::sync::Arc;
pub fn api_routes() -> Router<Arc<CortexState>> {
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions))
.route("/v1/models", get(list_models))
.route("/v1/messages", post(anthropic_messages))
.route("/health", get(health))
.route("/", get(health))
}
/// `POST /v1/chat/completions` — proxy to the appropriate backend node.
async fn chat_completions(
State(fleet): State<Arc<CortexState>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let model_id = match extract_model(&body) {
Some(m) => m,
None => return error_response(400, "missing 'model' field in request body"),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => return error_response(404, &e.to_string()),
};
match proxy::forward_request(&fleet.http_client, &route, "/v1/chat/completions", headers, body)
.await
{
Ok(resp) => resp,
Err(e) => e.into_response(),
}
}
/// `POST /v1/completions` — proxy completions endpoint.
async fn completions(
State(fleet): State<Arc<CortexState>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let model_id = match extract_model(&body) {
Some(m) => m,
None => return error_response(400, "missing 'model' field in request body"),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => return error_response(404, &e.to_string()),
};
match proxy::forward_request(&fleet.http_client, &route, "/v1/completions", headers, body)
.await
{
Ok(resp) => resp,
Err(e) => e.into_response(),
}
}
/// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back.
async fn anthropic_messages(
State(fleet): State<Arc<CortexState>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
// Parse as Anthropic request.
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
Ok(r) => r,
Err(e) => return error_response(400, &format!("invalid Anthropic request: {e}")),
};
let model_id = anth_req.model.clone();
let is_streaming = anth_req.stream.unwrap_or(false);
// Translate to OpenAI format.
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
let openai_body = match serde_json::to_vec(&openai_req) {
Ok(b) => Bytes::from(b),
Err(e) => return error_response(500, &format!("translation error: {e}")),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => return error_response(404, &e.to_string()),
};
if is_streaming {
// TODO: streaming Anthropic translation requires converting SSE format.
// For now, proxy the OpenAI SSE stream directly (clients that can handle
// OpenAI SSE will work; full Anthropic SSE translation is a follow-up).
match proxy::forward_request(
&fleet.http_client,
&route,
"/v1/chat/completions",
headers,
openai_body,
)
.await
{
Ok(resp) => resp,
Err(e) => e.into_response(),
}
} else {
// Non-streaming: proxy, await full response, translate back.
match proxy::forward_request(
&fleet.http_client,
&route,
"/v1/chat/completions",
headers,
openai_body,
)
.await
{
Ok(resp) => {
// TODO: buffer response, parse as OpenAI ChatCompletionResponse,
// translate to Anthropic MessagesResponse.
// For now, return the OpenAI response as-is.
resp
}
Err(e) => e.into_response(),
}
}
}
/// `GET /v1/models` — aggregate models from all nodes.
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
let nodes = fleet.nodes.read().await;
let mut model_map: std::collections::HashMap<String, CortexModelEntry> =
std::collections::HashMap::new();
for node in nodes.values() {
for (model_id, entry) in &node.models {
let location = ModelLocation {
node: node.name.clone(),
status: entry.status,
vram_estimate_mb: entry.vram_estimate_mb,
};
model_map
.entry(model_id.clone())
.and_modify(|e| e.locations.push(location.clone()))
.or_insert_with(|| CortexModelEntry {
id: model_id.clone(),
object: "model".into(),
locations: vec![location],
});
}
}
let data: Vec<Value> = model_map
.values()
.map(|e| json!(e))
.collect();
Json(json!({
"object": "list",
"data": data,
}))
}
/// `GET /health`
async fn health(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
let nodes = fleet.nodes.read().await;
let healthy_count = nodes.values().filter(|n| n.healthy).count();
let total_count = nodes.len();
Json(json!({
"status": if healthy_count > 0 { "ok" } else { "degraded" },
"nodes": {
"healthy": healthy_count,
"total": total_count,
}
}))
}
// ── Helpers ──────────────────────────────────────────────────────────
fn extract_model(body: &[u8]) -> Option<String> {
let v: Value = serde_json::from_slice(body).ok()?;
v.get("model")?.as_str().map(|s| s.to_string())
}
fn error_response(status: u16, message: &str) -> Response {
let code = axum::http::StatusCode::from_u16(status)
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
let body = json!({
"error": {
"message": message,
"type": "gateway_error",
}
});
(code, Json(body)).into_response()
}

View File

@@ -0,0 +1,51 @@
pub mod evictor;
pub mod handlers;
pub mod metrics;
pub mod poller;
pub mod proxy;
pub mod router;
pub mod state;
use anyhow::Result;
use axum::Router;
use cortex_core::config::GatewayConfig;
use std::sync::Arc;
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
/// Build the Axum application router with all routes wired up.
pub fn build_app(fleet: Arc<state::CortexState>) -> Router {
Router::new()
.merge(handlers::api_routes())
.layer(CorsLayer::permissive())
.layer(TraceLayer::new_for_http())
.with_state(fleet)
}
/// Start the gateway: build state from config, spawn background tasks,
/// bind the HTTP server.
pub async fn run(config: GatewayConfig) -> Result<()> {
let fleet = Arc::new(state::CortexState::from_config(&config));
// Spawn the background poller that refreshes node/model status.
let poller_fleet = Arc::clone(&fleet);
tokio::spawn(async move {
poller::poll_loop(poller_fleet).await;
});
// Spawn the evictor (reacts to VRAM pressure events from the router).
let evictor_fleet = Arc::clone(&fleet);
tokio::spawn(async move {
evictor::eviction_loop(evictor_fleet).await;
});
let app = build_app(Arc::clone(&fleet));
let listen_addr = config.gateway.listen.parse::<std::net::SocketAddr>()?;
tracing::info!("cortex listening on {listen_addr}");
let listener = tokio::net::TcpListener::bind(listen_addr).await?;
axum::serve(listener, app).await?;
Ok(())
}

View File

@@ -0,0 +1,55 @@
//! Prometheus metrics exporter.
//!
//! Runs on a separate port from the main API, exposing `/metrics`
//! in Prometheus text format.
use anyhow::Result;
use metrics_exporter_prometheus::PrometheusBuilder;
use std::net::SocketAddr;
/// Install the Prometheus metrics recorder and return a handle.
/// The `/metrics` endpoint is served by the exporter's built-in HTTP server.
pub fn install(listen: &str) -> Result<()> {
let addr: SocketAddr = listen.parse()?;
PrometheusBuilder::new()
.with_http_listener(addr)
.install()
.map_err(|e| anyhow::anyhow!("failed to install Prometheus exporter: {e}"))?;
tracing::info!("prometheus metrics exporter on {addr}");
// Register histograms and counters used by the proxy layer.
// The `metrics` crate lazily creates metrics on first use, but
// describing them up front gives Prometheus proper HELP/TYPE lines.
metrics::describe_histogram!(
"cortex_request_duration_seconds",
"Total request latency in seconds"
);
metrics::describe_histogram!(
"cortex_time_to_first_token_seconds",
"Time to first token in seconds"
);
metrics::describe_histogram!(
"cortex_tokens_per_second",
"Generation throughput in tokens per second"
);
metrics::describe_counter!(
"cortex_requests_total",
"Total number of proxied requests"
);
metrics::describe_counter!(
"cortex_request_errors_total",
"Total number of failed proxy requests"
);
metrics::describe_counter!(
"cortex_evictions_total",
"Total number of model evictions"
);
metrics::describe_counter!(
"cortex_cold_starts_total",
"Total number of cold-start model loads"
);
Ok(())
}

View File

@@ -0,0 +1,103 @@
//! Background poller that periodically queries each node's `/v1/models`
//! endpoint to refresh the fleet state.
use crate::state::CortexState;
use chrono::Utc;
use cortex_core::node::{MistralModelsResponse, ModelEntry, ModelStatus};
use std::sync::Arc;
use std::time::Duration;
const POLL_INTERVAL: Duration = Duration::from_secs(10);
/// Runs forever, polling all nodes on a fixed interval.
pub async fn poll_loop(fleet: Arc<CortexState>) {
loop {
for nc in &fleet.node_configs {
poll_node(&fleet, &nc.name, &nc.endpoint).await;
}
tokio::time::sleep(POLL_INTERVAL).await;
}
}
async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
let url = format!("{endpoint}/v1/models");
let result = fleet
.http_client
.get(&url)
.timeout(Duration::from_secs(5))
.send()
.await;
let mut nodes = fleet.nodes.write().await;
let Some(node) = nodes.get_mut(name) else {
return;
};
match result {
Ok(resp) if resp.status().is_success() => {
match resp.json::<MistralModelsResponse>().await {
Ok(models_resp) => {
// Merge upstream model list into our state, preserving
// our local metadata (last_accessed, vram_estimate).
let mut seen = std::collections::HashSet::new();
for upstream in &models_resp.data {
seen.insert(upstream.id.clone());
let status = parse_status(upstream.status.as_deref());
node.models
.entry(upstream.id.clone())
.and_modify(|e| {
e.status = status;
})
.or_insert_with(|| ModelEntry {
id: upstream.id.clone(),
status,
last_accessed: None,
vram_estimate_mb: None,
});
}
// Remove models that are no longer reported by the node
// (e.g. after a config change / restart).
node.models.retain(|id, _| seen.contains(id));
node.healthy = true;
node.last_poll = Some(Utc::now());
tracing::debug!(
node = name,
models = models_resp.data.len(),
"poll ok"
);
}
Err(e) => {
tracing::warn!(node = name, error = %e, "failed to parse /v1/models response");
node.healthy = false;
}
}
}
Ok(resp) => {
tracing::warn!(
node = name,
status = %resp.status(),
"node returned non-success status"
);
node.healthy = false;
}
Err(e) => {
tracing::warn!(node = name, error = %e, "failed to reach node");
node.healthy = false;
}
}
}
fn parse_status(s: Option<&str>) -> ModelStatus {
match s {
Some("loaded") => ModelStatus::Loaded,
Some("unloaded") => ModelStatus::Unloaded,
Some("reloading") => ModelStatus::Reloading,
// If the status field is absent, assume loaded (older mistral.rs versions
// may not include it).
_ => ModelStatus::Loaded,
}
}

View File

@@ -0,0 +1,82 @@
//! Streaming HTTP reverse proxy to mistral.rs backends.
//!
//! For streaming requests, SSE chunks are forwarded as they arrive.
//! The proxy captures timing information for metrics but does not
//! buffer the full response.
use crate::router::RouteDecision;
use anyhow::Result;
use axum::body::Body;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use reqwest::Client;
/// Proxy a request body to the resolved backend node and stream the response.
pub async fn forward_request(
client: &Client,
route: &RouteDecision,
path: &str,
headers: HeaderMap,
body: bytes::Bytes,
) -> Result<Response, ProxyError> {
let url = format!("{}{}", route.endpoint, path);
tracing::info!(
node = %route.node_name,
url = %url,
cold_start = route.cold_start,
"proxying request"
);
let mut req_builder = client.post(&url).body(body);
// Forward relevant headers.
for (key, value) in headers.iter() {
if key == "host" || key == "content-length" {
continue; // reqwest sets these
}
req_builder = req_builder.header(key, value);
}
let upstream_resp = req_builder.send().await.map_err(ProxyError::Upstream)?;
let status = StatusCode::from_u16(upstream_resp.status().as_u16())
.unwrap_or(StatusCode::BAD_GATEWAY);
let resp_headers = upstream_resp.headers().clone();
let stream = upstream_resp.bytes_stream();
let body = Body::from_stream(stream);
let mut response = Response::builder().status(status);
for (key, value) in resp_headers.iter() {
response = response.header(key, value);
}
response
.body(body)
.map_err(|e| ProxyError::ResponseBuild(e.to_string()))
}
#[derive(Debug, thiserror::Error)]
pub enum ProxyError {
#[error("upstream request failed: {0}")]
Upstream(reqwest::Error),
#[error("failed to build response: {0}")]
ResponseBuild(String),
}
impl IntoResponse for ProxyError {
fn into_response(self) -> Response {
let status = match &self {
ProxyError::Upstream(_) => StatusCode::BAD_GATEWAY,
ProxyError::ResponseBuild(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
let body = serde_json::json!({
"error": {
"message": self.to_string(),
"type": "proxy_error",
}
});
(status, axum::Json(body)).into_response()
}
}

View File

@@ -0,0 +1,74 @@
//! Model-to-node routing logic.
//!
//! Given a model ID from an inbound request, determine which node should
//! handle it. Priority:
//! 1. Node where the model is currently `Loaded`
//! 2. Node where the model is `Unloaded` (will lazy-load on request)
//! 3. Error: model not found on any node
use crate::state::CortexState;
use cortex_core::node::ModelStatus;
use std::sync::Arc;
/// The routing decision: which node endpoint to proxy the request to.
#[derive(Debug, Clone)]
pub struct RouteDecision {
pub node_name: String,
pub endpoint: String,
/// Whether the model will need to load (cold start).
pub cold_start: bool,
}
#[derive(Debug, thiserror::Error)]
pub enum RouteError {
#[error("model '{0}' not found on any node")]
ModelNotFound(String),
#[error("no healthy nodes available")]
NoHealthyNodes,
}
/// Resolve which node should serve a request for the given model.
pub async fn resolve(fleet: &Arc<CortexState>, model_id: &str) -> Result<RouteDecision, RouteError> {
let nodes = fleet.nodes.read().await;
// Pass 1: find a node where the model is already loaded.
let mut loaded_candidate = None;
let mut unloaded_candidate = None;
for node in nodes.values() {
if !node.healthy {
continue;
}
if let Some(entry) = node.models.get(model_id) {
match entry.status {
ModelStatus::Loaded | ModelStatus::Reloading => {
loaded_candidate = Some(RouteDecision {
node_name: node.name.clone(),
endpoint: node.endpoint.clone(),
cold_start: false,
});
break; // loaded is best, stop searching
}
ModelStatus::Unloaded => {
if unloaded_candidate.is_none() {
unloaded_candidate = Some(RouteDecision {
node_name: node.name.clone(),
endpoint: node.endpoint.clone(),
cold_start: true,
});
}
}
}
}
}
loaded_candidate
.or(unloaded_candidate)
.ok_or_else(|| {
if nodes.values().any(|n| n.healthy) {
RouteError::ModelNotFound(model_id.to_string())
} else {
RouteError::NoHealthyNodes
}
})
}

View File

@@ -0,0 +1,43 @@
use cortex_core::config::{EvictionSettings, GatewayConfig, NodeConfig};
use cortex_core::node::NodeState;
use std::collections::HashMap;
use tokio::sync::RwLock;
/// Shared fleet state, protected by a RwLock for concurrent reader access.
pub struct CortexState {
pub nodes: RwLock<HashMap<String, NodeState>>,
pub node_configs: Vec<NodeConfig>,
pub eviction: EvictionSettings,
pub http_client: reqwest::Client,
}
impl CortexState {
pub fn from_config(config: &GatewayConfig) -> Self {
let mut nodes = HashMap::new();
for nc in &config.nodes {
nodes.insert(
nc.name.clone(),
NodeState {
name: nc.name.clone(),
endpoint: nc.endpoint.clone(),
vram_mb: nc.vram_mb,
pinned: nc.pinned.clone(),
healthy: false, // will be set by first poll
models: HashMap::new(),
lifecycle_cycles: 0,
last_poll: None,
},
);
}
Self {
nodes: RwLock::new(nodes),
node_configs: config.nodes.clone(),
eviction: config.eviction.clone(),
http_client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()
.expect("failed to build HTTP client"),
}
}
}