feat: scaffold cortex workspace

Rust reverse-proxy for multi-node mistral.rs inference clusters.
Includes crate structure (cortex-core, cortex-gateway, cortex-agent,
cortex-cli), config loading, OpenAI/Anthropic translation stubs,
model routing, eviction, polling, and streaming proxy scaffolding.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-14 18:13:30 +03:00
commit 0da68833af
28 changed files with 4659 additions and 0 deletions

6
.gitignore vendored Normal file
View File

@@ -0,0 +1,6 @@
/target
*.swp
*.swo
.idea/
.vscode/
cortex.toml

141
CLAUDE.md Normal file
View File

@@ -0,0 +1,141 @@
# CLAUDE.md — cortex
## Project overview
cortex is a Rust reverse-proxy that sits in front of multiple
mistral.rs inference nodes and presents a unified OpenAI + Anthropic
compatible API surface. It handles model routing, lifecycle management
(load/unload/evict), request translation, and metrics collection.
## Repository layout
```
cortex/
├── Cargo.toml # workspace root
├── cortex.toml # example gateway config
├── README.md
├── CLAUDE.md # ← you are here
├── crates/
│ ├── cortex-core/ # shared types, config, envelopes
│ │ └── src/
│ │ ├── lib.rs
│ │ ├── config.rs # figment-based config structs
│ │ ├── node.rs # NodeState, ModelStatus
│ │ ├── openai.rs # OpenAI request/response types
│ │ ├── anthropic.rs # Anthropic request/response types
│ │ ├── translate.rs # OpenAI <-> Anthropic translation
│ │ └── metrics.rs # RequestMetrics, histogram helpers
│ ├── cortex-gateway/ # the HTTP proxy server
│ │ └── src/
│ │ ├── lib.rs
│ │ ├── state.rs # CortexState: Arc<RwLock<...>>
│ │ ├── router.rs # model -> node routing logic
│ │ ├── proxy.rs # streaming HTTP proxy to backends
│ │ ├── evictor.rs # LRU/priority eviction logic
│ │ ├── poller.rs # background task polling node status
│ │ ├── handlers.rs # axum handlers (chat, completions, models, etc.)
│ │ └── metrics.rs # prometheus exporter endpoint
│ ├── cortex-agent/ # per-node sidecar (future: defrag, restart)
│ │ └── src/
│ │ ├── lib.rs
│ │ └── agent.rs # local node management
│ └── cortex-cli/ # CLI entrypoint
│ └── src/
│ └── main.rs
└── tests/ # integration tests (future)
```
## Key design decisions
### mistral.rs HTTP API for model lifecycle
mistral.rs (v0.8+) supports dynamic model loading/unloading at runtime:
- `POST /v1/models/unload {"model_id": "..."}` — frees VRAM, preserves config
- `POST /v1/models/reload {"model_id": "..."}` — explicitly reload
- `POST /v1/models/status {"model_id": "..."}` — loaded/unloaded/reloading
- `GET /v1/models` — lists all models with status field
- Lazy loading: requests to unloaded models trigger automatic reload
The gateway does NOT manage systemd units for model swaps. It calls these
HTTP endpoints directly. The only systemd interaction is for full-process
restarts after VRAM fragmentation accumulates (defrag_after_cycles).
### Streaming proxy
Chat completions are proxied as SSE streams. The gateway must:
1. Parse the inbound request to extract the model name
2. Route to the correct backend node
3. Stream the response back, capturing token timing for metrics
4. NOT buffer the full response — true streaming passthrough
### Anthropic translation
When a request arrives at `/v1/messages` (Anthropic format), the gateway
translates it to OpenAI format before proxying to mistral.rs, then
translates the response back. This is stateless envelope transformation.
### Eviction
The evictor runs as a background task. Before loading a model on a node
where VRAM is tight:
1. Check if the model is already loaded elsewhere → route there instead
2. Find the LRU model on the target node (excluding pinned models)
3. Call `/v1/models/unload` on that model
4. The incoming request's lazy-load triggers the new model load
### Metrics
Per-request: model, node, prompt_tokens, completion_tokens, total_tokens,
tok_per_sec, time_to_first_token_ms, total_latency_ms.
Exposed as Prometheus histograms/counters on a separate port.
## Tech stack
- **Rust 2024 edition** — workspace with 4 crates
- **Axum 0.8** — HTTP framework (same as mistral.rs itself)
- **reqwest** — HTTP client for proxying to backends
- **figment** — config loading (TOML + env vars)
- **tokio** — async runtime
- **metrics + metrics-exporter-prometheus** — observability
- **tracing** — structured logging
## Build commands
```sh
cargo build --release # build all crates
cargo run -p cortex-cli -- serve # run the gateway
cargo test # run all tests
cargo clippy --workspace # lint
```
## Environment
- Targets Fedora 43 (systemd, SELinux enforcing)
- Nodes communicate over a private network (e.g. WireGuard mesh)
- One or more GPU nodes running mistral.rs on port 8080
- Optionally a metrics-only node (no GPU) for Prometheus/Grafana
- Each node runs `mistralrs serve` on port 8080
- Gateway listens on port 8000 (API) and 9100 (metrics)
- TLS terminated at gateway or via nginx; internal traffic is plaintext over WireGuard
## Conventions
- Error handling: `anyhow` for binaries, `thiserror` for library crates
- No `unwrap()` in library code; `expect()` only with clear rationale
- All public types derive `Debug, Clone, Serialize, Deserialize` where sensible
- Config structs use `figment` with TOML as primary source, env vars as override
- Prefer `Arc<RwLock<...>>` for shared fleet state; minimize lock duration
- SSE streaming uses `tokio_stream` + `eventsource-stream` for parsing
- Log at `info` for request routing, `debug` for proxy details, `warn` for
eviction and node health, `error` for proxy failures
## Current status
**Scaffold phase** — crate structure, types, and handler stubs are in place.
The following needs implementation:
1. **cortex-core**: Flesh out OpenAI/Anthropic envelope types with all fields
needed for chat completions (streaming + non-streaming)
2. **cortex-gateway/proxy.rs**: Implement streaming HTTP proxy with SSE passthrough
3. **cortex-gateway/router.rs**: Model-to-node routing with fallback to least-loaded
4. **cortex-gateway/evictor.rs**: LRU eviction with pinning support
5. **cortex-gateway/poller.rs**: Background polling of node `/v1/models` endpoints
6. **cortex-gateway/handlers.rs**: Wire up axum routes to proxy logic
7. **cortex-core/translate.rs**: OpenAI <-> Anthropic request/response translation
8. **cortex-agent**: Sidecar for VRAM defrag restarts (lower priority)
9. **Integration tests**: Mock mistralrs backends, test routing + eviction

2787
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

57
Cargo.toml Normal file
View File

@@ -0,0 +1,57 @@
[workspace]
resolver = "2"
members = [
"crates/cortex-core",
"crates/cortex-gateway",
"crates/cortex-agent",
"crates/cortex-cli",
]
[workspace.package]
version = "0.1.0"
edition = "2024"
license = "GPL-3.0"
repository = "https://git.lair.cafe/helexa/cortex"
[workspace.dependencies]
# async runtime
tokio = { version = "1", features = ["full"] }
# web framework
axum = { version = "0.8", features = ["macros"] }
tower = "0.5"
tower-http = { version = "0.6", features = ["cors", "trace", "timeout"] }
# serialization
serde = { version = "1", features = ["derive"] }
serde_json = "1"
toml = "0.8"
# http client (for proxying to mistralrs backends)
reqwest = { version = "0.12", features = ["json", "stream"] }
# observability
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
metrics = "0.24"
metrics-exporter-prometheus = "0.16"
# time
chrono = { version = "0.4", features = ["serde"] }
# config
figment = { version = "0.10", features = ["toml", "env"] }
# error handling
anyhow = "1"
thiserror = "2"
# futures / streams (for SSE proxying)
futures = "0.3"
tokio-stream = "0.1"
eventsource-stream = "0.2"
# workspace crates
cortex-core = { path = "crates/cortex-core" }
cortex-gateway = { path = "crates/cortex-gateway" }
cortex-agent = { path = "crates/cortex-agent" }

138
README.md Normal file
View File

@@ -0,0 +1,138 @@
# cortex
A Rust reverse-proxy and fleet management layer for multi-node
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) inference clusters.
## Problem
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
model affinities) requires a unified API surface that:
- Presents a **single `/v1/models` catalogue** merging every model across every
node.
- **Routes requests** to the correct node based on where a model is loaded (or
*can* be loaded).
- Manages **model lifecycle** — unload cold models, reload on demand, pin
critical ones — using the mistral.rs
`/v1/models/{unload,reload,status}` HTTP API (PR #1828+).
- Translates between **OpenAI and Anthropic** request/response envelopes so
every client in the homelab speaks whichever dialect it prefers.
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
them as Prometheus counters/histograms.
## Architecture
```
┌──────────────┐ ┌──────────┐ ┌────────────┐ ┌────────────┐
│ Claude Code │ │ Zed/IDE │ │ Tidal / mm │ │ curl / etc │
└──────┬───────┘ └─────┬────┘ └──────┬─────┘ └──────┬─────┘
│ │ │ │
└────────────────┴──────┬───────┴───────────────┘
┌──────────▼──────────┐
│ cortex │
│ (cortex-gateway) │
│ │
│ Router · Metrics │
│ Evictor · Translate│
└──┬──────┬────────┬──┘
│ │ │
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
│ gpu-large │ │gpu-med │ │ gpu-small │
│ mistralrs │ │mistral │ │ mistralrs │
│ serve │ │rs serve│ │ serve │
│ :8080 │ │ :8080 │ │ :8080 │
└───────────┘ └────────┘ └───────────┘
private network (.internal)
```
### Crates
| Crate | Purpose |
|---|---|
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic request/response envelopes |
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, metrics exporter |
| `cortex-agent` | Per-node sidecar: polls local mistralrs, reports to gateway, handles restart/defrag |
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
## Node setup
Each GPU node runs `mistralrs serve` with a multi-model config. Models are
declared but start **unloaded** — mistral.rs lazy-loads on first request and
the gateway can explicitly unload/reload via the HTTP API.
Example node systemd unit:
```ini
# /etc/systemd/system/mistralrs.service
[Unit]
Description=mistral.rs inference server
After=network-online.target
Wants=network-online.target
[Service]
Type=simple
ExecStart=/usr/local/bin/mistralrs serve \
--from-config /etc/mistralrs/config.toml \
--port 8080
Restart=on-failure
RestartSec=5
Environment=CUDA_VISIBLE_DEVICES=0,1
[Install]
WantedBy=multi-user.target
```
## Gateway config
```toml
# cortex.toml
[gateway]
listen = "0.0.0.0:8000"
metrics_listen = "0.0.0.0:9100"
[eviction]
strategy = "lru" # lru | priority
defrag_after_cycles = 50
[[nodes]]
name = "gpu-large"
endpoint = "http://gpu-large.internal:8080"
vram_mb = 49_152 # e.g. 2x RTX 4090
pinned = ["your-org/large-model"]
[[nodes]]
name = "gpu-medium"
endpoint = "http://gpu-medium.internal:8080"
vram_mb = 24_576 # e.g. RTX 4090
pinned = ["your-org/medium-model"]
[[nodes]]
name = "gpu-small"
endpoint = "http://gpu-small.internal:8080"
vram_mb = 12_288 # e.g. RTX 3060
pinned = ["your-org/embedding-model"]
```
## Building
```sh
cargo build --release
```
## Running
```sh
# start the gateway
cortex serve --config cortex.toml
# check fleet status
cortex status
# list all models across nodes
curl http://localhost:8000/v1/models
```
## License
GPL-3.0

45
cortex.example.toml Normal file
View File

@@ -0,0 +1,45 @@
# cortex.example.toml — example configuration
#
# Copy to cortex.toml and adjust for your environment.
#
# Environment variable overrides use CORTEX_ prefix with __ separators:
# CORTEX_GATEWAY__LISTEN=0.0.0.0:9000
[gateway]
listen = "0.0.0.0:8000"
metrics_listen = "0.0.0.0:9100"
[eviction]
strategy = "lru"
# Restart mistralrs after this many load/unload cycles to defragment VRAM.
# Set to 0 to disable.
defrag_after_cycles = 50
# -- Nodes ---------------------------------------------------------------
# Each [[nodes]] entry declares a mistral.rs instance in the fleet.
# Models are discovered by polling the node's /v1/models endpoint.
# Pinned models are never evicted.
[[nodes]]
name = "gpu-large"
endpoint = "http://gpu-large.internal:8080"
vram_mb = 49152 # e.g. 2x RTX 4090 (48 GB combined)
pinned = [
"your-org/large-model",
]
[[nodes]]
name = "gpu-medium"
endpoint = "http://gpu-medium.internal:8080"
vram_mb = 24576 # e.g. RTX 4090 (24 GB)
pinned = [
"your-org/medium-model",
]
[[nodes]]
name = "gpu-small"
endpoint = "http://gpu-small.internal:8080"
vram_mb = 12288 # e.g. RTX 3060 (12 GB)
pinned = [
"your-org/embedding-model",
]

View File

@@ -0,0 +1,14 @@
[package]
name = "cortex-agent"
version.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
cortex-core.workspace = true
tokio.workspace = true
serde.workspace = true
serde_json.workspace = true
reqwest.workspace = true
tracing.workspace = true
anyhow.workspace = true

View File

@@ -0,0 +1,72 @@
//! Per-node agent sidecar.
//!
//! This is a future component that runs on each GPU node alongside mistralrs.
//! It handles:
//! - VRAM defragmentation (restarting the mistralrs systemd unit when the
//! gateway signals that lifecycle_cycles has exceeded the threshold)
//! - Local nvidia-smi polling for actual VRAM usage reporting
//! - Systemd unit management for mistralrs process restarts
//!
//! For now this is a stub. The gateway's poller + evictor handle the critical
//! path (model lifecycle via the mistralrs HTTP API). The agent adds
//! operational niceties that can be built incrementally.
/// Placeholder for agent configuration.
#[derive(Debug, Clone)]
pub struct AgentConfig {
/// The local mistralrs endpoint to monitor.
pub mistralrs_endpoint: String,
/// The systemd unit name for mistralrs (e.g. "mistralrs.service").
pub systemd_unit: String,
}
/// Restart the local mistralrs process via systemd.
/// This is the nuclear option for VRAM defragmentation.
pub async fn restart_mistralrs(config: &AgentConfig) -> anyhow::Result<()> {
tracing::warn!(
unit = %config.systemd_unit,
"restarting mistralrs for VRAM defragmentation"
);
let output = tokio::process::Command::new("systemctl")
.args(["restart", &config.systemd_unit])
.output()
.await?;
if output.status.success() {
tracing::info!(unit = %config.systemd_unit, "mistralrs restarted successfully");
Ok(())
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("systemctl restart failed: {stderr}");
}
}
/// Query nvidia-smi for current VRAM usage on this node.
/// Returns (used_mb, total_mb) for each GPU.
pub async fn query_vram() -> anyhow::Result<Vec<(u64, u64)>> {
let output = tokio::process::Command::new("nvidia-smi")
.args([
"--query-gpu=memory.used,memory.total",
"--format=csv,noheader,nounits",
])
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("nvidia-smi failed: {stderr}");
}
let stdout = String::from_utf8_lossy(&output.stdout);
let mut gpus = Vec::new();
for line in stdout.lines() {
let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
if parts.len() == 2 {
let used: u64 = parts[0].parse().unwrap_or(0);
let total: u64 = parts[1].parse().unwrap_or(0);
gpus.push((used, total));
}
}
Ok(gpus)
}

View File

@@ -0,0 +1 @@
pub mod agent;

View File

@@ -0,0 +1,20 @@
[package]
name = "cortex-cli"
version.workspace = true
edition.workspace = true
license.workspace = true
[[bin]]
name = "cortex"
path = "src/main.rs"
[dependencies]
cortex-core.workspace = true
cortex-gateway.workspace = true
tokio.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
anyhow.workspace = true
reqwest.workspace = true
serde_json.workspace = true
clap = { version = "4", features = ["derive"] }

View File

@@ -0,0 +1,112 @@
use anyhow::Result;
use clap::{Parser, Subcommand};
use cortex_core::config::GatewayConfig;
use tracing_subscriber::EnvFilter;
#[derive(Parser)]
#[command(name = "cortex")]
#[command(about = "Unified inference gateway for multi-node mistral.rs clusters")]
#[command(version)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
/// Start the gateway server.
Serve {
/// Path to the gateway config file.
#[arg(short, long, default_value = "cortex.toml")]
config: String,
},
/// Print the fleet status (models, nodes, health).
Status {
/// Gateway API endpoint to query.
#[arg(short, long, default_value = "http://localhost:8000")]
endpoint: String,
},
}
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing with env filter (e.g. RUST_LOG=cortex_gateway=debug).
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info,cortex_gateway=debug")),
)
.init();
let cli = Cli::parse();
match cli.command {
Commands::Serve { config } => {
let cfg = GatewayConfig::load(&config).map_err(|e| {
anyhow::anyhow!("failed to load config from '{config}': {e}")
})?;
tracing::info!(
nodes = cfg.nodes.len(),
listen = %cfg.gateway.listen,
"starting cortex"
);
// Install Prometheus metrics exporter on a separate port.
cortex_gateway::metrics::install(&cfg.gateway.metrics_listen)?;
cortex_gateway::run(cfg).await?;
}
Commands::Status { endpoint } => {
print_status(&endpoint).await?;
}
}
Ok(())
}
async fn print_status(endpoint: &str) -> Result<()> {
let client = reqwest::Client::new();
// Fetch health.
let health: serde_json::Value = client
.get(format!("{endpoint}/health"))
.send()
.await?
.json()
.await?;
println!("Fleet health: {}", serde_json::to_string_pretty(&health)?);
// Fetch models.
let models: serde_json::Value = client
.get(format!("{endpoint}/v1/models"))
.send()
.await?
.json()
.await?;
println!("\nModels:");
if let Some(data) = models.get("data").and_then(|d| d.as_array()) {
for model in data {
let id = model.get("id").and_then(|v| v.as_str()).unwrap_or("?");
let locations = model
.get("locations")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|l| {
let node = l.get("node")?.as_str()?;
let status = l.get("status")?.as_str()?;
Some(format!("{node}({status})"))
})
.collect::<Vec<_>>()
.join(", ")
})
.unwrap_or_default();
println!(" {id:40} {locations}");
}
}
Ok(())
}

View File

@@ -0,0 +1,15 @@
[package]
name = "cortex-core"
version.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
serde.workspace = true
serde_json.workspace = true
toml.workspace = true
figment.workspace = true
chrono.workspace = true
anyhow.workspace = true
thiserror.workspace = true
tracing.workspace = true

View File

@@ -0,0 +1,87 @@
//! Anthropic Messages API request and response types.
//!
//! These mirror the `/v1/messages` format used by the Anthropic API.
//! The gateway accepts these, translates to OpenAI format, proxies to
//! mistral.rs, then translates the response back.
use serde::{Deserialize, Serialize};
use serde_json::Value;
// ── Messages request ─────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessagesRequest {
pub model: String,
pub messages: Vec<AnthropicMessage>,
pub max_tokens: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<SystemPrompt>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum SystemPrompt {
Text(String),
Blocks(Vec<Value>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnthropicMessage {
pub role: String,
pub content: AnthropicContent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum AnthropicContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContentBlock {
#[serde(rename = "type")]
pub block_type: String,
#[serde(flatten)]
pub data: Value,
}
// ── Messages response ────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessagesResponse {
pub id: String,
#[serde(rename = "type")]
pub response_type: String,
pub role: String,
pub content: Vec<ContentBlock>,
pub model: String,
pub stop_reason: Option<String>,
pub usage: AnthropicUsage,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnthropicUsage {
pub input_tokens: u64,
pub output_tokens: u64,
}
// ── Streaming events ─────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamEvent {
#[serde(rename = "type")]
pub event_type: String,
#[serde(flatten)]
pub data: Value,
}

View File

@@ -0,0 +1,79 @@
use figment::{
Figment,
providers::{Env, Format, Toml},
};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewayConfig {
pub gateway: GatewaySettings,
pub eviction: EvictionSettings,
pub nodes: Vec<NodeConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewaySettings {
/// Address to listen on for API requests (e.g. "0.0.0.0:8000")
pub listen: String,
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:9100")
pub metrics_listen: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvictionSettings {
/// Eviction strategy: "lru" or "priority"
pub strategy: EvictionStrategy,
/// Restart the mistralrs process after this many load/unload cycles
/// to reclaim fragmented VRAM. 0 = never.
#[serde(default)]
pub defrag_after_cycles: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EvictionStrategy {
Lru,
Priority,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeConfig {
/// Human-readable node name (e.g. "gpu-large")
pub name: String,
/// Base URL of the mistralrs HTTP server (e.g. "http://gpu-large.internal:8080")
pub endpoint: String,
/// Total VRAM in MB across all GPUs on this node
pub vram_mb: u64,
/// Model IDs that should never be evicted from this node
#[serde(default)]
pub pinned: Vec<String>,
}
impl GatewayConfig {
/// Load configuration from a TOML file, with environment variable overrides.
/// Env vars are prefixed with `CORTEX_` and use `__` as a separator
/// (e.g. `CORTEX_GATEWAY__LISTEN=0.0.0.0:9000`).
pub fn load(path: impl AsRef<Path>) -> Result<Self, figment::Error> {
Figment::new()
.merge(Toml::file(path))
.merge(Env::prefixed("CORTEX_").split("__"))
.extract()
}
}
impl Default for GatewayConfig {
fn default() -> Self {
Self {
gateway: GatewaySettings {
listen: "0.0.0.0:8000".into(),
metrics_listen: "0.0.0.0:9100".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,
defrag_after_cycles: 50,
},
nodes: vec![],
}
}
}

View File

@@ -0,0 +1,6 @@
pub mod anthropic;
pub mod config;
pub mod metrics;
pub mod node;
pub mod openai;
pub mod translate;

View File

@@ -0,0 +1,23 @@
//! Request-level metrics captured by the gateway proxy layer.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
/// Metrics captured for a single proxied request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestMetrics {
pub timestamp: DateTime<Utc>,
pub model: String,
pub node: String,
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
/// Tokens per second for the generation phase.
pub tok_per_sec: f64,
/// Time from request start to first SSE chunk (streaming) or full response.
pub time_to_first_token_ms: u64,
/// Total request latency including proxy overhead.
pub total_latency_ms: u64,
/// Whether this request triggered a model load (cold start).
pub cold_start: bool,
}

View File

@@ -0,0 +1,74 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Runtime state of a single node in the fleet.
#[derive(Debug, Clone)]
pub struct NodeState {
pub name: String,
pub endpoint: String,
pub vram_mb: u64,
pub pinned: Vec<String>,
pub healthy: bool,
pub models: HashMap<String, ModelEntry>,
/// Number of load/unload cycles since last process restart.
pub lifecycle_cycles: u32,
pub last_poll: Option<DateTime<Utc>>,
}
/// A model registered on a node, with its runtime status.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelEntry {
pub id: String,
pub status: ModelStatus,
/// When this model was last used (for LRU eviction).
pub last_accessed: Option<DateTime<Utc>>,
/// Estimated VRAM usage in MB when loaded.
pub vram_estimate_mb: Option<u64>,
}
/// Model lifecycle status, matching the mistral.rs API.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ModelStatus {
Loaded,
Unloaded,
Reloading,
}
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
/// Includes which node(s) host this model and their status.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CortexModelEntry {
pub id: String,
pub object: String,
/// Which nodes have this model (and their status).
pub locations: Vec<ModelLocation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelLocation {
pub node: String,
pub status: ModelStatus,
pub vram_estimate_mb: Option<u64>,
}
/// Response from mistral.rs `GET /v1/models`.
/// This is the upstream format we parse when polling nodes.
#[derive(Debug, Clone, Deserialize)]
pub struct MistralModelsResponse {
pub data: Vec<MistralModelEntry>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MistralModelEntry {
pub id: String,
#[serde(default)]
pub status: Option<String>,
}
/// Request body for mistral.rs model lifecycle endpoints.
#[derive(Debug, Clone, Serialize)]
pub struct ModelLifecycleRequest {
pub model_id: String,
}

View File

@@ -0,0 +1,122 @@
//! OpenAI-compatible request and response types.
//!
//! These are a subset sufficient for chat completions (streaming + non-streaming).
//! Fields not relevant to proxying are captured as `serde_json::Value` via
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
//! extension field mistral.rs supports.
use serde::{Deserialize, Serialize};
use serde_json::Value;
// ── Chat completion request ──────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
/// All other fields (tools, response_format, mistral.rs extensions, etc.)
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: MessageContent,
#[serde(flatten)]
pub extra: Value,
}
/// Content can be a simple string or an array of content parts (for vision).
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<Value>),
}
// ── Chat completion response (non-streaming) ─────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatCompletionChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
#[serde(flatten)]
pub extra: Value,
}
// ── Streaming chunk ──────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChunkChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkChoice {
pub index: u32,
pub delta: Value,
pub finish_reason: Option<String>,
#[serde(flatten)]
pub extra: Value,
}
// ── Usage ────────────────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
}
// ── Models list response ─────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelObject>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelObject {
pub id: String,
pub object: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub owned_by: Option<String>,
/// Gateway extensions: which node(s) host this model.
#[serde(skip_serializing_if = "Option::is_none")]
pub locations: Option<Vec<super::node::ModelLocation>>,
#[serde(flatten)]
pub extra: Value,
}

View File

@@ -0,0 +1,114 @@
//! Translation between OpenAI and Anthropic request/response envelopes.
//!
//! This is a stateless transformation — no context is carried between requests.
use crate::anthropic::{
AnthropicContent, AnthropicMessage, AnthropicUsage, ContentBlock, MessagesRequest,
MessagesResponse, SystemPrompt,
};
use crate::openai::{
ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, Usage,
MessageContent,
};
use serde_json::{json, Value};
/// Convert an Anthropic Messages request into an OpenAI ChatCompletion request.
pub fn anthropic_to_openai(req: MessagesRequest) -> ChatCompletionRequest {
let mut messages = Vec::new();
// Anthropic `system` field becomes a system message.
if let Some(system) = req.system {
let content = match system {
SystemPrompt::Text(t) => t,
SystemPrompt::Blocks(blocks) => serde_json::to_string(&blocks).unwrap_or_default(),
};
messages.push(ChatMessage {
role: "system".into(),
content: MessageContent::Text(content),
extra: Value::Null,
});
}
// Convert message roles and content.
for msg in req.messages {
let content = match msg.content {
AnthropicContent::Text(t) => MessageContent::Text(t),
AnthropicContent::Blocks(blocks) => {
// For simple text-only blocks, extract the text.
// For mixed content (images, etc.), pass as parts.
if blocks.len() == 1 && blocks[0].block_type == "text" {
let text = blocks[0]
.data
.get("text")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
MessageContent::Text(text)
} else {
MessageContent::Parts(
blocks.into_iter().map(|b| json!(b)).collect(),
)
}
}
};
messages.push(ChatMessage {
role: msg.role,
content,
extra: Value::Null,
});
}
ChatCompletionRequest {
model: req.model,
messages,
temperature: req.temperature,
top_p: req.top_p,
max_tokens: Some(req.max_tokens),
stream: req.stream,
extra: req.extra,
}
}
/// Convert an OpenAI ChatCompletion response into an Anthropic Messages response.
pub fn openai_to_anthropic(resp: ChatCompletionResponse) -> MessagesResponse {
let choice = resp.choices.into_iter().next();
let (content_text, stop_reason) = match choice {
Some(c) => {
let text = match c.message.content {
MessageContent::Text(t) => t,
MessageContent::Parts(parts) => serde_json::to_string(&parts).unwrap_or_default(),
};
let stop = c.finish_reason.map(|r| match r.as_str() {
"stop" => "end_turn".to_string(),
"length" => "max_tokens".to_string(),
other => other.to_string(),
});
(text, stop)
}
None => (String::new(), None),
};
let usage = resp.usage.unwrap_or(Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
});
MessagesResponse {
id: resp.id,
response_type: "message".into(),
role: "assistant".into(),
content: vec![ContentBlock {
block_type: "text".into(),
data: json!({ "text": content_text }),
}],
model: resp.model,
stop_reason,
usage: AnthropicUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
},
extra: Value::Null,
}
}

View File

@@ -0,0 +1,25 @@
[package]
name = "cortex-gateway"
version.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
cortex-core.workspace = true
tokio.workspace = true
axum.workspace = true
tower.workspace = true
tower-http.workspace = true
serde.workspace = true
serde_json.workspace = true
reqwest.workspace = true
tracing.workspace = true
metrics.workspace = true
metrics-exporter-prometheus.workspace = true
chrono.workspace = true
anyhow.workspace = true
thiserror.workspace = true
futures.workspace = true
tokio-stream.workspace = true
eventsource-stream.workspace = true
bytes = "1"

View File

@@ -0,0 +1,106 @@
//! Model eviction logic.
//!
//! The evictor runs as a background task. When the router determines that a
//! model needs to be loaded on a node but VRAM is tight, it can request
//! eviction via a channel. The evictor then:
//! 1. Identifies the LRU model on that node (excluding pinned models)
//! 2. Calls `POST /v1/models/unload` on the node
//! 3. Increments the lifecycle cycle counter (for defrag tracking)
use crate::state::CortexState;
use cortex_core::node::{ModelLifecycleRequest, ModelStatus};
use std::sync::Arc;
use std::time::Duration;
/// Runs forever. Currently a placeholder that periodically checks for
/// eviction opportunities. In the future, this will be driven by a
/// channel from the router when VRAM pressure is detected.
pub async fn eviction_loop(fleet: Arc<CortexState>) {
// TODO: Replace this polling approach with a channel-driven design
// where the router sends eviction requests when it detects that a
// model load would exceed available VRAM.
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
// Placeholder: the actual eviction logic is in `evict_lru_on_node`,
// called on demand by the router.
let _ = &fleet; // suppress unused warning
}
}
/// Evict the least-recently-used model on a given node.
/// Returns the model ID that was evicted, or None if nothing could be evicted.
pub async fn evict_lru_on_node(
fleet: &CortexState,
node_name: &str,
) -> anyhow::Result<Option<String>> {
let (endpoint, candidate) = {
let nodes = fleet.nodes.read().await;
let Some(node) = nodes.get(node_name) else {
anyhow::bail!("node '{node_name}' not found");
};
// Find the loaded model with the oldest last_accessed, excluding pinned.
let candidate = node
.models
.values()
.filter(|m| m.status == ModelStatus::Loaded)
.filter(|m| !node.pinned.contains(&m.id))
.min_by_key(|m| m.last_accessed)
.map(|m| m.id.clone());
(node.endpoint.clone(), candidate)
};
let Some(model_id) = candidate else {
tracing::info!(node = node_name, "no evictable models found");
return Ok(None);
};
tracing::info!(node = node_name, model = %model_id, "evicting model");
let url = format!("{endpoint}/v1/models/unload");
let resp = fleet
.http_client
.post(&url)
.json(&ModelLifecycleRequest {
model_id: model_id.clone(),
})
.send()
.await?;
if resp.status().is_success() {
// Update local state.
let mut nodes = fleet.nodes.write().await;
if let Some(node) = nodes.get_mut(node_name) {
if let Some(entry) = node.models.get_mut(&model_id) {
entry.status = ModelStatus::Unloaded;
}
node.lifecycle_cycles += 1;
// Check if we should flag for defrag.
if fleet.eviction.defrag_after_cycles > 0
&& node.lifecycle_cycles >= fleet.eviction.defrag_after_cycles
{
tracing::warn!(
node = node_name,
cycles = node.lifecycle_cycles,
"VRAM fragmentation threshold reached — consider restarting mistralrs"
);
}
}
tracing::info!(node = node_name, model = %model_id, "model evicted");
Ok(Some(model_id))
} else {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
tracing::error!(
node = node_name,
model = %model_id,
status = %status,
body = %body,
"failed to evict model"
);
anyhow::bail!("eviction failed: {status} {body}");
}
}

View File

@@ -0,0 +1,207 @@
//! Axum HTTP handlers for the gateway API surface.
use crate::proxy;
use crate::router;
use crate::state::CortexState;
use axum::body::Bytes;
use axum::extract::State;
use axum::http::HeaderMap;
use axum::response::{IntoResponse, Json, Response};
use axum::routing::{get, post};
use axum::Router;
use cortex_core::node::{CortexModelEntry, ModelLocation};
use cortex_core::openai::ChatCompletionRequest;
use serde_json::{json, Value};
use std::sync::Arc;
pub fn api_routes() -> Router<Arc<CortexState>> {
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions))
.route("/v1/models", get(list_models))
.route("/v1/messages", post(anthropic_messages))
.route("/health", get(health))
.route("/", get(health))
}
/// `POST /v1/chat/completions` — proxy to the appropriate backend node.
async fn chat_completions(
State(fleet): State<Arc<CortexState>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let model_id = match extract_model(&body) {
Some(m) => m,
None => return error_response(400, "missing 'model' field in request body"),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => return error_response(404, &e.to_string()),
};
match proxy::forward_request(&fleet.http_client, &route, "/v1/chat/completions", headers, body)
.await
{
Ok(resp) => resp,
Err(e) => e.into_response(),
}
}
/// `POST /v1/completions` — proxy completions endpoint.
async fn completions(
State(fleet): State<Arc<CortexState>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let model_id = match extract_model(&body) {
Some(m) => m,
None => return error_response(400, "missing 'model' field in request body"),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => return error_response(404, &e.to_string()),
};
match proxy::forward_request(&fleet.http_client, &route, "/v1/completions", headers, body)
.await
{
Ok(resp) => resp,
Err(e) => e.into_response(),
}
}
/// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back.
async fn anthropic_messages(
State(fleet): State<Arc<CortexState>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
// Parse as Anthropic request.
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
Ok(r) => r,
Err(e) => return error_response(400, &format!("invalid Anthropic request: {e}")),
};
let model_id = anth_req.model.clone();
let is_streaming = anth_req.stream.unwrap_or(false);
// Translate to OpenAI format.
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
let openai_body = match serde_json::to_vec(&openai_req) {
Ok(b) => Bytes::from(b),
Err(e) => return error_response(500, &format!("translation error: {e}")),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => return error_response(404, &e.to_string()),
};
if is_streaming {
// TODO: streaming Anthropic translation requires converting SSE format.
// For now, proxy the OpenAI SSE stream directly (clients that can handle
// OpenAI SSE will work; full Anthropic SSE translation is a follow-up).
match proxy::forward_request(
&fleet.http_client,
&route,
"/v1/chat/completions",
headers,
openai_body,
)
.await
{
Ok(resp) => resp,
Err(e) => e.into_response(),
}
} else {
// Non-streaming: proxy, await full response, translate back.
match proxy::forward_request(
&fleet.http_client,
&route,
"/v1/chat/completions",
headers,
openai_body,
)
.await
{
Ok(resp) => {
// TODO: buffer response, parse as OpenAI ChatCompletionResponse,
// translate to Anthropic MessagesResponse.
// For now, return the OpenAI response as-is.
resp
}
Err(e) => e.into_response(),
}
}
}
/// `GET /v1/models` — aggregate models from all nodes.
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
let nodes = fleet.nodes.read().await;
let mut model_map: std::collections::HashMap<String, CortexModelEntry> =
std::collections::HashMap::new();
for node in nodes.values() {
for (model_id, entry) in &node.models {
let location = ModelLocation {
node: node.name.clone(),
status: entry.status,
vram_estimate_mb: entry.vram_estimate_mb,
};
model_map
.entry(model_id.clone())
.and_modify(|e| e.locations.push(location.clone()))
.or_insert_with(|| CortexModelEntry {
id: model_id.clone(),
object: "model".into(),
locations: vec![location],
});
}
}
let data: Vec<Value> = model_map
.values()
.map(|e| json!(e))
.collect();
Json(json!({
"object": "list",
"data": data,
}))
}
/// `GET /health`
async fn health(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
let nodes = fleet.nodes.read().await;
let healthy_count = nodes.values().filter(|n| n.healthy).count();
let total_count = nodes.len();
Json(json!({
"status": if healthy_count > 0 { "ok" } else { "degraded" },
"nodes": {
"healthy": healthy_count,
"total": total_count,
}
}))
}
// ── Helpers ──────────────────────────────────────────────────────────
fn extract_model(body: &[u8]) -> Option<String> {
let v: Value = serde_json::from_slice(body).ok()?;
v.get("model")?.as_str().map(|s| s.to_string())
}
fn error_response(status: u16, message: &str) -> Response {
let code = axum::http::StatusCode::from_u16(status)
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
let body = json!({
"error": {
"message": message,
"type": "gateway_error",
}
});
(code, Json(body)).into_response()
}

View File

@@ -0,0 +1,51 @@
pub mod evictor;
pub mod handlers;
pub mod metrics;
pub mod poller;
pub mod proxy;
pub mod router;
pub mod state;
use anyhow::Result;
use axum::Router;
use cortex_core::config::GatewayConfig;
use std::sync::Arc;
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
/// Build the Axum application router with all routes wired up.
pub fn build_app(fleet: Arc<state::CortexState>) -> Router {
Router::new()
.merge(handlers::api_routes())
.layer(CorsLayer::permissive())
.layer(TraceLayer::new_for_http())
.with_state(fleet)
}
/// Start the gateway: build state from config, spawn background tasks,
/// bind the HTTP server.
pub async fn run(config: GatewayConfig) -> Result<()> {
let fleet = Arc::new(state::CortexState::from_config(&config));
// Spawn the background poller that refreshes node/model status.
let poller_fleet = Arc::clone(&fleet);
tokio::spawn(async move {
poller::poll_loop(poller_fleet).await;
});
// Spawn the evictor (reacts to VRAM pressure events from the router).
let evictor_fleet = Arc::clone(&fleet);
tokio::spawn(async move {
evictor::eviction_loop(evictor_fleet).await;
});
let app = build_app(Arc::clone(&fleet));
let listen_addr = config.gateway.listen.parse::<std::net::SocketAddr>()?;
tracing::info!("cortex listening on {listen_addr}");
let listener = tokio::net::TcpListener::bind(listen_addr).await?;
axum::serve(listener, app).await?;
Ok(())
}

View File

@@ -0,0 +1,55 @@
//! Prometheus metrics exporter.
//!
//! Runs on a separate port from the main API, exposing `/metrics`
//! in Prometheus text format.
use anyhow::Result;
use metrics_exporter_prometheus::PrometheusBuilder;
use std::net::SocketAddr;
/// Install the Prometheus metrics recorder and return a handle.
/// The `/metrics` endpoint is served by the exporter's built-in HTTP server.
pub fn install(listen: &str) -> Result<()> {
let addr: SocketAddr = listen.parse()?;
PrometheusBuilder::new()
.with_http_listener(addr)
.install()
.map_err(|e| anyhow::anyhow!("failed to install Prometheus exporter: {e}"))?;
tracing::info!("prometheus metrics exporter on {addr}");
// Register histograms and counters used by the proxy layer.
// The `metrics` crate lazily creates metrics on first use, but
// describing them up front gives Prometheus proper HELP/TYPE lines.
metrics::describe_histogram!(
"cortex_request_duration_seconds",
"Total request latency in seconds"
);
metrics::describe_histogram!(
"cortex_time_to_first_token_seconds",
"Time to first token in seconds"
);
metrics::describe_histogram!(
"cortex_tokens_per_second",
"Generation throughput in tokens per second"
);
metrics::describe_counter!(
"cortex_requests_total",
"Total number of proxied requests"
);
metrics::describe_counter!(
"cortex_request_errors_total",
"Total number of failed proxy requests"
);
metrics::describe_counter!(
"cortex_evictions_total",
"Total number of model evictions"
);
metrics::describe_counter!(
"cortex_cold_starts_total",
"Total number of cold-start model loads"
);
Ok(())
}

View File

@@ -0,0 +1,103 @@
//! Background poller that periodically queries each node's `/v1/models`
//! endpoint to refresh the fleet state.
use crate::state::CortexState;
use chrono::Utc;
use cortex_core::node::{MistralModelsResponse, ModelEntry, ModelStatus};
use std::sync::Arc;
use std::time::Duration;
const POLL_INTERVAL: Duration = Duration::from_secs(10);
/// Runs forever, polling all nodes on a fixed interval.
pub async fn poll_loop(fleet: Arc<CortexState>) {
loop {
for nc in &fleet.node_configs {
poll_node(&fleet, &nc.name, &nc.endpoint).await;
}
tokio::time::sleep(POLL_INTERVAL).await;
}
}
async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
let url = format!("{endpoint}/v1/models");
let result = fleet
.http_client
.get(&url)
.timeout(Duration::from_secs(5))
.send()
.await;
let mut nodes = fleet.nodes.write().await;
let Some(node) = nodes.get_mut(name) else {
return;
};
match result {
Ok(resp) if resp.status().is_success() => {
match resp.json::<MistralModelsResponse>().await {
Ok(models_resp) => {
// Merge upstream model list into our state, preserving
// our local metadata (last_accessed, vram_estimate).
let mut seen = std::collections::HashSet::new();
for upstream in &models_resp.data {
seen.insert(upstream.id.clone());
let status = parse_status(upstream.status.as_deref());
node.models
.entry(upstream.id.clone())
.and_modify(|e| {
e.status = status;
})
.or_insert_with(|| ModelEntry {
id: upstream.id.clone(),
status,
last_accessed: None,
vram_estimate_mb: None,
});
}
// Remove models that are no longer reported by the node
// (e.g. after a config change / restart).
node.models.retain(|id, _| seen.contains(id));
node.healthy = true;
node.last_poll = Some(Utc::now());
tracing::debug!(
node = name,
models = models_resp.data.len(),
"poll ok"
);
}
Err(e) => {
tracing::warn!(node = name, error = %e, "failed to parse /v1/models response");
node.healthy = false;
}
}
}
Ok(resp) => {
tracing::warn!(
node = name,
status = %resp.status(),
"node returned non-success status"
);
node.healthy = false;
}
Err(e) => {
tracing::warn!(node = name, error = %e, "failed to reach node");
node.healthy = false;
}
}
}
fn parse_status(s: Option<&str>) -> ModelStatus {
match s {
Some("loaded") => ModelStatus::Loaded,
Some("unloaded") => ModelStatus::Unloaded,
Some("reloading") => ModelStatus::Reloading,
// If the status field is absent, assume loaded (older mistral.rs versions
// may not include it).
_ => ModelStatus::Loaded,
}
}

View File

@@ -0,0 +1,82 @@
//! Streaming HTTP reverse proxy to mistral.rs backends.
//!
//! For streaming requests, SSE chunks are forwarded as they arrive.
//! The proxy captures timing information for metrics but does not
//! buffer the full response.
use crate::router::RouteDecision;
use anyhow::Result;
use axum::body::Body;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use reqwest::Client;
/// Proxy a request body to the resolved backend node and stream the response.
pub async fn forward_request(
client: &Client,
route: &RouteDecision,
path: &str,
headers: HeaderMap,
body: bytes::Bytes,
) -> Result<Response, ProxyError> {
let url = format!("{}{}", route.endpoint, path);
tracing::info!(
node = %route.node_name,
url = %url,
cold_start = route.cold_start,
"proxying request"
);
let mut req_builder = client.post(&url).body(body);
// Forward relevant headers.
for (key, value) in headers.iter() {
if key == "host" || key == "content-length" {
continue; // reqwest sets these
}
req_builder = req_builder.header(key, value);
}
let upstream_resp = req_builder.send().await.map_err(ProxyError::Upstream)?;
let status = StatusCode::from_u16(upstream_resp.status().as_u16())
.unwrap_or(StatusCode::BAD_GATEWAY);
let resp_headers = upstream_resp.headers().clone();
let stream = upstream_resp.bytes_stream();
let body = Body::from_stream(stream);
let mut response = Response::builder().status(status);
for (key, value) in resp_headers.iter() {
response = response.header(key, value);
}
response
.body(body)
.map_err(|e| ProxyError::ResponseBuild(e.to_string()))
}
#[derive(Debug, thiserror::Error)]
pub enum ProxyError {
#[error("upstream request failed: {0}")]
Upstream(reqwest::Error),
#[error("failed to build response: {0}")]
ResponseBuild(String),
}
impl IntoResponse for ProxyError {
fn into_response(self) -> Response {
let status = match &self {
ProxyError::Upstream(_) => StatusCode::BAD_GATEWAY,
ProxyError::ResponseBuild(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
let body = serde_json::json!({
"error": {
"message": self.to_string(),
"type": "proxy_error",
}
});
(status, axum::Json(body)).into_response()
}
}

View File

@@ -0,0 +1,74 @@
//! Model-to-node routing logic.
//!
//! Given a model ID from an inbound request, determine which node should
//! handle it. Priority:
//! 1. Node where the model is currently `Loaded`
//! 2. Node where the model is `Unloaded` (will lazy-load on request)
//! 3. Error: model not found on any node
use crate::state::CortexState;
use cortex_core::node::ModelStatus;
use std::sync::Arc;
/// The routing decision: which node endpoint to proxy the request to.
#[derive(Debug, Clone)]
pub struct RouteDecision {
pub node_name: String,
pub endpoint: String,
/// Whether the model will need to load (cold start).
pub cold_start: bool,
}
#[derive(Debug, thiserror::Error)]
pub enum RouteError {
#[error("model '{0}' not found on any node")]
ModelNotFound(String),
#[error("no healthy nodes available")]
NoHealthyNodes,
}
/// Resolve which node should serve a request for the given model.
pub async fn resolve(fleet: &Arc<CortexState>, model_id: &str) -> Result<RouteDecision, RouteError> {
let nodes = fleet.nodes.read().await;
// Pass 1: find a node where the model is already loaded.
let mut loaded_candidate = None;
let mut unloaded_candidate = None;
for node in nodes.values() {
if !node.healthy {
continue;
}
if let Some(entry) = node.models.get(model_id) {
match entry.status {
ModelStatus::Loaded | ModelStatus::Reloading => {
loaded_candidate = Some(RouteDecision {
node_name: node.name.clone(),
endpoint: node.endpoint.clone(),
cold_start: false,
});
break; // loaded is best, stop searching
}
ModelStatus::Unloaded => {
if unloaded_candidate.is_none() {
unloaded_candidate = Some(RouteDecision {
node_name: node.name.clone(),
endpoint: node.endpoint.clone(),
cold_start: true,
});
}
}
}
}
}
loaded_candidate
.or(unloaded_candidate)
.ok_or_else(|| {
if nodes.values().any(|n| n.healthy) {
RouteError::ModelNotFound(model_id.to_string())
} else {
RouteError::NoHealthyNodes
}
})
}

View File

@@ -0,0 +1,43 @@
use cortex_core::config::{EvictionSettings, GatewayConfig, NodeConfig};
use cortex_core::node::NodeState;
use std::collections::HashMap;
use tokio::sync::RwLock;
/// Shared fleet state, protected by a RwLock for concurrent reader access.
pub struct CortexState {
pub nodes: RwLock<HashMap<String, NodeState>>,
pub node_configs: Vec<NodeConfig>,
pub eviction: EvictionSettings,
pub http_client: reqwest::Client,
}
impl CortexState {
pub fn from_config(config: &GatewayConfig) -> Self {
let mut nodes = HashMap::new();
for nc in &config.nodes {
nodes.insert(
nc.name.clone(),
NodeState {
name: nc.name.clone(),
endpoint: nc.endpoint.clone(),
vram_mb: nc.vram_mb,
pinned: nc.pinned.clone(),
healthy: false, // will be set by first poll
models: HashMap::new(),
lifecycle_cycles: 0,
last_poll: None,
},
);
}
Self {
nodes: RwLock::new(nodes),
node_configs: config.nodes.clone(),
eviction: config.eviction.clone(),
http_client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()
.expect("failed to build HTTP client"),
}
}
}