feat: scaffold cortex workspace
Rust reverse-proxy for multi-node mistral.rs inference clusters. Includes crate structure (cortex-core, cortex-gateway, cortex-agent, cortex-cli), config loading, OpenAI/Anthropic translation stubs, model routing, eviction, polling, and streaming proxy scaffolding. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
/target
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
cortex.toml
|
||||||
141
CLAUDE.md
Normal file
141
CLAUDE.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# CLAUDE.md — cortex
|
||||||
|
|
||||||
|
## Project overview
|
||||||
|
|
||||||
|
cortex is a Rust reverse-proxy that sits in front of multiple
|
||||||
|
mistral.rs inference nodes and presents a unified OpenAI + Anthropic
|
||||||
|
compatible API surface. It handles model routing, lifecycle management
|
||||||
|
(load/unload/evict), request translation, and metrics collection.
|
||||||
|
|
||||||
|
## Repository layout
|
||||||
|
|
||||||
|
```
|
||||||
|
cortex/
|
||||||
|
├── Cargo.toml # workspace root
|
||||||
|
├── cortex.toml # example gateway config
|
||||||
|
├── README.md
|
||||||
|
├── CLAUDE.md # ← you are here
|
||||||
|
├── crates/
|
||||||
|
│ ├── cortex-core/ # shared types, config, envelopes
|
||||||
|
│ │ └── src/
|
||||||
|
│ │ ├── lib.rs
|
||||||
|
│ │ ├── config.rs # figment-based config structs
|
||||||
|
│ │ ├── node.rs # NodeState, ModelStatus
|
||||||
|
│ │ ├── openai.rs # OpenAI request/response types
|
||||||
|
│ │ ├── anthropic.rs # Anthropic request/response types
|
||||||
|
│ │ ├── translate.rs # OpenAI <-> Anthropic translation
|
||||||
|
│ │ └── metrics.rs # RequestMetrics, histogram helpers
|
||||||
|
│ ├── cortex-gateway/ # the HTTP proxy server
|
||||||
|
│ │ └── src/
|
||||||
|
│ │ ├── lib.rs
|
||||||
|
│ │ ├── state.rs # CortexState: Arc<RwLock<...>>
|
||||||
|
│ │ ├── router.rs # model -> node routing logic
|
||||||
|
│ │ ├── proxy.rs # streaming HTTP proxy to backends
|
||||||
|
│ │ ├── evictor.rs # LRU/priority eviction logic
|
||||||
|
│ │ ├── poller.rs # background task polling node status
|
||||||
|
│ │ ├── handlers.rs # axum handlers (chat, completions, models, etc.)
|
||||||
|
│ │ └── metrics.rs # prometheus exporter endpoint
|
||||||
|
│ ├── cortex-agent/ # per-node sidecar (future: defrag, restart)
|
||||||
|
│ │ └── src/
|
||||||
|
│ │ ├── lib.rs
|
||||||
|
│ │ └── agent.rs # local node management
|
||||||
|
│ └── cortex-cli/ # CLI entrypoint
|
||||||
|
│ └── src/
|
||||||
|
│ └── main.rs
|
||||||
|
└── tests/ # integration tests (future)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key design decisions
|
||||||
|
|
||||||
|
### mistral.rs HTTP API for model lifecycle
|
||||||
|
mistral.rs (v0.8+) supports dynamic model loading/unloading at runtime:
|
||||||
|
- `POST /v1/models/unload {"model_id": "..."}` — frees VRAM, preserves config
|
||||||
|
- `POST /v1/models/reload {"model_id": "..."}` — explicitly reload
|
||||||
|
- `POST /v1/models/status {"model_id": "..."}` — loaded/unloaded/reloading
|
||||||
|
- `GET /v1/models` — lists all models with status field
|
||||||
|
- Lazy loading: requests to unloaded models trigger automatic reload
|
||||||
|
|
||||||
|
The gateway does NOT manage systemd units for model swaps. It calls these
|
||||||
|
HTTP endpoints directly. The only systemd interaction is for full-process
|
||||||
|
restarts after VRAM fragmentation accumulates (defrag_after_cycles).
|
||||||
|
|
||||||
|
### Streaming proxy
|
||||||
|
Chat completions are proxied as SSE streams. The gateway must:
|
||||||
|
1. Parse the inbound request to extract the model name
|
||||||
|
2. Route to the correct backend node
|
||||||
|
3. Stream the response back, capturing token timing for metrics
|
||||||
|
4. NOT buffer the full response — true streaming passthrough
|
||||||
|
|
||||||
|
### Anthropic translation
|
||||||
|
When a request arrives at `/v1/messages` (Anthropic format), the gateway
|
||||||
|
translates it to OpenAI format before proxying to mistral.rs, then
|
||||||
|
translates the response back. This is stateless envelope transformation.
|
||||||
|
|
||||||
|
### Eviction
|
||||||
|
The evictor runs as a background task. Before loading a model on a node
|
||||||
|
where VRAM is tight:
|
||||||
|
1. Check if the model is already loaded elsewhere → route there instead
|
||||||
|
2. Find the LRU model on the target node (excluding pinned models)
|
||||||
|
3. Call `/v1/models/unload` on that model
|
||||||
|
4. The incoming request's lazy-load triggers the new model load
|
||||||
|
|
||||||
|
### Metrics
|
||||||
|
Per-request: model, node, prompt_tokens, completion_tokens, total_tokens,
|
||||||
|
tok_per_sec, time_to_first_token_ms, total_latency_ms.
|
||||||
|
Exposed as Prometheus histograms/counters on a separate port.
|
||||||
|
|
||||||
|
## Tech stack
|
||||||
|
|
||||||
|
- **Rust 2024 edition** — workspace with 4 crates
|
||||||
|
- **Axum 0.8** — HTTP framework (same as mistral.rs itself)
|
||||||
|
- **reqwest** — HTTP client for proxying to backends
|
||||||
|
- **figment** — config loading (TOML + env vars)
|
||||||
|
- **tokio** — async runtime
|
||||||
|
- **metrics + metrics-exporter-prometheus** — observability
|
||||||
|
- **tracing** — structured logging
|
||||||
|
|
||||||
|
## Build commands
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cargo build --release # build all crates
|
||||||
|
cargo run -p cortex-cli -- serve # run the gateway
|
||||||
|
cargo test # run all tests
|
||||||
|
cargo clippy --workspace # lint
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment
|
||||||
|
|
||||||
|
- Targets Fedora 43 (systemd, SELinux enforcing)
|
||||||
|
- Nodes communicate over a private network (e.g. WireGuard mesh)
|
||||||
|
- One or more GPU nodes running mistral.rs on port 8080
|
||||||
|
- Optionally a metrics-only node (no GPU) for Prometheus/Grafana
|
||||||
|
- Each node runs `mistralrs serve` on port 8080
|
||||||
|
- Gateway listens on port 8000 (API) and 9100 (metrics)
|
||||||
|
- TLS terminated at gateway or via nginx; internal traffic is plaintext over WireGuard
|
||||||
|
|
||||||
|
## Conventions
|
||||||
|
|
||||||
|
- Error handling: `anyhow` for binaries, `thiserror` for library crates
|
||||||
|
- No `unwrap()` in library code; `expect()` only with clear rationale
|
||||||
|
- All public types derive `Debug, Clone, Serialize, Deserialize` where sensible
|
||||||
|
- Config structs use `figment` with TOML as primary source, env vars as override
|
||||||
|
- Prefer `Arc<RwLock<...>>` for shared fleet state; minimize lock duration
|
||||||
|
- SSE streaming uses `tokio_stream` + `eventsource-stream` for parsing
|
||||||
|
- Log at `info` for request routing, `debug` for proxy details, `warn` for
|
||||||
|
eviction and node health, `error` for proxy failures
|
||||||
|
|
||||||
|
## Current status
|
||||||
|
|
||||||
|
**Scaffold phase** — crate structure, types, and handler stubs are in place.
|
||||||
|
The following needs implementation:
|
||||||
|
|
||||||
|
1. **cortex-core**: Flesh out OpenAI/Anthropic envelope types with all fields
|
||||||
|
needed for chat completions (streaming + non-streaming)
|
||||||
|
2. **cortex-gateway/proxy.rs**: Implement streaming HTTP proxy with SSE passthrough
|
||||||
|
3. **cortex-gateway/router.rs**: Model-to-node routing with fallback to least-loaded
|
||||||
|
4. **cortex-gateway/evictor.rs**: LRU eviction with pinning support
|
||||||
|
5. **cortex-gateway/poller.rs**: Background polling of node `/v1/models` endpoints
|
||||||
|
6. **cortex-gateway/handlers.rs**: Wire up axum routes to proxy logic
|
||||||
|
7. **cortex-core/translate.rs**: OpenAI <-> Anthropic request/response translation
|
||||||
|
8. **cortex-agent**: Sidecar for VRAM defrag restarts (lower priority)
|
||||||
|
9. **Integration tests**: Mock mistralrs backends, test routing + eviction
|
||||||
2787
Cargo.lock
generated
Normal file
2787
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
57
Cargo.toml
Normal file
57
Cargo.toml
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
[workspace]
|
||||||
|
resolver = "2"
|
||||||
|
members = [
|
||||||
|
"crates/cortex-core",
|
||||||
|
"crates/cortex-gateway",
|
||||||
|
"crates/cortex-agent",
|
||||||
|
"crates/cortex-cli",
|
||||||
|
]
|
||||||
|
|
||||||
|
[workspace.package]
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
license = "GPL-3.0"
|
||||||
|
repository = "https://git.lair.cafe/helexa/cortex"
|
||||||
|
|
||||||
|
[workspace.dependencies]
|
||||||
|
# async runtime
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
|
||||||
|
# web framework
|
||||||
|
axum = { version = "0.8", features = ["macros"] }
|
||||||
|
tower = "0.5"
|
||||||
|
tower-http = { version = "0.6", features = ["cors", "trace", "timeout"] }
|
||||||
|
|
||||||
|
# serialization
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
toml = "0.8"
|
||||||
|
|
||||||
|
# http client (for proxying to mistralrs backends)
|
||||||
|
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||||
|
|
||||||
|
# observability
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||||
|
metrics = "0.24"
|
||||||
|
metrics-exporter-prometheus = "0.16"
|
||||||
|
|
||||||
|
# time
|
||||||
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
|
|
||||||
|
# config
|
||||||
|
figment = { version = "0.10", features = ["toml", "env"] }
|
||||||
|
|
||||||
|
# error handling
|
||||||
|
anyhow = "1"
|
||||||
|
thiserror = "2"
|
||||||
|
|
||||||
|
# futures / streams (for SSE proxying)
|
||||||
|
futures = "0.3"
|
||||||
|
tokio-stream = "0.1"
|
||||||
|
eventsource-stream = "0.2"
|
||||||
|
|
||||||
|
# workspace crates
|
||||||
|
cortex-core = { path = "crates/cortex-core" }
|
||||||
|
cortex-gateway = { path = "crates/cortex-gateway" }
|
||||||
|
cortex-agent = { path = "crates/cortex-agent" }
|
||||||
138
README.md
Normal file
138
README.md
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
# cortex
|
||||||
|
|
||||||
|
A Rust reverse-proxy and fleet management layer for multi-node
|
||||||
|
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) inference clusters.
|
||||||
|
|
||||||
|
## Problem
|
||||||
|
|
||||||
|
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
|
||||||
|
model affinities) requires a unified API surface that:
|
||||||
|
|
||||||
|
- Presents a **single `/v1/models` catalogue** merging every model across every
|
||||||
|
node.
|
||||||
|
- **Routes requests** to the correct node based on where a model is loaded (or
|
||||||
|
*can* be loaded).
|
||||||
|
- Manages **model lifecycle** — unload cold models, reload on demand, pin
|
||||||
|
critical ones — using the mistral.rs
|
||||||
|
`/v1/models/{unload,reload,status}` HTTP API (PR #1828+).
|
||||||
|
- Translates between **OpenAI and Anthropic** request/response envelopes so
|
||||||
|
every client in the homelab speaks whichever dialect it prefers.
|
||||||
|
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
|
||||||
|
them as Prometheus counters/histograms.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────┐ ┌──────────┐ ┌────────────┐ ┌────────────┐
|
||||||
|
│ Claude Code │ │ Zed/IDE │ │ Tidal / mm │ │ curl / etc │
|
||||||
|
└──────┬───────┘ └─────┬────┘ └──────┬─────┘ └──────┬─────┘
|
||||||
|
│ │ │ │
|
||||||
|
└────────────────┴──────┬───────┴───────────────┘
|
||||||
|
│
|
||||||
|
┌──────────▼──────────┐
|
||||||
|
│ cortex │
|
||||||
|
│ (cortex-gateway) │
|
||||||
|
│ │
|
||||||
|
│ Router · Metrics │
|
||||||
|
│ Evictor · Translate│
|
||||||
|
└──┬──────┬────────┬──┘
|
||||||
|
│ │ │
|
||||||
|
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
|
||||||
|
│ gpu-large │ │gpu-med │ │ gpu-small │
|
||||||
|
│ mistralrs │ │mistral │ │ mistralrs │
|
||||||
|
│ serve │ │rs serve│ │ serve │
|
||||||
|
│ :8080 │ │ :8080 │ │ :8080 │
|
||||||
|
└───────────┘ └────────┘ └───────────┘
|
||||||
|
private network (.internal)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Crates
|
||||||
|
|
||||||
|
| Crate | Purpose |
|
||||||
|
|---|---|
|
||||||
|
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic request/response envelopes |
|
||||||
|
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, metrics exporter |
|
||||||
|
| `cortex-agent` | Per-node sidecar: polls local mistralrs, reports to gateway, handles restart/defrag |
|
||||||
|
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
|
||||||
|
|
||||||
|
## Node setup
|
||||||
|
|
||||||
|
Each GPU node runs `mistralrs serve` with a multi-model config. Models are
|
||||||
|
declared but start **unloaded** — mistral.rs lazy-loads on first request and
|
||||||
|
the gateway can explicitly unload/reload via the HTTP API.
|
||||||
|
|
||||||
|
Example node systemd unit:
|
||||||
|
|
||||||
|
```ini
|
||||||
|
# /etc/systemd/system/mistralrs.service
|
||||||
|
[Unit]
|
||||||
|
Description=mistral.rs inference server
|
||||||
|
After=network-online.target
|
||||||
|
Wants=network-online.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
ExecStart=/usr/local/bin/mistralrs serve \
|
||||||
|
--from-config /etc/mistralrs/config.toml \
|
||||||
|
--port 8080
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=5
|
||||||
|
Environment=CUDA_VISIBLE_DEVICES=0,1
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
```
|
||||||
|
|
||||||
|
## Gateway config
|
||||||
|
|
||||||
|
```toml
|
||||||
|
# cortex.toml
|
||||||
|
[gateway]
|
||||||
|
listen = "0.0.0.0:8000"
|
||||||
|
metrics_listen = "0.0.0.0:9100"
|
||||||
|
|
||||||
|
[eviction]
|
||||||
|
strategy = "lru" # lru | priority
|
||||||
|
defrag_after_cycles = 50
|
||||||
|
|
||||||
|
[[nodes]]
|
||||||
|
name = "gpu-large"
|
||||||
|
endpoint = "http://gpu-large.internal:8080"
|
||||||
|
vram_mb = 49_152 # e.g. 2x RTX 4090
|
||||||
|
pinned = ["your-org/large-model"]
|
||||||
|
|
||||||
|
[[nodes]]
|
||||||
|
name = "gpu-medium"
|
||||||
|
endpoint = "http://gpu-medium.internal:8080"
|
||||||
|
vram_mb = 24_576 # e.g. RTX 4090
|
||||||
|
pinned = ["your-org/medium-model"]
|
||||||
|
|
||||||
|
[[nodes]]
|
||||||
|
name = "gpu-small"
|
||||||
|
endpoint = "http://gpu-small.internal:8080"
|
||||||
|
vram_mb = 12_288 # e.g. RTX 3060
|
||||||
|
pinned = ["your-org/embedding-model"]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Building
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cargo build --release
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# start the gateway
|
||||||
|
cortex serve --config cortex.toml
|
||||||
|
|
||||||
|
# check fleet status
|
||||||
|
cortex status
|
||||||
|
|
||||||
|
# list all models across nodes
|
||||||
|
curl http://localhost:8000/v1/models
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
GPL-3.0
|
||||||
45
cortex.example.toml
Normal file
45
cortex.example.toml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# cortex.example.toml — example configuration
|
||||||
|
#
|
||||||
|
# Copy to cortex.toml and adjust for your environment.
|
||||||
|
#
|
||||||
|
# Environment variable overrides use CORTEX_ prefix with __ separators:
|
||||||
|
# CORTEX_GATEWAY__LISTEN=0.0.0.0:9000
|
||||||
|
|
||||||
|
[gateway]
|
||||||
|
listen = "0.0.0.0:8000"
|
||||||
|
metrics_listen = "0.0.0.0:9100"
|
||||||
|
|
||||||
|
[eviction]
|
||||||
|
strategy = "lru"
|
||||||
|
# Restart mistralrs after this many load/unload cycles to defragment VRAM.
|
||||||
|
# Set to 0 to disable.
|
||||||
|
defrag_after_cycles = 50
|
||||||
|
|
||||||
|
# -- Nodes ---------------------------------------------------------------
|
||||||
|
# Each [[nodes]] entry declares a mistral.rs instance in the fleet.
|
||||||
|
# Models are discovered by polling the node's /v1/models endpoint.
|
||||||
|
# Pinned models are never evicted.
|
||||||
|
|
||||||
|
[[nodes]]
|
||||||
|
name = "gpu-large"
|
||||||
|
endpoint = "http://gpu-large.internal:8080"
|
||||||
|
vram_mb = 49152 # e.g. 2x RTX 4090 (48 GB combined)
|
||||||
|
pinned = [
|
||||||
|
"your-org/large-model",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[nodes]]
|
||||||
|
name = "gpu-medium"
|
||||||
|
endpoint = "http://gpu-medium.internal:8080"
|
||||||
|
vram_mb = 24576 # e.g. RTX 4090 (24 GB)
|
||||||
|
pinned = [
|
||||||
|
"your-org/medium-model",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[nodes]]
|
||||||
|
name = "gpu-small"
|
||||||
|
endpoint = "http://gpu-small.internal:8080"
|
||||||
|
vram_mb = 12288 # e.g. RTX 3060 (12 GB)
|
||||||
|
pinned = [
|
||||||
|
"your-org/embedding-model",
|
||||||
|
]
|
||||||
14
crates/cortex-agent/Cargo.toml
Normal file
14
crates/cortex-agent/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
[package]
|
||||||
|
name = "cortex-agent"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
cortex-core.workspace = true
|
||||||
|
tokio.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
reqwest.workspace = true
|
||||||
|
tracing.workspace = true
|
||||||
|
anyhow.workspace = true
|
||||||
72
crates/cortex-agent/src/agent.rs
Normal file
72
crates/cortex-agent/src/agent.rs
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
//! Per-node agent sidecar.
|
||||||
|
//!
|
||||||
|
//! This is a future component that runs on each GPU node alongside mistralrs.
|
||||||
|
//! It handles:
|
||||||
|
//! - VRAM defragmentation (restarting the mistralrs systemd unit when the
|
||||||
|
//! gateway signals that lifecycle_cycles has exceeded the threshold)
|
||||||
|
//! - Local nvidia-smi polling for actual VRAM usage reporting
|
||||||
|
//! - Systemd unit management for mistralrs process restarts
|
||||||
|
//!
|
||||||
|
//! For now this is a stub. The gateway's poller + evictor handle the critical
|
||||||
|
//! path (model lifecycle via the mistralrs HTTP API). The agent adds
|
||||||
|
//! operational niceties that can be built incrementally.
|
||||||
|
|
||||||
|
/// Placeholder for agent configuration.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AgentConfig {
|
||||||
|
/// The local mistralrs endpoint to monitor.
|
||||||
|
pub mistralrs_endpoint: String,
|
||||||
|
/// The systemd unit name for mistralrs (e.g. "mistralrs.service").
|
||||||
|
pub systemd_unit: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Restart the local mistralrs process via systemd.
|
||||||
|
/// This is the nuclear option for VRAM defragmentation.
|
||||||
|
pub async fn restart_mistralrs(config: &AgentConfig) -> anyhow::Result<()> {
|
||||||
|
tracing::warn!(
|
||||||
|
unit = %config.systemd_unit,
|
||||||
|
"restarting mistralrs for VRAM defragmentation"
|
||||||
|
);
|
||||||
|
|
||||||
|
let output = tokio::process::Command::new("systemctl")
|
||||||
|
.args(["restart", &config.systemd_unit])
|
||||||
|
.output()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if output.status.success() {
|
||||||
|
tracing::info!(unit = %config.systemd_unit, "mistralrs restarted successfully");
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
|
anyhow::bail!("systemctl restart failed: {stderr}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Query nvidia-smi for current VRAM usage on this node.
|
||||||
|
/// Returns (used_mb, total_mb) for each GPU.
|
||||||
|
pub async fn query_vram() -> anyhow::Result<Vec<(u64, u64)>> {
|
||||||
|
let output = tokio::process::Command::new("nvidia-smi")
|
||||||
|
.args([
|
||||||
|
"--query-gpu=memory.used,memory.total",
|
||||||
|
"--format=csv,noheader,nounits",
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
|
anyhow::bail!("nvidia-smi failed: {stderr}");
|
||||||
|
}
|
||||||
|
|
||||||
|
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||||
|
let mut gpus = Vec::new();
|
||||||
|
for line in stdout.lines() {
|
||||||
|
let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
|
||||||
|
if parts.len() == 2 {
|
||||||
|
let used: u64 = parts[0].parse().unwrap_or(0);
|
||||||
|
let total: u64 = parts[1].parse().unwrap_or(0);
|
||||||
|
gpus.push((used, total));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(gpus)
|
||||||
|
}
|
||||||
1
crates/cortex-agent/src/lib.rs
Normal file
1
crates/cortex-agent/src/lib.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
pub mod agent;
|
||||||
20
crates/cortex-cli/Cargo.toml
Normal file
20
crates/cortex-cli/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
[package]
|
||||||
|
name = "cortex-cli"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "cortex"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
cortex-core.workspace = true
|
||||||
|
cortex-gateway.workspace = true
|
||||||
|
tokio.workspace = true
|
||||||
|
tracing.workspace = true
|
||||||
|
tracing-subscriber.workspace = true
|
||||||
|
anyhow.workspace = true
|
||||||
|
reqwest.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
clap = { version = "4", features = ["derive"] }
|
||||||
112
crates/cortex-cli/src/main.rs
Normal file
112
crates/cortex-cli/src/main.rs
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
use cortex_core::config::GatewayConfig;
|
||||||
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
#[command(name = "cortex")]
|
||||||
|
#[command(about = "Unified inference gateway for multi-node mistral.rs clusters")]
|
||||||
|
#[command(version)]
|
||||||
|
struct Cli {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Commands,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Subcommand)]
|
||||||
|
enum Commands {
|
||||||
|
/// Start the gateway server.
|
||||||
|
Serve {
|
||||||
|
/// Path to the gateway config file.
|
||||||
|
#[arg(short, long, default_value = "cortex.toml")]
|
||||||
|
config: String,
|
||||||
|
},
|
||||||
|
/// Print the fleet status (models, nodes, health).
|
||||||
|
Status {
|
||||||
|
/// Gateway API endpoint to query.
|
||||||
|
#[arg(short, long, default_value = "http://localhost:8000")]
|
||||||
|
endpoint: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
// Initialize tracing with env filter (e.g. RUST_LOG=cortex_gateway=debug).
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(
|
||||||
|
EnvFilter::try_from_default_env()
|
||||||
|
.unwrap_or_else(|_| EnvFilter::new("info,cortex_gateway=debug")),
|
||||||
|
)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let cli = Cli::parse();
|
||||||
|
|
||||||
|
match cli.command {
|
||||||
|
Commands::Serve { config } => {
|
||||||
|
let cfg = GatewayConfig::load(&config).map_err(|e| {
|
||||||
|
anyhow::anyhow!("failed to load config from '{config}': {e}")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
nodes = cfg.nodes.len(),
|
||||||
|
listen = %cfg.gateway.listen,
|
||||||
|
"starting cortex"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Install Prometheus metrics exporter on a separate port.
|
||||||
|
cortex_gateway::metrics::install(&cfg.gateway.metrics_listen)?;
|
||||||
|
|
||||||
|
cortex_gateway::run(cfg).await?;
|
||||||
|
}
|
||||||
|
Commands::Status { endpoint } => {
|
||||||
|
print_status(&endpoint).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn print_status(endpoint: &str) -> Result<()> {
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
// Fetch health.
|
||||||
|
let health: serde_json::Value = client
|
||||||
|
.get(format!("{endpoint}/health"))
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.json()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
println!("Fleet health: {}", serde_json::to_string_pretty(&health)?);
|
||||||
|
|
||||||
|
// Fetch models.
|
||||||
|
let models: serde_json::Value = client
|
||||||
|
.get(format!("{endpoint}/v1/models"))
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.json()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
println!("\nModels:");
|
||||||
|
if let Some(data) = models.get("data").and_then(|d| d.as_array()) {
|
||||||
|
for model in data {
|
||||||
|
let id = model.get("id").and_then(|v| v.as_str()).unwrap_or("?");
|
||||||
|
let locations = model
|
||||||
|
.get("locations")
|
||||||
|
.and_then(|v| v.as_array())
|
||||||
|
.map(|arr| {
|
||||||
|
arr.iter()
|
||||||
|
.filter_map(|l| {
|
||||||
|
let node = l.get("node")?.as_str()?;
|
||||||
|
let status = l.get("status")?.as_str()?;
|
||||||
|
Some(format!("{node}({status})"))
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ")
|
||||||
|
})
|
||||||
|
.unwrap_or_default();
|
||||||
|
println!(" {id:40} {locations}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
15
crates/cortex-core/Cargo.toml
Normal file
15
crates/cortex-core/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
[package]
|
||||||
|
name = "cortex-core"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
toml.workspace = true
|
||||||
|
figment.workspace = true
|
||||||
|
chrono.workspace = true
|
||||||
|
anyhow.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
|
tracing.workspace = true
|
||||||
87
crates/cortex-core/src/anthropic.rs
Normal file
87
crates/cortex-core/src/anthropic.rs
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
//! Anthropic Messages API request and response types.
|
||||||
|
//!
|
||||||
|
//! These mirror the `/v1/messages` format used by the Anthropic API.
|
||||||
|
//! The gateway accepts these, translates to OpenAI format, proxies to
|
||||||
|
//! mistral.rs, then translates the response back.
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
// ── Messages request ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct MessagesRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<AnthropicMessage>,
|
||||||
|
pub max_tokens: u64,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system: Option<SystemPrompt>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub extra: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum SystemPrompt {
|
||||||
|
Text(String),
|
||||||
|
Blocks(Vec<Value>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct AnthropicMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub content: AnthropicContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum AnthropicContent {
|
||||||
|
Text(String),
|
||||||
|
Blocks(Vec<ContentBlock>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ContentBlock {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub block_type: String,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub data: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Messages response ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct MessagesResponse {
|
||||||
|
pub id: String,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub response_type: String,
|
||||||
|
pub role: String,
|
||||||
|
pub content: Vec<ContentBlock>,
|
||||||
|
pub model: String,
|
||||||
|
pub stop_reason: Option<String>,
|
||||||
|
pub usage: AnthropicUsage,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub extra: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct AnthropicUsage {
|
||||||
|
pub input_tokens: u64,
|
||||||
|
pub output_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Streaming events ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct StreamEvent {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub event_type: String,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub data: Value,
|
||||||
|
}
|
||||||
79
crates/cortex-core/src/config.rs
Normal file
79
crates/cortex-core/src/config.rs
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
use figment::{
|
||||||
|
Figment,
|
||||||
|
providers::{Env, Format, Toml},
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct GatewayConfig {
|
||||||
|
pub gateway: GatewaySettings,
|
||||||
|
pub eviction: EvictionSettings,
|
||||||
|
pub nodes: Vec<NodeConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct GatewaySettings {
|
||||||
|
/// Address to listen on for API requests (e.g. "0.0.0.0:8000")
|
||||||
|
pub listen: String,
|
||||||
|
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:9100")
|
||||||
|
pub metrics_listen: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct EvictionSettings {
|
||||||
|
/// Eviction strategy: "lru" or "priority"
|
||||||
|
pub strategy: EvictionStrategy,
|
||||||
|
/// Restart the mistralrs process after this many load/unload cycles
|
||||||
|
/// to reclaim fragmented VRAM. 0 = never.
|
||||||
|
#[serde(default)]
|
||||||
|
pub defrag_after_cycles: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum EvictionStrategy {
|
||||||
|
Lru,
|
||||||
|
Priority,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct NodeConfig {
|
||||||
|
/// Human-readable node name (e.g. "gpu-large")
|
||||||
|
pub name: String,
|
||||||
|
/// Base URL of the mistralrs HTTP server (e.g. "http://gpu-large.internal:8080")
|
||||||
|
pub endpoint: String,
|
||||||
|
/// Total VRAM in MB across all GPUs on this node
|
||||||
|
pub vram_mb: u64,
|
||||||
|
/// Model IDs that should never be evicted from this node
|
||||||
|
#[serde(default)]
|
||||||
|
pub pinned: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GatewayConfig {
|
||||||
|
/// Load configuration from a TOML file, with environment variable overrides.
|
||||||
|
/// Env vars are prefixed with `CORTEX_` and use `__` as a separator
|
||||||
|
/// (e.g. `CORTEX_GATEWAY__LISTEN=0.0.0.0:9000`).
|
||||||
|
pub fn load(path: impl AsRef<Path>) -> Result<Self, figment::Error> {
|
||||||
|
Figment::new()
|
||||||
|
.merge(Toml::file(path))
|
||||||
|
.merge(Env::prefixed("CORTEX_").split("__"))
|
||||||
|
.extract()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for GatewayConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
gateway: GatewaySettings {
|
||||||
|
listen: "0.0.0.0:8000".into(),
|
||||||
|
metrics_listen: "0.0.0.0:9100".into(),
|
||||||
|
},
|
||||||
|
eviction: EvictionSettings {
|
||||||
|
strategy: EvictionStrategy::Lru,
|
||||||
|
defrag_after_cycles: 50,
|
||||||
|
},
|
||||||
|
nodes: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
6
crates/cortex-core/src/lib.rs
Normal file
6
crates/cortex-core/src/lib.rs
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
pub mod anthropic;
|
||||||
|
pub mod config;
|
||||||
|
pub mod metrics;
|
||||||
|
pub mod node;
|
||||||
|
pub mod openai;
|
||||||
|
pub mod translate;
|
||||||
23
crates/cortex-core/src/metrics.rs
Normal file
23
crates/cortex-core/src/metrics.rs
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
//! Request-level metrics captured by the gateway proxy layer.
|
||||||
|
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Metrics captured for a single proxied request.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct RequestMetrics {
|
||||||
|
pub timestamp: DateTime<Utc>,
|
||||||
|
pub model: String,
|
||||||
|
pub node: String,
|
||||||
|
pub prompt_tokens: u64,
|
||||||
|
pub completion_tokens: u64,
|
||||||
|
pub total_tokens: u64,
|
||||||
|
/// Tokens per second for the generation phase.
|
||||||
|
pub tok_per_sec: f64,
|
||||||
|
/// Time from request start to first SSE chunk (streaming) or full response.
|
||||||
|
pub time_to_first_token_ms: u64,
|
||||||
|
/// Total request latency including proxy overhead.
|
||||||
|
pub total_latency_ms: u64,
|
||||||
|
/// Whether this request triggered a model load (cold start).
|
||||||
|
pub cold_start: bool,
|
||||||
|
}
|
||||||
74
crates/cortex-core/src/node.rs
Normal file
74
crates/cortex-core/src/node.rs
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
/// Runtime state of a single node in the fleet.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct NodeState {
|
||||||
|
pub name: String,
|
||||||
|
pub endpoint: String,
|
||||||
|
pub vram_mb: u64,
|
||||||
|
pub pinned: Vec<String>,
|
||||||
|
pub healthy: bool,
|
||||||
|
pub models: HashMap<String, ModelEntry>,
|
||||||
|
/// Number of load/unload cycles since last process restart.
|
||||||
|
pub lifecycle_cycles: u32,
|
||||||
|
pub last_poll: Option<DateTime<Utc>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A model registered on a node, with its runtime status.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModelEntry {
|
||||||
|
pub id: String,
|
||||||
|
pub status: ModelStatus,
|
||||||
|
/// When this model was last used (for LRU eviction).
|
||||||
|
pub last_accessed: Option<DateTime<Utc>>,
|
||||||
|
/// Estimated VRAM usage in MB when loaded.
|
||||||
|
pub vram_estimate_mb: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Model lifecycle status, matching the mistral.rs API.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum ModelStatus {
|
||||||
|
Loaded,
|
||||||
|
Unloaded,
|
||||||
|
Reloading,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
|
||||||
|
/// Includes which node(s) host this model and their status.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct CortexModelEntry {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
/// Which nodes have this model (and their status).
|
||||||
|
pub locations: Vec<ModelLocation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModelLocation {
|
||||||
|
pub node: String,
|
||||||
|
pub status: ModelStatus,
|
||||||
|
pub vram_estimate_mb: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response from mistral.rs `GET /v1/models`.
|
||||||
|
/// This is the upstream format we parse when polling nodes.
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct MistralModelsResponse {
|
||||||
|
pub data: Vec<MistralModelEntry>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct MistralModelEntry {
|
||||||
|
pub id: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub status: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Request body for mistral.rs model lifecycle endpoints.
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct ModelLifecycleRequest {
|
||||||
|
pub model_id: String,
|
||||||
|
}
|
||||||
122
crates/cortex-core/src/openai.rs
Normal file
122
crates/cortex-core/src/openai.rs
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
//! OpenAI-compatible request and response types.
|
||||||
|
//!
|
||||||
|
//! These are a subset sufficient for chat completions (streaming + non-streaming).
|
||||||
|
//! Fields not relevant to proxying are captured as `serde_json::Value` via
|
||||||
|
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
|
||||||
|
//! extension field mistral.rs supports.
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
// ── Chat completion request ──────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<ChatMessage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
/// All other fields (tools, response_format, mistral.rs extensions, etc.)
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub extra: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub content: MessageContent,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub extra: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Content can be a simple string or an array of content parts (for vision).
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum MessageContent {
|
||||||
|
Text(String),
|
||||||
|
Parts(Vec<Value>),
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Chat completion response (non-streaming) ─────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChatCompletionChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub extra: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub message: ChatMessage,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub extra: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Streaming chunk ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionChunk {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChunkChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub extra: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChunkChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub delta: Value,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub extra: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Usage ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: u64,
|
||||||
|
pub completion_tokens: u64,
|
||||||
|
pub total_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Models list response ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModelsResponse {
|
||||||
|
pub object: String,
|
||||||
|
pub data: Vec<ModelObject>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModelObject {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub owned_by: Option<String>,
|
||||||
|
/// Gateway extensions: which node(s) host this model.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub locations: Option<Vec<super::node::ModelLocation>>,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub extra: Value,
|
||||||
|
}
|
||||||
114
crates/cortex-core/src/translate.rs
Normal file
114
crates/cortex-core/src/translate.rs
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
//! Translation between OpenAI and Anthropic request/response envelopes.
|
||||||
|
//!
|
||||||
|
//! This is a stateless transformation — no context is carried between requests.
|
||||||
|
|
||||||
|
use crate::anthropic::{
|
||||||
|
AnthropicContent, AnthropicMessage, AnthropicUsage, ContentBlock, MessagesRequest,
|
||||||
|
MessagesResponse, SystemPrompt,
|
||||||
|
};
|
||||||
|
use crate::openai::{
|
||||||
|
ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, Usage,
|
||||||
|
MessageContent,
|
||||||
|
};
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
|
/// Convert an Anthropic Messages request into an OpenAI ChatCompletion request.
|
||||||
|
pub fn anthropic_to_openai(req: MessagesRequest) -> ChatCompletionRequest {
|
||||||
|
let mut messages = Vec::new();
|
||||||
|
|
||||||
|
// Anthropic `system` field becomes a system message.
|
||||||
|
if let Some(system) = req.system {
|
||||||
|
let content = match system {
|
||||||
|
SystemPrompt::Text(t) => t,
|
||||||
|
SystemPrompt::Blocks(blocks) => serde_json::to_string(&blocks).unwrap_or_default(),
|
||||||
|
};
|
||||||
|
messages.push(ChatMessage {
|
||||||
|
role: "system".into(),
|
||||||
|
content: MessageContent::Text(content),
|
||||||
|
extra: Value::Null,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert message roles and content.
|
||||||
|
for msg in req.messages {
|
||||||
|
let content = match msg.content {
|
||||||
|
AnthropicContent::Text(t) => MessageContent::Text(t),
|
||||||
|
AnthropicContent::Blocks(blocks) => {
|
||||||
|
// For simple text-only blocks, extract the text.
|
||||||
|
// For mixed content (images, etc.), pass as parts.
|
||||||
|
if blocks.len() == 1 && blocks[0].block_type == "text" {
|
||||||
|
let text = blocks[0]
|
||||||
|
.data
|
||||||
|
.get("text")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
MessageContent::Text(text)
|
||||||
|
} else {
|
||||||
|
MessageContent::Parts(
|
||||||
|
blocks.into_iter().map(|b| json!(b)).collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
messages.push(ChatMessage {
|
||||||
|
role: msg.role,
|
||||||
|
content,
|
||||||
|
extra: Value::Null,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatCompletionRequest {
|
||||||
|
model: req.model,
|
||||||
|
messages,
|
||||||
|
temperature: req.temperature,
|
||||||
|
top_p: req.top_p,
|
||||||
|
max_tokens: Some(req.max_tokens),
|
||||||
|
stream: req.stream,
|
||||||
|
extra: req.extra,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert an OpenAI ChatCompletion response into an Anthropic Messages response.
|
||||||
|
pub fn openai_to_anthropic(resp: ChatCompletionResponse) -> MessagesResponse {
|
||||||
|
let choice = resp.choices.into_iter().next();
|
||||||
|
|
||||||
|
let (content_text, stop_reason) = match choice {
|
||||||
|
Some(c) => {
|
||||||
|
let text = match c.message.content {
|
||||||
|
MessageContent::Text(t) => t,
|
||||||
|
MessageContent::Parts(parts) => serde_json::to_string(&parts).unwrap_or_default(),
|
||||||
|
};
|
||||||
|
let stop = c.finish_reason.map(|r| match r.as_str() {
|
||||||
|
"stop" => "end_turn".to_string(),
|
||||||
|
"length" => "max_tokens".to_string(),
|
||||||
|
other => other.to_string(),
|
||||||
|
});
|
||||||
|
(text, stop)
|
||||||
|
}
|
||||||
|
None => (String::new(), None),
|
||||||
|
};
|
||||||
|
|
||||||
|
let usage = resp.usage.unwrap_or(Usage {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
completion_tokens: 0,
|
||||||
|
total_tokens: 0,
|
||||||
|
});
|
||||||
|
|
||||||
|
MessagesResponse {
|
||||||
|
id: resp.id,
|
||||||
|
response_type: "message".into(),
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: vec![ContentBlock {
|
||||||
|
block_type: "text".into(),
|
||||||
|
data: json!({ "text": content_text }),
|
||||||
|
}],
|
||||||
|
model: resp.model,
|
||||||
|
stop_reason,
|
||||||
|
usage: AnthropicUsage {
|
||||||
|
input_tokens: usage.prompt_tokens,
|
||||||
|
output_tokens: usage.completion_tokens,
|
||||||
|
},
|
||||||
|
extra: Value::Null,
|
||||||
|
}
|
||||||
|
}
|
||||||
25
crates/cortex-gateway/Cargo.toml
Normal file
25
crates/cortex-gateway/Cargo.toml
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
[package]
|
||||||
|
name = "cortex-gateway"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
cortex-core.workspace = true
|
||||||
|
tokio.workspace = true
|
||||||
|
axum.workspace = true
|
||||||
|
tower.workspace = true
|
||||||
|
tower-http.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
reqwest.workspace = true
|
||||||
|
tracing.workspace = true
|
||||||
|
metrics.workspace = true
|
||||||
|
metrics-exporter-prometheus.workspace = true
|
||||||
|
chrono.workspace = true
|
||||||
|
anyhow.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
tokio-stream.workspace = true
|
||||||
|
eventsource-stream.workspace = true
|
||||||
|
bytes = "1"
|
||||||
106
crates/cortex-gateway/src/evictor.rs
Normal file
106
crates/cortex-gateway/src/evictor.rs
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
//! Model eviction logic.
|
||||||
|
//!
|
||||||
|
//! The evictor runs as a background task. When the router determines that a
|
||||||
|
//! model needs to be loaded on a node but VRAM is tight, it can request
|
||||||
|
//! eviction via a channel. The evictor then:
|
||||||
|
//! 1. Identifies the LRU model on that node (excluding pinned models)
|
||||||
|
//! 2. Calls `POST /v1/models/unload` on the node
|
||||||
|
//! 3. Increments the lifecycle cycle counter (for defrag tracking)
|
||||||
|
|
||||||
|
use crate::state::CortexState;
|
||||||
|
use cortex_core::node::{ModelLifecycleRequest, ModelStatus};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
/// Runs forever. Currently a placeholder that periodically checks for
|
||||||
|
/// eviction opportunities. In the future, this will be driven by a
|
||||||
|
/// channel from the router when VRAM pressure is detected.
|
||||||
|
pub async fn eviction_loop(fleet: Arc<CortexState>) {
|
||||||
|
// TODO: Replace this polling approach with a channel-driven design
|
||||||
|
// where the router sends eviction requests when it detects that a
|
||||||
|
// model load would exceed available VRAM.
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||||
|
// Placeholder: the actual eviction logic is in `evict_lru_on_node`,
|
||||||
|
// called on demand by the router.
|
||||||
|
let _ = &fleet; // suppress unused warning
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Evict the least-recently-used model on a given node.
|
||||||
|
/// Returns the model ID that was evicted, or None if nothing could be evicted.
|
||||||
|
pub async fn evict_lru_on_node(
|
||||||
|
fleet: &CortexState,
|
||||||
|
node_name: &str,
|
||||||
|
) -> anyhow::Result<Option<String>> {
|
||||||
|
let (endpoint, candidate) = {
|
||||||
|
let nodes = fleet.nodes.read().await;
|
||||||
|
let Some(node) = nodes.get(node_name) else {
|
||||||
|
anyhow::bail!("node '{node_name}' not found");
|
||||||
|
};
|
||||||
|
|
||||||
|
// Find the loaded model with the oldest last_accessed, excluding pinned.
|
||||||
|
let candidate = node
|
||||||
|
.models
|
||||||
|
.values()
|
||||||
|
.filter(|m| m.status == ModelStatus::Loaded)
|
||||||
|
.filter(|m| !node.pinned.contains(&m.id))
|
||||||
|
.min_by_key(|m| m.last_accessed)
|
||||||
|
.map(|m| m.id.clone());
|
||||||
|
|
||||||
|
(node.endpoint.clone(), candidate)
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(model_id) = candidate else {
|
||||||
|
tracing::info!(node = node_name, "no evictable models found");
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::info!(node = node_name, model = %model_id, "evicting model");
|
||||||
|
|
||||||
|
let url = format!("{endpoint}/v1/models/unload");
|
||||||
|
let resp = fleet
|
||||||
|
.http_client
|
||||||
|
.post(&url)
|
||||||
|
.json(&ModelLifecycleRequest {
|
||||||
|
model_id: model_id.clone(),
|
||||||
|
})
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if resp.status().is_success() {
|
||||||
|
// Update local state.
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
if let Some(node) = nodes.get_mut(node_name) {
|
||||||
|
if let Some(entry) = node.models.get_mut(&model_id) {
|
||||||
|
entry.status = ModelStatus::Unloaded;
|
||||||
|
}
|
||||||
|
node.lifecycle_cycles += 1;
|
||||||
|
|
||||||
|
// Check if we should flag for defrag.
|
||||||
|
if fleet.eviction.defrag_after_cycles > 0
|
||||||
|
&& node.lifecycle_cycles >= fleet.eviction.defrag_after_cycles
|
||||||
|
{
|
||||||
|
tracing::warn!(
|
||||||
|
node = node_name,
|
||||||
|
cycles = node.lifecycle_cycles,
|
||||||
|
"VRAM fragmentation threshold reached — consider restarting mistralrs"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(node = node_name, model = %model_id, "model evicted");
|
||||||
|
Ok(Some(model_id))
|
||||||
|
} else {
|
||||||
|
let status = resp.status();
|
||||||
|
let body = resp.text().await.unwrap_or_default();
|
||||||
|
tracing::error!(
|
||||||
|
node = node_name,
|
||||||
|
model = %model_id,
|
||||||
|
status = %status,
|
||||||
|
body = %body,
|
||||||
|
"failed to evict model"
|
||||||
|
);
|
||||||
|
anyhow::bail!("eviction failed: {status} {body}");
|
||||||
|
}
|
||||||
|
}
|
||||||
207
crates/cortex-gateway/src/handlers.rs
Normal file
207
crates/cortex-gateway/src/handlers.rs
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
//! Axum HTTP handlers for the gateway API surface.
|
||||||
|
|
||||||
|
use crate::proxy;
|
||||||
|
use crate::router;
|
||||||
|
use crate::state::CortexState;
|
||||||
|
use axum::body::Bytes;
|
||||||
|
use axum::extract::State;
|
||||||
|
use axum::http::HeaderMap;
|
||||||
|
use axum::response::{IntoResponse, Json, Response};
|
||||||
|
use axum::routing::{get, post};
|
||||||
|
use axum::Router;
|
||||||
|
use cortex_core::node::{CortexModelEntry, ModelLocation};
|
||||||
|
use cortex_core::openai::ChatCompletionRequest;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub fn api_routes() -> Router<Arc<CortexState>> {
|
||||||
|
Router::new()
|
||||||
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
|
.route("/v1/completions", post(completions))
|
||||||
|
.route("/v1/models", get(list_models))
|
||||||
|
.route("/v1/messages", post(anthropic_messages))
|
||||||
|
.route("/health", get(health))
|
||||||
|
.route("/", get(health))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `POST /v1/chat/completions` — proxy to the appropriate backend node.
|
||||||
|
async fn chat_completions(
|
||||||
|
State(fleet): State<Arc<CortexState>>,
|
||||||
|
headers: HeaderMap,
|
||||||
|
body: Bytes,
|
||||||
|
) -> Response {
|
||||||
|
let model_id = match extract_model(&body) {
|
||||||
|
Some(m) => m,
|
||||||
|
None => return error_response(400, "missing 'model' field in request body"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let route = match router::resolve(&fleet, &model_id).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => return error_response(404, &e.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
match proxy::forward_request(&fleet.http_client, &route, "/v1/chat/completions", headers, body)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(resp) => resp,
|
||||||
|
Err(e) => e.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `POST /v1/completions` — proxy completions endpoint.
|
||||||
|
async fn completions(
|
||||||
|
State(fleet): State<Arc<CortexState>>,
|
||||||
|
headers: HeaderMap,
|
||||||
|
body: Bytes,
|
||||||
|
) -> Response {
|
||||||
|
let model_id = match extract_model(&body) {
|
||||||
|
Some(m) => m,
|
||||||
|
None => return error_response(400, "missing 'model' field in request body"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let route = match router::resolve(&fleet, &model_id).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => return error_response(404, &e.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
match proxy::forward_request(&fleet.http_client, &route, "/v1/completions", headers, body)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(resp) => resp,
|
||||||
|
Err(e) => e.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back.
|
||||||
|
async fn anthropic_messages(
|
||||||
|
State(fleet): State<Arc<CortexState>>,
|
||||||
|
headers: HeaderMap,
|
||||||
|
body: Bytes,
|
||||||
|
) -> Response {
|
||||||
|
// Parse as Anthropic request.
|
||||||
|
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => return error_response(400, &format!("invalid Anthropic request: {e}")),
|
||||||
|
};
|
||||||
|
|
||||||
|
let model_id = anth_req.model.clone();
|
||||||
|
let is_streaming = anth_req.stream.unwrap_or(false);
|
||||||
|
|
||||||
|
// Translate to OpenAI format.
|
||||||
|
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
|
||||||
|
let openai_body = match serde_json::to_vec(&openai_req) {
|
||||||
|
Ok(b) => Bytes::from(b),
|
||||||
|
Err(e) => return error_response(500, &format!("translation error: {e}")),
|
||||||
|
};
|
||||||
|
|
||||||
|
let route = match router::resolve(&fleet, &model_id).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => return error_response(404, &e.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
if is_streaming {
|
||||||
|
// TODO: streaming Anthropic translation requires converting SSE format.
|
||||||
|
// For now, proxy the OpenAI SSE stream directly (clients that can handle
|
||||||
|
// OpenAI SSE will work; full Anthropic SSE translation is a follow-up).
|
||||||
|
match proxy::forward_request(
|
||||||
|
&fleet.http_client,
|
||||||
|
&route,
|
||||||
|
"/v1/chat/completions",
|
||||||
|
headers,
|
||||||
|
openai_body,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(resp) => resp,
|
||||||
|
Err(e) => e.into_response(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Non-streaming: proxy, await full response, translate back.
|
||||||
|
match proxy::forward_request(
|
||||||
|
&fleet.http_client,
|
||||||
|
&route,
|
||||||
|
"/v1/chat/completions",
|
||||||
|
headers,
|
||||||
|
openai_body,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(resp) => {
|
||||||
|
// TODO: buffer response, parse as OpenAI ChatCompletionResponse,
|
||||||
|
// translate to Anthropic MessagesResponse.
|
||||||
|
// For now, return the OpenAI response as-is.
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
Err(e) => e.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `GET /v1/models` — aggregate models from all nodes.
|
||||||
|
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
||||||
|
let nodes = fleet.nodes.read().await;
|
||||||
|
let mut model_map: std::collections::HashMap<String, CortexModelEntry> =
|
||||||
|
std::collections::HashMap::new();
|
||||||
|
|
||||||
|
for node in nodes.values() {
|
||||||
|
for (model_id, entry) in &node.models {
|
||||||
|
let location = ModelLocation {
|
||||||
|
node: node.name.clone(),
|
||||||
|
status: entry.status,
|
||||||
|
vram_estimate_mb: entry.vram_estimate_mb,
|
||||||
|
};
|
||||||
|
model_map
|
||||||
|
.entry(model_id.clone())
|
||||||
|
.and_modify(|e| e.locations.push(location.clone()))
|
||||||
|
.or_insert_with(|| CortexModelEntry {
|
||||||
|
id: model_id.clone(),
|
||||||
|
object: "model".into(),
|
||||||
|
locations: vec![location],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let data: Vec<Value> = model_map
|
||||||
|
.values()
|
||||||
|
.map(|e| json!(e))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Json(json!({
|
||||||
|
"object": "list",
|
||||||
|
"data": data,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `GET /health`
|
||||||
|
async fn health(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
||||||
|
let nodes = fleet.nodes.read().await;
|
||||||
|
let healthy_count = nodes.values().filter(|n| n.healthy).count();
|
||||||
|
let total_count = nodes.len();
|
||||||
|
|
||||||
|
Json(json!({
|
||||||
|
"status": if healthy_count > 0 { "ok" } else { "degraded" },
|
||||||
|
"nodes": {
|
||||||
|
"healthy": healthy_count,
|
||||||
|
"total": total_count,
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn extract_model(body: &[u8]) -> Option<String> {
|
||||||
|
let v: Value = serde_json::from_slice(body).ok()?;
|
||||||
|
v.get("model")?.as_str().map(|s| s.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn error_response(status: u16, message: &str) -> Response {
|
||||||
|
let code = axum::http::StatusCode::from_u16(status)
|
||||||
|
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
let body = json!({
|
||||||
|
"error": {
|
||||||
|
"message": message,
|
||||||
|
"type": "gateway_error",
|
||||||
|
}
|
||||||
|
});
|
||||||
|
(code, Json(body)).into_response()
|
||||||
|
}
|
||||||
51
crates/cortex-gateway/src/lib.rs
Normal file
51
crates/cortex-gateway/src/lib.rs
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
pub mod evictor;
|
||||||
|
pub mod handlers;
|
||||||
|
pub mod metrics;
|
||||||
|
pub mod poller;
|
||||||
|
pub mod proxy;
|
||||||
|
pub mod router;
|
||||||
|
pub mod state;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use axum::Router;
|
||||||
|
use cortex_core::config::GatewayConfig;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tower_http::cors::CorsLayer;
|
||||||
|
use tower_http::trace::TraceLayer;
|
||||||
|
|
||||||
|
/// Build the Axum application router with all routes wired up.
|
||||||
|
pub fn build_app(fleet: Arc<state::CortexState>) -> Router {
|
||||||
|
Router::new()
|
||||||
|
.merge(handlers::api_routes())
|
||||||
|
.layer(CorsLayer::permissive())
|
||||||
|
.layer(TraceLayer::new_for_http())
|
||||||
|
.with_state(fleet)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start the gateway: build state from config, spawn background tasks,
|
||||||
|
/// bind the HTTP server.
|
||||||
|
pub async fn run(config: GatewayConfig) -> Result<()> {
|
||||||
|
let fleet = Arc::new(state::CortexState::from_config(&config));
|
||||||
|
|
||||||
|
// Spawn the background poller that refreshes node/model status.
|
||||||
|
let poller_fleet = Arc::clone(&fleet);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
poller::poll_loop(poller_fleet).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Spawn the evictor (reacts to VRAM pressure events from the router).
|
||||||
|
let evictor_fleet = Arc::clone(&fleet);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
evictor::eviction_loop(evictor_fleet).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let app = build_app(Arc::clone(&fleet));
|
||||||
|
|
||||||
|
let listen_addr = config.gateway.listen.parse::<std::net::SocketAddr>()?;
|
||||||
|
tracing::info!("cortex listening on {listen_addr}");
|
||||||
|
|
||||||
|
let listener = tokio::net::TcpListener::bind(listen_addr).await?;
|
||||||
|
axum::serve(listener, app).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
55
crates/cortex-gateway/src/metrics.rs
Normal file
55
crates/cortex-gateway/src/metrics.rs
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
//! Prometheus metrics exporter.
|
||||||
|
//!
|
||||||
|
//! Runs on a separate port from the main API, exposing `/metrics`
|
||||||
|
//! in Prometheus text format.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use metrics_exporter_prometheus::PrometheusBuilder;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
|
/// Install the Prometheus metrics recorder and return a handle.
|
||||||
|
/// The `/metrics` endpoint is served by the exporter's built-in HTTP server.
|
||||||
|
pub fn install(listen: &str) -> Result<()> {
|
||||||
|
let addr: SocketAddr = listen.parse()?;
|
||||||
|
|
||||||
|
PrometheusBuilder::new()
|
||||||
|
.with_http_listener(addr)
|
||||||
|
.install()
|
||||||
|
.map_err(|e| anyhow::anyhow!("failed to install Prometheus exporter: {e}"))?;
|
||||||
|
|
||||||
|
tracing::info!("prometheus metrics exporter on {addr}");
|
||||||
|
|
||||||
|
// Register histograms and counters used by the proxy layer.
|
||||||
|
// The `metrics` crate lazily creates metrics on first use, but
|
||||||
|
// describing them up front gives Prometheus proper HELP/TYPE lines.
|
||||||
|
metrics::describe_histogram!(
|
||||||
|
"cortex_request_duration_seconds",
|
||||||
|
"Total request latency in seconds"
|
||||||
|
);
|
||||||
|
metrics::describe_histogram!(
|
||||||
|
"cortex_time_to_first_token_seconds",
|
||||||
|
"Time to first token in seconds"
|
||||||
|
);
|
||||||
|
metrics::describe_histogram!(
|
||||||
|
"cortex_tokens_per_second",
|
||||||
|
"Generation throughput in tokens per second"
|
||||||
|
);
|
||||||
|
metrics::describe_counter!(
|
||||||
|
"cortex_requests_total",
|
||||||
|
"Total number of proxied requests"
|
||||||
|
);
|
||||||
|
metrics::describe_counter!(
|
||||||
|
"cortex_request_errors_total",
|
||||||
|
"Total number of failed proxy requests"
|
||||||
|
);
|
||||||
|
metrics::describe_counter!(
|
||||||
|
"cortex_evictions_total",
|
||||||
|
"Total number of model evictions"
|
||||||
|
);
|
||||||
|
metrics::describe_counter!(
|
||||||
|
"cortex_cold_starts_total",
|
||||||
|
"Total number of cold-start model loads"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
103
crates/cortex-gateway/src/poller.rs
Normal file
103
crates/cortex-gateway/src/poller.rs
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
//! Background poller that periodically queries each node's `/v1/models`
|
||||||
|
//! endpoint to refresh the fleet state.
|
||||||
|
|
||||||
|
use crate::state::CortexState;
|
||||||
|
use chrono::Utc;
|
||||||
|
use cortex_core::node::{MistralModelsResponse, ModelEntry, ModelStatus};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
const POLL_INTERVAL: Duration = Duration::from_secs(10);
|
||||||
|
|
||||||
|
/// Runs forever, polling all nodes on a fixed interval.
|
||||||
|
pub async fn poll_loop(fleet: Arc<CortexState>) {
|
||||||
|
loop {
|
||||||
|
for nc in &fleet.node_configs {
|
||||||
|
poll_node(&fleet, &nc.name, &nc.endpoint).await;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(POLL_INTERVAL).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||||
|
let url = format!("{endpoint}/v1/models");
|
||||||
|
|
||||||
|
let result = fleet
|
||||||
|
.http_client
|
||||||
|
.get(&url)
|
||||||
|
.timeout(Duration::from_secs(5))
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
let Some(node) = nodes.get_mut(name) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(resp) if resp.status().is_success() => {
|
||||||
|
match resp.json::<MistralModelsResponse>().await {
|
||||||
|
Ok(models_resp) => {
|
||||||
|
// Merge upstream model list into our state, preserving
|
||||||
|
// our local metadata (last_accessed, vram_estimate).
|
||||||
|
let mut seen = std::collections::HashSet::new();
|
||||||
|
for upstream in &models_resp.data {
|
||||||
|
seen.insert(upstream.id.clone());
|
||||||
|
let status = parse_status(upstream.status.as_deref());
|
||||||
|
|
||||||
|
node.models
|
||||||
|
.entry(upstream.id.clone())
|
||||||
|
.and_modify(|e| {
|
||||||
|
e.status = status;
|
||||||
|
})
|
||||||
|
.or_insert_with(|| ModelEntry {
|
||||||
|
id: upstream.id.clone(),
|
||||||
|
status,
|
||||||
|
last_accessed: None,
|
||||||
|
vram_estimate_mb: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove models that are no longer reported by the node
|
||||||
|
// (e.g. after a config change / restart).
|
||||||
|
node.models.retain(|id, _| seen.contains(id));
|
||||||
|
|
||||||
|
node.healthy = true;
|
||||||
|
node.last_poll = Some(Utc::now());
|
||||||
|
tracing::debug!(
|
||||||
|
node = name,
|
||||||
|
models = models_resp.data.len(),
|
||||||
|
"poll ok"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(node = name, error = %e, "failed to parse /v1/models response");
|
||||||
|
node.healthy = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(resp) => {
|
||||||
|
tracing::warn!(
|
||||||
|
node = name,
|
||||||
|
status = %resp.status(),
|
||||||
|
"node returned non-success status"
|
||||||
|
);
|
||||||
|
node.healthy = false;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(node = name, error = %e, "failed to reach node");
|
||||||
|
node.healthy = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_status(s: Option<&str>) -> ModelStatus {
|
||||||
|
match s {
|
||||||
|
Some("loaded") => ModelStatus::Loaded,
|
||||||
|
Some("unloaded") => ModelStatus::Unloaded,
|
||||||
|
Some("reloading") => ModelStatus::Reloading,
|
||||||
|
// If the status field is absent, assume loaded (older mistral.rs versions
|
||||||
|
// may not include it).
|
||||||
|
_ => ModelStatus::Loaded,
|
||||||
|
}
|
||||||
|
}
|
||||||
82
crates/cortex-gateway/src/proxy.rs
Normal file
82
crates/cortex-gateway/src/proxy.rs
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
//! Streaming HTTP reverse proxy to mistral.rs backends.
|
||||||
|
//!
|
||||||
|
//! For streaming requests, SSE chunks are forwarded as they arrive.
|
||||||
|
//! The proxy captures timing information for metrics but does not
|
||||||
|
//! buffer the full response.
|
||||||
|
|
||||||
|
use crate::router::RouteDecision;
|
||||||
|
use anyhow::Result;
|
||||||
|
use axum::body::Body;
|
||||||
|
use axum::http::{HeaderMap, StatusCode};
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use reqwest::Client;
|
||||||
|
|
||||||
|
/// Proxy a request body to the resolved backend node and stream the response.
|
||||||
|
pub async fn forward_request(
|
||||||
|
client: &Client,
|
||||||
|
route: &RouteDecision,
|
||||||
|
path: &str,
|
||||||
|
headers: HeaderMap,
|
||||||
|
body: bytes::Bytes,
|
||||||
|
) -> Result<Response, ProxyError> {
|
||||||
|
let url = format!("{}{}", route.endpoint, path);
|
||||||
|
tracing::info!(
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %url,
|
||||||
|
cold_start = route.cold_start,
|
||||||
|
"proxying request"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut req_builder = client.post(&url).body(body);
|
||||||
|
|
||||||
|
// Forward relevant headers.
|
||||||
|
for (key, value) in headers.iter() {
|
||||||
|
if key == "host" || key == "content-length" {
|
||||||
|
continue; // reqwest sets these
|
||||||
|
}
|
||||||
|
req_builder = req_builder.header(key, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
let upstream_resp = req_builder.send().await.map_err(ProxyError::Upstream)?;
|
||||||
|
|
||||||
|
let status = StatusCode::from_u16(upstream_resp.status().as_u16())
|
||||||
|
.unwrap_or(StatusCode::BAD_GATEWAY);
|
||||||
|
|
||||||
|
let resp_headers = upstream_resp.headers().clone();
|
||||||
|
let stream = upstream_resp.bytes_stream();
|
||||||
|
|
||||||
|
let body = Body::from_stream(stream);
|
||||||
|
|
||||||
|
let mut response = Response::builder().status(status);
|
||||||
|
for (key, value) in resp_headers.iter() {
|
||||||
|
response = response.header(key, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
response
|
||||||
|
.body(body)
|
||||||
|
.map_err(|e| ProxyError::ResponseBuild(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ProxyError {
|
||||||
|
#[error("upstream request failed: {0}")]
|
||||||
|
Upstream(reqwest::Error),
|
||||||
|
#[error("failed to build response: {0}")]
|
||||||
|
ResponseBuild(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for ProxyError {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
let status = match &self {
|
||||||
|
ProxyError::Upstream(_) => StatusCode::BAD_GATEWAY,
|
||||||
|
ProxyError::ResponseBuild(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
};
|
||||||
|
let body = serde_json::json!({
|
||||||
|
"error": {
|
||||||
|
"message": self.to_string(),
|
||||||
|
"type": "proxy_error",
|
||||||
|
}
|
||||||
|
});
|
||||||
|
(status, axum::Json(body)).into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
74
crates/cortex-gateway/src/router.rs
Normal file
74
crates/cortex-gateway/src/router.rs
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
//! Model-to-node routing logic.
|
||||||
|
//!
|
||||||
|
//! Given a model ID from an inbound request, determine which node should
|
||||||
|
//! handle it. Priority:
|
||||||
|
//! 1. Node where the model is currently `Loaded`
|
||||||
|
//! 2. Node where the model is `Unloaded` (will lazy-load on request)
|
||||||
|
//! 3. Error: model not found on any node
|
||||||
|
|
||||||
|
use crate::state::CortexState;
|
||||||
|
use cortex_core::node::ModelStatus;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// The routing decision: which node endpoint to proxy the request to.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RouteDecision {
|
||||||
|
pub node_name: String,
|
||||||
|
pub endpoint: String,
|
||||||
|
/// Whether the model will need to load (cold start).
|
||||||
|
pub cold_start: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum RouteError {
|
||||||
|
#[error("model '{0}' not found on any node")]
|
||||||
|
ModelNotFound(String),
|
||||||
|
#[error("no healthy nodes available")]
|
||||||
|
NoHealthyNodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve which node should serve a request for the given model.
|
||||||
|
pub async fn resolve(fleet: &Arc<CortexState>, model_id: &str) -> Result<RouteDecision, RouteError> {
|
||||||
|
let nodes = fleet.nodes.read().await;
|
||||||
|
|
||||||
|
// Pass 1: find a node where the model is already loaded.
|
||||||
|
let mut loaded_candidate = None;
|
||||||
|
let mut unloaded_candidate = None;
|
||||||
|
|
||||||
|
for node in nodes.values() {
|
||||||
|
if !node.healthy {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if let Some(entry) = node.models.get(model_id) {
|
||||||
|
match entry.status {
|
||||||
|
ModelStatus::Loaded | ModelStatus::Reloading => {
|
||||||
|
loaded_candidate = Some(RouteDecision {
|
||||||
|
node_name: node.name.clone(),
|
||||||
|
endpoint: node.endpoint.clone(),
|
||||||
|
cold_start: false,
|
||||||
|
});
|
||||||
|
break; // loaded is best, stop searching
|
||||||
|
}
|
||||||
|
ModelStatus::Unloaded => {
|
||||||
|
if unloaded_candidate.is_none() {
|
||||||
|
unloaded_candidate = Some(RouteDecision {
|
||||||
|
node_name: node.name.clone(),
|
||||||
|
endpoint: node.endpoint.clone(),
|
||||||
|
cold_start: true,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded_candidate
|
||||||
|
.or(unloaded_candidate)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
if nodes.values().any(|n| n.healthy) {
|
||||||
|
RouteError::ModelNotFound(model_id.to_string())
|
||||||
|
} else {
|
||||||
|
RouteError::NoHealthyNodes
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
43
crates/cortex-gateway/src/state.rs
Normal file
43
crates/cortex-gateway/src/state.rs
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
use cortex_core::config::{EvictionSettings, GatewayConfig, NodeConfig};
|
||||||
|
use cortex_core::node::NodeState;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
/// Shared fleet state, protected by a RwLock for concurrent reader access.
|
||||||
|
pub struct CortexState {
|
||||||
|
pub nodes: RwLock<HashMap<String, NodeState>>,
|
||||||
|
pub node_configs: Vec<NodeConfig>,
|
||||||
|
pub eviction: EvictionSettings,
|
||||||
|
pub http_client: reqwest::Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CortexState {
|
||||||
|
pub fn from_config(config: &GatewayConfig) -> Self {
|
||||||
|
let mut nodes = HashMap::new();
|
||||||
|
for nc in &config.nodes {
|
||||||
|
nodes.insert(
|
||||||
|
nc.name.clone(),
|
||||||
|
NodeState {
|
||||||
|
name: nc.name.clone(),
|
||||||
|
endpoint: nc.endpoint.clone(),
|
||||||
|
vram_mb: nc.vram_mb,
|
||||||
|
pinned: nc.pinned.clone(),
|
||||||
|
healthy: false, // will be set by first poll
|
||||||
|
models: HashMap::new(),
|
||||||
|
lifecycle_cycles: 0,
|
||||||
|
last_poll: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
nodes: RwLock::new(nodes),
|
||||||
|
node_configs: config.nodes.clone(),
|
||||||
|
eviction: config.eviction.clone(),
|
||||||
|
http_client: reqwest::Client::builder()
|
||||||
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
|
.build()
|
||||||
|
.expect("failed to build HTTP client"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user