feat: scaffold cortex workspace
Rust reverse-proxy for multi-node mistral.rs inference clusters. Includes crate structure (cortex-core, cortex-gateway, cortex-agent, cortex-cli), config loading, OpenAI/Anthropic translation stubs, model routing, eviction, polling, and streaming proxy scaffolding. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
/target
|
||||
*.swp
|
||||
*.swo
|
||||
.idea/
|
||||
.vscode/
|
||||
cortex.toml
|
||||
141
CLAUDE.md
Normal file
141
CLAUDE.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# CLAUDE.md — cortex
|
||||
|
||||
## Project overview
|
||||
|
||||
cortex is a Rust reverse-proxy that sits in front of multiple
|
||||
mistral.rs inference nodes and presents a unified OpenAI + Anthropic
|
||||
compatible API surface. It handles model routing, lifecycle management
|
||||
(load/unload/evict), request translation, and metrics collection.
|
||||
|
||||
## Repository layout
|
||||
|
||||
```
|
||||
cortex/
|
||||
├── Cargo.toml # workspace root
|
||||
├── cortex.toml # example gateway config
|
||||
├── README.md
|
||||
├── CLAUDE.md # ← you are here
|
||||
├── crates/
|
||||
│ ├── cortex-core/ # shared types, config, envelopes
|
||||
│ │ └── src/
|
||||
│ │ ├── lib.rs
|
||||
│ │ ├── config.rs # figment-based config structs
|
||||
│ │ ├── node.rs # NodeState, ModelStatus
|
||||
│ │ ├── openai.rs # OpenAI request/response types
|
||||
│ │ ├── anthropic.rs # Anthropic request/response types
|
||||
│ │ ├── translate.rs # OpenAI <-> Anthropic translation
|
||||
│ │ └── metrics.rs # RequestMetrics, histogram helpers
|
||||
│ ├── cortex-gateway/ # the HTTP proxy server
|
||||
│ │ └── src/
|
||||
│ │ ├── lib.rs
|
||||
│ │ ├── state.rs # CortexState: Arc<RwLock<...>>
|
||||
│ │ ├── router.rs # model -> node routing logic
|
||||
│ │ ├── proxy.rs # streaming HTTP proxy to backends
|
||||
│ │ ├── evictor.rs # LRU/priority eviction logic
|
||||
│ │ ├── poller.rs # background task polling node status
|
||||
│ │ ├── handlers.rs # axum handlers (chat, completions, models, etc.)
|
||||
│ │ └── metrics.rs # prometheus exporter endpoint
|
||||
│ ├── cortex-agent/ # per-node sidecar (future: defrag, restart)
|
||||
│ │ └── src/
|
||||
│ │ ├── lib.rs
|
||||
│ │ └── agent.rs # local node management
|
||||
│ └── cortex-cli/ # CLI entrypoint
|
||||
│ └── src/
|
||||
│ └── main.rs
|
||||
└── tests/ # integration tests (future)
|
||||
```
|
||||
|
||||
## Key design decisions
|
||||
|
||||
### mistral.rs HTTP API for model lifecycle
|
||||
mistral.rs (v0.8+) supports dynamic model loading/unloading at runtime:
|
||||
- `POST /v1/models/unload {"model_id": "..."}` — frees VRAM, preserves config
|
||||
- `POST /v1/models/reload {"model_id": "..."}` — explicitly reload
|
||||
- `POST /v1/models/status {"model_id": "..."}` — loaded/unloaded/reloading
|
||||
- `GET /v1/models` — lists all models with status field
|
||||
- Lazy loading: requests to unloaded models trigger automatic reload
|
||||
|
||||
The gateway does NOT manage systemd units for model swaps. It calls these
|
||||
HTTP endpoints directly. The only systemd interaction is for full-process
|
||||
restarts after VRAM fragmentation accumulates (defrag_after_cycles).
|
||||
|
||||
### Streaming proxy
|
||||
Chat completions are proxied as SSE streams. The gateway must:
|
||||
1. Parse the inbound request to extract the model name
|
||||
2. Route to the correct backend node
|
||||
3. Stream the response back, capturing token timing for metrics
|
||||
4. NOT buffer the full response — true streaming passthrough
|
||||
|
||||
### Anthropic translation
|
||||
When a request arrives at `/v1/messages` (Anthropic format), the gateway
|
||||
translates it to OpenAI format before proxying to mistral.rs, then
|
||||
translates the response back. This is stateless envelope transformation.
|
||||
|
||||
### Eviction
|
||||
The evictor runs as a background task. Before loading a model on a node
|
||||
where VRAM is tight:
|
||||
1. Check if the model is already loaded elsewhere → route there instead
|
||||
2. Find the LRU model on the target node (excluding pinned models)
|
||||
3. Call `/v1/models/unload` on that model
|
||||
4. The incoming request's lazy-load triggers the new model load
|
||||
|
||||
### Metrics
|
||||
Per-request: model, node, prompt_tokens, completion_tokens, total_tokens,
|
||||
tok_per_sec, time_to_first_token_ms, total_latency_ms.
|
||||
Exposed as Prometheus histograms/counters on a separate port.
|
||||
|
||||
## Tech stack
|
||||
|
||||
- **Rust 2024 edition** — workspace with 4 crates
|
||||
- **Axum 0.8** — HTTP framework (same as mistral.rs itself)
|
||||
- **reqwest** — HTTP client for proxying to backends
|
||||
- **figment** — config loading (TOML + env vars)
|
||||
- **tokio** — async runtime
|
||||
- **metrics + metrics-exporter-prometheus** — observability
|
||||
- **tracing** — structured logging
|
||||
|
||||
## Build commands
|
||||
|
||||
```sh
|
||||
cargo build --release # build all crates
|
||||
cargo run -p cortex-cli -- serve # run the gateway
|
||||
cargo test # run all tests
|
||||
cargo clippy --workspace # lint
|
||||
```
|
||||
|
||||
## Environment
|
||||
|
||||
- Targets Fedora 43 (systemd, SELinux enforcing)
|
||||
- Nodes communicate over a private network (e.g. WireGuard mesh)
|
||||
- One or more GPU nodes running mistral.rs on port 8080
|
||||
- Optionally a metrics-only node (no GPU) for Prometheus/Grafana
|
||||
- Each node runs `mistralrs serve` on port 8080
|
||||
- Gateway listens on port 8000 (API) and 9100 (metrics)
|
||||
- TLS terminated at gateway or via nginx; internal traffic is plaintext over WireGuard
|
||||
|
||||
## Conventions
|
||||
|
||||
- Error handling: `anyhow` for binaries, `thiserror` for library crates
|
||||
- No `unwrap()` in library code; `expect()` only with clear rationale
|
||||
- All public types derive `Debug, Clone, Serialize, Deserialize` where sensible
|
||||
- Config structs use `figment` with TOML as primary source, env vars as override
|
||||
- Prefer `Arc<RwLock<...>>` for shared fleet state; minimize lock duration
|
||||
- SSE streaming uses `tokio_stream` + `eventsource-stream` for parsing
|
||||
- Log at `info` for request routing, `debug` for proxy details, `warn` for
|
||||
eviction and node health, `error` for proxy failures
|
||||
|
||||
## Current status
|
||||
|
||||
**Scaffold phase** — crate structure, types, and handler stubs are in place.
|
||||
The following needs implementation:
|
||||
|
||||
1. **cortex-core**: Flesh out OpenAI/Anthropic envelope types with all fields
|
||||
needed for chat completions (streaming + non-streaming)
|
||||
2. **cortex-gateway/proxy.rs**: Implement streaming HTTP proxy with SSE passthrough
|
||||
3. **cortex-gateway/router.rs**: Model-to-node routing with fallback to least-loaded
|
||||
4. **cortex-gateway/evictor.rs**: LRU eviction with pinning support
|
||||
5. **cortex-gateway/poller.rs**: Background polling of node `/v1/models` endpoints
|
||||
6. **cortex-gateway/handlers.rs**: Wire up axum routes to proxy logic
|
||||
7. **cortex-core/translate.rs**: OpenAI <-> Anthropic request/response translation
|
||||
8. **cortex-agent**: Sidecar for VRAM defrag restarts (lower priority)
|
||||
9. **Integration tests**: Mock mistralrs backends, test routing + eviction
|
||||
2787
Cargo.lock
generated
Normal file
2787
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
57
Cargo.toml
Normal file
57
Cargo.toml
Normal file
@@ -0,0 +1,57 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/cortex-core",
|
||||
"crates/cortex-gateway",
|
||||
"crates/cortex-agent",
|
||||
"crates/cortex-cli",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
license = "GPL-3.0"
|
||||
repository = "https://git.lair.cafe/helexa/cortex"
|
||||
|
||||
[workspace.dependencies]
|
||||
# async runtime
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
# web framework
|
||||
axum = { version = "0.8", features = ["macros"] }
|
||||
tower = "0.5"
|
||||
tower-http = { version = "0.6", features = ["cors", "trace", "timeout"] }
|
||||
|
||||
# serialization
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
toml = "0.8"
|
||||
|
||||
# http client (for proxying to mistralrs backends)
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
|
||||
# observability
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||
metrics = "0.24"
|
||||
metrics-exporter-prometheus = "0.16"
|
||||
|
||||
# time
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
|
||||
# config
|
||||
figment = { version = "0.10", features = ["toml", "env"] }
|
||||
|
||||
# error handling
|
||||
anyhow = "1"
|
||||
thiserror = "2"
|
||||
|
||||
# futures / streams (for SSE proxying)
|
||||
futures = "0.3"
|
||||
tokio-stream = "0.1"
|
||||
eventsource-stream = "0.2"
|
||||
|
||||
# workspace crates
|
||||
cortex-core = { path = "crates/cortex-core" }
|
||||
cortex-gateway = { path = "crates/cortex-gateway" }
|
||||
cortex-agent = { path = "crates/cortex-agent" }
|
||||
138
README.md
Normal file
138
README.md
Normal file
@@ -0,0 +1,138 @@
|
||||
# cortex
|
||||
|
||||
A Rust reverse-proxy and fleet management layer for multi-node
|
||||
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) inference clusters.
|
||||
|
||||
## Problem
|
||||
|
||||
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
|
||||
model affinities) requires a unified API surface that:
|
||||
|
||||
- Presents a **single `/v1/models` catalogue** merging every model across every
|
||||
node.
|
||||
- **Routes requests** to the correct node based on where a model is loaded (or
|
||||
*can* be loaded).
|
||||
- Manages **model lifecycle** — unload cold models, reload on demand, pin
|
||||
critical ones — using the mistral.rs
|
||||
`/v1/models/{unload,reload,status}` HTTP API (PR #1828+).
|
||||
- Translates between **OpenAI and Anthropic** request/response envelopes so
|
||||
every client in the homelab speaks whichever dialect it prefers.
|
||||
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
|
||||
them as Prometheus counters/histograms.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌──────────────┐ ┌──────────┐ ┌────────────┐ ┌────────────┐
|
||||
│ Claude Code │ │ Zed/IDE │ │ Tidal / mm │ │ curl / etc │
|
||||
└──────┬───────┘ └─────┬────┘ └──────┬─────┘ └──────┬─────┘
|
||||
│ │ │ │
|
||||
└────────────────┴──────┬───────┴───────────────┘
|
||||
│
|
||||
┌──────────▼──────────┐
|
||||
│ cortex │
|
||||
│ (cortex-gateway) │
|
||||
│ │
|
||||
│ Router · Metrics │
|
||||
│ Evictor · Translate│
|
||||
└──┬──────┬────────┬──┘
|
||||
│ │ │
|
||||
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
|
||||
│ gpu-large │ │gpu-med │ │ gpu-small │
|
||||
│ mistralrs │ │mistral │ │ mistralrs │
|
||||
│ serve │ │rs serve│ │ serve │
|
||||
│ :8080 │ │ :8080 │ │ :8080 │
|
||||
└───────────┘ └────────┘ └───────────┘
|
||||
private network (.internal)
|
||||
```
|
||||
|
||||
### Crates
|
||||
|
||||
| Crate | Purpose |
|
||||
|---|---|
|
||||
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic request/response envelopes |
|
||||
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, metrics exporter |
|
||||
| `cortex-agent` | Per-node sidecar: polls local mistralrs, reports to gateway, handles restart/defrag |
|
||||
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
|
||||
|
||||
## Node setup
|
||||
|
||||
Each GPU node runs `mistralrs serve` with a multi-model config. Models are
|
||||
declared but start **unloaded** — mistral.rs lazy-loads on first request and
|
||||
the gateway can explicitly unload/reload via the HTTP API.
|
||||
|
||||
Example node systemd unit:
|
||||
|
||||
```ini
|
||||
# /etc/systemd/system/mistralrs.service
|
||||
[Unit]
|
||||
Description=mistral.rs inference server
|
||||
After=network-online.target
|
||||
Wants=network-online.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart=/usr/local/bin/mistralrs serve \
|
||||
--from-config /etc/mistralrs/config.toml \
|
||||
--port 8080
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
Environment=CUDA_VISIBLE_DEVICES=0,1
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
## Gateway config
|
||||
|
||||
```toml
|
||||
# cortex.toml
|
||||
[gateway]
|
||||
listen = "0.0.0.0:8000"
|
||||
metrics_listen = "0.0.0.0:9100"
|
||||
|
||||
[eviction]
|
||||
strategy = "lru" # lru | priority
|
||||
defrag_after_cycles = 50
|
||||
|
||||
[[nodes]]
|
||||
name = "gpu-large"
|
||||
endpoint = "http://gpu-large.internal:8080"
|
||||
vram_mb = 49_152 # e.g. 2x RTX 4090
|
||||
pinned = ["your-org/large-model"]
|
||||
|
||||
[[nodes]]
|
||||
name = "gpu-medium"
|
||||
endpoint = "http://gpu-medium.internal:8080"
|
||||
vram_mb = 24_576 # e.g. RTX 4090
|
||||
pinned = ["your-org/medium-model"]
|
||||
|
||||
[[nodes]]
|
||||
name = "gpu-small"
|
||||
endpoint = "http://gpu-small.internal:8080"
|
||||
vram_mb = 12_288 # e.g. RTX 3060
|
||||
pinned = ["your-org/embedding-model"]
|
||||
```
|
||||
|
||||
## Building
|
||||
|
||||
```sh
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
```sh
|
||||
# start the gateway
|
||||
cortex serve --config cortex.toml
|
||||
|
||||
# check fleet status
|
||||
cortex status
|
||||
|
||||
# list all models across nodes
|
||||
curl http://localhost:8000/v1/models
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
GPL-3.0
|
||||
45
cortex.example.toml
Normal file
45
cortex.example.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
# cortex.example.toml — example configuration
|
||||
#
|
||||
# Copy to cortex.toml and adjust for your environment.
|
||||
#
|
||||
# Environment variable overrides use CORTEX_ prefix with __ separators:
|
||||
# CORTEX_GATEWAY__LISTEN=0.0.0.0:9000
|
||||
|
||||
[gateway]
|
||||
listen = "0.0.0.0:8000"
|
||||
metrics_listen = "0.0.0.0:9100"
|
||||
|
||||
[eviction]
|
||||
strategy = "lru"
|
||||
# Restart mistralrs after this many load/unload cycles to defragment VRAM.
|
||||
# Set to 0 to disable.
|
||||
defrag_after_cycles = 50
|
||||
|
||||
# -- Nodes ---------------------------------------------------------------
|
||||
# Each [[nodes]] entry declares a mistral.rs instance in the fleet.
|
||||
# Models are discovered by polling the node's /v1/models endpoint.
|
||||
# Pinned models are never evicted.
|
||||
|
||||
[[nodes]]
|
||||
name = "gpu-large"
|
||||
endpoint = "http://gpu-large.internal:8080"
|
||||
vram_mb = 49152 # e.g. 2x RTX 4090 (48 GB combined)
|
||||
pinned = [
|
||||
"your-org/large-model",
|
||||
]
|
||||
|
||||
[[nodes]]
|
||||
name = "gpu-medium"
|
||||
endpoint = "http://gpu-medium.internal:8080"
|
||||
vram_mb = 24576 # e.g. RTX 4090 (24 GB)
|
||||
pinned = [
|
||||
"your-org/medium-model",
|
||||
]
|
||||
|
||||
[[nodes]]
|
||||
name = "gpu-small"
|
||||
endpoint = "http://gpu-small.internal:8080"
|
||||
vram_mb = 12288 # e.g. RTX 3060 (12 GB)
|
||||
pinned = [
|
||||
"your-org/embedding-model",
|
||||
]
|
||||
14
crates/cortex-agent/Cargo.toml
Normal file
14
crates/cortex-agent/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "cortex-agent"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cortex-core.workspace = true
|
||||
tokio.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
reqwest.workspace = true
|
||||
tracing.workspace = true
|
||||
anyhow.workspace = true
|
||||
72
crates/cortex-agent/src/agent.rs
Normal file
72
crates/cortex-agent/src/agent.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
//! Per-node agent sidecar.
|
||||
//!
|
||||
//! This is a future component that runs on each GPU node alongside mistralrs.
|
||||
//! It handles:
|
||||
//! - VRAM defragmentation (restarting the mistralrs systemd unit when the
|
||||
//! gateway signals that lifecycle_cycles has exceeded the threshold)
|
||||
//! - Local nvidia-smi polling for actual VRAM usage reporting
|
||||
//! - Systemd unit management for mistralrs process restarts
|
||||
//!
|
||||
//! For now this is a stub. The gateway's poller + evictor handle the critical
|
||||
//! path (model lifecycle via the mistralrs HTTP API). The agent adds
|
||||
//! operational niceties that can be built incrementally.
|
||||
|
||||
/// Placeholder for agent configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentConfig {
|
||||
/// The local mistralrs endpoint to monitor.
|
||||
pub mistralrs_endpoint: String,
|
||||
/// The systemd unit name for mistralrs (e.g. "mistralrs.service").
|
||||
pub systemd_unit: String,
|
||||
}
|
||||
|
||||
/// Restart the local mistralrs process via systemd.
|
||||
/// This is the nuclear option for VRAM defragmentation.
|
||||
pub async fn restart_mistralrs(config: &AgentConfig) -> anyhow::Result<()> {
|
||||
tracing::warn!(
|
||||
unit = %config.systemd_unit,
|
||||
"restarting mistralrs for VRAM defragmentation"
|
||||
);
|
||||
|
||||
let output = tokio::process::Command::new("systemctl")
|
||||
.args(["restart", &config.systemd_unit])
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if output.status.success() {
|
||||
tracing::info!(unit = %config.systemd_unit, "mistralrs restarted successfully");
|
||||
Ok(())
|
||||
} else {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("systemctl restart failed: {stderr}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Query nvidia-smi for current VRAM usage on this node.
|
||||
/// Returns (used_mb, total_mb) for each GPU.
|
||||
pub async fn query_vram() -> anyhow::Result<Vec<(u64, u64)>> {
|
||||
let output = tokio::process::Command::new("nvidia-smi")
|
||||
.args([
|
||||
"--query-gpu=memory.used,memory.total",
|
||||
"--format=csv,noheader,nounits",
|
||||
])
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("nvidia-smi failed: {stderr}");
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let mut gpus = Vec::new();
|
||||
for line in stdout.lines() {
|
||||
let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
|
||||
if parts.len() == 2 {
|
||||
let used: u64 = parts[0].parse().unwrap_or(0);
|
||||
let total: u64 = parts[1].parse().unwrap_or(0);
|
||||
gpus.push((used, total));
|
||||
}
|
||||
}
|
||||
Ok(gpus)
|
||||
}
|
||||
1
crates/cortex-agent/src/lib.rs
Normal file
1
crates/cortex-agent/src/lib.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod agent;
|
||||
20
crates/cortex-cli/Cargo.toml
Normal file
20
crates/cortex-cli/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "cortex-cli"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "cortex"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
cortex-core.workspace = true
|
||||
cortex-gateway.workspace = true
|
||||
tokio.workspace = true
|
||||
tracing.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
anyhow.workspace = true
|
||||
reqwest.workspace = true
|
||||
serde_json.workspace = true
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
112
crates/cortex-cli/src/main.rs
Normal file
112
crates/cortex-cli/src/main.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, Subcommand};
|
||||
use cortex_core::config::GatewayConfig;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "cortex")]
|
||||
#[command(about = "Unified inference gateway for multi-node mistral.rs clusters")]
|
||||
#[command(version)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Commands {
|
||||
/// Start the gateway server.
|
||||
Serve {
|
||||
/// Path to the gateway config file.
|
||||
#[arg(short, long, default_value = "cortex.toml")]
|
||||
config: String,
|
||||
},
|
||||
/// Print the fleet status (models, nodes, health).
|
||||
Status {
|
||||
/// Gateway API endpoint to query.
|
||||
#[arg(short, long, default_value = "http://localhost:8000")]
|
||||
endpoint: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize tracing with env filter (e.g. RUST_LOG=cortex_gateway=debug).
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("info,cortex_gateway=debug")),
|
||||
)
|
||||
.init();
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
match cli.command {
|
||||
Commands::Serve { config } => {
|
||||
let cfg = GatewayConfig::load(&config).map_err(|e| {
|
||||
anyhow::anyhow!("failed to load config from '{config}': {e}")
|
||||
})?;
|
||||
|
||||
tracing::info!(
|
||||
nodes = cfg.nodes.len(),
|
||||
listen = %cfg.gateway.listen,
|
||||
"starting cortex"
|
||||
);
|
||||
|
||||
// Install Prometheus metrics exporter on a separate port.
|
||||
cortex_gateway::metrics::install(&cfg.gateway.metrics_listen)?;
|
||||
|
||||
cortex_gateway::run(cfg).await?;
|
||||
}
|
||||
Commands::Status { endpoint } => {
|
||||
print_status(&endpoint).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn print_status(endpoint: &str) -> Result<()> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Fetch health.
|
||||
let health: serde_json::Value = client
|
||||
.get(format!("{endpoint}/health"))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
println!("Fleet health: {}", serde_json::to_string_pretty(&health)?);
|
||||
|
||||
// Fetch models.
|
||||
let models: serde_json::Value = client
|
||||
.get(format!("{endpoint}/v1/models"))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
println!("\nModels:");
|
||||
if let Some(data) = models.get("data").and_then(|d| d.as_array()) {
|
||||
for model in data {
|
||||
let id = model.get("id").and_then(|v| v.as_str()).unwrap_or("?");
|
||||
let locations = model
|
||||
.get("locations")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|l| {
|
||||
let node = l.get("node")?.as_str()?;
|
||||
let status = l.get("status")?.as_str()?;
|
||||
Some(format!("{node}({status})"))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
})
|
||||
.unwrap_or_default();
|
||||
println!(" {id:40} {locations}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
15
crates/cortex-core/Cargo.toml
Normal file
15
crates/cortex-core/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[package]
|
||||
name = "cortex-core"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
toml.workspace = true
|
||||
figment.workspace = true
|
||||
chrono.workspace = true
|
||||
anyhow.workspace = true
|
||||
thiserror.workspace = true
|
||||
tracing.workspace = true
|
||||
87
crates/cortex-core/src/anthropic.rs
Normal file
87
crates/cortex-core/src/anthropic.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
//! Anthropic Messages API request and response types.
|
||||
//!
|
||||
//! These mirror the `/v1/messages` format used by the Anthropic API.
|
||||
//! The gateway accepts these, translates to OpenAI format, proxies to
|
||||
//! mistral.rs, then translates the response back.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// ── Messages request ─────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MessagesRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<AnthropicMessage>,
|
||||
pub max_tokens: u64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<SystemPrompt>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
#[serde(flatten)]
|
||||
pub extra: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum SystemPrompt {
|
||||
Text(String),
|
||||
Blocks(Vec<Value>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AnthropicMessage {
|
||||
pub role: String,
|
||||
pub content: AnthropicContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum AnthropicContent {
|
||||
Text(String),
|
||||
Blocks(Vec<ContentBlock>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ContentBlock {
|
||||
#[serde(rename = "type")]
|
||||
pub block_type: String,
|
||||
#[serde(flatten)]
|
||||
pub data: Value,
|
||||
}
|
||||
|
||||
// ── Messages response ────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MessagesResponse {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub response_type: String,
|
||||
pub role: String,
|
||||
pub content: Vec<ContentBlock>,
|
||||
pub model: String,
|
||||
pub stop_reason: Option<String>,
|
||||
pub usage: AnthropicUsage,
|
||||
#[serde(flatten)]
|
||||
pub extra: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AnthropicUsage {
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
}
|
||||
|
||||
// ── Streaming events ─────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StreamEvent {
|
||||
#[serde(rename = "type")]
|
||||
pub event_type: String,
|
||||
#[serde(flatten)]
|
||||
pub data: Value,
|
||||
}
|
||||
79
crates/cortex-core/src/config.rs
Normal file
79
crates/cortex-core/src/config.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use figment::{
|
||||
Figment,
|
||||
providers::{Env, Format, Toml},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GatewayConfig {
|
||||
pub gateway: GatewaySettings,
|
||||
pub eviction: EvictionSettings,
|
||||
pub nodes: Vec<NodeConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GatewaySettings {
|
||||
/// Address to listen on for API requests (e.g. "0.0.0.0:8000")
|
||||
pub listen: String,
|
||||
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:9100")
|
||||
pub metrics_listen: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EvictionSettings {
|
||||
/// Eviction strategy: "lru" or "priority"
|
||||
pub strategy: EvictionStrategy,
|
||||
/// Restart the mistralrs process after this many load/unload cycles
|
||||
/// to reclaim fragmented VRAM. 0 = never.
|
||||
#[serde(default)]
|
||||
pub defrag_after_cycles: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum EvictionStrategy {
|
||||
Lru,
|
||||
Priority,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NodeConfig {
|
||||
/// Human-readable node name (e.g. "gpu-large")
|
||||
pub name: String,
|
||||
/// Base URL of the mistralrs HTTP server (e.g. "http://gpu-large.internal:8080")
|
||||
pub endpoint: String,
|
||||
/// Total VRAM in MB across all GPUs on this node
|
||||
pub vram_mb: u64,
|
||||
/// Model IDs that should never be evicted from this node
|
||||
#[serde(default)]
|
||||
pub pinned: Vec<String>,
|
||||
}
|
||||
|
||||
impl GatewayConfig {
|
||||
/// Load configuration from a TOML file, with environment variable overrides.
|
||||
/// Env vars are prefixed with `CORTEX_` and use `__` as a separator
|
||||
/// (e.g. `CORTEX_GATEWAY__LISTEN=0.0.0.0:9000`).
|
||||
pub fn load(path: impl AsRef<Path>) -> Result<Self, figment::Error> {
|
||||
Figment::new()
|
||||
.merge(Toml::file(path))
|
||||
.merge(Env::prefixed("CORTEX_").split("__"))
|
||||
.extract()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GatewayConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
gateway: GatewaySettings {
|
||||
listen: "0.0.0.0:8000".into(),
|
||||
metrics_listen: "0.0.0.0:9100".into(),
|
||||
},
|
||||
eviction: EvictionSettings {
|
||||
strategy: EvictionStrategy::Lru,
|
||||
defrag_after_cycles: 50,
|
||||
},
|
||||
nodes: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
6
crates/cortex-core/src/lib.rs
Normal file
6
crates/cortex-core/src/lib.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod anthropic;
|
||||
pub mod config;
|
||||
pub mod metrics;
|
||||
pub mod node;
|
||||
pub mod openai;
|
||||
pub mod translate;
|
||||
23
crates/cortex-core/src/metrics.rs
Normal file
23
crates/cortex-core/src/metrics.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
//! Request-level metrics captured by the gateway proxy layer.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Metrics captured for a single proxied request.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RequestMetrics {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub model: String,
|
||||
pub node: String,
|
||||
pub prompt_tokens: u64,
|
||||
pub completion_tokens: u64,
|
||||
pub total_tokens: u64,
|
||||
/// Tokens per second for the generation phase.
|
||||
pub tok_per_sec: f64,
|
||||
/// Time from request start to first SSE chunk (streaming) or full response.
|
||||
pub time_to_first_token_ms: u64,
|
||||
/// Total request latency including proxy overhead.
|
||||
pub total_latency_ms: u64,
|
||||
/// Whether this request triggered a model load (cold start).
|
||||
pub cold_start: bool,
|
||||
}
|
||||
74
crates/cortex-core/src/node.rs
Normal file
74
crates/cortex-core/src/node.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Runtime state of a single node in the fleet.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NodeState {
|
||||
pub name: String,
|
||||
pub endpoint: String,
|
||||
pub vram_mb: u64,
|
||||
pub pinned: Vec<String>,
|
||||
pub healthy: bool,
|
||||
pub models: HashMap<String, ModelEntry>,
|
||||
/// Number of load/unload cycles since last process restart.
|
||||
pub lifecycle_cycles: u32,
|
||||
pub last_poll: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// A model registered on a node, with its runtime status.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelEntry {
|
||||
pub id: String,
|
||||
pub status: ModelStatus,
|
||||
/// When this model was last used (for LRU eviction).
|
||||
pub last_accessed: Option<DateTime<Utc>>,
|
||||
/// Estimated VRAM usage in MB when loaded.
|
||||
pub vram_estimate_mb: Option<u64>,
|
||||
}
|
||||
|
||||
/// Model lifecycle status, matching the mistral.rs API.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ModelStatus {
|
||||
Loaded,
|
||||
Unloaded,
|
||||
Reloading,
|
||||
}
|
||||
|
||||
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
|
||||
/// Includes which node(s) host this model and their status.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CortexModelEntry {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
/// Which nodes have this model (and their status).
|
||||
pub locations: Vec<ModelLocation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelLocation {
|
||||
pub node: String,
|
||||
pub status: ModelStatus,
|
||||
pub vram_estimate_mb: Option<u64>,
|
||||
}
|
||||
|
||||
/// Response from mistral.rs `GET /v1/models`.
|
||||
/// This is the upstream format we parse when polling nodes.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MistralModelsResponse {
|
||||
pub data: Vec<MistralModelEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MistralModelEntry {
|
||||
pub id: String,
|
||||
#[serde(default)]
|
||||
pub status: Option<String>,
|
||||
}
|
||||
|
||||
/// Request body for mistral.rs model lifecycle endpoints.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ModelLifecycleRequest {
|
||||
pub model_id: String,
|
||||
}
|
||||
122
crates/cortex-core/src/openai.rs
Normal file
122
crates/cortex-core/src/openai.rs
Normal file
@@ -0,0 +1,122 @@
|
||||
//! OpenAI-compatible request and response types.
|
||||
//!
|
||||
//! These are a subset sufficient for chat completions (streaming + non-streaming).
|
||||
//! Fields not relevant to proxying are captured as `serde_json::Value` via
|
||||
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
|
||||
//! extension field mistral.rs supports.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// ── Chat completion request ──────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
/// All other fields (tools, response_format, mistral.rs extensions, etc.)
|
||||
#[serde(flatten)]
|
||||
pub extra: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: MessageContent,
|
||||
#[serde(flatten)]
|
||||
pub extra: Value,
|
||||
}
|
||||
|
||||
/// Content can be a simple string or an array of content parts (for vision).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
Text(String),
|
||||
Parts(Vec<Value>),
|
||||
}
|
||||
|
||||
// ── Chat completion response (non-streaming) ─────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatCompletionChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<Usage>,
|
||||
#[serde(flatten)]
|
||||
pub extra: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionChoice {
|
||||
pub index: u32,
|
||||
pub message: ChatMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub extra: Value,
|
||||
}
|
||||
|
||||
// ── Streaming chunk ──────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<Usage>,
|
||||
#[serde(flatten)]
|
||||
pub extra: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChunkChoice {
|
||||
pub index: u32,
|
||||
pub delta: Value,
|
||||
pub finish_reason: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub extra: Value,
|
||||
}
|
||||
|
||||
// ── Usage ────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u64,
|
||||
pub completion_tokens: u64,
|
||||
pub total_tokens: u64,
|
||||
}
|
||||
|
||||
// ── Models list response ─────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelsResponse {
|
||||
pub object: String,
|
||||
pub data: Vec<ModelObject>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelObject {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub owned_by: Option<String>,
|
||||
/// Gateway extensions: which node(s) host this model.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub locations: Option<Vec<super::node::ModelLocation>>,
|
||||
#[serde(flatten)]
|
||||
pub extra: Value,
|
||||
}
|
||||
114
crates/cortex-core/src/translate.rs
Normal file
114
crates/cortex-core/src/translate.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
//! Translation between OpenAI and Anthropic request/response envelopes.
|
||||
//!
|
||||
//! This is a stateless transformation — no context is carried between requests.
|
||||
|
||||
use crate::anthropic::{
|
||||
AnthropicContent, AnthropicMessage, AnthropicUsage, ContentBlock, MessagesRequest,
|
||||
MessagesResponse, SystemPrompt,
|
||||
};
|
||||
use crate::openai::{
|
||||
ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, Usage,
|
||||
MessageContent,
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
/// Convert an Anthropic Messages request into an OpenAI ChatCompletion request.
|
||||
pub fn anthropic_to_openai(req: MessagesRequest) -> ChatCompletionRequest {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
// Anthropic `system` field becomes a system message.
|
||||
if let Some(system) = req.system {
|
||||
let content = match system {
|
||||
SystemPrompt::Text(t) => t,
|
||||
SystemPrompt::Blocks(blocks) => serde_json::to_string(&blocks).unwrap_or_default(),
|
||||
};
|
||||
messages.push(ChatMessage {
|
||||
role: "system".into(),
|
||||
content: MessageContent::Text(content),
|
||||
extra: Value::Null,
|
||||
});
|
||||
}
|
||||
|
||||
// Convert message roles and content.
|
||||
for msg in req.messages {
|
||||
let content = match msg.content {
|
||||
AnthropicContent::Text(t) => MessageContent::Text(t),
|
||||
AnthropicContent::Blocks(blocks) => {
|
||||
// For simple text-only blocks, extract the text.
|
||||
// For mixed content (images, etc.), pass as parts.
|
||||
if blocks.len() == 1 && blocks[0].block_type == "text" {
|
||||
let text = blocks[0]
|
||||
.data
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
MessageContent::Text(text)
|
||||
} else {
|
||||
MessageContent::Parts(
|
||||
blocks.into_iter().map(|b| json!(b)).collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
messages.push(ChatMessage {
|
||||
role: msg.role,
|
||||
content,
|
||||
extra: Value::Null,
|
||||
});
|
||||
}
|
||||
|
||||
ChatCompletionRequest {
|
||||
model: req.model,
|
||||
messages,
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
max_tokens: Some(req.max_tokens),
|
||||
stream: req.stream,
|
||||
extra: req.extra,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert an OpenAI ChatCompletion response into an Anthropic Messages response.
|
||||
pub fn openai_to_anthropic(resp: ChatCompletionResponse) -> MessagesResponse {
|
||||
let choice = resp.choices.into_iter().next();
|
||||
|
||||
let (content_text, stop_reason) = match choice {
|
||||
Some(c) => {
|
||||
let text = match c.message.content {
|
||||
MessageContent::Text(t) => t,
|
||||
MessageContent::Parts(parts) => serde_json::to_string(&parts).unwrap_or_default(),
|
||||
};
|
||||
let stop = c.finish_reason.map(|r| match r.as_str() {
|
||||
"stop" => "end_turn".to_string(),
|
||||
"length" => "max_tokens".to_string(),
|
||||
other => other.to_string(),
|
||||
});
|
||||
(text, stop)
|
||||
}
|
||||
None => (String::new(), None),
|
||||
};
|
||||
|
||||
let usage = resp.usage.unwrap_or(Usage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
});
|
||||
|
||||
MessagesResponse {
|
||||
id: resp.id,
|
||||
response_type: "message".into(),
|
||||
role: "assistant".into(),
|
||||
content: vec![ContentBlock {
|
||||
block_type: "text".into(),
|
||||
data: json!({ "text": content_text }),
|
||||
}],
|
||||
model: resp.model,
|
||||
stop_reason,
|
||||
usage: AnthropicUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
},
|
||||
extra: Value::Null,
|
||||
}
|
||||
}
|
||||
25
crates/cortex-gateway/Cargo.toml
Normal file
25
crates/cortex-gateway/Cargo.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "cortex-gateway"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cortex-core.workspace = true
|
||||
tokio.workspace = true
|
||||
axum.workspace = true
|
||||
tower.workspace = true
|
||||
tower-http.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
reqwest.workspace = true
|
||||
tracing.workspace = true
|
||||
metrics.workspace = true
|
||||
metrics-exporter-prometheus.workspace = true
|
||||
chrono.workspace = true
|
||||
anyhow.workspace = true
|
||||
thiserror.workspace = true
|
||||
futures.workspace = true
|
||||
tokio-stream.workspace = true
|
||||
eventsource-stream.workspace = true
|
||||
bytes = "1"
|
||||
106
crates/cortex-gateway/src/evictor.rs
Normal file
106
crates/cortex-gateway/src/evictor.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
//! Model eviction logic.
|
||||
//!
|
||||
//! The evictor runs as a background task. When the router determines that a
|
||||
//! model needs to be loaded on a node but VRAM is tight, it can request
|
||||
//! eviction via a channel. The evictor then:
|
||||
//! 1. Identifies the LRU model on that node (excluding pinned models)
|
||||
//! 2. Calls `POST /v1/models/unload` on the node
|
||||
//! 3. Increments the lifecycle cycle counter (for defrag tracking)
|
||||
|
||||
use crate::state::CortexState;
|
||||
use cortex_core::node::{ModelLifecycleRequest, ModelStatus};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Runs forever. Currently a placeholder that periodically checks for
|
||||
/// eviction opportunities. In the future, this will be driven by a
|
||||
/// channel from the router when VRAM pressure is detected.
|
||||
pub async fn eviction_loop(fleet: Arc<CortexState>) {
|
||||
// TODO: Replace this polling approach with a channel-driven design
|
||||
// where the router sends eviction requests when it detects that a
|
||||
// model load would exceed available VRAM.
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||
// Placeholder: the actual eviction logic is in `evict_lru_on_node`,
|
||||
// called on demand by the router.
|
||||
let _ = &fleet; // suppress unused warning
|
||||
}
|
||||
}
|
||||
|
||||
/// Evict the least-recently-used model on a given node.
|
||||
/// Returns the model ID that was evicted, or None if nothing could be evicted.
|
||||
pub async fn evict_lru_on_node(
|
||||
fleet: &CortexState,
|
||||
node_name: &str,
|
||||
) -> anyhow::Result<Option<String>> {
|
||||
let (endpoint, candidate) = {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let Some(node) = nodes.get(node_name) else {
|
||||
anyhow::bail!("node '{node_name}' not found");
|
||||
};
|
||||
|
||||
// Find the loaded model with the oldest last_accessed, excluding pinned.
|
||||
let candidate = node
|
||||
.models
|
||||
.values()
|
||||
.filter(|m| m.status == ModelStatus::Loaded)
|
||||
.filter(|m| !node.pinned.contains(&m.id))
|
||||
.min_by_key(|m| m.last_accessed)
|
||||
.map(|m| m.id.clone());
|
||||
|
||||
(node.endpoint.clone(), candidate)
|
||||
};
|
||||
|
||||
let Some(model_id) = candidate else {
|
||||
tracing::info!(node = node_name, "no evictable models found");
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
tracing::info!(node = node_name, model = %model_id, "evicting model");
|
||||
|
||||
let url = format!("{endpoint}/v1/models/unload");
|
||||
let resp = fleet
|
||||
.http_client
|
||||
.post(&url)
|
||||
.json(&ModelLifecycleRequest {
|
||||
model_id: model_id.clone(),
|
||||
})
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if resp.status().is_success() {
|
||||
// Update local state.
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
if let Some(node) = nodes.get_mut(node_name) {
|
||||
if let Some(entry) = node.models.get_mut(&model_id) {
|
||||
entry.status = ModelStatus::Unloaded;
|
||||
}
|
||||
node.lifecycle_cycles += 1;
|
||||
|
||||
// Check if we should flag for defrag.
|
||||
if fleet.eviction.defrag_after_cycles > 0
|
||||
&& node.lifecycle_cycles >= fleet.eviction.defrag_after_cycles
|
||||
{
|
||||
tracing::warn!(
|
||||
node = node_name,
|
||||
cycles = node.lifecycle_cycles,
|
||||
"VRAM fragmentation threshold reached — consider restarting mistralrs"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(node = node_name, model = %model_id, "model evicted");
|
||||
Ok(Some(model_id))
|
||||
} else {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
tracing::error!(
|
||||
node = node_name,
|
||||
model = %model_id,
|
||||
status = %status,
|
||||
body = %body,
|
||||
"failed to evict model"
|
||||
);
|
||||
anyhow::bail!("eviction failed: {status} {body}");
|
||||
}
|
||||
}
|
||||
207
crates/cortex-gateway/src/handlers.rs
Normal file
207
crates/cortex-gateway/src/handlers.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
//! Axum HTTP handlers for the gateway API surface.
|
||||
|
||||
use crate::proxy;
|
||||
use crate::router;
|
||||
use crate::state::CortexState;
|
||||
use axum::body::Bytes;
|
||||
use axum::extract::State;
|
||||
use axum::http::HeaderMap;
|
||||
use axum::response::{IntoResponse, Json, Response};
|
||||
use axum::routing::{get, post};
|
||||
use axum::Router;
|
||||
use cortex_core::node::{CortexModelEntry, ModelLocation};
|
||||
use cortex_core::openai::ChatCompletionRequest;
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub fn api_routes() -> Router<Arc<CortexState>> {
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/completions", post(completions))
|
||||
.route("/v1/models", get(list_models))
|
||||
.route("/v1/messages", post(anthropic_messages))
|
||||
.route("/health", get(health))
|
||||
.route("/", get(health))
|
||||
}
|
||||
|
||||
/// `POST /v1/chat/completions` — proxy to the appropriate backend node.
|
||||
async fn chat_completions(
|
||||
State(fleet): State<Arc<CortexState>>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
let model_id = match extract_model(&body) {
|
||||
Some(m) => m,
|
||||
None => return error_response(400, "missing 'model' field in request body"),
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return error_response(404, &e.to_string()),
|
||||
};
|
||||
|
||||
match proxy::forward_request(&fleet.http_client, &route, "/v1/chat/completions", headers, body)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => e.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// `POST /v1/completions` — proxy completions endpoint.
|
||||
async fn completions(
|
||||
State(fleet): State<Arc<CortexState>>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
let model_id = match extract_model(&body) {
|
||||
Some(m) => m,
|
||||
None => return error_response(400, "missing 'model' field in request body"),
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return error_response(404, &e.to_string()),
|
||||
};
|
||||
|
||||
match proxy::forward_request(&fleet.http_client, &route, "/v1/completions", headers, body)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => e.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back.
|
||||
async fn anthropic_messages(
|
||||
State(fleet): State<Arc<CortexState>>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
// Parse as Anthropic request.
|
||||
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
|
||||
Ok(r) => r,
|
||||
Err(e) => return error_response(400, &format!("invalid Anthropic request: {e}")),
|
||||
};
|
||||
|
||||
let model_id = anth_req.model.clone();
|
||||
let is_streaming = anth_req.stream.unwrap_or(false);
|
||||
|
||||
// Translate to OpenAI format.
|
||||
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
|
||||
let openai_body = match serde_json::to_vec(&openai_req) {
|
||||
Ok(b) => Bytes::from(b),
|
||||
Err(e) => return error_response(500, &format!("translation error: {e}")),
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return error_response(404, &e.to_string()),
|
||||
};
|
||||
|
||||
if is_streaming {
|
||||
// TODO: streaming Anthropic translation requires converting SSE format.
|
||||
// For now, proxy the OpenAI SSE stream directly (clients that can handle
|
||||
// OpenAI SSE will work; full Anthropic SSE translation is a follow-up).
|
||||
match proxy::forward_request(
|
||||
&fleet.http_client,
|
||||
&route,
|
||||
"/v1/chat/completions",
|
||||
headers,
|
||||
openai_body,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => e.into_response(),
|
||||
}
|
||||
} else {
|
||||
// Non-streaming: proxy, await full response, translate back.
|
||||
match proxy::forward_request(
|
||||
&fleet.http_client,
|
||||
&route,
|
||||
"/v1/chat/completions",
|
||||
headers,
|
||||
openai_body,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
// TODO: buffer response, parse as OpenAI ChatCompletionResponse,
|
||||
// translate to Anthropic MessagesResponse.
|
||||
// For now, return the OpenAI response as-is.
|
||||
resp
|
||||
}
|
||||
Err(e) => e.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `GET /v1/models` — aggregate models from all nodes.
|
||||
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let mut model_map: std::collections::HashMap<String, CortexModelEntry> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
for node in nodes.values() {
|
||||
for (model_id, entry) in &node.models {
|
||||
let location = ModelLocation {
|
||||
node: node.name.clone(),
|
||||
status: entry.status,
|
||||
vram_estimate_mb: entry.vram_estimate_mb,
|
||||
};
|
||||
model_map
|
||||
.entry(model_id.clone())
|
||||
.and_modify(|e| e.locations.push(location.clone()))
|
||||
.or_insert_with(|| CortexModelEntry {
|
||||
id: model_id.clone(),
|
||||
object: "model".into(),
|
||||
locations: vec![location],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let data: Vec<Value> = model_map
|
||||
.values()
|
||||
.map(|e| json!(e))
|
||||
.collect();
|
||||
|
||||
Json(json!({
|
||||
"object": "list",
|
||||
"data": data,
|
||||
}))
|
||||
}
|
||||
|
||||
/// `GET /health`
|
||||
async fn health(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let healthy_count = nodes.values().filter(|n| n.healthy).count();
|
||||
let total_count = nodes.len();
|
||||
|
||||
Json(json!({
|
||||
"status": if healthy_count > 0 { "ok" } else { "degraded" },
|
||||
"nodes": {
|
||||
"healthy": healthy_count,
|
||||
"total": total_count,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
fn extract_model(body: &[u8]) -> Option<String> {
|
||||
let v: Value = serde_json::from_slice(body).ok()?;
|
||||
v.get("model")?.as_str().map(|s| s.to_string())
|
||||
}
|
||||
|
||||
fn error_response(status: u16, message: &str) -> Response {
|
||||
let code = axum::http::StatusCode::from_u16(status)
|
||||
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let body = json!({
|
||||
"error": {
|
||||
"message": message,
|
||||
"type": "gateway_error",
|
||||
}
|
||||
});
|
||||
(code, Json(body)).into_response()
|
||||
}
|
||||
51
crates/cortex-gateway/src/lib.rs
Normal file
51
crates/cortex-gateway/src/lib.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
pub mod evictor;
|
||||
pub mod handlers;
|
||||
pub mod metrics;
|
||||
pub mod poller;
|
||||
pub mod proxy;
|
||||
pub mod router;
|
||||
pub mod state;
|
||||
|
||||
use anyhow::Result;
|
||||
use axum::Router;
|
||||
use cortex_core::config::GatewayConfig;
|
||||
use std::sync::Arc;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
/// Build the Axum application router with all routes wired up.
|
||||
pub fn build_app(fleet: Arc<state::CortexState>) -> Router {
|
||||
Router::new()
|
||||
.merge(handlers::api_routes())
|
||||
.layer(CorsLayer::permissive())
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.with_state(fleet)
|
||||
}
|
||||
|
||||
/// Start the gateway: build state from config, spawn background tasks,
|
||||
/// bind the HTTP server.
|
||||
pub async fn run(config: GatewayConfig) -> Result<()> {
|
||||
let fleet = Arc::new(state::CortexState::from_config(&config));
|
||||
|
||||
// Spawn the background poller that refreshes node/model status.
|
||||
let poller_fleet = Arc::clone(&fleet);
|
||||
tokio::spawn(async move {
|
||||
poller::poll_loop(poller_fleet).await;
|
||||
});
|
||||
|
||||
// Spawn the evictor (reacts to VRAM pressure events from the router).
|
||||
let evictor_fleet = Arc::clone(&fleet);
|
||||
tokio::spawn(async move {
|
||||
evictor::eviction_loop(evictor_fleet).await;
|
||||
});
|
||||
|
||||
let app = build_app(Arc::clone(&fleet));
|
||||
|
||||
let listen_addr = config.gateway.listen.parse::<std::net::SocketAddr>()?;
|
||||
tracing::info!("cortex listening on {listen_addr}");
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(listen_addr).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
55
crates/cortex-gateway/src/metrics.rs
Normal file
55
crates/cortex-gateway/src/metrics.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
//! Prometheus metrics exporter.
|
||||
//!
|
||||
//! Runs on a separate port from the main API, exposing `/metrics`
|
||||
//! in Prometheus text format.
|
||||
|
||||
use anyhow::Result;
|
||||
use metrics_exporter_prometheus::PrometheusBuilder;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
/// Install the Prometheus metrics recorder and return a handle.
|
||||
/// The `/metrics` endpoint is served by the exporter's built-in HTTP server.
|
||||
pub fn install(listen: &str) -> Result<()> {
|
||||
let addr: SocketAddr = listen.parse()?;
|
||||
|
||||
PrometheusBuilder::new()
|
||||
.with_http_listener(addr)
|
||||
.install()
|
||||
.map_err(|e| anyhow::anyhow!("failed to install Prometheus exporter: {e}"))?;
|
||||
|
||||
tracing::info!("prometheus metrics exporter on {addr}");
|
||||
|
||||
// Register histograms and counters used by the proxy layer.
|
||||
// The `metrics` crate lazily creates metrics on first use, but
|
||||
// describing them up front gives Prometheus proper HELP/TYPE lines.
|
||||
metrics::describe_histogram!(
|
||||
"cortex_request_duration_seconds",
|
||||
"Total request latency in seconds"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"cortex_time_to_first_token_seconds",
|
||||
"Time to first token in seconds"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"cortex_tokens_per_second",
|
||||
"Generation throughput in tokens per second"
|
||||
);
|
||||
metrics::describe_counter!(
|
||||
"cortex_requests_total",
|
||||
"Total number of proxied requests"
|
||||
);
|
||||
metrics::describe_counter!(
|
||||
"cortex_request_errors_total",
|
||||
"Total number of failed proxy requests"
|
||||
);
|
||||
metrics::describe_counter!(
|
||||
"cortex_evictions_total",
|
||||
"Total number of model evictions"
|
||||
);
|
||||
metrics::describe_counter!(
|
||||
"cortex_cold_starts_total",
|
||||
"Total number of cold-start model loads"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
103
crates/cortex-gateway/src/poller.rs
Normal file
103
crates/cortex-gateway/src/poller.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
//! Background poller that periodically queries each node's `/v1/models`
|
||||
//! endpoint to refresh the fleet state.
|
||||
|
||||
use crate::state::CortexState;
|
||||
use chrono::Utc;
|
||||
use cortex_core::node::{MistralModelsResponse, ModelEntry, ModelStatus};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
const POLL_INTERVAL: Duration = Duration::from_secs(10);
|
||||
|
||||
/// Runs forever, polling all nodes on a fixed interval.
|
||||
pub async fn poll_loop(fleet: Arc<CortexState>) {
|
||||
loop {
|
||||
for nc in &fleet.node_configs {
|
||||
poll_node(&fleet, &nc.name, &nc.endpoint).await;
|
||||
}
|
||||
tokio::time::sleep(POLL_INTERVAL).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
let url = format!("{endpoint}/v1/models");
|
||||
|
||||
let result = fleet
|
||||
.http_client
|
||||
.get(&url)
|
||||
.timeout(Duration::from_secs(5))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
let Some(node) = nodes.get_mut(name) else {
|
||||
return;
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
match resp.json::<MistralModelsResponse>().await {
|
||||
Ok(models_resp) => {
|
||||
// Merge upstream model list into our state, preserving
|
||||
// our local metadata (last_accessed, vram_estimate).
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
for upstream in &models_resp.data {
|
||||
seen.insert(upstream.id.clone());
|
||||
let status = parse_status(upstream.status.as_deref());
|
||||
|
||||
node.models
|
||||
.entry(upstream.id.clone())
|
||||
.and_modify(|e| {
|
||||
e.status = status;
|
||||
})
|
||||
.or_insert_with(|| ModelEntry {
|
||||
id: upstream.id.clone(),
|
||||
status,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Remove models that are no longer reported by the node
|
||||
// (e.g. after a config change / restart).
|
||||
node.models.retain(|id, _| seen.contains(id));
|
||||
|
||||
node.healthy = true;
|
||||
node.last_poll = Some(Utc::now());
|
||||
tracing::debug!(
|
||||
node = name,
|
||||
models = models_resp.data.len(),
|
||||
"poll ok"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(node = name, error = %e, "failed to parse /v1/models response");
|
||||
node.healthy = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(resp) => {
|
||||
tracing::warn!(
|
||||
node = name,
|
||||
status = %resp.status(),
|
||||
"node returned non-success status"
|
||||
);
|
||||
node.healthy = false;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(node = name, error = %e, "failed to reach node");
|
||||
node.healthy = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_status(s: Option<&str>) -> ModelStatus {
|
||||
match s {
|
||||
Some("loaded") => ModelStatus::Loaded,
|
||||
Some("unloaded") => ModelStatus::Unloaded,
|
||||
Some("reloading") => ModelStatus::Reloading,
|
||||
// If the status field is absent, assume loaded (older mistral.rs versions
|
||||
// may not include it).
|
||||
_ => ModelStatus::Loaded,
|
||||
}
|
||||
}
|
||||
82
crates/cortex-gateway/src/proxy.rs
Normal file
82
crates/cortex-gateway/src/proxy.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
//! Streaming HTTP reverse proxy to mistral.rs backends.
|
||||
//!
|
||||
//! For streaming requests, SSE chunks are forwarded as they arrive.
|
||||
//! The proxy captures timing information for metrics but does not
|
||||
//! buffer the full response.
|
||||
|
||||
use crate::router::RouteDecision;
|
||||
use anyhow::Result;
|
||||
use axum::body::Body;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use reqwest::Client;
|
||||
|
||||
/// Proxy a request body to the resolved backend node and stream the response.
|
||||
pub async fn forward_request(
|
||||
client: &Client,
|
||||
route: &RouteDecision,
|
||||
path: &str,
|
||||
headers: HeaderMap,
|
||||
body: bytes::Bytes,
|
||||
) -> Result<Response, ProxyError> {
|
||||
let url = format!("{}{}", route.endpoint, path);
|
||||
tracing::info!(
|
||||
node = %route.node_name,
|
||||
url = %url,
|
||||
cold_start = route.cold_start,
|
||||
"proxying request"
|
||||
);
|
||||
|
||||
let mut req_builder = client.post(&url).body(body);
|
||||
|
||||
// Forward relevant headers.
|
||||
for (key, value) in headers.iter() {
|
||||
if key == "host" || key == "content-length" {
|
||||
continue; // reqwest sets these
|
||||
}
|
||||
req_builder = req_builder.header(key, value);
|
||||
}
|
||||
|
||||
let upstream_resp = req_builder.send().await.map_err(ProxyError::Upstream)?;
|
||||
|
||||
let status = StatusCode::from_u16(upstream_resp.status().as_u16())
|
||||
.unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
|
||||
let resp_headers = upstream_resp.headers().clone();
|
||||
let stream = upstream_resp.bytes_stream();
|
||||
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
let mut response = Response::builder().status(status);
|
||||
for (key, value) in resp_headers.iter() {
|
||||
response = response.header(key, value);
|
||||
}
|
||||
|
||||
response
|
||||
.body(body)
|
||||
.map_err(|e| ProxyError::ResponseBuild(e.to_string()))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ProxyError {
|
||||
#[error("upstream request failed: {0}")]
|
||||
Upstream(reqwest::Error),
|
||||
#[error("failed to build response: {0}")]
|
||||
ResponseBuild(String),
|
||||
}
|
||||
|
||||
impl IntoResponse for ProxyError {
|
||||
fn into_response(self) -> Response {
|
||||
let status = match &self {
|
||||
ProxyError::Upstream(_) => StatusCode::BAD_GATEWAY,
|
||||
ProxyError::ResponseBuild(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
let body = serde_json::json!({
|
||||
"error": {
|
||||
"message": self.to_string(),
|
||||
"type": "proxy_error",
|
||||
}
|
||||
});
|
||||
(status, axum::Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
74
crates/cortex-gateway/src/router.rs
Normal file
74
crates/cortex-gateway/src/router.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
//! Model-to-node routing logic.
|
||||
//!
|
||||
//! Given a model ID from an inbound request, determine which node should
|
||||
//! handle it. Priority:
|
||||
//! 1. Node where the model is currently `Loaded`
|
||||
//! 2. Node where the model is `Unloaded` (will lazy-load on request)
|
||||
//! 3. Error: model not found on any node
|
||||
|
||||
use crate::state::CortexState;
|
||||
use cortex_core::node::ModelStatus;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// The routing decision: which node endpoint to proxy the request to.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RouteDecision {
|
||||
pub node_name: String,
|
||||
pub endpoint: String,
|
||||
/// Whether the model will need to load (cold start).
|
||||
pub cold_start: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RouteError {
|
||||
#[error("model '{0}' not found on any node")]
|
||||
ModelNotFound(String),
|
||||
#[error("no healthy nodes available")]
|
||||
NoHealthyNodes,
|
||||
}
|
||||
|
||||
/// Resolve which node should serve a request for the given model.
|
||||
pub async fn resolve(fleet: &Arc<CortexState>, model_id: &str) -> Result<RouteDecision, RouteError> {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
|
||||
// Pass 1: find a node where the model is already loaded.
|
||||
let mut loaded_candidate = None;
|
||||
let mut unloaded_candidate = None;
|
||||
|
||||
for node in nodes.values() {
|
||||
if !node.healthy {
|
||||
continue;
|
||||
}
|
||||
if let Some(entry) = node.models.get(model_id) {
|
||||
match entry.status {
|
||||
ModelStatus::Loaded | ModelStatus::Reloading => {
|
||||
loaded_candidate = Some(RouteDecision {
|
||||
node_name: node.name.clone(),
|
||||
endpoint: node.endpoint.clone(),
|
||||
cold_start: false,
|
||||
});
|
||||
break; // loaded is best, stop searching
|
||||
}
|
||||
ModelStatus::Unloaded => {
|
||||
if unloaded_candidate.is_none() {
|
||||
unloaded_candidate = Some(RouteDecision {
|
||||
node_name: node.name.clone(),
|
||||
endpoint: node.endpoint.clone(),
|
||||
cold_start: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loaded_candidate
|
||||
.or(unloaded_candidate)
|
||||
.ok_or_else(|| {
|
||||
if nodes.values().any(|n| n.healthy) {
|
||||
RouteError::ModelNotFound(model_id.to_string())
|
||||
} else {
|
||||
RouteError::NoHealthyNodes
|
||||
}
|
||||
})
|
||||
}
|
||||
43
crates/cortex-gateway/src/state.rs
Normal file
43
crates/cortex-gateway/src/state.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use cortex_core::config::{EvictionSettings, GatewayConfig, NodeConfig};
|
||||
use cortex_core::node::NodeState;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Shared fleet state, protected by a RwLock for concurrent reader access.
|
||||
pub struct CortexState {
|
||||
pub nodes: RwLock<HashMap<String, NodeState>>,
|
||||
pub node_configs: Vec<NodeConfig>,
|
||||
pub eviction: EvictionSettings,
|
||||
pub http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl CortexState {
|
||||
pub fn from_config(config: &GatewayConfig) -> Self {
|
||||
let mut nodes = HashMap::new();
|
||||
for nc in &config.nodes {
|
||||
nodes.insert(
|
||||
nc.name.clone(),
|
||||
NodeState {
|
||||
name: nc.name.clone(),
|
||||
endpoint: nc.endpoint.clone(),
|
||||
vram_mb: nc.vram_mb,
|
||||
pinned: nc.pinned.clone(),
|
||||
healthy: false, // will be set by first poll
|
||||
models: HashMap::new(),
|
||||
lifecycle_cycles: 0,
|
||||
last_poll: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Self {
|
||||
nodes: RwLock::new(nodes),
|
||||
node_configs: config.nodes.clone(),
|
||||
eviction: config.eviction.clone(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.build()
|
||||
.expect("failed to build HTTP client"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user