feat: scaffold cortex workspace
Rust reverse-proxy for multi-node mistral.rs inference clusters. Includes crate structure (cortex-core, cortex-gateway, cortex-agent, cortex-cli), config loading, OpenAI/Anthropic translation stubs, model routing, eviction, polling, and streaming proxy scaffolding. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
25
crates/cortex-gateway/Cargo.toml
Normal file
25
crates/cortex-gateway/Cargo.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "cortex-gateway"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cortex-core.workspace = true
|
||||
tokio.workspace = true
|
||||
axum.workspace = true
|
||||
tower.workspace = true
|
||||
tower-http.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
reqwest.workspace = true
|
||||
tracing.workspace = true
|
||||
metrics.workspace = true
|
||||
metrics-exporter-prometheus.workspace = true
|
||||
chrono.workspace = true
|
||||
anyhow.workspace = true
|
||||
thiserror.workspace = true
|
||||
futures.workspace = true
|
||||
tokio-stream.workspace = true
|
||||
eventsource-stream.workspace = true
|
||||
bytes = "1"
|
||||
106
crates/cortex-gateway/src/evictor.rs
Normal file
106
crates/cortex-gateway/src/evictor.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
//! Model eviction logic.
|
||||
//!
|
||||
//! The evictor runs as a background task. When the router determines that a
|
||||
//! model needs to be loaded on a node but VRAM is tight, it can request
|
||||
//! eviction via a channel. The evictor then:
|
||||
//! 1. Identifies the LRU model on that node (excluding pinned models)
|
||||
//! 2. Calls `POST /v1/models/unload` on the node
|
||||
//! 3. Increments the lifecycle cycle counter (for defrag tracking)
|
||||
|
||||
use crate::state::CortexState;
|
||||
use cortex_core::node::{ModelLifecycleRequest, ModelStatus};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Runs forever. Currently a placeholder that periodically checks for
|
||||
/// eviction opportunities. In the future, this will be driven by a
|
||||
/// channel from the router when VRAM pressure is detected.
|
||||
pub async fn eviction_loop(fleet: Arc<CortexState>) {
|
||||
// TODO: Replace this polling approach with a channel-driven design
|
||||
// where the router sends eviction requests when it detects that a
|
||||
// model load would exceed available VRAM.
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||
// Placeholder: the actual eviction logic is in `evict_lru_on_node`,
|
||||
// called on demand by the router.
|
||||
let _ = &fleet; // suppress unused warning
|
||||
}
|
||||
}
|
||||
|
||||
/// Evict the least-recently-used model on a given node.
|
||||
/// Returns the model ID that was evicted, or None if nothing could be evicted.
|
||||
pub async fn evict_lru_on_node(
|
||||
fleet: &CortexState,
|
||||
node_name: &str,
|
||||
) -> anyhow::Result<Option<String>> {
|
||||
let (endpoint, candidate) = {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let Some(node) = nodes.get(node_name) else {
|
||||
anyhow::bail!("node '{node_name}' not found");
|
||||
};
|
||||
|
||||
// Find the loaded model with the oldest last_accessed, excluding pinned.
|
||||
let candidate = node
|
||||
.models
|
||||
.values()
|
||||
.filter(|m| m.status == ModelStatus::Loaded)
|
||||
.filter(|m| !node.pinned.contains(&m.id))
|
||||
.min_by_key(|m| m.last_accessed)
|
||||
.map(|m| m.id.clone());
|
||||
|
||||
(node.endpoint.clone(), candidate)
|
||||
};
|
||||
|
||||
let Some(model_id) = candidate else {
|
||||
tracing::info!(node = node_name, "no evictable models found");
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
tracing::info!(node = node_name, model = %model_id, "evicting model");
|
||||
|
||||
let url = format!("{endpoint}/v1/models/unload");
|
||||
let resp = fleet
|
||||
.http_client
|
||||
.post(&url)
|
||||
.json(&ModelLifecycleRequest {
|
||||
model_id: model_id.clone(),
|
||||
})
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if resp.status().is_success() {
|
||||
// Update local state.
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
if let Some(node) = nodes.get_mut(node_name) {
|
||||
if let Some(entry) = node.models.get_mut(&model_id) {
|
||||
entry.status = ModelStatus::Unloaded;
|
||||
}
|
||||
node.lifecycle_cycles += 1;
|
||||
|
||||
// Check if we should flag for defrag.
|
||||
if fleet.eviction.defrag_after_cycles > 0
|
||||
&& node.lifecycle_cycles >= fleet.eviction.defrag_after_cycles
|
||||
{
|
||||
tracing::warn!(
|
||||
node = node_name,
|
||||
cycles = node.lifecycle_cycles,
|
||||
"VRAM fragmentation threshold reached — consider restarting mistralrs"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(node = node_name, model = %model_id, "model evicted");
|
||||
Ok(Some(model_id))
|
||||
} else {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
tracing::error!(
|
||||
node = node_name,
|
||||
model = %model_id,
|
||||
status = %status,
|
||||
body = %body,
|
||||
"failed to evict model"
|
||||
);
|
||||
anyhow::bail!("eviction failed: {status} {body}");
|
||||
}
|
||||
}
|
||||
207
crates/cortex-gateway/src/handlers.rs
Normal file
207
crates/cortex-gateway/src/handlers.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
//! Axum HTTP handlers for the gateway API surface.
|
||||
|
||||
use crate::proxy;
|
||||
use crate::router;
|
||||
use crate::state::CortexState;
|
||||
use axum::body::Bytes;
|
||||
use axum::extract::State;
|
||||
use axum::http::HeaderMap;
|
||||
use axum::response::{IntoResponse, Json, Response};
|
||||
use axum::routing::{get, post};
|
||||
use axum::Router;
|
||||
use cortex_core::node::{CortexModelEntry, ModelLocation};
|
||||
use cortex_core::openai::ChatCompletionRequest;
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub fn api_routes() -> Router<Arc<CortexState>> {
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/completions", post(completions))
|
||||
.route("/v1/models", get(list_models))
|
||||
.route("/v1/messages", post(anthropic_messages))
|
||||
.route("/health", get(health))
|
||||
.route("/", get(health))
|
||||
}
|
||||
|
||||
/// `POST /v1/chat/completions` — proxy to the appropriate backend node.
|
||||
async fn chat_completions(
|
||||
State(fleet): State<Arc<CortexState>>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
let model_id = match extract_model(&body) {
|
||||
Some(m) => m,
|
||||
None => return error_response(400, "missing 'model' field in request body"),
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return error_response(404, &e.to_string()),
|
||||
};
|
||||
|
||||
match proxy::forward_request(&fleet.http_client, &route, "/v1/chat/completions", headers, body)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => e.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// `POST /v1/completions` — proxy completions endpoint.
|
||||
async fn completions(
|
||||
State(fleet): State<Arc<CortexState>>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
let model_id = match extract_model(&body) {
|
||||
Some(m) => m,
|
||||
None => return error_response(400, "missing 'model' field in request body"),
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return error_response(404, &e.to_string()),
|
||||
};
|
||||
|
||||
match proxy::forward_request(&fleet.http_client, &route, "/v1/completions", headers, body)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => e.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back.
|
||||
async fn anthropic_messages(
|
||||
State(fleet): State<Arc<CortexState>>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
// Parse as Anthropic request.
|
||||
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
|
||||
Ok(r) => r,
|
||||
Err(e) => return error_response(400, &format!("invalid Anthropic request: {e}")),
|
||||
};
|
||||
|
||||
let model_id = anth_req.model.clone();
|
||||
let is_streaming = anth_req.stream.unwrap_or(false);
|
||||
|
||||
// Translate to OpenAI format.
|
||||
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
|
||||
let openai_body = match serde_json::to_vec(&openai_req) {
|
||||
Ok(b) => Bytes::from(b),
|
||||
Err(e) => return error_response(500, &format!("translation error: {e}")),
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return error_response(404, &e.to_string()),
|
||||
};
|
||||
|
||||
if is_streaming {
|
||||
// TODO: streaming Anthropic translation requires converting SSE format.
|
||||
// For now, proxy the OpenAI SSE stream directly (clients that can handle
|
||||
// OpenAI SSE will work; full Anthropic SSE translation is a follow-up).
|
||||
match proxy::forward_request(
|
||||
&fleet.http_client,
|
||||
&route,
|
||||
"/v1/chat/completions",
|
||||
headers,
|
||||
openai_body,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => e.into_response(),
|
||||
}
|
||||
} else {
|
||||
// Non-streaming: proxy, await full response, translate back.
|
||||
match proxy::forward_request(
|
||||
&fleet.http_client,
|
||||
&route,
|
||||
"/v1/chat/completions",
|
||||
headers,
|
||||
openai_body,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
// TODO: buffer response, parse as OpenAI ChatCompletionResponse,
|
||||
// translate to Anthropic MessagesResponse.
|
||||
// For now, return the OpenAI response as-is.
|
||||
resp
|
||||
}
|
||||
Err(e) => e.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `GET /v1/models` — aggregate models from all nodes.
|
||||
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let mut model_map: std::collections::HashMap<String, CortexModelEntry> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
for node in nodes.values() {
|
||||
for (model_id, entry) in &node.models {
|
||||
let location = ModelLocation {
|
||||
node: node.name.clone(),
|
||||
status: entry.status,
|
||||
vram_estimate_mb: entry.vram_estimate_mb,
|
||||
};
|
||||
model_map
|
||||
.entry(model_id.clone())
|
||||
.and_modify(|e| e.locations.push(location.clone()))
|
||||
.or_insert_with(|| CortexModelEntry {
|
||||
id: model_id.clone(),
|
||||
object: "model".into(),
|
||||
locations: vec![location],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let data: Vec<Value> = model_map
|
||||
.values()
|
||||
.map(|e| json!(e))
|
||||
.collect();
|
||||
|
||||
Json(json!({
|
||||
"object": "list",
|
||||
"data": data,
|
||||
}))
|
||||
}
|
||||
|
||||
/// `GET /health`
|
||||
async fn health(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let healthy_count = nodes.values().filter(|n| n.healthy).count();
|
||||
let total_count = nodes.len();
|
||||
|
||||
Json(json!({
|
||||
"status": if healthy_count > 0 { "ok" } else { "degraded" },
|
||||
"nodes": {
|
||||
"healthy": healthy_count,
|
||||
"total": total_count,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
fn extract_model(body: &[u8]) -> Option<String> {
|
||||
let v: Value = serde_json::from_slice(body).ok()?;
|
||||
v.get("model")?.as_str().map(|s| s.to_string())
|
||||
}
|
||||
|
||||
fn error_response(status: u16, message: &str) -> Response {
|
||||
let code = axum::http::StatusCode::from_u16(status)
|
||||
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let body = json!({
|
||||
"error": {
|
||||
"message": message,
|
||||
"type": "gateway_error",
|
||||
}
|
||||
});
|
||||
(code, Json(body)).into_response()
|
||||
}
|
||||
51
crates/cortex-gateway/src/lib.rs
Normal file
51
crates/cortex-gateway/src/lib.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
pub mod evictor;
|
||||
pub mod handlers;
|
||||
pub mod metrics;
|
||||
pub mod poller;
|
||||
pub mod proxy;
|
||||
pub mod router;
|
||||
pub mod state;
|
||||
|
||||
use anyhow::Result;
|
||||
use axum::Router;
|
||||
use cortex_core::config::GatewayConfig;
|
||||
use std::sync::Arc;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
/// Build the Axum application router with all routes wired up.
|
||||
pub fn build_app(fleet: Arc<state::CortexState>) -> Router {
|
||||
Router::new()
|
||||
.merge(handlers::api_routes())
|
||||
.layer(CorsLayer::permissive())
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.with_state(fleet)
|
||||
}
|
||||
|
||||
/// Start the gateway: build state from config, spawn background tasks,
|
||||
/// bind the HTTP server.
|
||||
pub async fn run(config: GatewayConfig) -> Result<()> {
|
||||
let fleet = Arc::new(state::CortexState::from_config(&config));
|
||||
|
||||
// Spawn the background poller that refreshes node/model status.
|
||||
let poller_fleet = Arc::clone(&fleet);
|
||||
tokio::spawn(async move {
|
||||
poller::poll_loop(poller_fleet).await;
|
||||
});
|
||||
|
||||
// Spawn the evictor (reacts to VRAM pressure events from the router).
|
||||
let evictor_fleet = Arc::clone(&fleet);
|
||||
tokio::spawn(async move {
|
||||
evictor::eviction_loop(evictor_fleet).await;
|
||||
});
|
||||
|
||||
let app = build_app(Arc::clone(&fleet));
|
||||
|
||||
let listen_addr = config.gateway.listen.parse::<std::net::SocketAddr>()?;
|
||||
tracing::info!("cortex listening on {listen_addr}");
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(listen_addr).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
55
crates/cortex-gateway/src/metrics.rs
Normal file
55
crates/cortex-gateway/src/metrics.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
//! Prometheus metrics exporter.
|
||||
//!
|
||||
//! Runs on a separate port from the main API, exposing `/metrics`
|
||||
//! in Prometheus text format.
|
||||
|
||||
use anyhow::Result;
|
||||
use metrics_exporter_prometheus::PrometheusBuilder;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
/// Install the Prometheus metrics recorder and return a handle.
|
||||
/// The `/metrics` endpoint is served by the exporter's built-in HTTP server.
|
||||
pub fn install(listen: &str) -> Result<()> {
|
||||
let addr: SocketAddr = listen.parse()?;
|
||||
|
||||
PrometheusBuilder::new()
|
||||
.with_http_listener(addr)
|
||||
.install()
|
||||
.map_err(|e| anyhow::anyhow!("failed to install Prometheus exporter: {e}"))?;
|
||||
|
||||
tracing::info!("prometheus metrics exporter on {addr}");
|
||||
|
||||
// Register histograms and counters used by the proxy layer.
|
||||
// The `metrics` crate lazily creates metrics on first use, but
|
||||
// describing them up front gives Prometheus proper HELP/TYPE lines.
|
||||
metrics::describe_histogram!(
|
||||
"cortex_request_duration_seconds",
|
||||
"Total request latency in seconds"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"cortex_time_to_first_token_seconds",
|
||||
"Time to first token in seconds"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"cortex_tokens_per_second",
|
||||
"Generation throughput in tokens per second"
|
||||
);
|
||||
metrics::describe_counter!(
|
||||
"cortex_requests_total",
|
||||
"Total number of proxied requests"
|
||||
);
|
||||
metrics::describe_counter!(
|
||||
"cortex_request_errors_total",
|
||||
"Total number of failed proxy requests"
|
||||
);
|
||||
metrics::describe_counter!(
|
||||
"cortex_evictions_total",
|
||||
"Total number of model evictions"
|
||||
);
|
||||
metrics::describe_counter!(
|
||||
"cortex_cold_starts_total",
|
||||
"Total number of cold-start model loads"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
103
crates/cortex-gateway/src/poller.rs
Normal file
103
crates/cortex-gateway/src/poller.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
//! Background poller that periodically queries each node's `/v1/models`
|
||||
//! endpoint to refresh the fleet state.
|
||||
|
||||
use crate::state::CortexState;
|
||||
use chrono::Utc;
|
||||
use cortex_core::node::{MistralModelsResponse, ModelEntry, ModelStatus};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
const POLL_INTERVAL: Duration = Duration::from_secs(10);
|
||||
|
||||
/// Runs forever, polling all nodes on a fixed interval.
|
||||
pub async fn poll_loop(fleet: Arc<CortexState>) {
|
||||
loop {
|
||||
for nc in &fleet.node_configs {
|
||||
poll_node(&fleet, &nc.name, &nc.endpoint).await;
|
||||
}
|
||||
tokio::time::sleep(POLL_INTERVAL).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
let url = format!("{endpoint}/v1/models");
|
||||
|
||||
let result = fleet
|
||||
.http_client
|
||||
.get(&url)
|
||||
.timeout(Duration::from_secs(5))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
let Some(node) = nodes.get_mut(name) else {
|
||||
return;
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
match resp.json::<MistralModelsResponse>().await {
|
||||
Ok(models_resp) => {
|
||||
// Merge upstream model list into our state, preserving
|
||||
// our local metadata (last_accessed, vram_estimate).
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
for upstream in &models_resp.data {
|
||||
seen.insert(upstream.id.clone());
|
||||
let status = parse_status(upstream.status.as_deref());
|
||||
|
||||
node.models
|
||||
.entry(upstream.id.clone())
|
||||
.and_modify(|e| {
|
||||
e.status = status;
|
||||
})
|
||||
.or_insert_with(|| ModelEntry {
|
||||
id: upstream.id.clone(),
|
||||
status,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Remove models that are no longer reported by the node
|
||||
// (e.g. after a config change / restart).
|
||||
node.models.retain(|id, _| seen.contains(id));
|
||||
|
||||
node.healthy = true;
|
||||
node.last_poll = Some(Utc::now());
|
||||
tracing::debug!(
|
||||
node = name,
|
||||
models = models_resp.data.len(),
|
||||
"poll ok"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(node = name, error = %e, "failed to parse /v1/models response");
|
||||
node.healthy = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(resp) => {
|
||||
tracing::warn!(
|
||||
node = name,
|
||||
status = %resp.status(),
|
||||
"node returned non-success status"
|
||||
);
|
||||
node.healthy = false;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(node = name, error = %e, "failed to reach node");
|
||||
node.healthy = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_status(s: Option<&str>) -> ModelStatus {
|
||||
match s {
|
||||
Some("loaded") => ModelStatus::Loaded,
|
||||
Some("unloaded") => ModelStatus::Unloaded,
|
||||
Some("reloading") => ModelStatus::Reloading,
|
||||
// If the status field is absent, assume loaded (older mistral.rs versions
|
||||
// may not include it).
|
||||
_ => ModelStatus::Loaded,
|
||||
}
|
||||
}
|
||||
82
crates/cortex-gateway/src/proxy.rs
Normal file
82
crates/cortex-gateway/src/proxy.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
//! Streaming HTTP reverse proxy to mistral.rs backends.
|
||||
//!
|
||||
//! For streaming requests, SSE chunks are forwarded as they arrive.
|
||||
//! The proxy captures timing information for metrics but does not
|
||||
//! buffer the full response.
|
||||
|
||||
use crate::router::RouteDecision;
|
||||
use anyhow::Result;
|
||||
use axum::body::Body;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use reqwest::Client;
|
||||
|
||||
/// Proxy a request body to the resolved backend node and stream the response.
|
||||
pub async fn forward_request(
|
||||
client: &Client,
|
||||
route: &RouteDecision,
|
||||
path: &str,
|
||||
headers: HeaderMap,
|
||||
body: bytes::Bytes,
|
||||
) -> Result<Response, ProxyError> {
|
||||
let url = format!("{}{}", route.endpoint, path);
|
||||
tracing::info!(
|
||||
node = %route.node_name,
|
||||
url = %url,
|
||||
cold_start = route.cold_start,
|
||||
"proxying request"
|
||||
);
|
||||
|
||||
let mut req_builder = client.post(&url).body(body);
|
||||
|
||||
// Forward relevant headers.
|
||||
for (key, value) in headers.iter() {
|
||||
if key == "host" || key == "content-length" {
|
||||
continue; // reqwest sets these
|
||||
}
|
||||
req_builder = req_builder.header(key, value);
|
||||
}
|
||||
|
||||
let upstream_resp = req_builder.send().await.map_err(ProxyError::Upstream)?;
|
||||
|
||||
let status = StatusCode::from_u16(upstream_resp.status().as_u16())
|
||||
.unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
|
||||
let resp_headers = upstream_resp.headers().clone();
|
||||
let stream = upstream_resp.bytes_stream();
|
||||
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
let mut response = Response::builder().status(status);
|
||||
for (key, value) in resp_headers.iter() {
|
||||
response = response.header(key, value);
|
||||
}
|
||||
|
||||
response
|
||||
.body(body)
|
||||
.map_err(|e| ProxyError::ResponseBuild(e.to_string()))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ProxyError {
|
||||
#[error("upstream request failed: {0}")]
|
||||
Upstream(reqwest::Error),
|
||||
#[error("failed to build response: {0}")]
|
||||
ResponseBuild(String),
|
||||
}
|
||||
|
||||
impl IntoResponse for ProxyError {
|
||||
fn into_response(self) -> Response {
|
||||
let status = match &self {
|
||||
ProxyError::Upstream(_) => StatusCode::BAD_GATEWAY,
|
||||
ProxyError::ResponseBuild(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
let body = serde_json::json!({
|
||||
"error": {
|
||||
"message": self.to_string(),
|
||||
"type": "proxy_error",
|
||||
}
|
||||
});
|
||||
(status, axum::Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
74
crates/cortex-gateway/src/router.rs
Normal file
74
crates/cortex-gateway/src/router.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
//! Model-to-node routing logic.
|
||||
//!
|
||||
//! Given a model ID from an inbound request, determine which node should
|
||||
//! handle it. Priority:
|
||||
//! 1. Node where the model is currently `Loaded`
|
||||
//! 2. Node where the model is `Unloaded` (will lazy-load on request)
|
||||
//! 3. Error: model not found on any node
|
||||
|
||||
use crate::state::CortexState;
|
||||
use cortex_core::node::ModelStatus;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// The routing decision: which node endpoint to proxy the request to.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RouteDecision {
|
||||
pub node_name: String,
|
||||
pub endpoint: String,
|
||||
/// Whether the model will need to load (cold start).
|
||||
pub cold_start: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RouteError {
|
||||
#[error("model '{0}' not found on any node")]
|
||||
ModelNotFound(String),
|
||||
#[error("no healthy nodes available")]
|
||||
NoHealthyNodes,
|
||||
}
|
||||
|
||||
/// Resolve which node should serve a request for the given model.
|
||||
pub async fn resolve(fleet: &Arc<CortexState>, model_id: &str) -> Result<RouteDecision, RouteError> {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
|
||||
// Pass 1: find a node where the model is already loaded.
|
||||
let mut loaded_candidate = None;
|
||||
let mut unloaded_candidate = None;
|
||||
|
||||
for node in nodes.values() {
|
||||
if !node.healthy {
|
||||
continue;
|
||||
}
|
||||
if let Some(entry) = node.models.get(model_id) {
|
||||
match entry.status {
|
||||
ModelStatus::Loaded | ModelStatus::Reloading => {
|
||||
loaded_candidate = Some(RouteDecision {
|
||||
node_name: node.name.clone(),
|
||||
endpoint: node.endpoint.clone(),
|
||||
cold_start: false,
|
||||
});
|
||||
break; // loaded is best, stop searching
|
||||
}
|
||||
ModelStatus::Unloaded => {
|
||||
if unloaded_candidate.is_none() {
|
||||
unloaded_candidate = Some(RouteDecision {
|
||||
node_name: node.name.clone(),
|
||||
endpoint: node.endpoint.clone(),
|
||||
cold_start: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loaded_candidate
|
||||
.or(unloaded_candidate)
|
||||
.ok_or_else(|| {
|
||||
if nodes.values().any(|n| n.healthy) {
|
||||
RouteError::ModelNotFound(model_id.to_string())
|
||||
} else {
|
||||
RouteError::NoHealthyNodes
|
||||
}
|
||||
})
|
||||
}
|
||||
43
crates/cortex-gateway/src/state.rs
Normal file
43
crates/cortex-gateway/src/state.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use cortex_core::config::{EvictionSettings, GatewayConfig, NodeConfig};
|
||||
use cortex_core::node::NodeState;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Shared fleet state, protected by a RwLock for concurrent reader access.
|
||||
pub struct CortexState {
|
||||
pub nodes: RwLock<HashMap<String, NodeState>>,
|
||||
pub node_configs: Vec<NodeConfig>,
|
||||
pub eviction: EvictionSettings,
|
||||
pub http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl CortexState {
|
||||
pub fn from_config(config: &GatewayConfig) -> Self {
|
||||
let mut nodes = HashMap::new();
|
||||
for nc in &config.nodes {
|
||||
nodes.insert(
|
||||
nc.name.clone(),
|
||||
NodeState {
|
||||
name: nc.name.clone(),
|
||||
endpoint: nc.endpoint.clone(),
|
||||
vram_mb: nc.vram_mb,
|
||||
pinned: nc.pinned.clone(),
|
||||
healthy: false, // will be set by first poll
|
||||
models: HashMap::new(),
|
||||
lifecycle_cycles: 0,
|
||||
last_poll: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Self {
|
||||
nodes: RwLock::new(nodes),
|
||||
node_configs: config.nodes.clone(),
|
||||
eviction: config.eviction.clone(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.build()
|
||||
.expect("failed to build HTTP client"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user