Compare commits
5 Commits
feat/47-ph
...
feat/47-ph
| Author | SHA1 | Date | |
|---|---|---|---|
|
b2bd86bfa5
|
|||
|
cdf87284af
|
|||
|
4f16b8c541
|
|||
|
486d7e9a8f
|
|||
|
bc74e0e95f
|
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -793,6 +793,7 @@ name = "cortex-gateway"
|
|||||||
version = "0.1.16"
|
version = "0.1.16"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"async-trait",
|
||||||
"axum",
|
"axum",
|
||||||
"bytes",
|
"bytes",
|
||||||
"chrono",
|
"chrono",
|
||||||
|
|||||||
@@ -48,3 +48,45 @@ vram_mb = 12288 # e.g. RTX 3060 (12 GB)
|
|||||||
pinned = [
|
pinned = [
|
||||||
"your-org/embedding-model",
|
"your-org/embedding-model",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# -- Entitlements (multi-tenant governance, #47) -------------------------
|
||||||
|
# Identity + per-key token budgets. Omit this section entirely for the
|
||||||
|
# legacy single-operator behaviour: requests are anonymous and uncapped.
|
||||||
|
#
|
||||||
|
# The local/static provider below is the source of truth for accounts,
|
||||||
|
# keys, and hard caps until the upstream clearing house exists. Identity
|
||||||
|
# rides standard bearer auth only — clients send
|
||||||
|
# Authorization: Bearer <key>
|
||||||
|
# no custom headers or body fields.
|
||||||
|
|
||||||
|
[entitlements]
|
||||||
|
# Reject unauthenticated requests with 401 invalid_api_key. Leave false
|
||||||
|
# (allow-anonymous) during rollout; flip to true once keys are issued.
|
||||||
|
require_auth = false
|
||||||
|
|
||||||
|
# One entry per API key.
|
||||||
|
[[entitlements.keys]]
|
||||||
|
key = "sk-example-rolling" # the bearer token the client sends
|
||||||
|
account_id = "team-research" # billable account (keys may share one)
|
||||||
|
key_id = "research-ci" # stable label for ledger/metrics (optional)
|
||||||
|
hard_cap = 5_000_000 # hard token cap over the window
|
||||||
|
# Rolling window that resets — over-cap requests get 429 rate_limit_exceeded
|
||||||
|
# + Retry-After, so well-behaved clients (opencode/AI SDK) back off and retry.
|
||||||
|
window = { kind = "rolling", seconds = 3600 }
|
||||||
|
|
||||||
|
[[entitlements.keys]]
|
||||||
|
key = "sk-example-balance"
|
||||||
|
account_id = "team-research"
|
||||||
|
key_id = "research-prepaid"
|
||||||
|
hard_cap = 20_000_000
|
||||||
|
# Hard balance, no reset — exhaustion returns 429 insufficient_quota
|
||||||
|
# (the client surfaces and stops). This is the default when `window` is
|
||||||
|
# omitted. Never 402.
|
||||||
|
window = { kind = "balance" }
|
||||||
|
|
||||||
|
[[entitlements.keys]]
|
||||||
|
key = "sk-example-infra"
|
||||||
|
account_id = "operator"
|
||||||
|
key_id = "infra"
|
||||||
|
# No hard_cap → uncapped operator infra key (own fleet, own use). Still
|
||||||
|
# metered for visibility.
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use crate::entitlements::CapWindow;
|
||||||
use figment::{
|
use figment::{
|
||||||
Figment,
|
Figment,
|
||||||
providers::{Env, Format, Toml},
|
providers::{Env, Format, Toml},
|
||||||
@@ -16,6 +17,46 @@ pub struct GatewayConfig {
|
|||||||
/// non-packaged / local runs.
|
/// non-packaged / local runs.
|
||||||
#[serde(default = "default_models_path")]
|
#[serde(default = "default_models_path")]
|
||||||
pub models_config: String,
|
pub models_config: String,
|
||||||
|
/// Multi-tenant governance: auth + per-key token budgets (#47). Empty
|
||||||
|
/// by default — anonymous, uncapped — so existing single-operator
|
||||||
|
/// setups keep working until keys are configured.
|
||||||
|
#[serde(default)]
|
||||||
|
pub entitlements: EntitlementsConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `[entitlements]` — the local/static [`crate::entitlements::EntitlementProvider`]
|
||||||
|
/// source of truth (#50). Accounts, keys, and hard caps live here; the
|
||||||
|
/// future upstream client (#57) ignores this section.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
|
pub struct EntitlementsConfig {
|
||||||
|
/// Reject unauthenticated requests with `401 invalid_api_key` when
|
||||||
|
/// true. Default `false` (allow-anonymous) for dev / single-operator
|
||||||
|
/// continuity.
|
||||||
|
#[serde(default)]
|
||||||
|
pub require_auth: bool,
|
||||||
|
/// Static API keys and their budgets, consumed by the local provider.
|
||||||
|
#[serde(default)]
|
||||||
|
pub keys: Vec<ApiKeyConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One configured API key: the bearer token, the account it bills to, and
|
||||||
|
/// its hard cap. `[[entitlements.keys]]` in TOML.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ApiKeyConfig {
|
||||||
|
/// The bearer token clients send in `Authorization: Bearer <key>`.
|
||||||
|
pub key: String,
|
||||||
|
/// Billable account. Multiple keys may share one account.
|
||||||
|
pub account_id: String,
|
||||||
|
/// Stable per-key identifier for ledger/metrics labels. Defaults to
|
||||||
|
/// `account_id` when omitted, so the secret is never used as a label.
|
||||||
|
#[serde(default)]
|
||||||
|
pub key_id: Option<String>,
|
||||||
|
/// Hard token cap. `None`/omitted = uncapped (e.g. operator infra key).
|
||||||
|
#[serde(default)]
|
||||||
|
pub hard_cap: Option<u64>,
|
||||||
|
/// Cap-window semantics. Default: a non-resetting [`CapWindow::Balance`].
|
||||||
|
#[serde(default)]
|
||||||
|
pub window: CapWindow,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_models_path() -> String {
|
fn default_models_path() -> String {
|
||||||
@@ -87,6 +128,7 @@ impl Default for GatewayConfig {
|
|||||||
},
|
},
|
||||||
neurons: vec![],
|
neurons: vec![],
|
||||||
models_config: default_models_path(),
|
models_config: default_models_path(),
|
||||||
|
entitlements: EntitlementsConfig::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
145
crates/cortex-core/src/entitlements.rs
Normal file
145
crates/cortex-core/src/entitlements.rs
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
//! Identity and entitlement primitives for multi-tenant governance (#47).
|
||||||
|
//!
|
||||||
|
//! Identity is the shared substrate the whole epic hangs off:
|
||||||
|
//! `identity (principal) → accounting (spend) → policy → enforcement`. This
|
||||||
|
//! module defines the seam — the [`EntitlementProvider`] trait and its data
|
||||||
|
//! types — so the local/static provider (operator-config caps, in
|
||||||
|
//! cortex-gateway) can land the auth + per-key-cap + amplification fix
|
||||||
|
//! *before* any upstream clearing house exists. The future helexa-upstream
|
||||||
|
//! client (#57) is just another impl of this trait.
|
||||||
|
//!
|
||||||
|
//! The provider owns three jobs:
|
||||||
|
//! 1. **resolve** a bearer key to a [`Principal`] (drives auth, #49);
|
||||||
|
//! 2. **reserve → settle/release** token budget around a request so spend
|
||||||
|
//! can never overshoot a hard cap under concurrency (drives budget
|
||||||
|
//! enforcement, #52);
|
||||||
|
//! 3. expose a [`BudgetSnapshot`] for metering/metrics (#51).
|
||||||
|
//!
|
||||||
|
//! [`BudgetError`] carries the cap-window semantics so the caller can pick
|
||||||
|
//! the correct #63 rejection (`rate_limit_exceeded` + `Retry-After` for a
|
||||||
|
//! resetting window vs `insufficient_quota` for a hard balance) without the
|
||||||
|
//! provider knowing anything about HTTP.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Internal header carrying the resolved account id from cortex to neuron.
|
||||||
|
/// neuron trusts these over the WireGuard link (#54); cortex **strips** any
|
||||||
|
/// client-supplied copy before stamping the authoritative value, so a client
|
||||||
|
/// can never assert a principal directly.
|
||||||
|
pub const HEADER_ACCOUNT_ID: &str = "x-helexa-account-id";
|
||||||
|
/// Internal header carrying the resolved key id from cortex to neuron.
|
||||||
|
pub const HEADER_KEY_ID: &str = "x-helexa-key-id";
|
||||||
|
|
||||||
|
/// Who a request is for. Resolved once at the edge from the bearer key and
|
||||||
|
/// carried through the request context. `account_id` is the billable owner
|
||||||
|
/// (spendable at any operator, by decision); `key_id` identifies the
|
||||||
|
/// specific API key for per-key hard caps and ledger/metrics labels.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub struct Principal {
|
||||||
|
pub account_id: String,
|
||||||
|
pub key_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cap-window semantics for a key's hard cap. Determines which #63 code an
|
||||||
|
/// over-cap reservation maps to.
|
||||||
|
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "kind", rename_all = "snake_case")]
|
||||||
|
pub enum CapWindow {
|
||||||
|
/// Hard balance — the cap never resets. Exhaustion is permanent
|
||||||
|
/// (`429 insufficient_quota`, no `Retry-After`).
|
||||||
|
#[default]
|
||||||
|
Balance,
|
||||||
|
/// Rolling window of `seconds` that resets. Exhaustion is transient
|
||||||
|
/// (`429 rate_limit_exceeded` + `Retry-After` until reset).
|
||||||
|
Rolling { seconds: u64 },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An outstanding budget reservation. The caller holds this opaque handle
|
||||||
|
/// between [`EntitlementProvider::reserve`] and exactly one of
|
||||||
|
/// [`EntitlementProvider::settle`] / [`EntitlementProvider::release`]. Not
|
||||||
|
/// `Clone` — a reservation is consumed once.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Reservation {
|
||||||
|
/// Provider-local handle; opaque to the caller.
|
||||||
|
pub id: u64,
|
||||||
|
/// The principal this reservation belongs to.
|
||||||
|
pub principal: Principal,
|
||||||
|
/// Tokens reserved against the cap.
|
||||||
|
pub reserved: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A point-in-time view of a key's budget, for metering and metrics (#51).
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct BudgetSnapshot {
|
||||||
|
/// Hard cap in tokens. `None` means uncapped (e.g. an operator infra
|
||||||
|
/// key, #58).
|
||||||
|
pub hard_cap: Option<u64>,
|
||||||
|
/// Settled spend in the current window.
|
||||||
|
pub spent: u64,
|
||||||
|
/// Sum of outstanding (un-settled) reservations.
|
||||||
|
pub reserved: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Authentication failure — the bearer key could not be resolved. Maps to
|
||||||
|
/// `401 invalid_api_key` (#49/#63).
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum AuthError {
|
||||||
|
#[error("invalid or unknown API key")]
|
||||||
|
InvalidKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Why a reservation was refused. Carries enough for the caller to build the
|
||||||
|
/// correct #63 envelope without the provider touching HTTP.
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum BudgetError {
|
||||||
|
/// A resetting window is exhausted → `429 rate_limit_exceeded` +
|
||||||
|
/// `Retry-After: retry_after_secs`.
|
||||||
|
#[error(
|
||||||
|
"rolling-window budget exhausted ({requested} requested, {available} available); \
|
||||||
|
resets in {retry_after_secs}s"
|
||||||
|
)]
|
||||||
|
RateLimited {
|
||||||
|
requested: u64,
|
||||||
|
available: u64,
|
||||||
|
retry_after_secs: u64,
|
||||||
|
},
|
||||||
|
/// A hard balance is exhausted → `429 insufficient_quota` (no
|
||||||
|
/// `Retry-After`; the client surfaces and stops). Never `402`.
|
||||||
|
#[error("hard balance exhausted ({requested} requested, {available} available)")]
|
||||||
|
InsufficientQuota { requested: u64, available: u64 },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The seam between cortex's enforcement and whatever decides entitlement —
|
||||||
|
/// a local/static config provider today (#50), the helexa-upstream client
|
||||||
|
/// later (#57). All methods are async so the upstream impl can do network
|
||||||
|
/// I/O; the local impl resolves in-process.
|
||||||
|
#[async_trait]
|
||||||
|
pub trait EntitlementProvider: Send + Sync {
|
||||||
|
/// Resolve a bearer API key to its principal. `Err(InvalidKey)` for an
|
||||||
|
/// unknown/empty key.
|
||||||
|
async fn resolve(&self, api_key: &str) -> Result<Principal, AuthError>;
|
||||||
|
|
||||||
|
/// Reserve up to `max_tokens` against the principal's cap. Returns a
|
||||||
|
/// handle on success, or a [`BudgetError`] (which the caller maps to a
|
||||||
|
/// #63 `429`) if the reservation would exceed the cap. Reserving the
|
||||||
|
/// *maximum* a request could consume before dispatch is what prevents
|
||||||
|
/// overshoot under concurrency.
|
||||||
|
async fn reserve(
|
||||||
|
&self,
|
||||||
|
principal: &Principal,
|
||||||
|
max_tokens: u64,
|
||||||
|
) -> Result<Reservation, BudgetError>;
|
||||||
|
|
||||||
|
/// Settle a reservation with the tokens actually consumed, releasing the
|
||||||
|
/// unused remainder back to the cap.
|
||||||
|
async fn settle(&self, reservation: Reservation, actual_tokens: u64);
|
||||||
|
|
||||||
|
/// Release a reservation in full — e.g. dispatch failed before any
|
||||||
|
/// tokens were consumed.
|
||||||
|
async fn release(&self, reservation: Reservation);
|
||||||
|
|
||||||
|
/// Current budget snapshot for a principal, for metering/metrics.
|
||||||
|
/// `None` if the provider doesn't track this principal.
|
||||||
|
async fn snapshot(&self, principal: &Principal) -> Option<BudgetSnapshot>;
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ pub mod build_info;
|
|||||||
pub mod catalogue;
|
pub mod catalogue;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod discovery;
|
pub mod discovery;
|
||||||
|
pub mod entitlements;
|
||||||
pub mod error_envelope;
|
pub mod error_envelope;
|
||||||
pub mod harness;
|
pub mod harness;
|
||||||
pub mod metrics;
|
pub mod metrics;
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ license.workspace = true
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
cortex-core.workspace = true
|
cortex-core.workspace = true
|
||||||
|
async-trait.workspace = true
|
||||||
tokio.workspace = true
|
tokio.workspace = true
|
||||||
axum.workspace = true
|
axum.workspace = true
|
||||||
tower.workspace = true
|
tower.workspace = true
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ pub async fn stream_translated(
|
|||||||
openai_body: axum::body::Bytes,
|
openai_body: axum::body::Bytes,
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
node_name: &str,
|
node_name: &str,
|
||||||
|
inbound_headers: &axum::http::HeaderMap,
|
||||||
|
usage_sink: Option<crate::metering::UsageSink>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let url = format!("{endpoint}/v1/chat/completions");
|
let url = format!("{endpoint}/v1/chat/completions");
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
@@ -42,13 +44,14 @@ pub async fn stream_translated(
|
|||||||
"proxying streaming request (anthropic SSE translation)"
|
"proxying streaming request (anthropic SSE translation)"
|
||||||
);
|
);
|
||||||
|
|
||||||
let upstream = match client
|
let request = crate::auth::forward_principal_headers(
|
||||||
.post(&url)
|
client
|
||||||
.header("content-type", "application/json")
|
.post(&url)
|
||||||
.body(openai_body)
|
.header("content-type", "application/json")
|
||||||
.send()
|
.body(openai_body),
|
||||||
.await
|
inbound_headers,
|
||||||
{
|
);
|
||||||
|
let upstream = match request.send().await {
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
@@ -94,6 +97,10 @@ pub async fn stream_translated(
|
|||||||
let mut saw_tool_call = false;
|
let mut saw_tool_call = false;
|
||||||
let mut last_finish: Option<String> = None;
|
let mut last_finish: Option<String> = None;
|
||||||
let mut frames = 0u64;
|
let mut frames = 0u64;
|
||||||
|
// Engine-truth usage for metering (#51), scanned from the upstream
|
||||||
|
// frames (neuron emits a final `usage` object on the stream, #48).
|
||||||
|
let mut usage_prompt = 0u64;
|
||||||
|
let mut usage_completion = 0u64;
|
||||||
|
|
||||||
'outer: while let Some(block) = upstream.next().await {
|
'outer: while let Some(block) = upstream.next().await {
|
||||||
let block = match block {
|
let block = match block {
|
||||||
@@ -121,6 +128,15 @@ pub async fn stream_translated(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
tracing::trace!(node = %node, frame = %data, "anthropic stream: upstream frame");
|
tracing::trace!(node = %node, frame = %data, "anthropic stream: upstream frame");
|
||||||
|
// Capture usage for metering before translation — the
|
||||||
|
// usage object rides on a late frame (often after the
|
||||||
|
// last content delta).
|
||||||
|
if let Some(p) = crate::proxy::last_count_for(data, "prompt_tokens") {
|
||||||
|
usage_prompt = p;
|
||||||
|
}
|
||||||
|
if let Some(c) = crate::proxy::last_count_for(data, "completion_tokens") {
|
||||||
|
usage_completion = c;
|
||||||
|
}
|
||||||
let Ok(chunk) = serde_json::from_str::<ChatCompletionChunk>(data) else {
|
let Ok(chunk) = serde_json::from_str::<ChatCompletionChunk>(data) else {
|
||||||
tracing::debug!(node = %node, "anthropic stream: unparsable upstream frame skipped");
|
tracing::debug!(node = %node, "anthropic stream: unparsable upstream frame skipped");
|
||||||
continue;
|
continue;
|
||||||
@@ -162,6 +178,14 @@ pub async fn stream_translated(
|
|||||||
terminated = done,
|
terminated = done,
|
||||||
"anthropic stream complete"
|
"anthropic stream complete"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Settle metering with the observed usage (#51). Runs on every exit
|
||||||
|
// path of the pump — clean end, early break, or upstream error — so
|
||||||
|
// the reservation is always resolved. `(0, 0)` when no usage frame
|
||||||
|
// was seen, which releases without recording spend.
|
||||||
|
if let Some(sink) = usage_sink {
|
||||||
|
sink(usage_prompt, usage_completion);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
Response::builder()
|
Response::builder()
|
||||||
|
|||||||
119
crates/cortex-gateway/src/auth.rs
Normal file
119
crates/cortex-gateway/src/auth.rs
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
//! API-key authentication + principal resolution (#49).
|
||||||
|
//!
|
||||||
|
//! Identity rides standard bearer auth only — `Authorization: Bearer <key>`
|
||||||
|
//! — which is what keeps every tier OpenAI-compatible by construction (no
|
||||||
|
//! custom required headers or body fields, per #47). The middleware resolves
|
||||||
|
//! the key to a [`Principal`] via the [`EntitlementProvider`], carries it in
|
||||||
|
//! the request extensions for cortex-side metering/enforcement (#51/#52), and
|
||||||
|
//! stamps it as internal headers on the request so it reaches neuron, which
|
||||||
|
//! trusts cortex's assertion over WireGuard (#54).
|
||||||
|
//!
|
||||||
|
//! Anti-spoofing: any client-supplied principal header is **stripped** before
|
||||||
|
//! the authoritative value is stamped, so a client can never assert a
|
||||||
|
//! principal it didn't authenticate as.
|
||||||
|
//!
|
||||||
|
//! Rejection contract (#63): missing key under `require_auth`, or any present
|
||||||
|
//! but unresolvable key, yields `401 invalid_api_key` in the #60 envelope.
|
||||||
|
|
||||||
|
use crate::error::envelope_response;
|
||||||
|
use crate::state::CortexState;
|
||||||
|
use axum::extract::{Request, State};
|
||||||
|
use axum::http::header::AUTHORIZATION;
|
||||||
|
use axum::http::{HeaderMap, HeaderValue};
|
||||||
|
use axum::middleware::Next;
|
||||||
|
use axum::response::Response;
|
||||||
|
use cortex_core::entitlements::{HEADER_ACCOUNT_ID, HEADER_KEY_ID};
|
||||||
|
use cortex_core::error_envelope::OpenAiError;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Endpoints that never require auth: liveness/readiness probes. Everything
|
||||||
|
/// else flows through resolution.
|
||||||
|
fn is_public(path: &str) -> bool {
|
||||||
|
path == "/health" || path == "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the bearer token from an `Authorization` header value, if present
|
||||||
|
/// and well-formed. Scheme match is case-insensitive per RFC 7235.
|
||||||
|
fn parse_bearer(headers: &HeaderMap) -> Option<String> {
|
||||||
|
let raw = headers.get(AUTHORIZATION)?.to_str().ok()?;
|
||||||
|
let (scheme, token) = raw.split_once(' ')?;
|
||||||
|
if scheme.eq_ignore_ascii_case("bearer") {
|
||||||
|
let token = token.trim();
|
||||||
|
(!token.is_empty()).then(|| token.to_string())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Axum middleware: resolve the bearer key, attach the principal, stamp the
|
||||||
|
/// internal headers. Wired in `build_app` via `from_fn_with_state`.
|
||||||
|
pub async fn require_principal(
|
||||||
|
State(fleet): State<Arc<CortexState>>,
|
||||||
|
mut req: Request,
|
||||||
|
next: Next,
|
||||||
|
) -> Response {
|
||||||
|
if is_public(req.uri().path()) {
|
||||||
|
return next.run(req).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Anti-spoof: drop any client-supplied principal headers up front.
|
||||||
|
{
|
||||||
|
let headers = req.headers_mut();
|
||||||
|
headers.remove(HEADER_ACCOUNT_ID);
|
||||||
|
headers.remove(HEADER_KEY_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
match parse_bearer(req.headers()) {
|
||||||
|
Some(key) => match fleet.entitlements.resolve(&key).await {
|
||||||
|
Ok(principal) => {
|
||||||
|
// Stamp the authoritative principal for neuron. Account/key
|
||||||
|
// ids come from operator config, so they're valid header
|
||||||
|
// values; guard anyway and skip a malformed one rather than
|
||||||
|
// panic.
|
||||||
|
if let (Ok(account), Ok(key_id)) = (
|
||||||
|
HeaderValue::from_str(&principal.account_id),
|
||||||
|
HeaderValue::from_str(&principal.key_id),
|
||||||
|
) {
|
||||||
|
let headers = req.headers_mut();
|
||||||
|
headers.insert(HEADER_ACCOUNT_ID, account);
|
||||||
|
headers.insert(HEADER_KEY_ID, key_id);
|
||||||
|
}
|
||||||
|
// Carry the typed principal for cortex-side metering (#51)
|
||||||
|
// and budget enforcement (#52).
|
||||||
|
req.extensions_mut().insert(principal);
|
||||||
|
next.run(req).await
|
||||||
|
}
|
||||||
|
// A present-but-invalid credential is always an error, even when
|
||||||
|
// anonymous access is otherwise allowed.
|
||||||
|
Err(_) => unauthorized("invalid API key"),
|
||||||
|
},
|
||||||
|
None => {
|
||||||
|
if fleet.require_auth {
|
||||||
|
unauthorized("missing API key; supply 'Authorization: Bearer <key>'")
|
||||||
|
} else {
|
||||||
|
next.run(req).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `401 invalid_api_key` in the standard envelope (#63).
|
||||||
|
fn unauthorized(message: &str) -> Response {
|
||||||
|
envelope_response(OpenAiError::invalid_api_key(message))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copy the cortex-stamped principal headers from an inbound [`HeaderMap`]
|
||||||
|
/// onto an outbound reqwest builder. Used by the Anthropic proxy paths,
|
||||||
|
/// which construct their own upstream requests instead of going through
|
||||||
|
/// [`crate::proxy::forward_request`] (which forwards all headers verbatim).
|
||||||
|
pub fn forward_principal_headers(
|
||||||
|
mut builder: reqwest::RequestBuilder,
|
||||||
|
headers: &HeaderMap,
|
||||||
|
) -> reqwest::RequestBuilder {
|
||||||
|
for name in [HEADER_ACCOUNT_ID, HEADER_KEY_ID] {
|
||||||
|
if let Some(value) = headers.get(name) {
|
||||||
|
builder = builder.header(name, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
builder
|
||||||
|
}
|
||||||
317
crates/cortex-gateway/src/entitlements_local.rs
Normal file
317
crates/cortex-gateway/src/entitlements_local.rs
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
//! The local/static [`EntitlementProvider`] (#50).
|
||||||
|
//!
|
||||||
|
//! Accounts, keys, and hard caps come from operator config
|
||||||
|
//! ([`cortex_core::config::EntitlementsConfig`]); reservations and settled
|
||||||
|
//! spend are tracked in-process. This lands auth + per-key caps + the
|
||||||
|
//! amplification fix before any upstream clearing house exists; the future
|
||||||
|
//! helexa-upstream client (#57) implements the same trait.
|
||||||
|
//!
|
||||||
|
//! Budget math is serialized under a single [`std::sync::Mutex`] so
|
||||||
|
//! reserve/settle/release are atomic — a key's `spent + reserved` can never
|
||||||
|
//! exceed its hard cap even under concurrent requests (the #52 guarantee).
|
||||||
|
//! The lock is held only for the in-memory arithmetic, never across an
|
||||||
|
//! await.
|
||||||
|
|
||||||
|
use cortex_core::config::{ApiKeyConfig, EntitlementsConfig};
|
||||||
|
use cortex_core::entitlements::{
|
||||||
|
AuthError, BudgetError, BudgetSnapshot, CapWindow, EntitlementProvider, Principal, Reservation,
|
||||||
|
};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
/// Per-key budget configuration (resolved from [`ApiKeyConfig`]).
|
||||||
|
struct Budget {
|
||||||
|
hard_cap: Option<u64>,
|
||||||
|
window: CapWindow,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Live, mutable accounting for one key over its current window.
|
||||||
|
#[derive(Default)]
|
||||||
|
struct Ledger {
|
||||||
|
/// Settled spend in the current window.
|
||||||
|
spent: u64,
|
||||||
|
/// Sum of outstanding (un-settled) reservations.
|
||||||
|
reserved: u64,
|
||||||
|
/// Start of the current rolling window; `None` until the first reserve.
|
||||||
|
/// Unused for [`CapWindow::Balance`].
|
||||||
|
window_start: Option<Instant>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct LocalEntitlementProvider {
|
||||||
|
/// Bearer token → principal.
|
||||||
|
keys: HashMap<String, Principal>,
|
||||||
|
/// `key_id` → budget config.
|
||||||
|
budgets: HashMap<String, Budget>,
|
||||||
|
/// `key_id` → live ledger.
|
||||||
|
ledgers: Mutex<HashMap<String, Ledger>>,
|
||||||
|
/// Monotonic source of opaque reservation handles.
|
||||||
|
next_id: AtomicU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LocalEntitlementProvider {
|
||||||
|
/// Build from the `[entitlements]` config. A key without an explicit
|
||||||
|
/// `key_id` is tracked at `account_id` granularity (its secret is never
|
||||||
|
/// used as a label).
|
||||||
|
pub fn from_config(config: &EntitlementsConfig) -> Self {
|
||||||
|
let mut keys = HashMap::new();
|
||||||
|
let mut budgets = HashMap::new();
|
||||||
|
for ApiKeyConfig {
|
||||||
|
key,
|
||||||
|
account_id,
|
||||||
|
key_id,
|
||||||
|
hard_cap,
|
||||||
|
window,
|
||||||
|
} in &config.keys
|
||||||
|
{
|
||||||
|
let key_id = key_id.clone().unwrap_or_else(|| account_id.clone());
|
||||||
|
keys.insert(
|
||||||
|
key.clone(),
|
||||||
|
Principal {
|
||||||
|
account_id: account_id.clone(),
|
||||||
|
key_id: key_id.clone(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
budgets.insert(
|
||||||
|
key_id,
|
||||||
|
Budget {
|
||||||
|
hard_cap: *hard_cap,
|
||||||
|
window: window.clone(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Self {
|
||||||
|
keys,
|
||||||
|
budgets,
|
||||||
|
ledgers: Mutex::new(HashMap::new()),
|
||||||
|
next_id: AtomicU64::new(1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tokens still available under `cap` given current `spent`/`reserved`.
|
||||||
|
/// `None` cap = unlimited.
|
||||||
|
fn available(cap: Option<u64>, spent: u64, reserved: u64) -> Option<u64> {
|
||||||
|
cap.map(|c| c.saturating_sub(spent).saturating_sub(reserved))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl EntitlementProvider for LocalEntitlementProvider {
|
||||||
|
async fn resolve(&self, api_key: &str) -> Result<Principal, AuthError> {
|
||||||
|
self.keys.get(api_key).cloned().ok_or(AuthError::InvalidKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn reserve(
|
||||||
|
&self,
|
||||||
|
principal: &Principal,
|
||||||
|
max_tokens: u64,
|
||||||
|
) -> Result<Reservation, BudgetError> {
|
||||||
|
// A principal with no configured budget (or an uncapped one) always
|
||||||
|
// reserves; we still track spend for metrics.
|
||||||
|
let budget = self.budgets.get(&principal.key_id);
|
||||||
|
let (cap, window) = match budget {
|
||||||
|
Some(b) => (b.hard_cap, b.window.clone()),
|
||||||
|
None => (None, CapWindow::Balance),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut ledgers = self.ledgers.lock().expect("ledger mutex poisoned");
|
||||||
|
let ledger = ledgers.entry(principal.key_id.clone()).or_default();
|
||||||
|
|
||||||
|
// Lazily reset a rolling window that has elapsed before checking.
|
||||||
|
let mut retry_after_secs = 0;
|
||||||
|
if let CapWindow::Rolling { seconds } = window {
|
||||||
|
let now = Instant::now();
|
||||||
|
match ledger.window_start {
|
||||||
|
Some(start) if now.duration_since(start).as_secs() < seconds => {
|
||||||
|
retry_after_secs = seconds - now.duration_since(start).as_secs();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// First reserve, or the window has fully elapsed: reset.
|
||||||
|
ledger.spent = 0;
|
||||||
|
ledger.window_start = Some(now);
|
||||||
|
retry_after_secs = seconds;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(avail) = available(cap, ledger.spent, ledger.reserved)
|
||||||
|
&& max_tokens > avail
|
||||||
|
{
|
||||||
|
return Err(match window {
|
||||||
|
CapWindow::Rolling { .. } => BudgetError::RateLimited {
|
||||||
|
requested: max_tokens,
|
||||||
|
available: avail,
|
||||||
|
// At least 1s so clients don't hot-loop on a sub-second
|
||||||
|
// remainder.
|
||||||
|
retry_after_secs: retry_after_secs.max(1),
|
||||||
|
},
|
||||||
|
CapWindow::Balance => BudgetError::InsufficientQuota {
|
||||||
|
requested: max_tokens,
|
||||||
|
available: avail,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
ledger.reserved += max_tokens;
|
||||||
|
Ok(Reservation {
|
||||||
|
id: self.next_id.fetch_add(1, Ordering::Relaxed),
|
||||||
|
principal: principal.clone(),
|
||||||
|
reserved: max_tokens,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn settle(&self, reservation: Reservation, actual_tokens: u64) {
|
||||||
|
let mut ledgers = self.ledgers.lock().expect("ledger mutex poisoned");
|
||||||
|
if let Some(ledger) = ledgers.get_mut(&reservation.principal.key_id) {
|
||||||
|
ledger.reserved = ledger.reserved.saturating_sub(reservation.reserved);
|
||||||
|
ledger.spent += actual_tokens;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn release(&self, reservation: Reservation) {
|
||||||
|
let mut ledgers = self.ledgers.lock().expect("ledger mutex poisoned");
|
||||||
|
if let Some(ledger) = ledgers.get_mut(&reservation.principal.key_id) {
|
||||||
|
ledger.reserved = ledger.reserved.saturating_sub(reservation.reserved);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn snapshot(&self, principal: &Principal) -> Option<BudgetSnapshot> {
|
||||||
|
let ledgers = self.ledgers.lock().expect("ledger mutex poisoned");
|
||||||
|
let (spent, reserved) = ledgers
|
||||||
|
.get(&principal.key_id)
|
||||||
|
.map(|l| (l.spent, l.reserved))
|
||||||
|
.unwrap_or((0, 0));
|
||||||
|
let hard_cap = self.budgets.get(&principal.key_id).and_then(|b| b.hard_cap);
|
||||||
|
Some(BudgetSnapshot {
|
||||||
|
hard_cap,
|
||||||
|
spent,
|
||||||
|
reserved,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn provider() -> LocalEntitlementProvider {
|
||||||
|
let config = EntitlementsConfig {
|
||||||
|
require_auth: true,
|
||||||
|
keys: vec![
|
||||||
|
ApiKeyConfig {
|
||||||
|
key: "sk-balance".into(),
|
||||||
|
account_id: "acct-a".into(),
|
||||||
|
key_id: Some("key-balance".into()),
|
||||||
|
hard_cap: Some(1_000),
|
||||||
|
window: CapWindow::Balance,
|
||||||
|
},
|
||||||
|
ApiKeyConfig {
|
||||||
|
key: "sk-rolling".into(),
|
||||||
|
account_id: "acct-b".into(),
|
||||||
|
key_id: Some("key-rolling".into()),
|
||||||
|
hard_cap: Some(500),
|
||||||
|
window: CapWindow::Rolling { seconds: 3_600 },
|
||||||
|
},
|
||||||
|
ApiKeyConfig {
|
||||||
|
key: "sk-infra".into(),
|
||||||
|
account_id: "operator".into(),
|
||||||
|
key_id: Some("key-infra".into()),
|
||||||
|
hard_cap: None,
|
||||||
|
window: CapWindow::Balance,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
LocalEntitlementProvider::from_config(&config)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn resolves_configured_key_to_principal() {
|
||||||
|
let p = provider();
|
||||||
|
let principal = p.resolve("sk-balance").await.expect("known key resolves");
|
||||||
|
assert_eq!(principal.account_id, "acct-a");
|
||||||
|
assert_eq!(principal.key_id, "key-balance");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn unknown_key_is_invalid() {
|
||||||
|
let p = provider();
|
||||||
|
assert!(matches!(
|
||||||
|
p.resolve("sk-nope").await,
|
||||||
|
Err(AuthError::InvalidKey)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reserve_settle_release_round_trip() {
|
||||||
|
let p = provider();
|
||||||
|
let principal = p.resolve("sk-balance").await.unwrap();
|
||||||
|
|
||||||
|
let r = p.reserve(&principal, 400).await.expect("within cap");
|
||||||
|
// Reserved, not yet spent.
|
||||||
|
let snap = p.snapshot(&principal).await.unwrap();
|
||||||
|
assert_eq!(snap.hard_cap, Some(1_000));
|
||||||
|
assert_eq!(snap.reserved, 400);
|
||||||
|
assert_eq!(snap.spent, 0);
|
||||||
|
|
||||||
|
// Used fewer tokens than reserved → remainder released, spend exact.
|
||||||
|
p.settle(r, 250).await;
|
||||||
|
let snap = p.snapshot(&principal).await.unwrap();
|
||||||
|
assert_eq!(snap.reserved, 0);
|
||||||
|
assert_eq!(snap.spent, 250);
|
||||||
|
|
||||||
|
// A reservation that is released contributes no spend.
|
||||||
|
let r2 = p.reserve(&principal, 100).await.unwrap();
|
||||||
|
p.release(r2).await;
|
||||||
|
let snap = p.snapshot(&principal).await.unwrap();
|
||||||
|
assert_eq!(snap.reserved, 0);
|
||||||
|
assert_eq!(snap.spent, 250);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn balance_over_cap_is_insufficient_quota() {
|
||||||
|
let p = provider();
|
||||||
|
let principal = p.resolve("sk-balance").await.unwrap();
|
||||||
|
// Reserve most of the cap, then ask for more than remains.
|
||||||
|
let _r = p.reserve(&principal, 900).await.unwrap();
|
||||||
|
let err = p.reserve(&principal, 200).await.expect_err("over cap");
|
||||||
|
match err {
|
||||||
|
BudgetError::InsufficientQuota {
|
||||||
|
requested,
|
||||||
|
available,
|
||||||
|
} => {
|
||||||
|
assert_eq!(requested, 200);
|
||||||
|
assert_eq!(available, 100);
|
||||||
|
}
|
||||||
|
other => panic!("expected InsufficientQuota, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn rolling_over_cap_is_rate_limited_with_retry_after() {
|
||||||
|
let p = provider();
|
||||||
|
let principal = p.resolve("sk-rolling").await.unwrap();
|
||||||
|
let _r = p.reserve(&principal, 500).await.unwrap();
|
||||||
|
let err = p.reserve(&principal, 1).await.expect_err("over cap");
|
||||||
|
match err {
|
||||||
|
BudgetError::RateLimited {
|
||||||
|
retry_after_secs, ..
|
||||||
|
} => {
|
||||||
|
assert!(retry_after_secs >= 1, "must advertise a retry hint");
|
||||||
|
assert!(retry_after_secs <= 3_600);
|
||||||
|
}
|
||||||
|
other => panic!("expected RateLimited, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn uncapped_infra_key_never_refuses() {
|
||||||
|
let p = provider();
|
||||||
|
let principal = p.resolve("sk-infra").await.unwrap();
|
||||||
|
let r = p.reserve(&principal, 10_000_000).await.expect("uncapped");
|
||||||
|
p.settle(r, 10_000_000).await;
|
||||||
|
let snap = p.snapshot(&principal).await.unwrap();
|
||||||
|
assert_eq!(snap.hard_cap, None);
|
||||||
|
assert_eq!(snap.spent, 10_000_000);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -190,7 +190,7 @@ async fn completions(
|
|||||||
/// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back.
|
/// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back.
|
||||||
async fn anthropic_messages(
|
async fn anthropic_messages(
|
||||||
State(fleet): State<Arc<CortexState>>,
|
State(fleet): State<Arc<CortexState>>,
|
||||||
_headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
// Parse as Anthropic request.
|
// Parse as Anthropic request.
|
||||||
@@ -306,6 +306,29 @@ async fn anthropic_messages(
|
|||||||
}
|
}
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
|
// Per-request metering + budget enforcement (#51/#52), same lifecycle as
|
||||||
|
// the OpenAI paths. Estimate from the translated OpenAI body (what neuron
|
||||||
|
// sees). Refuse over-cap before dispatch via the #63 envelope; otherwise
|
||||||
|
// build the sink consumed by whichever branch runs below.
|
||||||
|
let usage_sink = match crate::metering::principal_from_headers(&headers) {
|
||||||
|
Some(principal) => {
|
||||||
|
let advertised =
|
||||||
|
advertised_output_limit(&fleet, &route.node_name, &route.resolved_model_id).await;
|
||||||
|
let max_tokens = crate::metering::reservation_estimate(&openai_body, advertised);
|
||||||
|
match crate::metering::reserve_or_reject(
|
||||||
|
Arc::clone(&fleet.entitlements),
|
||||||
|
&principal,
|
||||||
|
max_tokens,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(guard) => Some(crate::metering::usage_sink(principal, guard)),
|
||||||
|
Err(env) => return crate::error::envelope_response(env),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
if is_streaming {
|
if is_streaming {
|
||||||
// Anthropic SSE translation (#24): upstream speaks OpenAI SSE;
|
// Anthropic SSE translation (#24): upstream speaks OpenAI SSE;
|
||||||
// re-frame it event-by-event into Anthropic's message_start /
|
// re-frame it event-by-event into Anthropic's message_start /
|
||||||
@@ -316,6 +339,8 @@ async fn anthropic_messages(
|
|||||||
openai_body,
|
openai_body,
|
||||||
&model_id,
|
&model_id,
|
||||||
&route.node_name,
|
&route.node_name,
|
||||||
|
&headers,
|
||||||
|
usage_sink,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
metrics::histogram!("cortex_request_duration_seconds", &labels)
|
metrics::histogram!("cortex_request_duration_seconds", &labels)
|
||||||
@@ -335,13 +360,16 @@ async fn anthropic_messages(
|
|||||||
cold_start = route.cold_start,
|
cold_start = route.cold_start,
|
||||||
"proxying request"
|
"proxying request"
|
||||||
);
|
);
|
||||||
let upstream_resp = fleet
|
let upstream_resp = crate::auth::forward_principal_headers(
|
||||||
.http_client
|
fleet
|
||||||
.post(&target_url)
|
.http_client
|
||||||
.body(openai_body)
|
.post(&target_url)
|
||||||
.header("content-type", "application/json")
|
.body(openai_body)
|
||||||
.send()
|
.header("content-type", "application/json"),
|
||||||
.await;
|
&headers,
|
||||||
|
)
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
|
||||||
let upstream_resp = match upstream_resp {
|
let upstream_resp = match upstream_resp {
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
@@ -437,6 +465,15 @@ async fn anthropic_messages(
|
|||||||
|
|
||||||
metrics::histogram!("cortex_request_duration_seconds", &labels)
|
metrics::histogram!("cortex_request_duration_seconds", &labels)
|
||||||
.record(start.elapsed().as_secs_f64());
|
.record(start.elapsed().as_secs_f64());
|
||||||
|
// Settle metering with the upstream usage (#51). Scanned from the
|
||||||
|
// raw body — same engine-truth source as the streaming path — so we
|
||||||
|
// don't depend on the typed usage struct's optionality.
|
||||||
|
if let Some(sink) = usage_sink {
|
||||||
|
let tail = String::from_utf8_lossy(&body_bytes);
|
||||||
|
let prompt = proxy::last_count_for(&tail, "prompt_tokens").unwrap_or(0);
|
||||||
|
let completion = proxy::last_count_for(&tail, "completion_tokens").unwrap_or(0);
|
||||||
|
sink(prompt, completion);
|
||||||
|
}
|
||||||
// Did the model actually produce a structured tool call, or just
|
// Did the model actually produce a structured tool call, or just
|
||||||
// text? This is the single most useful signal for "is tool
|
// text? This is the single most useful signal for "is tool
|
||||||
// calling working end-to-end" — a `false` here alongside a
|
// calling working end-to-end" — a `false` here alongside a
|
||||||
@@ -734,9 +771,42 @@ async fn proxy_with_metrics(
|
|||||||
metrics::counter!("cortex_cold_starts_total", &labels).increment(1);
|
metrics::counter!("cortex_cold_starts_total", &labels).increment(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Per-request metering + budget enforcement (#51/#52): reconstruct the
|
||||||
|
// principal from the middleware-stamped headers, reserve the request's
|
||||||
|
// upper-bound cost (prompt estimate + max output), and build the
|
||||||
|
// completion sink that settles actual spend when the response finishes.
|
||||||
|
// A reservation over the hard cap is refused *before* dispatch with the
|
||||||
|
// #63 envelope. Anonymous requests skip all of this. Must happen before
|
||||||
|
// `headers`/`body` are moved into the proxy.
|
||||||
|
let usage_sink = match crate::metering::principal_from_headers(&headers) {
|
||||||
|
Some(principal) => {
|
||||||
|
let advertised = advertised_output_limit(fleet, &route.node_name, model_id).await;
|
||||||
|
let max_tokens = crate::metering::reservation_estimate(&body, advertised);
|
||||||
|
match crate::metering::reserve_or_reject(
|
||||||
|
Arc::clone(&fleet.entitlements),
|
||||||
|
&principal,
|
||||||
|
max_tokens,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(guard) => Some(crate::metering::usage_sink(principal, guard)),
|
||||||
|
Err(env) => return crate::error::envelope_response(env),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let result =
|
let result = proxy::forward_request(
|
||||||
proxy::forward_request(&fleet.http_client, route, path, headers, body, model_id).await;
|
&fleet.http_client,
|
||||||
|
route,
|
||||||
|
path,
|
||||||
|
headers,
|
||||||
|
body,
|
||||||
|
model_id,
|
||||||
|
usage_sink,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
let duration = start.elapsed();
|
let duration = start.elapsed();
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
@@ -755,6 +825,25 @@ async fn proxy_with_metrics(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The model's advertised `limit.output` (#62) on a given node, used as the
|
||||||
|
/// default output budget for budget reservations (#52) when the request
|
||||||
|
/// omits `max_(completion_)tokens`. `None` when the node/model/limit is
|
||||||
|
/// unknown — callers fall back to [`crate::metering::FALLBACK_MAX_OUTPUT`].
|
||||||
|
async fn advertised_output_limit(
|
||||||
|
fleet: &CortexState,
|
||||||
|
node_name: &str,
|
||||||
|
model_id: &str,
|
||||||
|
) -> Option<u64> {
|
||||||
|
let nodes = fleet.nodes.read().await;
|
||||||
|
nodes
|
||||||
|
.get(node_name)?
|
||||||
|
.models
|
||||||
|
.get(model_id)?
|
||||||
|
.limit
|
||||||
|
.as_ref()
|
||||||
|
.map(|l| l.output as u64)
|
||||||
|
}
|
||||||
|
|
||||||
/// Update `last_accessed` timestamp for a model on a node (drives LRU eviction).
|
/// Update `last_accessed` timestamp for a model on a node (drives LRU eviction).
|
||||||
async fn touch_model(fleet: &CortexState, node_name: &str, model_id: &str) {
|
async fn touch_model(fleet: &CortexState, node_name: &str, model_id: &str) {
|
||||||
let mut nodes = fleet.nodes.write().await;
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
pub mod anthropic_sse;
|
pub mod anthropic_sse;
|
||||||
|
pub mod auth;
|
||||||
|
pub mod entitlements_local;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod evictor;
|
pub mod evictor;
|
||||||
pub mod handlers;
|
pub mod handlers;
|
||||||
|
pub mod metering;
|
||||||
pub mod metrics;
|
pub mod metrics;
|
||||||
pub mod poller;
|
pub mod poller;
|
||||||
pub mod proxy;
|
pub mod proxy;
|
||||||
@@ -10,15 +13,26 @@ pub mod state;
|
|||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
|
use axum::middleware::from_fn_with_state;
|
||||||
use cortex_core::config::GatewayConfig;
|
use cortex_core::config::GatewayConfig;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tower_http::cors::CorsLayer;
|
use tower_http::cors::CorsLayer;
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
|
|
||||||
/// Build the Axum application router with all routes wired up.
|
/// Build the Axum application router with all routes wired up.
|
||||||
|
///
|
||||||
|
/// Layer order (outermost first): trace → CORS → auth → handlers. CORS is
|
||||||
|
/// outer to auth so preflight `OPTIONS` short-circuits before resolution;
|
||||||
|
/// auth (`require_principal`) resolves the bearer key, attaches the
|
||||||
|
/// principal, and stamps the internal principal headers before any handler
|
||||||
|
/// runs.
|
||||||
pub fn build_app(fleet: Arc<state::CortexState>) -> Router {
|
pub fn build_app(fleet: Arc<state::CortexState>) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.merge(handlers::api_routes())
|
.merge(handlers::api_routes())
|
||||||
|
.layer(from_fn_with_state(
|
||||||
|
Arc::clone(&fleet),
|
||||||
|
auth::require_principal,
|
||||||
|
))
|
||||||
.layer(CorsLayer::permissive())
|
.layer(CorsLayer::permissive())
|
||||||
.layer(TraceLayer::new_for_http())
|
.layer(TraceLayer::new_for_http())
|
||||||
.with_state(fleet)
|
.with_state(fleet)
|
||||||
|
|||||||
219
crates/cortex-gateway/src/metering.rs
Normal file
219
crates/cortex-gateway/src/metering.rs
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
//! Per-request token metering (#51).
|
||||||
|
//!
|
||||||
|
//! Captures the real `(prompt, completion)` usage of every request and feeds
|
||||||
|
//! it to two places: the [`EntitlementProvider`] spend ledger (via
|
||||||
|
//! reserve→settle) and per-principal Prometheus counters. The principal is
|
||||||
|
//! reconstructed from the internal headers the auth middleware stamped (#49),
|
||||||
|
//! so this works uniformly across every proxy path without threading the
|
||||||
|
//! typed principal through each handler.
|
||||||
|
//!
|
||||||
|
//! The reserve→settle lifecycle is established here but, in this phase,
|
||||||
|
//! reserves **zero** tokens — metering only, no enforcement. Budget
|
||||||
|
//! enforcement (#52) flips the reserved amount to the real
|
||||||
|
//! `prompt + max_output` and handles the [`BudgetError`] rejection; the
|
||||||
|
//! settle/release plumbing is identical, so that change is localized.
|
||||||
|
//!
|
||||||
|
//! [`ReservationGuard`] makes leaks impossible: settling records actual
|
||||||
|
//! spend and releases the unused remainder; dropping a guard that was never
|
||||||
|
//! settled releases the whole reservation. So an early return, error path,
|
||||||
|
//! or dropped stream can't strand a reservation.
|
||||||
|
|
||||||
|
use axum::http::HeaderMap;
|
||||||
|
use cortex_core::entitlements::{
|
||||||
|
BudgetError, EntitlementProvider, HEADER_ACCOUNT_ID, HEADER_KEY_ID, Principal,
|
||||||
|
};
|
||||||
|
use cortex_core::error_envelope::OpenAiError;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Fallback output-token budget when neither the request nor the model's
|
||||||
|
/// advertised limit gives one. Bounds the reservation so a capped key is
|
||||||
|
/// still gated even on under-specified requests (#52).
|
||||||
|
pub const FALLBACK_MAX_OUTPUT: u64 = 4096;
|
||||||
|
|
||||||
|
/// Invoked exactly once at request completion with best-effort
|
||||||
|
/// `(prompt_tokens, completion_tokens)`. When no usage could be observed
|
||||||
|
/// (e.g. a pre-dispatch failure or a dropped stream) it is dropped unused —
|
||||||
|
/// which releases the held reservation via [`ReservationGuard`]'s `Drop`.
|
||||||
|
pub type UsageSink = Box<dyn FnOnce(u64, u64) + Send>;
|
||||||
|
|
||||||
|
/// Reconstruct the principal from the cortex-stamped internal headers. The
|
||||||
|
/// auth middleware strips any client copy and stamps the authoritative value,
|
||||||
|
/// so these headers are trustworthy within cortex. `None` for anonymous
|
||||||
|
/// (unauthenticated) requests.
|
||||||
|
pub fn principal_from_headers(headers: &HeaderMap) -> Option<Principal> {
|
||||||
|
let account_id = headers.get(HEADER_ACCOUNT_ID)?.to_str().ok()?.to_string();
|
||||||
|
let key_id = headers.get(HEADER_KEY_ID)?.to_str().ok()?.to_string();
|
||||||
|
Some(Principal { account_id, key_id })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Emit per-principal spend counters (#51). Labelled by account/key only —
|
||||||
|
/// both are operator-bounded, so cardinality is controlled.
|
||||||
|
pub fn record_spend(principal: &Principal, prompt: u64, completion: u64) {
|
||||||
|
let labels = [
|
||||||
|
("account", principal.account_id.clone()),
|
||||||
|
("key", principal.key_id.clone()),
|
||||||
|
];
|
||||||
|
metrics::counter!("cortex_spend_tokens_total", &labels).increment(prompt + completion);
|
||||||
|
metrics::counter!("cortex_spend_prompt_tokens_total", &labels).increment(prompt);
|
||||||
|
metrics::counter!("cortex_spend_completion_tokens_total", &labels).increment(completion);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Holds a budget reservation for the life of a request. [`settle`] records
|
||||||
|
/// actual spend and releases the remainder; an un-settled guard releases the
|
||||||
|
/// whole reservation when dropped. Anonymous requests carry an empty guard,
|
||||||
|
/// where every operation is a no-op.
|
||||||
|
///
|
||||||
|
/// [`settle`]: ReservationGuard::settle
|
||||||
|
pub struct ReservationGuard {
|
||||||
|
provider: Arc<dyn EntitlementProvider>,
|
||||||
|
reservation: Option<cortex_core::entitlements::Reservation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReservationGuard {
|
||||||
|
/// An empty guard for an anonymous request — no reservation to resolve.
|
||||||
|
pub fn anonymous(provider: Arc<dyn EntitlementProvider>) -> Self {
|
||||||
|
Self {
|
||||||
|
provider,
|
||||||
|
reservation: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrap an already-acquired reservation.
|
||||||
|
fn held(
|
||||||
|
provider: Arc<dyn EntitlementProvider>,
|
||||||
|
reservation: cortex_core::entitlements::Reservation,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
provider,
|
||||||
|
reservation: Some(reservation),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Settle with the tokens actually consumed, disarming the drop-release.
|
||||||
|
/// Spawns the (fast, in-process for the local provider) settle so the
|
||||||
|
/// caller — which may be a sync stream-completion callback — needn't
|
||||||
|
/// await.
|
||||||
|
pub fn settle(mut self, actual_tokens: u64) {
|
||||||
|
if let Some(reservation) = self.reservation.take() {
|
||||||
|
let provider = Arc::clone(&self.provider);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
provider.settle(reservation, actual_tokens).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ReservationGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(reservation) = self.reservation.take() {
|
||||||
|
let provider = Arc::clone(&self.provider);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
provider.release(reservation).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build the completion sink for an authenticated request: record spend and
|
||||||
|
/// settle the reservation with the observed total. Dropping it unused (no
|
||||||
|
/// usage observed) releases the reservation via the guard.
|
||||||
|
pub fn usage_sink(principal: Principal, guard: ReservationGuard) -> UsageSink {
|
||||||
|
Box::new(move |prompt, completion| {
|
||||||
|
record_spend(&principal, prompt, completion);
|
||||||
|
guard.settle(prompt + completion);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reserve the request's upper-bound token cost for the principal, refusing
|
||||||
|
/// *before* dispatch if it would exceed the hard cap (#52). On success
|
||||||
|
/// returns a guard the caller settles with actual usage; on refusal returns
|
||||||
|
/// the #63 envelope (`rate_limit_exceeded` + `Retry-After` for a resetting
|
||||||
|
/// window, `insufficient_quota` for a hard balance — never `402`).
|
||||||
|
pub async fn reserve_or_reject(
|
||||||
|
provider: Arc<dyn EntitlementProvider>,
|
||||||
|
principal: &Principal,
|
||||||
|
max_tokens: u64,
|
||||||
|
) -> Result<ReservationGuard, OpenAiError> {
|
||||||
|
match provider.reserve(principal, max_tokens).await {
|
||||||
|
Ok(reservation) => Ok(ReservationGuard::held(provider, reservation)),
|
||||||
|
Err(err) => Err(budget_error_to_envelope(err)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map a [`BudgetError`] to the #63 envelope. The provider chose the window
|
||||||
|
/// semantics; this only translates them to HTTP.
|
||||||
|
fn budget_error_to_envelope(err: BudgetError) -> OpenAiError {
|
||||||
|
match err {
|
||||||
|
BudgetError::RateLimited {
|
||||||
|
retry_after_secs, ..
|
||||||
|
} => OpenAiError::rate_limit_exceeded(err.to_string(), retry_after_secs),
|
||||||
|
BudgetError::InsufficientQuota { .. } => OpenAiError::insufficient_quota(err.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Upper-bound tokens to reserve for a request (#52): an over-estimate of
|
||||||
|
/// the prompt plus the maximum output. `advertised_output` is the model's
|
||||||
|
/// `limit.output` (#62), used when the request omits `max_(completion_)tokens`.
|
||||||
|
/// Over-reserving is safe — settle corrects spend to the actual usage.
|
||||||
|
pub fn reservation_estimate(body: &[u8], advertised_output: Option<u64>) -> u64 {
|
||||||
|
let max_output = requested_max_output(body)
|
||||||
|
.or(advertised_output)
|
||||||
|
.unwrap_or(FALLBACK_MAX_OUTPUT);
|
||||||
|
estimate_prompt_tokens(body).saturating_add(max_output)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The client's requested output cap, from `max_completion_tokens` (or the
|
||||||
|
/// legacy `max_tokens`). `None` when unspecified.
|
||||||
|
fn requested_max_output(body: &[u8]) -> Option<u64> {
|
||||||
|
let v: serde_json::Value = serde_json::from_slice(body).ok()?;
|
||||||
|
v.get("max_completion_tokens")
|
||||||
|
.or_else(|| v.get("max_tokens"))
|
||||||
|
.and_then(serde_json::Value::as_u64)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rough prompt-token estimate at ~4 chars/token over the whole body. cortex
|
||||||
|
/// has no tokenizer; JSON overhead makes this a conservative over-estimate,
|
||||||
|
/// and neuron remains the exact context wall (#56/#60). Settle reconciles to
|
||||||
|
/// the real usage afterward.
|
||||||
|
fn estimate_prompt_tokens(body: &[u8]) -> u64 {
|
||||||
|
(body.len() as u64 / 4).max(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn requested_max_output_prefers_max_completion_tokens() {
|
||||||
|
let body = br#"{"model":"m","max_completion_tokens":256,"max_tokens":99}"#;
|
||||||
|
assert_eq!(requested_max_output(body), Some(256));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn requested_max_output_falls_back_to_legacy_max_tokens() {
|
||||||
|
let body = br#"{"model":"m","max_tokens":128}"#;
|
||||||
|
assert_eq!(requested_max_output(body), Some(128));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn estimate_uses_requested_output_when_present() {
|
||||||
|
// Requested output dominates; prompt estimate is small for a tiny body.
|
||||||
|
let body = br#"{"model":"m","max_tokens":1000}"#;
|
||||||
|
let est = reservation_estimate(body, Some(8192));
|
||||||
|
assert!(est >= 1000 && est < 1100, "est was {est}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn estimate_uses_advertised_output_when_request_omits_it() {
|
||||||
|
let body = br#"{"model":"m","messages":[]}"#;
|
||||||
|
let est = reservation_estimate(body, Some(8192));
|
||||||
|
assert!(est >= 8192, "est was {est}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn estimate_falls_back_when_nothing_advertised() {
|
||||||
|
let body = br#"{"model":"m"}"#;
|
||||||
|
let est = reservation_estimate(body, None);
|
||||||
|
assert!(est >= FALLBACK_MAX_OUTPUT, "est was {est}");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -63,4 +63,16 @@ fn describe_metrics() {
|
|||||||
"cortex_cold_starts_total",
|
"cortex_cold_starts_total",
|
||||||
"Total number of cold-start model loads"
|
"Total number of cold-start model loads"
|
||||||
);
|
);
|
||||||
|
metrics::describe_counter!(
|
||||||
|
"cortex_spend_tokens_total",
|
||||||
|
"Total metered tokens (prompt + completion) per principal, labelled by account/key (#51)"
|
||||||
|
);
|
||||||
|
metrics::describe_counter!(
|
||||||
|
"cortex_spend_prompt_tokens_total",
|
||||||
|
"Metered prompt tokens per principal, labelled by account/key (#51)"
|
||||||
|
);
|
||||||
|
metrics::describe_counter!(
|
||||||
|
"cortex_spend_completion_tokens_total",
|
||||||
|
"Metered completion tokens per principal, labelled by account/key (#51)"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ pub async fn forward_request(
|
|||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
body: bytes::Bytes,
|
body: bytes::Bytes,
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
|
usage_sink: Option<crate::metering::UsageSink>,
|
||||||
) -> Result<Response, ProxyError> {
|
) -> Result<Response, ProxyError> {
|
||||||
let request_start = Instant::now();
|
let request_start = Instant::now();
|
||||||
let url = format!("{}{}", route.endpoint, path);
|
let url = format!("{}{}", route.endpoint, path);
|
||||||
@@ -82,7 +83,7 @@ pub async fn forward_request(
|
|||||||
let resp_headers = upstream_resp.headers().clone();
|
let resp_headers = upstream_resp.headers().clone();
|
||||||
let stream = TokenMetricsStream::new(
|
let stream = TokenMetricsStream::new(
|
||||||
Box::pin(upstream_resp.bytes_stream()),
|
Box::pin(upstream_resp.bytes_stream()),
|
||||||
TokenMetrics::new(model_id, &route.node_name, request_start),
|
TokenMetrics::new(model_id, &route.node_name, request_start, usage_sink),
|
||||||
);
|
);
|
||||||
|
|
||||||
let body = Body::from_stream(stream);
|
let body = Body::from_stream(stream);
|
||||||
@@ -186,10 +187,19 @@ struct TokenMetrics {
|
|||||||
last_chunk: Option<Instant>,
|
last_chunk: Option<Instant>,
|
||||||
tail: String,
|
tail: String,
|
||||||
finished: bool,
|
finished: bool,
|
||||||
|
/// Per-principal metering hook (#51). Invoked exactly once in `finish`
|
||||||
|
/// with the observed `(prompt, completion)` so the reservation can be
|
||||||
|
/// settled and spend recorded. `None` for anonymous requests.
|
||||||
|
usage_sink: Option<crate::metering::UsageSink>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TokenMetrics {
|
impl TokenMetrics {
|
||||||
fn new(model_id: &str, node_name: &str, request_start: Instant) -> Self {
|
fn new(
|
||||||
|
model_id: &str,
|
||||||
|
node_name: &str,
|
||||||
|
request_start: Instant,
|
||||||
|
usage_sink: Option<crate::metering::UsageSink>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
labels: [
|
labels: [
|
||||||
("model", model_id.to_string()),
|
("model", model_id.to_string()),
|
||||||
@@ -200,6 +210,7 @@ impl TokenMetrics {
|
|||||||
last_chunk: None,
|
last_chunk: None,
|
||||||
tail: String::new(),
|
tail: String::new(),
|
||||||
finished: false,
|
finished: false,
|
||||||
|
usage_sink,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,36 +238,45 @@ impl TokenMetrics {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
self.finished = true;
|
self.finished = true;
|
||||||
let Some(first) = self.first_chunk else {
|
|
||||||
return; // no body ever arrived — nothing to record
|
|
||||||
};
|
|
||||||
let ttft = first.duration_since(self.request_start).as_secs_f64();
|
|
||||||
metrics::histogram!("cortex_time_to_first_token_seconds", &self.labels).record(ttft);
|
|
||||||
|
|
||||||
if let Some(prompt) = last_count_for(&self.tail, "prompt_tokens") {
|
let prompt = last_count_for(&self.tail, "prompt_tokens");
|
||||||
metrics::counter!("cortex_prompt_tokens_total", &self.labels).increment(prompt);
|
let completion = last_count_for(&self.tail, "completion_tokens");
|
||||||
}
|
|
||||||
let Some(completion) = last_count_for(&self.tail, "completion_tokens") else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
if completion == 0 {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
metrics::counter!("cortex_completion_tokens_total", &self.labels).increment(completion);
|
|
||||||
|
|
||||||
let last = self.last_chunk.unwrap_or(first);
|
// Per-model metrics — only when body chunks actually arrived.
|
||||||
let decode_window = last.duration_since(first).as_secs_f64();
|
if let Some(first) = self.first_chunk {
|
||||||
// Streaming: rate over the decode window (first→last chunk).
|
let ttft = first.duration_since(self.request_start).as_secs_f64();
|
||||||
// Non-streaming bodies arrive as ~one chunk (window ≈ 0), where
|
metrics::histogram!("cortex_time_to_first_token_seconds", &self.labels).record(ttft);
|
||||||
// the only honest denominator is the full request duration.
|
|
||||||
let secs = if decode_window >= 0.1 {
|
if let Some(prompt) = prompt {
|
||||||
decode_window
|
metrics::counter!("cortex_prompt_tokens_total", &self.labels).increment(prompt);
|
||||||
} else {
|
}
|
||||||
last.duration_since(self.request_start).as_secs_f64()
|
if let Some(completion) = completion.filter(|c| *c > 0) {
|
||||||
};
|
metrics::counter!("cortex_completion_tokens_total", &self.labels)
|
||||||
if secs > 0.0 {
|
.increment(completion);
|
||||||
metrics::histogram!("cortex_tokens_per_second", &self.labels)
|
|
||||||
.record(completion as f64 / secs);
|
let last = self.last_chunk.unwrap_or(first);
|
||||||
|
let decode_window = last.duration_since(first).as_secs_f64();
|
||||||
|
// Streaming: rate over the decode window (first→last chunk).
|
||||||
|
// Non-streaming bodies arrive as ~one chunk (window ≈ 0),
|
||||||
|
// where the only honest denominator is the full request
|
||||||
|
// duration.
|
||||||
|
let secs = if decode_window >= 0.1 {
|
||||||
|
decode_window
|
||||||
|
} else {
|
||||||
|
last.duration_since(self.request_start).as_secs_f64()
|
||||||
|
};
|
||||||
|
if secs > 0.0 {
|
||||||
|
metrics::histogram!("cortex_tokens_per_second", &self.labels)
|
||||||
|
.record(completion as f64 / secs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-principal metering + reservation settle (#51). Always runs so
|
||||||
|
// the reservation is resolved even when no usage/body was observed
|
||||||
|
// (sink with (0, 0) → settle 0 → release).
|
||||||
|
if let Some(sink) = self.usage_sink.take() {
|
||||||
|
sink(prompt.unwrap_or(0), completion.unwrap_or(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
|
use crate::entitlements_local::LocalEntitlementProvider;
|
||||||
use cortex_core::catalogue::ModelCatalogue;
|
use cortex_core::catalogue::ModelCatalogue;
|
||||||
use cortex_core::config::{EvictionSettings, GatewayConfig, NeuronEndpoint};
|
use cortex_core::config::{EvictionSettings, GatewayConfig, NeuronEndpoint};
|
||||||
|
use cortex_core::entitlements::EntitlementProvider;
|
||||||
use cortex_core::node::NodeState;
|
use cortex_core::node::NodeState;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
/// Shared fleet state, protected by a RwLock for concurrent reader access.
|
/// Shared fleet state, protected by a RwLock for concurrent reader access.
|
||||||
@@ -11,6 +14,12 @@ pub struct CortexState {
|
|||||||
pub eviction: EvictionSettings,
|
pub eviction: EvictionSettings,
|
||||||
pub catalogue: ModelCatalogue,
|
pub catalogue: ModelCatalogue,
|
||||||
pub http_client: reqwest::Client,
|
pub http_client: reqwest::Client,
|
||||||
|
/// Resolves bearer keys to principals and enforces token budgets (#47).
|
||||||
|
/// A local/static provider today (#50); the upstream client later (#57).
|
||||||
|
pub entitlements: Arc<dyn EntitlementProvider>,
|
||||||
|
/// Whether to reject unauthenticated requests (#49). Read by the auth
|
||||||
|
/// middleware once it lands.
|
||||||
|
pub require_auth: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CortexState {
|
impl CortexState {
|
||||||
@@ -34,6 +43,9 @@ impl CortexState {
|
|||||||
|
|
||||||
let catalogue = ModelCatalogue::load(&config.models_config);
|
let catalogue = ModelCatalogue::load(&config.models_config);
|
||||||
|
|
||||||
|
let entitlements: Arc<dyn EntitlementProvider> =
|
||||||
|
Arc::new(LocalEntitlementProvider::from_config(&config.entitlements));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
nodes: RwLock::new(nodes),
|
nodes: RwLock::new(nodes),
|
||||||
neuron_configs: config.neurons.clone(),
|
neuron_configs: config.neurons.clone(),
|
||||||
@@ -43,6 +55,8 @@ impl CortexState {
|
|||||||
.timeout(std::time::Duration::from_secs(300))
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
.build()
|
.build()
|
||||||
.expect("failed to build HTTP client"),
|
.expect("failed to build HTTP client"),
|
||||||
|
entitlements,
|
||||||
|
require_auth: config.entitlements.require_auth,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ async fn test_alias_resolves_in_chat_completions() {
|
|||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
}],
|
}],
|
||||||
models_config: models_path.to_string_lossy().to_string(),
|
models_config: models_path.to_string_lossy().to_string(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
@@ -141,6 +142,7 @@ async fn test_aliases_surface_in_v1_models() {
|
|||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
}],
|
}],
|
||||||
models_config: models_path.to_string_lossy().to_string(),
|
models_config: models_path.to_string_lossy().to_string(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
@@ -229,6 +231,7 @@ async fn test_alias_falls_through_for_unmapped_model() {
|
|||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
}],
|
}],
|
||||||
models_config: models_path.to_string_lossy().to_string(),
|
models_config: models_path.to_string_lossy().to_string(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|||||||
250
crates/cortex-gateway/tests/auth.rs
Normal file
250
crates/cortex-gateway/tests/auth.rs
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
//! Integration tests for API-key auth + principal resolution (#49).
|
||||||
|
//!
|
||||||
|
//! Verifies the #63 rejection contract (401 invalid_api_key via the #60
|
||||||
|
//! envelope) and that an authenticated request reaches neuron carrying the
|
||||||
|
//! internal principal headers — while a client-supplied principal header is
|
||||||
|
//! stripped (anti-spoofing).
|
||||||
|
|
||||||
|
use axum::Json;
|
||||||
|
use axum::extract::Path;
|
||||||
|
use axum::http::HeaderMap;
|
||||||
|
use axum::routing::{get, post};
|
||||||
|
use cortex_core::config::{
|
||||||
|
ApiKeyConfig, EntitlementsConfig, EvictionSettings, EvictionStrategy, GatewayConfig,
|
||||||
|
GatewaySettings, NeuronEndpoint,
|
||||||
|
};
|
||||||
|
use cortex_core::entitlements::{CapWindow, HEADER_ACCOUNT_ID, HEADER_KEY_ID};
|
||||||
|
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||||
|
use cortex_gateway::state::CortexState;
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
/// What the mock neuron observed on the inbound `/v1/chat/completions`
|
||||||
|
/// request: the principal headers cortex stamped (or didn't).
|
||||||
|
#[derive(Default)]
|
||||||
|
struct Seen {
|
||||||
|
account_id: Option<String>,
|
||||||
|
key_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spawn a mock neuron that records the principal headers it receives and
|
||||||
|
/// returns a trivial chat completion. Returns (base_url, observed).
|
||||||
|
async fn spawn_capturing_neuron() -> (String, Arc<Mutex<Seen>>) {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let base_url = format!("http://{addr}");
|
||||||
|
let inference_url = base_url.clone();
|
||||||
|
let seen: Arc<Mutex<Seen>> = Arc::new(Mutex::new(Seen::default()));
|
||||||
|
let sink = Arc::clone(&seen);
|
||||||
|
|
||||||
|
let app = axum::Router::new()
|
||||||
|
.route(
|
||||||
|
"/models/{model_id}/endpoint",
|
||||||
|
get(move |Path(_): Path<String>| {
|
||||||
|
let url = inference_url.clone();
|
||||||
|
async move { Json(json!({ "url": url })) }
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
post(move |headers: HeaderMap, Json(body): Json<Value>| {
|
||||||
|
let sink = Arc::clone(&sink);
|
||||||
|
async move {
|
||||||
|
{
|
||||||
|
let mut s = sink.lock().unwrap();
|
||||||
|
s.account_id = headers
|
||||||
|
.get(HEADER_ACCOUNT_ID)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(str::to_string);
|
||||||
|
s.key_id = headers
|
||||||
|
.get(HEADER_KEY_ID)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(str::to_string);
|
||||||
|
}
|
||||||
|
let model = body.get("model").and_then(Value::as_str).unwrap_or("m");
|
||||||
|
Json(json!({
|
||||||
|
"id": "chatcmpl-auth-001",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1700000000_u64,
|
||||||
|
"model": model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": "ok"},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {"prompt_tokens": 3, "completion_tokens": 1, "total_tokens": 4}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.with_state(());
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
(base_url, seen)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spawn a gateway with the given entitlements config, a single neuron, and
|
||||||
|
/// `test-model` seeded as loaded (build_app spawns no poller).
|
||||||
|
async fn spawn_gateway(neuron_url: &str, entitlements: EntitlementsConfig) -> String {
|
||||||
|
let config = GatewayConfig {
|
||||||
|
gateway: GatewaySettings {
|
||||||
|
listen: "127.0.0.1:0".into(),
|
||||||
|
metrics_listen: "127.0.0.1:0".into(),
|
||||||
|
},
|
||||||
|
eviction: EvictionSettings {
|
||||||
|
strategy: EvictionStrategy::Lru,
|
||||||
|
defrag_after_cycles: 0,
|
||||||
|
},
|
||||||
|
neurons: vec![NeuronEndpoint {
|
||||||
|
name: "mock-node".into(),
|
||||||
|
endpoint: neuron_url.to_string(),
|
||||||
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements,
|
||||||
|
};
|
||||||
|
|
||||||
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
{
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
let node = nodes.get_mut("mock-node").unwrap();
|
||||||
|
node.healthy = true;
|
||||||
|
node.models.insert(
|
||||||
|
"test-model".into(),
|
||||||
|
ModelEntry {
|
||||||
|
id: "test-model".into(),
|
||||||
|
status: ModelStatus::Loaded,
|
||||||
|
last_accessed: None,
|
||||||
|
vram_estimate_mb: Some(8000),
|
||||||
|
capabilities: Vec::new(),
|
||||||
|
tool_call: false,
|
||||||
|
reasoning: false,
|
||||||
|
limit: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
format!("http://{addr}")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn one_key_config(require_auth: bool) -> EntitlementsConfig {
|
||||||
|
EntitlementsConfig {
|
||||||
|
require_auth,
|
||||||
|
keys: vec![ApiKeyConfig {
|
||||||
|
key: "sk-good".into(),
|
||||||
|
account_id: "acct-1".into(),
|
||||||
|
key_id: Some("key-1".into()),
|
||||||
|
hard_cap: None,
|
||||||
|
window: CapWindow::Balance,
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn chat_body() -> Value {
|
||||||
|
json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [{"role": "user", "content": "hi"}]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn missing_key_when_required_is_401_invalid_api_key() {
|
||||||
|
let (neuron, _seen) = spawn_capturing_neuron().await;
|
||||||
|
let gateway = spawn_gateway(&neuron, one_key_config(true)).await;
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{gateway}/v1/chat/completions"))
|
||||||
|
.json(&chat_body())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::UNAUTHORIZED);
|
||||||
|
let body: Value = resp.json().await.unwrap();
|
||||||
|
assert_eq!(body["error"]["code"], "invalid_api_key");
|
||||||
|
assert_eq!(body["error"]["type"], "invalid_request_error");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn invalid_key_is_401_even_when_auth_not_required() {
|
||||||
|
let (neuron, seen) = spawn_capturing_neuron().await;
|
||||||
|
// A present-but-wrong credential is always an error.
|
||||||
|
let gateway = spawn_gateway(&neuron, one_key_config(false)).await;
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{gateway}/v1/chat/completions"))
|
||||||
|
.bearer_auth("sk-wrong")
|
||||||
|
.json(&chat_body())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::UNAUTHORIZED);
|
||||||
|
let body: Value = resp.json().await.unwrap();
|
||||||
|
assert_eq!(body["error"]["code"], "invalid_api_key");
|
||||||
|
// Rejected before dispatch — neuron never saw the request.
|
||||||
|
assert!(seen.lock().unwrap().account_id.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn valid_key_reaches_neuron_with_principal_headers() {
|
||||||
|
let (neuron, seen) = spawn_capturing_neuron().await;
|
||||||
|
let gateway = spawn_gateway(&neuron, one_key_config(true)).await;
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{gateway}/v1/chat/completions"))
|
||||||
|
.bearer_auth("sk-good")
|
||||||
|
// A spoofed principal header must be stripped, not forwarded.
|
||||||
|
.header(HEADER_ACCOUNT_ID, "attacker")
|
||||||
|
.json(&chat_body())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::OK);
|
||||||
|
let s = seen.lock().unwrap();
|
||||||
|
assert_eq!(s.account_id.as_deref(), Some("acct-1"));
|
||||||
|
assert_eq!(s.key_id.as_deref(), Some("key-1"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn anonymous_allowed_when_auth_not_required() {
|
||||||
|
let (neuron, seen) = spawn_capturing_neuron().await;
|
||||||
|
let gateway = spawn_gateway(&neuron, EntitlementsConfig::default()).await;
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{gateway}/v1/chat/completions"))
|
||||||
|
.json(&chat_body())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::OK);
|
||||||
|
// No principal resolved → no principal headers stamped.
|
||||||
|
let s = seen.lock().unwrap();
|
||||||
|
assert!(s.account_id.is_none());
|
||||||
|
assert!(s.key_id.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn health_is_public_even_when_auth_required() {
|
||||||
|
let (neuron, _seen) = spawn_capturing_neuron().await;
|
||||||
|
let gateway = spawn_gateway(&neuron, one_key_config(true)).await;
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.get(format!("{gateway}/health"))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::OK);
|
||||||
|
}
|
||||||
253
crates/cortex-gateway/tests/budget_enforcement.rs
Normal file
253
crates/cortex-gateway/tests/budget_enforcement.rs
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
//! Integration tests for budget enforcement (#52) — the A0 seatbelt.
|
||||||
|
//!
|
||||||
|
//! A reservation over the key's hard cap is refused *before* neuron is hit,
|
||||||
|
//! with the #63 code matching the cap-window semantics (rate_limit_exceeded
|
||||||
|
//! + Retry-After for a resetting window, insufficient_quota for a hard
|
||||||
|
//! balance). Spend never exceeds the cap. No 402, ever.
|
||||||
|
|
||||||
|
use axum::Json;
|
||||||
|
use axum::extract::Path;
|
||||||
|
use axum::routing::{get, post};
|
||||||
|
use cortex_core::config::{
|
||||||
|
ApiKeyConfig, EntitlementsConfig, EvictionSettings, EvictionStrategy, GatewayConfig,
|
||||||
|
GatewaySettings, NeuronEndpoint,
|
||||||
|
};
|
||||||
|
use cortex_core::entitlements::{CapWindow, Principal};
|
||||||
|
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||||
|
use cortex_gateway::state::CortexState;
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
/// Mock neuron with a hit counter on the inference path, so a test can prove
|
||||||
|
/// a request was (or wasn't) dispatched.
|
||||||
|
async fn spawn_counting_neuron() -> (String, Arc<AtomicU64>) {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let base_url = format!("http://{addr}");
|
||||||
|
let inference_url = base_url.clone();
|
||||||
|
let hits = Arc::new(AtomicU64::new(0));
|
||||||
|
let sink = Arc::clone(&hits);
|
||||||
|
|
||||||
|
let app = axum::Router::new()
|
||||||
|
.route(
|
||||||
|
"/models/{model_id}/endpoint",
|
||||||
|
get(move |Path(_): Path<String>| {
|
||||||
|
let url = inference_url.clone();
|
||||||
|
async move { Json(json!({ "url": url })) }
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
post(move |Json(body): Json<Value>| {
|
||||||
|
let sink = Arc::clone(&sink);
|
||||||
|
async move {
|
||||||
|
sink.fetch_add(1, Ordering::SeqCst);
|
||||||
|
let model = body.get("model").and_then(Value::as_str).unwrap_or("m");
|
||||||
|
Json(json!({
|
||||||
|
"id": "chatcmpl-budget",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1700000000_u64,
|
||||||
|
"model": model,
|
||||||
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "stop"}],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
(base_url, hits)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn spawn_gateway(neuron_url: &str, key: ApiKeyConfig) -> (Arc<CortexState>, String) {
|
||||||
|
let config = GatewayConfig {
|
||||||
|
gateway: GatewaySettings {
|
||||||
|
listen: "127.0.0.1:0".into(),
|
||||||
|
metrics_listen: "127.0.0.1:0".into(),
|
||||||
|
},
|
||||||
|
eviction: EvictionSettings {
|
||||||
|
strategy: EvictionStrategy::Lru,
|
||||||
|
defrag_after_cycles: 0,
|
||||||
|
},
|
||||||
|
neurons: vec![NeuronEndpoint {
|
||||||
|
name: "mock-node".into(),
|
||||||
|
endpoint: neuron_url.to_string(),
|
||||||
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: EntitlementsConfig {
|
||||||
|
require_auth: true,
|
||||||
|
keys: vec![key],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
{
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
let node = nodes.get_mut("mock-node").unwrap();
|
||||||
|
node.healthy = true;
|
||||||
|
node.models.insert(
|
||||||
|
"test-model".into(),
|
||||||
|
ModelEntry {
|
||||||
|
id: "test-model".into(),
|
||||||
|
status: ModelStatus::Loaded,
|
||||||
|
last_accessed: None,
|
||||||
|
vram_estimate_mb: Some(8000),
|
||||||
|
capabilities: Vec::new(),
|
||||||
|
tool_call: false,
|
||||||
|
reasoning: false,
|
||||||
|
limit: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
(fleet, format!("http://{addr}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn key(window: CapWindow, hard_cap: u64) -> ApiKeyConfig {
|
||||||
|
ApiKeyConfig {
|
||||||
|
key: "sk-cap".into(),
|
||||||
|
account_id: "acct-cap".into(),
|
||||||
|
key_id: Some("key-cap".into()),
|
||||||
|
hard_cap: Some(hard_cap),
|
||||||
|
window,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn chat(max_tokens: u64) -> Value {
|
||||||
|
json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"messages": [{"role": "user", "content": "hi"}]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn balance_over_cap_is_429_insufficient_quota_before_dispatch() {
|
||||||
|
let (neuron, hits) = spawn_counting_neuron().await;
|
||||||
|
// Cap far below a single request's reservation (max_tokens 1000).
|
||||||
|
let (_fleet, gateway) = spawn_gateway(&neuron, key(CapWindow::Balance, 10)).await;
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{gateway}/v1/chat/completions"))
|
||||||
|
.bearer_auth("sk-cap")
|
||||||
|
.json(&chat(1000))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::TOO_MANY_REQUESTS);
|
||||||
|
// Hard balance → no Retry-After.
|
||||||
|
assert!(resp.headers().get(reqwest::header::RETRY_AFTER).is_none());
|
||||||
|
let body: Value = resp.json().await.unwrap();
|
||||||
|
assert_eq!(body["error"]["code"], "insufficient_quota");
|
||||||
|
// Refused before dispatch — neuron never saw it.
|
||||||
|
assert_eq!(hits.load(Ordering::SeqCst), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn rolling_over_cap_is_429_rate_limited_with_retry_after() {
|
||||||
|
let (neuron, hits) = spawn_counting_neuron().await;
|
||||||
|
let (_fleet, gateway) =
|
||||||
|
spawn_gateway(&neuron, key(CapWindow::Rolling { seconds: 3600 }, 10)).await;
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{gateway}/v1/chat/completions"))
|
||||||
|
.bearer_auth("sk-cap")
|
||||||
|
.json(&chat(1000))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::TOO_MANY_REQUESTS);
|
||||||
|
let retry = resp
|
||||||
|
.headers()
|
||||||
|
.get(reqwest::header::RETRY_AFTER)
|
||||||
|
.expect("rolling-window rejection must carry Retry-After");
|
||||||
|
assert!(retry.to_str().unwrap().parse::<u64>().unwrap() >= 1);
|
||||||
|
let body: Value = resp.json().await.unwrap();
|
||||||
|
assert_eq!(body["error"]["code"], "rate_limit_exceeded");
|
||||||
|
assert_eq!(hits.load(Ordering::SeqCst), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn within_cap_is_served() {
|
||||||
|
let (neuron, hits) = spawn_counting_neuron().await;
|
||||||
|
let (_fleet, gateway) = spawn_gateway(&neuron, key(CapWindow::Balance, 1_000_000)).await;
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{gateway}/v1/chat/completions"))
|
||||||
|
.bearer_auth("sk-cap")
|
||||||
|
.json(&chat(50))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::OK);
|
||||||
|
let _ = resp.bytes().await.unwrap();
|
||||||
|
assert_eq!(hits.load(Ordering::SeqCst), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn a0_seatbelt_caps_a_runaway_fan_out() {
|
||||||
|
// An Agent-Zero-style key with a modest cap: a burst of requests drains
|
||||||
|
// it, then further requests are refused — the account stops draining and
|
||||||
|
// spend never exceeds the cap.
|
||||||
|
let (neuron, hits) = spawn_counting_neuron().await;
|
||||||
|
let (fleet, gateway) = spawn_gateway(&neuron, key(CapWindow::Balance, 100)).await;
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
let mut ok = 0;
|
||||||
|
let mut refused = 0;
|
||||||
|
for _ in 0..20 {
|
||||||
|
let resp = client
|
||||||
|
.post(format!("{gateway}/v1/chat/completions"))
|
||||||
|
.bearer_auth("sk-cap")
|
||||||
|
.json(&chat(20))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
match resp.status() {
|
||||||
|
reqwest::StatusCode::OK => {
|
||||||
|
ok += 1;
|
||||||
|
let _ = resp.bytes().await.unwrap();
|
||||||
|
}
|
||||||
|
reqwest::StatusCode::TOO_MANY_REQUESTS => {
|
||||||
|
refused += 1;
|
||||||
|
let body: Value = resp.json().await.unwrap();
|
||||||
|
assert_eq!(body["error"]["code"], "insufficient_quota");
|
||||||
|
}
|
||||||
|
other => panic!("unexpected status {other}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(ok >= 1, "some requests should be served");
|
||||||
|
assert!(refused >= 1, "the cap must eventually refuse the fan-out");
|
||||||
|
assert_eq!(
|
||||||
|
hits.load(Ordering::SeqCst),
|
||||||
|
ok,
|
||||||
|
"refused requests never dispatched"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Spend never exceeded the hard cap (reservation prevents overshoot).
|
||||||
|
// Poll briefly for in-flight settles to land.
|
||||||
|
let principal = Principal {
|
||||||
|
account_id: "acct-cap".into(),
|
||||||
|
key_id: "key-cap".into(),
|
||||||
|
};
|
||||||
|
for _ in 0..50 {
|
||||||
|
let snap = fleet.entitlements.snapshot(&principal).await.unwrap();
|
||||||
|
if snap.reserved == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
|
||||||
|
}
|
||||||
|
let snap = fleet.entitlements.snapshot(&principal).await.unwrap();
|
||||||
|
assert!(snap.spent <= 100, "spent {} exceeded cap", snap.spent);
|
||||||
|
}
|
||||||
@@ -429,6 +429,7 @@ pub async fn spawn_gateway_with_state(mock_url: &str) -> (Arc<CortexState>, Stri
|
|||||||
endpoint: mock_url.to_string(),
|
endpoint: mock_url.to_string(),
|
||||||
}],
|
}],
|
||||||
models_config: "/dev/null".into(),
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ async fn error_response_no_healthy_nodes() {
|
|||||||
endpoint: "http://127.0.0.1:1".into(),
|
endpoint: "http://127.0.0.1:1".into(),
|
||||||
}],
|
}],
|
||||||
models_config: "/dev/null".into(),
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(cortex_gateway::state::CortexState::from_config(&config));
|
let fleet = Arc::new(cortex_gateway::state::CortexState::from_config(&config));
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ fn make_fleet(endpoint: &str, defrag_after: u32) -> Arc<CortexState> {
|
|||||||
endpoint: endpoint.to_string(),
|
endpoint: endpoint.to_string(),
|
||||||
}],
|
}],
|
||||||
models_config: "/dev/null".into(),
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
Arc::new(CortexState::from_config(&config))
|
Arc::new(CortexState::from_config(&config))
|
||||||
}
|
}
|
||||||
|
|||||||
207
crates/cortex-gateway/tests/metering.rs
Normal file
207
crates/cortex-gateway/tests/metering.rs
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
//! Integration tests for per-request token metering (#51).
|
||||||
|
//!
|
||||||
|
//! Drives authenticated requests through the gateway to a mock neuron that
|
||||||
|
//! reports a fixed `usage` object, then asserts the EntitlementProvider's
|
||||||
|
//! spend ledger reflects cumulative per-key spend and that reservations
|
||||||
|
//! settle to actual (no outstanding reserved tokens once requests complete).
|
||||||
|
|
||||||
|
mod common;
|
||||||
|
|
||||||
|
use cortex_core::config::{
|
||||||
|
ApiKeyConfig, EntitlementsConfig, EvictionSettings, EvictionStrategy, GatewayConfig,
|
||||||
|
GatewaySettings, NeuronEndpoint,
|
||||||
|
};
|
||||||
|
use cortex_core::entitlements::{CapWindow, Principal};
|
||||||
|
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||||
|
use cortex_gateway::state::CortexState;
|
||||||
|
use serde_json::json;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
const ACCOUNT: &str = "acct-meter";
|
||||||
|
const KEY_ID: &str = "key-meter";
|
||||||
|
const BEARER: &str = "sk-meter";
|
||||||
|
|
||||||
|
/// The mock neuron (common::spawn_mock_neuron) reports this fixed usage on
|
||||||
|
/// every chat completion.
|
||||||
|
const PROMPT_PER_REQ: u64 = 10;
|
||||||
|
const COMPLETION_PER_REQ: u64 = 5;
|
||||||
|
|
||||||
|
async fn spawn_metered_gateway(neuron_url: &str) -> (Arc<CortexState>, String) {
|
||||||
|
let config = GatewayConfig {
|
||||||
|
gateway: GatewaySettings {
|
||||||
|
listen: "127.0.0.1:0".into(),
|
||||||
|
metrics_listen: "127.0.0.1:0".into(),
|
||||||
|
},
|
||||||
|
eviction: EvictionSettings {
|
||||||
|
strategy: EvictionStrategy::Lru,
|
||||||
|
defrag_after_cycles: 0,
|
||||||
|
},
|
||||||
|
neurons: vec![NeuronEndpoint {
|
||||||
|
name: "mock-node".into(),
|
||||||
|
endpoint: neuron_url.to_string(),
|
||||||
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: EntitlementsConfig {
|
||||||
|
require_auth: true,
|
||||||
|
keys: vec![ApiKeyConfig {
|
||||||
|
key: BEARER.into(),
|
||||||
|
account_id: ACCOUNT.into(),
|
||||||
|
key_id: Some(KEY_ID.into()),
|
||||||
|
hard_cap: Some(1_000_000),
|
||||||
|
window: CapWindow::Balance,
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
{
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
let node = nodes.get_mut("mock-node").unwrap();
|
||||||
|
node.healthy = true;
|
||||||
|
node.models.insert(
|
||||||
|
"test-model".into(),
|
||||||
|
ModelEntry {
|
||||||
|
id: "test-model".into(),
|
||||||
|
status: ModelStatus::Loaded,
|
||||||
|
last_accessed: None,
|
||||||
|
vram_estimate_mb: Some(8000),
|
||||||
|
capabilities: Vec::new(),
|
||||||
|
tool_call: false,
|
||||||
|
reasoning: false,
|
||||||
|
limit: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
(fleet, format!("http://{addr}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn principal() -> Principal {
|
||||||
|
Principal {
|
||||||
|
account_id: ACCOUNT.into(),
|
||||||
|
key_id: KEY_ID.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Poll the provider ledger until settled spend reaches `expected` (settle
|
||||||
|
/// runs in a spawned task after the response stream finishes) or time out.
|
||||||
|
async fn await_spent(fleet: &CortexState, expected: u64) -> u64 {
|
||||||
|
let principal = principal();
|
||||||
|
for _ in 0..100 {
|
||||||
|
let snap = fleet.entitlements.snapshot(&principal).await.unwrap();
|
||||||
|
if snap.spent >= expected {
|
||||||
|
return snap.spent;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||||
|
}
|
||||||
|
fleet.entitlements.snapshot(&principal).await.unwrap().spent
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn cumulative_spend_is_metered_per_key() {
|
||||||
|
let neuron = common::spawn_mock_neuron().await;
|
||||||
|
let (fleet, gateway) = spawn_metered_gateway(&neuron).await;
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
const N: u64 = 3;
|
||||||
|
for _ in 0..N {
|
||||||
|
let resp = client
|
||||||
|
.post(format!("{gateway}/v1/chat/completions"))
|
||||||
|
.bearer_auth(BEARER)
|
||||||
|
.json(&json!({"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::OK);
|
||||||
|
// Drain the body so the response stream finishes and metering settles.
|
||||||
|
let _ = resp.bytes().await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let expected = N * (PROMPT_PER_REQ + COMPLETION_PER_REQ);
|
||||||
|
let spent = await_spent(&fleet, expected).await;
|
||||||
|
assert_eq!(
|
||||||
|
spent, expected,
|
||||||
|
"ledger must reflect cumulative per-key spend"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Reservations settled to actual — nothing left outstanding.
|
||||||
|
let snap = fleet.entitlements.snapshot(&principal()).await.unwrap();
|
||||||
|
assert_eq!(snap.reserved, 0, "all reservations must settle/release");
|
||||||
|
assert_eq!(snap.hard_cap, Some(1_000_000));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn anonymous_request_records_no_spend() {
|
||||||
|
// require_auth=false so the unauthenticated request is served, but with
|
||||||
|
// no principal it must not touch any ledger.
|
||||||
|
let neuron = common::spawn_mock_neuron().await;
|
||||||
|
let config = GatewayConfig {
|
||||||
|
gateway: GatewaySettings {
|
||||||
|
listen: "127.0.0.1:0".into(),
|
||||||
|
metrics_listen: "127.0.0.1:0".into(),
|
||||||
|
},
|
||||||
|
eviction: EvictionSettings {
|
||||||
|
strategy: EvictionStrategy::Lru,
|
||||||
|
defrag_after_cycles: 0,
|
||||||
|
},
|
||||||
|
neurons: vec![NeuronEndpoint {
|
||||||
|
name: "mock-node".into(),
|
||||||
|
endpoint: neuron.clone(),
|
||||||
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: EntitlementsConfig::default(),
|
||||||
|
};
|
||||||
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
{
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
let node = nodes.get_mut("mock-node").unwrap();
|
||||||
|
node.healthy = true;
|
||||||
|
node.models.insert(
|
||||||
|
"test-model".into(),
|
||||||
|
ModelEntry {
|
||||||
|
id: "test-model".into(),
|
||||||
|
status: ModelStatus::Loaded,
|
||||||
|
last_accessed: None,
|
||||||
|
vram_estimate_mb: Some(8000),
|
||||||
|
capabilities: Vec::new(),
|
||||||
|
tool_call: false,
|
||||||
|
reasoning: false,
|
||||||
|
limit: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("http://{addr}/v1/chat/completions"))
|
||||||
|
.json(&json!({"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::OK);
|
||||||
|
let _ = resp.bytes().await.unwrap();
|
||||||
|
|
||||||
|
// An unconfigured principal has a zeroed snapshot — nothing was metered.
|
||||||
|
let snap = fleet
|
||||||
|
.entitlements
|
||||||
|
.snapshot(&Principal {
|
||||||
|
account_id: "nobody".into(),
|
||||||
|
key_id: "nobody".into(),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(snap.spent, 0);
|
||||||
|
}
|
||||||
@@ -54,6 +54,7 @@ capabilities = ["text"]
|
|||||||
endpoint: "http://127.0.0.1:1".into(),
|
endpoint: "http://127.0.0.1:1".into(),
|
||||||
}],
|
}],
|
||||||
models_config: cat_path.to_string_lossy().into_owned(),
|
models_config: cat_path.to_string_lossy().into_owned(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ async fn test_poller_discovers_models() {
|
|||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
}],
|
}],
|
||||||
models_config: "/dev/null".into(),
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
@@ -82,6 +83,7 @@ async fn test_poller_updates_gateway_models_endpoint() {
|
|||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
}],
|
}],
|
||||||
models_config: "/dev/null".into(),
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
@@ -153,6 +155,7 @@ async fn test_models_endpoint_unions_capabilities_across_nodes() {
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
models_config: "/dev/null".into(),
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
@@ -215,6 +218,7 @@ async fn test_poller_marks_unreachable_node_unhealthy() {
|
|||||||
endpoint: "http://127.0.0.1:1".into(),
|
endpoint: "http://127.0.0.1:1".into(),
|
||||||
}],
|
}],
|
||||||
models_config: "/dev/null".into(),
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
@@ -252,6 +256,7 @@ async fn test_poller_removes_stale_models() {
|
|||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
}],
|
}],
|
||||||
models_config: "/dev/null".into(),
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
@@ -282,6 +287,7 @@ async fn test_poller_removes_stale_models() {
|
|||||||
endpoint: new_mock_url,
|
endpoint: new_mock_url,
|
||||||
}],
|
}],
|
||||||
models_config: "/dev/null".into(),
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet2 = Arc::new(CortexState::from_config(&config2));
|
let fleet2 = Arc::new(CortexState::from_config(&config2));
|
||||||
@@ -363,6 +369,7 @@ async fn test_poller_captures_activation_from_health() {
|
|||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
}],
|
}],
|
||||||
models_config: "/dev/null".into(),
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
@@ -407,6 +414,7 @@ async fn test_poller_parses_recovering_status() {
|
|||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
}],
|
}],
|
||||||
models_config: "/dev/null".into(),
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ async fn test_no_healthy_nodes() {
|
|||||||
endpoint: "http://127.0.0.1:1".into(),
|
endpoint: "http://127.0.0.1:1".into(),
|
||||||
}],
|
}],
|
||||||
models_config: "/dev/null".into(),
|
models_config: "/dev/null".into(),
|
||||||
|
entitlements: Default::default(),
|
||||||
};
|
};
|
||||||
let fleet = std::sync::Arc::new(cortex_gateway::state::CortexState::from_config(&config));
|
let fleet = std::sync::Arc::new(cortex_gateway::state::CortexState::from_config(&config));
|
||||||
|
|
||||||
|
|||||||
@@ -486,6 +486,15 @@ fn inference_error_response(err: InferenceError) -> axum::response::Response {
|
|||||||
"template_render_failed",
|
"template_render_failed",
|
||||||
format!("chat template could not render this request: {detail}"),
|
format!("chat template could not render this request: {detail}"),
|
||||||
),
|
),
|
||||||
|
// Admission control refused (#53): a fast, retryable "busy" signal.
|
||||||
|
// 503 (service busy) + Retry-After; opencode/AI SDK back off.
|
||||||
|
InferenceError::Overloaded { retry_after_secs } => OpenAiError::new(
|
||||||
|
503,
|
||||||
|
"rate_limit_error",
|
||||||
|
"rate_limit_exceeded",
|
||||||
|
"model is busy (admission queue full); retry shortly",
|
||||||
|
)
|
||||||
|
.with_retry_after(retry_after_secs),
|
||||||
InferenceError::Other(e) => OpenAiError::without_code(500, "api_error", format!("{e:#}")),
|
InferenceError::Other(e) => OpenAiError::without_code(500, "api_error", format!("{e:#}")),
|
||||||
};
|
};
|
||||||
envelope_response(env)
|
envelope_response(env)
|
||||||
@@ -660,6 +669,26 @@ mod error_envelope_tests {
|
|||||||
assert_eq!(error["required_mb"], 8_192);
|
assert_eq!(error["required_mb"], 8_192);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn overloaded_is_503_rate_limited_with_retry_after() {
|
||||||
|
// Admission rejection (#53) → fast, retryable backpressure.
|
||||||
|
let resp = inference_error_response(InferenceError::Overloaded {
|
||||||
|
retry_after_secs: 7,
|
||||||
|
});
|
||||||
|
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||||
|
let retry = resp
|
||||||
|
.headers()
|
||||||
|
.get(axum::http::header::RETRY_AFTER)
|
||||||
|
.expect("admission rejection must advertise Retry-After");
|
||||||
|
assert_eq!(retry.to_str().unwrap(), "7");
|
||||||
|
|
||||||
|
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let body: Value = serde_json::from_slice(&bytes).unwrap();
|
||||||
|
assert_eq!(body["error"]["code"], "rate_limit_exceeded");
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn insufficient_vram_carries_retry_after() {
|
async fn insufficient_vram_carries_retry_after() {
|
||||||
// Transient 503 — VRAM frees as in-flight requests finish, so the
|
// Transient 503 — VRAM frees as in-flight requests finish, so the
|
||||||
|
|||||||
@@ -85,6 +85,56 @@ pub struct CandleHarnessConfig {
|
|||||||
/// `/models`, and enforces it. These knobs tune that derivation.
|
/// `/models`, and enforces it. These knobs tune that derivation.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub context_limit: ContextLimitConfig,
|
pub context_limit: ContextLimitConfig,
|
||||||
|
|
||||||
|
/// Admission control (#53): bounds the per-model wait queue so a busy
|
||||||
|
/// model returns a fast, retryable `429`/`503` instead of stalling new
|
||||||
|
/// requests until their client times out.
|
||||||
|
#[serde(default)]
|
||||||
|
pub admission: AdmissionConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `[harness.candle.admission]` settings (#53).
|
||||||
|
///
|
||||||
|
/// Inference is batch-1, so `max_in_flight` is 1 in practice; the queue
|
||||||
|
/// (`max_queue_depth`) absorbs short bursts, and `max_wait_secs` caps how
|
||||||
|
/// long a queued request waits before it's refused with backpressure.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct AdmissionConfig {
|
||||||
|
/// Concurrent running requests per model. Batch-1 inference → 1.
|
||||||
|
#[serde(default = "default_admission_max_in_flight")]
|
||||||
|
pub max_in_flight: usize,
|
||||||
|
/// Queued (waiting) requests allowed beyond the in-flight one. The
|
||||||
|
/// `(max_in_flight + max_queue_depth + 1)`-th request is refused
|
||||||
|
/// immediately with `429`/`503` + `Retry-After`.
|
||||||
|
#[serde(default = "default_admission_max_queue_depth")]
|
||||||
|
pub max_queue_depth: usize,
|
||||||
|
/// Maximum seconds a queued request waits for the in-flight slot before
|
||||||
|
/// it is refused (turns the old ~300s client-side hang into a fast,
|
||||||
|
/// honest signal).
|
||||||
|
#[serde(default = "default_admission_max_wait_secs")]
|
||||||
|
pub max_wait_secs: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AdmissionConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_in_flight: default_admission_max_in_flight(),
|
||||||
|
max_queue_depth: default_admission_max_queue_depth(),
|
||||||
|
max_wait_secs: default_admission_max_wait_secs(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_admission_max_in_flight() -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_admission_max_queue_depth() -> usize {
|
||||||
|
8
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_admission_max_wait_secs() -> u64 {
|
||||||
|
30
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `[harness.candle.prefix_cache]` settings.
|
/// `[harness.candle.prefix_cache]` settings.
|
||||||
|
|||||||
202
crates/neuron/src/harness/admission.rs
Normal file
202
crates/neuron/src/harness/admission.rs
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
//! Per-model admission control (#53).
|
||||||
|
//!
|
||||||
|
//! Inference against a loaded model is batch-1: one request runs at a time,
|
||||||
|
//! serialized by the model's `inference_lock` (single-GPU) / `pool` mutex
|
||||||
|
//! (TP). Before this, the wait for that lock was an **unbounded FIFO of
|
||||||
|
//! mutex waiters with no timeout** — a busy model made every new request
|
||||||
|
//! hang until its client gave up (~300s) with an opaque error.
|
||||||
|
//!
|
||||||
|
//! [`AdmissionController`] replaces that implicit unbounded wait with an
|
||||||
|
//! explicit bounded scheduler: at most `max_in_flight` running (1, batch-1)
|
||||||
|
//! plus a bounded queue of `max_queue_depth` waiters, each waiting at most
|
||||||
|
//! `max_wait`. When the queue is full or the wait elapses, the request is
|
||||||
|
//! rejected *immediately* — an honest, fast, retryable "busy" signal
|
||||||
|
//! (`429`/`503` + `Retry-After` per #63) instead of a silent stall.
|
||||||
|
//!
|
||||||
|
//! The controller is pure async (no CUDA), so the inference paths just call
|
||||||
|
//! [`AdmissionController::enter`] before taking the inference lock and hold
|
||||||
|
//! the returned [`AdmissionPermit`] for the request's lifetime. Its counters
|
||||||
|
//! ([`in_flight`](AdmissionController::in_flight) /
|
||||||
|
//! [`queue_depth`](AdmissionController::queue_depth)) are lock-free, so
|
||||||
|
//! `/health` can read live load without contending with inference.
|
||||||
|
|
||||||
|
use crate::config::AdmissionConfig;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||||
|
|
||||||
|
/// Why admission was refused. Both map to the #63 backpressure envelope
|
||||||
|
/// (`429`/`503` + `rate_limit_exceeded` + `Retry-After`); they differ only
|
||||||
|
/// in cause, for logging.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum AdmissionRejection {
|
||||||
|
/// The bounded wait queue was already full.
|
||||||
|
QueueFull { retry_after_secs: u64 },
|
||||||
|
/// A queue slot was taken but the in-flight slot didn't free within
|
||||||
|
/// `max_wait`.
|
||||||
|
Timeout { retry_after_secs: u64 },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AdmissionRejection {
|
||||||
|
pub fn retry_after_secs(&self) -> u64 {
|
||||||
|
match self {
|
||||||
|
AdmissionRejection::QueueFull { retry_after_secs }
|
||||||
|
| AdmissionRejection::Timeout { retry_after_secs } => *retry_after_secs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Bounded batch-1 scheduler for one loaded model.
|
||||||
|
pub struct AdmissionController {
|
||||||
|
/// In-flight slots — `max_in_flight` permits (1 for batch-1).
|
||||||
|
slots: Arc<Semaphore>,
|
||||||
|
/// Queued + in-flight count, for fast rejection and load reporting.
|
||||||
|
pending: Arc<AtomicUsize>,
|
||||||
|
/// `max_in_flight + max_queue_depth` — the rejection threshold.
|
||||||
|
max_pending: usize,
|
||||||
|
max_in_flight: usize,
|
||||||
|
max_wait: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AdmissionController {
|
||||||
|
pub fn new(cfg: &AdmissionConfig) -> Self {
|
||||||
|
// A controller with zero in-flight slots would deadlock; clamp.
|
||||||
|
let max_in_flight = cfg.max_in_flight.max(1);
|
||||||
|
Self {
|
||||||
|
slots: Arc::new(Semaphore::new(max_in_flight)),
|
||||||
|
pending: Arc::new(AtomicUsize::new(0)),
|
||||||
|
max_pending: max_in_flight + cfg.max_queue_depth,
|
||||||
|
max_in_flight,
|
||||||
|
max_wait: Duration::from_secs(cfg.max_wait_secs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Admit a request: reserve a queue slot (fast-rejecting if full), then
|
||||||
|
/// wait up to `max_wait` for an in-flight slot. The returned permit must
|
||||||
|
/// be held for the request's lifetime; dropping it frees both slots.
|
||||||
|
pub async fn enter(&self) -> Result<AdmissionPermit, AdmissionRejection> {
|
||||||
|
// Reserve a pending slot up front so concurrent callers can't all
|
||||||
|
// slip past the threshold check. Roll back if we're over capacity.
|
||||||
|
let prev = self.pending.fetch_add(1, Ordering::AcqRel);
|
||||||
|
if prev >= self.max_pending {
|
||||||
|
self.pending.fetch_sub(1, Ordering::AcqRel);
|
||||||
|
return Err(AdmissionRejection::QueueFull {
|
||||||
|
retry_after_secs: self.retry_hint(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
match tokio::time::timeout(self.max_wait, Arc::clone(&self.slots).acquire_owned()).await {
|
||||||
|
Ok(Ok(permit)) => Ok(AdmissionPermit {
|
||||||
|
_permit: permit,
|
||||||
|
pending: Arc::clone(&self.pending),
|
||||||
|
}),
|
||||||
|
// Semaphore is never closed; treat a closed/elapsed wait the same.
|
||||||
|
Ok(Err(_)) | Err(_) => {
|
||||||
|
self.pending.fetch_sub(1, Ordering::AcqRel);
|
||||||
|
Err(AdmissionRejection::Timeout {
|
||||||
|
retry_after_secs: self.retry_hint(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Requests currently running (holding an in-flight slot).
|
||||||
|
pub fn in_flight(&self) -> usize {
|
||||||
|
self.max_in_flight
|
||||||
|
.saturating_sub(self.slots.available_permits())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Requests waiting for an in-flight slot.
|
||||||
|
pub fn queue_depth(&self) -> usize {
|
||||||
|
self.pending
|
||||||
|
.load(Ordering::Acquire)
|
||||||
|
.saturating_sub(self.in_flight())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rough `Retry-After`: scale with how backed-up the model is, clamped to
|
||||||
|
/// a sane band. Without per-request timing this is a heuristic, but it
|
||||||
|
/// gives well-behaved clients (opencode/AI SDK) a sensible backoff.
|
||||||
|
fn retry_hint(&self) -> u64 {
|
||||||
|
((self.queue_depth() as u64 + 1) * 2).clamp(1, 120)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Held for a request's lifetime; frees the in-flight + queue slot on drop.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AdmissionPermit {
|
||||||
|
_permit: OwnedSemaphorePermit,
|
||||||
|
pending: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for AdmissionPermit {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.pending.fetch_sub(1, Ordering::AcqRel);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn cfg(max_in_flight: usize, max_queue_depth: usize, max_wait_secs: u64) -> AdmissionConfig {
|
||||||
|
AdmissionConfig {
|
||||||
|
max_in_flight,
|
||||||
|
max_queue_depth,
|
||||||
|
max_wait_secs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn admits_up_to_in_flight_and_reports_load() {
|
||||||
|
let ctrl = AdmissionController::new(&cfg(1, 4, 30));
|
||||||
|
assert_eq!(ctrl.in_flight(), 0);
|
||||||
|
let p = ctrl.enter().await.expect("first admits");
|
||||||
|
assert_eq!(ctrl.in_flight(), 1);
|
||||||
|
assert_eq!(ctrl.queue_depth(), 0);
|
||||||
|
drop(p);
|
||||||
|
assert_eq!(ctrl.in_flight(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn rejects_when_queue_full() {
|
||||||
|
// 1 in-flight + 1 queue slot = capacity 2; the 3rd is refused fast.
|
||||||
|
let ctrl = Arc::new(AdmissionController::new(&cfg(1, 1, 30)));
|
||||||
|
let _running = ctrl.enter().await.expect("admit running");
|
||||||
|
|
||||||
|
// Fill the single queue slot with a waiter that parks on the semaphore.
|
||||||
|
let ctrl2 = Arc::clone(&ctrl);
|
||||||
|
let waiter = tokio::spawn(async move { ctrl2.enter().await.map(|p| drop(p)) });
|
||||||
|
// Give the waiter a moment to occupy the queue slot.
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
assert_eq!(ctrl.queue_depth(), 1);
|
||||||
|
|
||||||
|
// Queue full → immediate QueueFull with a Retry-After hint.
|
||||||
|
match ctrl.enter().await {
|
||||||
|
Err(AdmissionRejection::QueueFull { retry_after_secs }) => {
|
||||||
|
assert!(retry_after_secs >= 1)
|
||||||
|
}
|
||||||
|
other => panic!("expected QueueFull, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release the runner so the parked waiter can proceed and finish.
|
||||||
|
drop(_running);
|
||||||
|
waiter.await.unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn rejects_on_wait_timeout() {
|
||||||
|
// Zero queue depth + a runner holding the only slot → a second
|
||||||
|
// request can't even queue, so it's QueueFull, not Timeout. Use a
|
||||||
|
// queue of 1 and a tiny max_wait to exercise the timeout path.
|
||||||
|
let ctrl = Arc::new(AdmissionController::new(&cfg(1, 1, 0)));
|
||||||
|
let _running = ctrl.enter().await.expect("admit running");
|
||||||
|
// max_wait 0 → the queued request times out almost immediately.
|
||||||
|
match ctrl.enter().await {
|
||||||
|
Err(AdmissionRejection::Timeout { .. }) => {}
|
||||||
|
other => panic!("expected Timeout, got {other:?}"),
|
||||||
|
}
|
||||||
|
// The timed-out request released its queue slot.
|
||||||
|
assert_eq!(ctrl.queue_depth(), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -81,6 +81,9 @@ pub struct CandleHarness {
|
|||||||
/// Context-limit derivation settings (#67), read in `list_models`
|
/// Context-limit derivation settings (#67), read in `list_models`
|
||||||
/// to compute each model's advertised `limit{context,input,output}`.
|
/// to compute each model's advertised `limit{context,input,output}`.
|
||||||
context_limit_cfg: crate::config::ContextLimitConfig,
|
context_limit_cfg: crate::config::ContextLimitConfig,
|
||||||
|
/// Admission-control settings (#53), used to build each loaded model's
|
||||||
|
/// [`super::admission::AdmissionController`] at load time.
|
||||||
|
admission_cfg: crate::config::AdmissionConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Devices/capabilities snapshot of a model entering auto-recovery
|
/// Devices/capabilities snapshot of a model entering auto-recovery
|
||||||
@@ -305,6 +308,10 @@ pub struct LoadedModel {
|
|||||||
/// for the TP path (which already had this invariant by accident
|
/// for the TP path (which already had this invariant by accident
|
||||||
/// because the pool lock covered the same window).
|
/// because the pool lock covered the same window).
|
||||||
pub inference_lock: tokio::sync::Mutex<()>,
|
pub inference_lock: tokio::sync::Mutex<()>,
|
||||||
|
/// Bounded admission scheduler (#53). Gated *before* `inference_lock`
|
||||||
|
/// so a busy model refuses overflow fast instead of growing an
|
||||||
|
/// unbounded, untimed queue of lock waiters.
|
||||||
|
pub admission: super::admission::AdmissionController,
|
||||||
/// Open/close token IDs for the reasoning marker this model
|
/// Open/close token IDs for the reasoning marker this model
|
||||||
/// emits, populated once at load time by probing the tokenizer's
|
/// emits, populated once at load time by probing the tokenizer's
|
||||||
/// added-tokens table. `None` for non-reasoning models or
|
/// added-tokens table. `None` for non-reasoning models or
|
||||||
@@ -422,6 +429,10 @@ pub struct TpLoadedModel {
|
|||||||
/// serialises subprocess RPC traffic on the pool's
|
/// serialises subprocess RPC traffic on the pool's
|
||||||
/// `Vec<Worker>` channels.
|
/// `Vec<Worker>` channels.
|
||||||
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
|
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
|
||||||
|
/// Bounded admission scheduler (#53), mirroring the single-GPU path.
|
||||||
|
/// Gated before the pool lock so an overloaded TP model returns fast
|
||||||
|
/// backpressure instead of an unbounded, untimed wait.
|
||||||
|
pub admission: super::admission::AdmissionController,
|
||||||
/// Handle into the leader device worker's TP slab. The boxed
|
/// Handle into the leader device worker's TP slab. The boxed
|
||||||
/// `TpLeaderModel` (with its embedded `Arc<Comm>` clones and
|
/// `TpLeaderModel` (with its embedded `Arc<Comm>` clones and
|
||||||
/// per-rank CUDA tensors) lives on the worker thread; we hold an
|
/// per-rank CUDA tensors) lives on the worker thread; we hold an
|
||||||
@@ -1565,6 +1576,7 @@ impl CandleHarness {
|
|||||||
recovery_tx,
|
recovery_tx,
|
||||||
prefix_cache_cfg: config.prefix_cache.clone(),
|
prefix_cache_cfg: config.prefix_cache.clone(),
|
||||||
context_limit_cfg: config.context_limit.clone(),
|
context_limit_cfg: config.context_limit.clone(),
|
||||||
|
admission_cfg: config.admission.clone(),
|
||||||
});
|
});
|
||||||
// Background auto-recovery task (#17). Holds a `Weak` so it can't
|
// Background auto-recovery task (#17). Holds a `Weak` so it can't
|
||||||
// keep the harness alive. Spawned only when a tokio runtime is
|
// keep the harness alive. Spawned only when a tokio runtime is
|
||||||
@@ -2059,6 +2071,15 @@ impl CandleHarness {
|
|||||||
return Err(self.trigger_recovery(&model_id).await);
|
return Err(self.trigger_recovery(&model_id).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Admission control (#53): refuse fast if the bounded queue is full
|
||||||
|
// or the wait elapses, rather than joining an unbounded lock-wait.
|
||||||
|
// The permit is held for the whole request (released on drop).
|
||||||
|
let _admit = loaded
|
||||||
|
.admission
|
||||||
|
.enter()
|
||||||
|
.await
|
||||||
|
.map_err(InferenceError::from)?;
|
||||||
|
|
||||||
// Serialise concurrent requests against this model. Holds for
|
// Serialise concurrent requests against this model. Holds for
|
||||||
// the duration of clear_kv_cache → prefill → decode so two
|
// the duration of clear_kv_cache → prefill → decode so two
|
||||||
// requests' chunked-prefill sequences can't interleave on the
|
// requests' chunked-prefill sequences can't interleave on the
|
||||||
@@ -2610,6 +2631,15 @@ impl CandleHarness {
|
|||||||
// role chunk was already sent above, so the client sees
|
// role chunk was already sent above, so the client sees
|
||||||
// immediate "stream open" feedback even when this request
|
// immediate "stream open" feedback even when this request
|
||||||
// queues behind another for the lock.
|
// queues behind another for the lock.
|
||||||
|
// Admission control (#53): refuse before opening the stream if the
|
||||||
|
// model's bounded queue is full / the wait elapses. The permit moves
|
||||||
|
// into the inference task and is held until it completes.
|
||||||
|
let admit = loaded
|
||||||
|
.admission
|
||||||
|
.enter()
|
||||||
|
.await
|
||||||
|
.map_err(InferenceError::from)?;
|
||||||
|
|
||||||
let tool_schemas = build_tool_schemas(&request);
|
let tool_schemas = build_tool_schemas(&request);
|
||||||
if let (Some(worker), Some(handle)) = (loaded.worker.clone(), loaded.arch_handle) {
|
if let (Some(worker), Some(handle)) = (loaded.worker.clone(), loaded.arch_handle) {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
@@ -2620,6 +2650,7 @@ impl CandleHarness {
|
|||||||
let tool_schemas_inner = tool_schemas.clone();
|
let tool_schemas_inner = tool_schemas.clone();
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
async move {
|
async move {
|
||||||
|
let _admit = admit;
|
||||||
let _inference_guard = loaded_for_task.inference_lock.lock().await;
|
let _inference_guard = loaded_for_task.inference_lock.lock().await;
|
||||||
match stream_inference_via_worker(
|
match stream_inference_via_worker(
|
||||||
worker,
|
worker,
|
||||||
@@ -2680,6 +2711,7 @@ impl CandleHarness {
|
|||||||
let tool_call_tokens_inner = loaded.tool_call_tokens.clone();
|
let tool_call_tokens_inner = loaded.tool_call_tokens.clone();
|
||||||
let tool_schemas_inner = tool_schemas.clone();
|
let tool_schemas_inner = tool_schemas.clone();
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
|
let _admit = admit;
|
||||||
let _g = span_for_task.enter();
|
let _g = span_for_task.enter();
|
||||||
// `blocking_lock` is safe here: spawn_blocking runs on
|
// `blocking_lock` is safe here: spawn_blocking runs on
|
||||||
// a dedicated thread, not on the async runtime, so
|
// a dedicated thread, not on the async runtime, so
|
||||||
@@ -3128,6 +3160,7 @@ impl Harness for CandleHarness {
|
|||||||
worker,
|
worker,
|
||||||
arch_handle,
|
arch_handle,
|
||||||
inference_lock: tokio::sync::Mutex::new(()),
|
inference_lock: tokio::sync::Mutex::new(()),
|
||||||
|
admission: super::admission::AdmissionController::new(&self.admission_cfg),
|
||||||
reasoning_tokens,
|
reasoning_tokens,
|
||||||
tool_call_tokens,
|
tool_call_tokens,
|
||||||
chat_template,
|
chat_template,
|
||||||
@@ -3372,6 +3405,7 @@ impl CandleHarness {
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
devices: devices.clone(),
|
devices: devices.clone(),
|
||||||
pool: TMutex::new(pool),
|
pool: TMutex::new(pool),
|
||||||
|
admission: super::admission::AdmissionController::new(&self.admission_cfg),
|
||||||
leader_handle,
|
leader_handle,
|
||||||
leader_device: leader_device.clone(),
|
leader_device: leader_device.clone(),
|
||||||
poisoned: AtomicBool::new(false),
|
poisoned: AtomicBool::new(false),
|
||||||
@@ -3690,10 +3724,15 @@ impl CandleHarness {
|
|||||||
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Admission control (#53): refuse before opening the stream; the
|
||||||
|
// permit moves into the orchestration task and is held for its life.
|
||||||
|
let admit = tp.admission.enter().await.map_err(InferenceError::from)?;
|
||||||
|
|
||||||
let tool_schemas = build_tool_schemas(&request);
|
let tool_schemas = build_tool_schemas(&request);
|
||||||
let tp_for_task = Arc::clone(&tp);
|
let tp_for_task = Arc::clone(&tp);
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
async move {
|
async move {
|
||||||
|
let _admit = admit;
|
||||||
let mut failure: Option<String> = None;
|
let mut failure: Option<String> = None;
|
||||||
let mut pool = acquire_pool_lock(&tp_for_task.pool, &model_id).await;
|
let mut pool = acquire_pool_lock(&tp_for_task.pool, &model_id).await;
|
||||||
let leader_handle = tp_for_task.leader_handle;
|
let leader_handle = tp_for_task.leader_handle;
|
||||||
@@ -4284,6 +4323,10 @@ async fn chat_completion_tp_inner(
|
|||||||
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Admission control (#53): bounded queue + fast reject before joining
|
||||||
|
// the pool-lock wait. Held for the whole request (released on drop).
|
||||||
|
let _admit = tp.admission.enter().await.map_err(InferenceError::from)?;
|
||||||
|
|
||||||
// Acquire the pool lock for the duration of the request. After
|
// Acquire the pool lock for the duration of the request. After
|
||||||
// Phase 3 the leader's TpLeaderModel lives in the device worker
|
// Phase 3 the leader's TpLeaderModel lives in the device worker
|
||||||
// thread, so the pool lock now serialises only subprocess RPC
|
// thread, so the pool lock now serialises only subprocess RPC
|
||||||
@@ -4826,10 +4869,23 @@ pub enum InferenceError {
|
|||||||
/// failure mode that hid several client-compat bugs. Maps to 422.
|
/// failure mode that hid several client-compat bugs. Maps to 422.
|
||||||
#[error("chat template could not render this request: {detail}")]
|
#[error("chat template could not render this request: {detail}")]
|
||||||
TemplateRenderFailed { detail: String },
|
TemplateRenderFailed { detail: String },
|
||||||
|
/// Admission control (#53) refused the request: the model's bounded
|
||||||
|
/// queue is full or the wait elapsed. Maps to `429 rate_limit_exceeded`
|
||||||
|
/// + `Retry-After` — a fast, retryable "busy" signal, not a stall.
|
||||||
|
#[error("model is busy; retry after {retry_after_secs}s")]
|
||||||
|
Overloaded { retry_after_secs: u64 },
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Other(#[from] anyhow::Error),
|
Other(#[from] anyhow::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<super::admission::AdmissionRejection> for InferenceError {
|
||||||
|
fn from(rejection: super::admission::AdmissionRejection) -> Self {
|
||||||
|
InferenceError::Overloaded {
|
||||||
|
retry_after_secs: rejection.retry_after_secs(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Build the model's prompt from a [`ChatCompletionRequest`].
|
/// Build the model's prompt from a [`ChatCompletionRequest`].
|
||||||
///
|
///
|
||||||
/// Prefers the model's own `chat_template` when one was loaded
|
/// Prefers the model's own `chat_template` when one was loaded
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
//! Harness registry — maps harness names to trait implementations.
|
//! Harness registry — maps harness names to trait implementations.
|
||||||
|
|
||||||
|
pub mod admission;
|
||||||
pub mod arch;
|
pub mod arch;
|
||||||
pub mod candle;
|
pub mod candle;
|
||||||
pub mod chat_template;
|
pub mod chat_template;
|
||||||
|
|||||||
Reference in New Issue
Block a user