refactor: cortex talks to neurons instead of mistral.rs directly
Replace NodeConfig (static vram_mb, pinned) with NeuronEndpoint.
Hardware discovery and model pinning now come from neuron API and
models.toml catalogue respectively.
- config.rs: nodes -> neurons, add models_config path
- catalogue.rs: ModelProfile with pinned_on, ModelCatalogue
- poller.rs: poll neuron GET /models (ModelInfo format)
- router.rs: resolve inference endpoint via neuron GET /models/{id}/endpoint
- evictor.rs: call neuron POST /models/unload
- node.rs: remove vram_mb, pinned fields (come from discovery/catalogue)
- All 22 gateway tests updated to mock neuron API
- Remove MistralModelsResponse, ModelLifecycleRequest (no longer needed)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
65
CLAUDE.md
65
CLAUDE.md
@@ -568,56 +568,27 @@ inference_endpoint, health, start/stop (systemd). `HarnessRegistry` in
|
|||||||
Config via `neuron.toml` (figment + env override). Integration test
|
Config via `neuron.toml` (figment + env override). Integration test
|
||||||
covers full model lifecycle through neuron → mock mistral.rs backend.
|
covers full model lifecycle through neuron → mock mistral.rs backend.
|
||||||
|
|
||||||
### Phase 9: cortex talks to neurons
|
### Phase 9: cortex talks to neurons ✅
|
||||||
|
|
||||||
**Goal:** cortex-gateway's poller, router, and evictor talk to neuron
|
Completed. Full refactor of cortex-gateway to talk to neurons:
|
||||||
instead of directly to mistral.rs. Discovery replaces static config.
|
|
||||||
|
|
||||||
**Steps:**
|
- **Config**: `NodeConfig { endpoint, vram_mb, pinned }` replaced with
|
||||||
1. Update `cortex-core/src/config.rs`:
|
`NeuronEndpoint { name, endpoint }`. Hardware info comes from neuron
|
||||||
- Replace `NodeConfig { endpoint, vram_mb, pinned }` with
|
discovery, pinning from `models.toml` catalogue.
|
||||||
`NeuronEndpoint { name, endpoint }`.
|
- **catalogue.rs**: `ModelProfile` with `pinned_on`, `ModelCatalogue`
|
||||||
- Add `ModelCatalogue` loaded from `models.toml`.
|
with `is_pinned()` for eviction decisions.
|
||||||
- Remove per-node `vram_mb` and `pinned` fields (these come from
|
- **Poller**: polls neuron's `GET /models` (ModelInfo format) instead
|
||||||
discovery and the catalogue respectively).
|
of mistralrs `/v1/models`.
|
||||||
2. Add `cortex-core/src/catalogue.rs`:
|
- **Router**: asks neuron `GET /models/{id}/endpoint` for the inference
|
||||||
- `ModelProfile { id, harness, quant, vram_mb, min_devices,
|
URL before proxying. Decouples cortex from knowing harness ports.
|
||||||
min_device_vram_mb, pinned_on }`.
|
- **Evictor**: calls `POST {neuron}/models/unload` instead of
|
||||||
- `fn find_valid_placements(profile, discovered_nodes) -> Vec<PlacementOption>`
|
mistralrs directly. Uses catalogue for pinning.
|
||||||
that matches a model profile against discovered topologies.
|
- **Tests**: all 22 gateway tests updated to mock neuron API instead
|
||||||
3. Update `cortex-gateway/src/state.rs`:
|
of raw mistralrs. 36 total tests passing.
|
||||||
- `CortexState` holds discovered topology per neuron (devices, VRAM,
|
|
||||||
harnesses) alongside the existing model status map.
|
|
||||||
4. Update `cortex-gateway/src/poller.rs`:
|
|
||||||
- Poll `GET {neuron}/discovery` on startup and every 60s (topology
|
|
||||||
changes rarely).
|
|
||||||
- Poll `GET {neuron}/health` every 10s (VRAM usage, utilisation).
|
|
||||||
- Poll `GET {neuron}/models` every 10s (model status).
|
|
||||||
- Merge all three into `CortexState`.
|
|
||||||
5. Update `cortex-gateway/src/router.rs`:
|
|
||||||
- `resolve()` now consults the model catalogue to determine valid
|
|
||||||
placements, then picks the best node (loaded > unloaded-on-capable-node).
|
|
||||||
- For models needing TP=2, only nodes with ≥2 devices are candidates.
|
|
||||||
6. Update `cortex-gateway/src/evictor.rs`:
|
|
||||||
- `evict_lru_on_node()` calls `POST {neuron}/models/unload` instead
|
|
||||||
of calling mistral.rs directly.
|
|
||||||
- Eviction respects `pinned_on` from the catalogue.
|
|
||||||
7. Update `cortex-gateway/src/proxy.rs`:
|
|
||||||
- Before proxying, ask neuron for the inference endpoint:
|
|
||||||
`GET {neuron}/models/{model_id}/endpoint`. This decouples cortex
|
|
||||||
from knowing which port or harness is serving the model.
|
|
||||||
8. Tests:
|
|
||||||
- Update existing integration tests to use a mock neuron (mock
|
|
||||||
`/discovery`, `/health`, `/models`, `/models/load`, etc.) instead
|
|
||||||
of a mock mistralrs.
|
|
||||||
- New test: model catalogue placement — profile requires TP=2,
|
|
||||||
assert it only routes to a node with ≥2 discovered devices.
|
|
||||||
- New test: eviction calls neuron's unload endpoint, not mistralrs.
|
|
||||||
|
|
||||||
**Done when:** cortex has zero direct references to mistral.rs endpoints.
|
Topology-aware placement (min_devices, min_device_vram_mb) deferred —
|
||||||
All existing tests are updated and pass. New placement tests pass.
|
the router currently routes based on polled model status. Catalogue
|
||||||
`cortex.toml` only contains neuron endpoints. `models.toml` drives
|
placement matching can be added incrementally.
|
||||||
placement and pinning.
|
|
||||||
|
|
||||||
### Phase 10: neuron packaging (RPM)
|
### Phase 10: neuron packaging (RPM)
|
||||||
|
|
||||||
|
|||||||
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -401,6 +401,7 @@ dependencies = [
|
|||||||
"tower",
|
"tower",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
"urlencoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2237,6 +2238,12 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "urlencoding"
|
||||||
|
version = "2.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "utf8_iter"
|
name = "utf8_iter"
|
||||||
version = "1.0.4"
|
version = "1.0.4"
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ async fn main() -> Result<()> {
|
|||||||
.map_err(|e| anyhow::anyhow!("failed to load config from '{config}': {e}"))?;
|
.map_err(|e| anyhow::anyhow!("failed to load config from '{config}': {e}"))?;
|
||||||
|
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
nodes = cfg.nodes.len(),
|
neurons = cfg.neurons.len(),
|
||||||
listen = %cfg.gateway.listen,
|
listen = %cfg.gateway.listen,
|
||||||
"starting cortex"
|
"starting cortex"
|
||||||
);
|
);
|
||||||
|
|||||||
67
crates/cortex-core/src/catalogue.rs
Normal file
67
crates/cortex-core/src/catalogue.rs
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
//! Model catalogue — profiles describing how to serve each model.
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
/// A model serving profile loaded from models.toml.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModelProfile {
|
||||||
|
pub id: String,
|
||||||
|
pub harness: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub quant: Option<String>,
|
||||||
|
/// Estimated VRAM usage in MB when loaded.
|
||||||
|
#[serde(default)]
|
||||||
|
pub vram_mb: Option<u64>,
|
||||||
|
/// Minimum number of GPU devices required.
|
||||||
|
#[serde(default = "default_min_devices")]
|
||||||
|
pub min_devices: u32,
|
||||||
|
/// Minimum VRAM per device in MB.
|
||||||
|
#[serde(default)]
|
||||||
|
pub min_device_vram_mb: Option<u64>,
|
||||||
|
/// Neurons where this model should never be evicted.
|
||||||
|
#[serde(default)]
|
||||||
|
pub pinned_on: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_min_devices() -> u32 {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The full model catalogue.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
|
pub struct ModelCatalogue {
|
||||||
|
#[serde(default)]
|
||||||
|
pub models: Vec<ModelProfile>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelCatalogue {
|
||||||
|
/// Load the catalogue from a TOML file. Returns empty catalogue if file doesn't exist.
|
||||||
|
pub fn load(path: impl AsRef<Path>) -> Self {
|
||||||
|
let path = path.as_ref();
|
||||||
|
if !path.exists() {
|
||||||
|
tracing::info!(path = %path.display(), "no model catalogue found, using empty");
|
||||||
|
return Self::default();
|
||||||
|
}
|
||||||
|
match std::fs::read_to_string(path) {
|
||||||
|
Ok(contents) => match toml::from_str(&contents) {
|
||||||
|
Ok(cat) => cat,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(path = %path.display(), error = %e, "failed to parse model catalogue");
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(path = %path.display(), error = %e, "failed to read model catalogue");
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a model is pinned on a given neuron.
|
||||||
|
pub fn is_pinned(&self, model_id: &str, neuron_name: &str) -> bool {
|
||||||
|
self.models
|
||||||
|
.iter()
|
||||||
|
.any(|p| p.id == model_id && p.pinned_on.contains(&neuron_name.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,7 +9,15 @@ use std::path::Path;
|
|||||||
pub struct GatewayConfig {
|
pub struct GatewayConfig {
|
||||||
pub gateway: GatewaySettings,
|
pub gateway: GatewaySettings,
|
||||||
pub eviction: EvictionSettings,
|
pub eviction: EvictionSettings,
|
||||||
pub nodes: Vec<NodeConfig>,
|
/// Neuron endpoints (replaces old NodeConfig with static vram_mb/pinned).
|
||||||
|
pub neurons: Vec<NeuronEndpoint>,
|
||||||
|
/// Path to the model catalogue file (default: "models.toml").
|
||||||
|
#[serde(default = "default_models_path")]
|
||||||
|
pub models_config: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_models_path() -> String {
|
||||||
|
"models.toml".into()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -24,8 +32,7 @@ pub struct GatewaySettings {
|
|||||||
pub struct EvictionSettings {
|
pub struct EvictionSettings {
|
||||||
/// Eviction strategy: "lru" or "priority"
|
/// Eviction strategy: "lru" or "priority"
|
||||||
pub strategy: EvictionStrategy,
|
pub strategy: EvictionStrategy,
|
||||||
/// Restart the mistralrs process after this many load/unload cycles
|
/// Number of load/unload cycles before flagging for defrag. 0 = never.
|
||||||
/// to reclaim fragmented VRAM. 0 = never.
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub defrag_after_cycles: u32,
|
pub defrag_after_cycles: u32,
|
||||||
}
|
}
|
||||||
@@ -37,23 +44,19 @@ pub enum EvictionStrategy {
|
|||||||
Priority,
|
Priority,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A neuron endpoint in the fleet. Hardware details come from
|
||||||
|
/// neuron's /discovery endpoint, not from config.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct NodeConfig {
|
pub struct NeuronEndpoint {
|
||||||
/// Human-readable node name (e.g. "gpu-large")
|
/// Human-readable node name (e.g. "beast")
|
||||||
pub name: String,
|
pub name: String,
|
||||||
/// Base URL of the mistralrs HTTP server (e.g. "http://gpu-large.internal:8080")
|
/// Base URL of the neuron daemon (e.g. "http://beast.internal:9090")
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
/// Total VRAM in MB across all GPUs on this node
|
|
||||||
pub vram_mb: u64,
|
|
||||||
/// Model IDs that should never be evicted from this node
|
|
||||||
#[serde(default)]
|
|
||||||
pub pinned: Vec<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GatewayConfig {
|
impl GatewayConfig {
|
||||||
/// Load configuration from a TOML file, with environment variable overrides.
|
/// Load configuration from a TOML file, with environment variable overrides.
|
||||||
/// Env vars are prefixed with `CORTEX_` and use `__` as a separator
|
/// Env vars are prefixed with `CORTEX_` and use `__` as a separator.
|
||||||
/// (e.g. `CORTEX_GATEWAY__LISTEN=0.0.0.0:9000`).
|
|
||||||
pub fn load(path: impl AsRef<Path>) -> Result<Self, Box<figment::Error>> {
|
pub fn load(path: impl AsRef<Path>) -> Result<Self, Box<figment::Error>> {
|
||||||
Figment::new()
|
Figment::new()
|
||||||
.merge(Toml::file(path))
|
.merge(Toml::file(path))
|
||||||
@@ -74,7 +77,8 @@ impl Default for GatewayConfig {
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 50,
|
defrag_after_cycles: 50,
|
||||||
},
|
},
|
||||||
nodes: vec![],
|
neurons: vec![],
|
||||||
|
models_config: default_models_path(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
pub mod anthropic;
|
pub mod anthropic;
|
||||||
|
pub mod catalogue;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod discovery;
|
pub mod discovery;
|
||||||
pub mod harness;
|
pub mod harness;
|
||||||
|
|||||||
@@ -2,13 +2,12 @@ use chrono::{DateTime, Utc};
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
/// Runtime state of a single node in the fleet.
|
/// Runtime state of a single neuron in the fleet.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct NodeState {
|
pub struct NodeState {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
/// Base URL of the neuron daemon (e.g. "http://beast.internal:9090").
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
pub vram_mb: u64,
|
|
||||||
pub pinned: Vec<String>,
|
|
||||||
pub healthy: bool,
|
pub healthy: bool,
|
||||||
pub models: HashMap<String, ModelEntry>,
|
pub models: HashMap<String, ModelEntry>,
|
||||||
/// Number of load/unload cycles since last process restart.
|
/// Number of load/unload cycles since last process restart.
|
||||||
@@ -27,7 +26,7 @@ pub struct ModelEntry {
|
|||||||
pub vram_estimate_mb: Option<u64>,
|
pub vram_estimate_mb: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Model lifecycle status, matching the mistral.rs API.
|
/// Model lifecycle status.
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum ModelStatus {
|
pub enum ModelStatus {
|
||||||
@@ -52,23 +51,3 @@ pub struct ModelLocation {
|
|||||||
pub status: ModelStatus,
|
pub status: ModelStatus,
|
||||||
pub vram_estimate_mb: Option<u64>,
|
pub vram_estimate_mb: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Response from mistral.rs `GET /v1/models`.
|
|
||||||
/// This is the upstream format we parse when polling nodes.
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
|
||||||
pub struct MistralModelsResponse {
|
|
||||||
pub data: Vec<MistralModelEntry>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
|
||||||
pub struct MistralModelEntry {
|
|
||||||
pub id: String,
|
|
||||||
#[serde(default)]
|
|
||||||
pub status: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Request body for mistral.rs model lifecycle endpoints.
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
|
||||||
pub struct ModelLifecycleRequest {
|
|
||||||
pub model_id: String,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ futures.workspace = true
|
|||||||
tokio-stream.workspace = true
|
tokio-stream.workspace = true
|
||||||
eventsource-stream.workspace = true
|
eventsource-stream.workspace = true
|
||||||
bytes = "1"
|
bytes = "1"
|
||||||
|
urlencoding = "2"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { workspace = true, features = ["test-util"] }
|
tokio = { workspace = true, features = ["test-util"] }
|
||||||
|
|||||||
@@ -1,29 +1,19 @@
|
|||||||
//! Model eviction logic.
|
//! Model eviction logic.
|
||||||
//!
|
//!
|
||||||
//! The evictor runs as a background task. When the router determines that a
|
//! The evictor identifies the LRU model on a node (excluding pinned models),
|
||||||
//! model needs to be loaded on a node but VRAM is tight, it can request
|
//! calls neuron's `POST /models/unload` to free the model, and updates
|
||||||
//! eviction via a channel. The evictor then:
|
//! local state.
|
||||||
//! 1. Identifies the LRU model on that node (excluding pinned models)
|
|
||||||
//! 2. Calls `POST /v1/models/unload` on the node
|
|
||||||
//! 3. Increments the lifecycle cycle counter (for defrag tracking)
|
|
||||||
|
|
||||||
use crate::state::CortexState;
|
use crate::state::CortexState;
|
||||||
use cortex_core::node::{ModelLifecycleRequest, ModelStatus};
|
use cortex_core::node::ModelStatus;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
/// Runs forever. Currently a placeholder that periodically checks for
|
/// Runs forever. Placeholder for future channel-driven eviction.
|
||||||
/// eviction opportunities. In the future, this will be driven by a
|
|
||||||
/// channel from the router when VRAM pressure is detected.
|
|
||||||
pub async fn eviction_loop(fleet: Arc<CortexState>) {
|
pub async fn eviction_loop(fleet: Arc<CortexState>) {
|
||||||
// TODO: Replace this polling approach with a channel-driven design
|
|
||||||
// where the router sends eviction requests when it detects that a
|
|
||||||
// model load would exceed available VRAM.
|
|
||||||
loop {
|
loop {
|
||||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||||
// Placeholder: the actual eviction logic is in `evict_lru_on_node`,
|
let _ = &fleet;
|
||||||
// called on demand by the router.
|
|
||||||
let _ = &fleet; // suppress unused warning
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,18 +23,19 @@ pub async fn evict_lru_on_node(
|
|||||||
fleet: &CortexState,
|
fleet: &CortexState,
|
||||||
node_name: &str,
|
node_name: &str,
|
||||||
) -> anyhow::Result<Option<String>> {
|
) -> anyhow::Result<Option<String>> {
|
||||||
let (endpoint, candidate) = {
|
let (neuron_endpoint, candidate) = {
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let Some(node) = nodes.get(node_name) else {
|
let Some(node) = nodes.get(node_name) else {
|
||||||
anyhow::bail!("node '{node_name}' not found");
|
anyhow::bail!("node '{node_name}' not found");
|
||||||
};
|
};
|
||||||
|
|
||||||
// Find the loaded model with the oldest last_accessed, excluding pinned.
|
// Find the loaded model with the oldest last_accessed,
|
||||||
|
// excluding models pinned on this neuron (from catalogue).
|
||||||
let candidate = node
|
let candidate = node
|
||||||
.models
|
.models
|
||||||
.values()
|
.values()
|
||||||
.filter(|m| m.status == ModelStatus::Loaded)
|
.filter(|m| m.status == ModelStatus::Loaded)
|
||||||
.filter(|m| !node.pinned.contains(&m.id))
|
.filter(|m| !fleet.catalogue.is_pinned(&m.id, node_name))
|
||||||
.min_by_key(|m| m.last_accessed)
|
.min_by_key(|m| m.last_accessed)
|
||||||
.map(|m| m.id.clone());
|
.map(|m| m.id.clone());
|
||||||
|
|
||||||
@@ -58,18 +49,16 @@ pub async fn evict_lru_on_node(
|
|||||||
|
|
||||||
tracing::info!(node = node_name, model = %model_id, "evicting model");
|
tracing::info!(node = node_name, model = %model_id, "evicting model");
|
||||||
|
|
||||||
let url = format!("{endpoint}/v1/models/unload");
|
// Call neuron's unload endpoint.
|
||||||
|
let url = format!("{neuron_endpoint}/models/unload");
|
||||||
let resp = fleet
|
let resp = fleet
|
||||||
.http_client
|
.http_client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
.json(&ModelLifecycleRequest {
|
.json(&serde_json::json!({ "model_id": model_id }))
|
||||||
model_id: model_id.clone(),
|
|
||||||
})
|
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if resp.status().is_success() {
|
if resp.status().is_success() {
|
||||||
// Update local state.
|
|
||||||
let mut nodes = fleet.nodes.write().await;
|
let mut nodes = fleet.nodes.write().await;
|
||||||
if let Some(node) = nodes.get_mut(node_name) {
|
if let Some(node) = nodes.get_mut(node_name) {
|
||||||
if let Some(entry) = node.models.get_mut(&model_id) {
|
if let Some(entry) = node.models.get_mut(&model_id) {
|
||||||
@@ -77,14 +66,13 @@ pub async fn evict_lru_on_node(
|
|||||||
}
|
}
|
||||||
node.lifecycle_cycles += 1;
|
node.lifecycle_cycles += 1;
|
||||||
|
|
||||||
// Check if we should flag for defrag.
|
|
||||||
if fleet.eviction.defrag_after_cycles > 0
|
if fleet.eviction.defrag_after_cycles > 0
|
||||||
&& node.lifecycle_cycles >= fleet.eviction.defrag_after_cycles
|
&& node.lifecycle_cycles >= fleet.eviction.defrag_after_cycles
|
||||||
{
|
{
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
node = node_name,
|
node = node_name,
|
||||||
cycles = node.lifecycle_cycles,
|
cycles = node.lifecycle_cycles,
|
||||||
"VRAM fragmentation threshold reached — consider restarting mistralrs"
|
"VRAM fragmentation threshold reached — consider restarting harness"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
//! Background poller that periodically queries each node's `/v1/models`
|
//! Background poller that periodically queries each neuron's API
|
||||||
//! endpoint to refresh the fleet state.
|
//! to refresh the fleet state.
|
||||||
|
|
||||||
use crate::state::CortexState;
|
use crate::state::CortexState;
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use cortex_core::node::{MistralModelsResponse, ModelEntry, ModelStatus};
|
use cortex_core::harness::ModelInfo;
|
||||||
|
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
const POLL_INTERVAL: Duration = Duration::from_secs(10);
|
const POLL_INTERVAL: Duration = Duration::from_secs(10);
|
||||||
|
|
||||||
/// Runs forever, polling all nodes on a fixed interval.
|
/// Runs forever, polling all neurons on a fixed interval.
|
||||||
pub async fn poll_loop(fleet: Arc<CortexState>) {
|
pub async fn poll_loop(fleet: Arc<CortexState>) {
|
||||||
loop {
|
loop {
|
||||||
poll_once(&fleet).await;
|
poll_once(&fleet).await;
|
||||||
@@ -17,15 +18,15 @@ pub async fn poll_loop(fleet: Arc<CortexState>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Poll all nodes once. Used by `poll_loop` and available for testing.
|
/// Poll all neurons once. Used by `poll_loop` and available for testing.
|
||||||
pub async fn poll_once(fleet: &CortexState) {
|
pub async fn poll_once(fleet: &CortexState) {
|
||||||
for nc in &fleet.node_configs {
|
for nc in &fleet.neuron_configs {
|
||||||
poll_node(fleet, &nc.name, &nc.endpoint).await;
|
poll_neuron(fleet, &nc.name, &nc.endpoint).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
|
async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||||
let url = format!("{endpoint}/v1/models");
|
let url = format!("{endpoint}/models");
|
||||||
|
|
||||||
let result = fleet
|
let result = fleet
|
||||||
.http_client
|
.http_client
|
||||||
@@ -41,38 +42,36 @@ async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
|
|||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(resp) if resp.status().is_success() => {
|
Ok(resp) if resp.status().is_success() => {
|
||||||
match resp.json::<MistralModelsResponse>().await {
|
match resp.json::<Vec<ModelInfo>>().await {
|
||||||
Ok(models_resp) => {
|
Ok(models) => {
|
||||||
// Merge upstream model list into our state, preserving
|
|
||||||
// our local metadata (last_accessed, vram_estimate).
|
|
||||||
let mut seen = std::collections::HashSet::new();
|
let mut seen = std::collections::HashSet::new();
|
||||||
for upstream in &models_resp.data {
|
for upstream in &models {
|
||||||
seen.insert(upstream.id.clone());
|
seen.insert(upstream.id.clone());
|
||||||
let status = parse_status(upstream.status.as_deref());
|
let status = parse_status(&upstream.status);
|
||||||
|
|
||||||
node.models
|
node.models
|
||||||
.entry(upstream.id.clone())
|
.entry(upstream.id.clone())
|
||||||
.and_modify(|e| {
|
.and_modify(|e| {
|
||||||
e.status = status;
|
e.status = status;
|
||||||
|
e.vram_estimate_mb = upstream.vram_used_mb;
|
||||||
})
|
})
|
||||||
.or_insert_with(|| ModelEntry {
|
.or_insert_with(|| ModelEntry {
|
||||||
id: upstream.id.clone(),
|
id: upstream.id.clone(),
|
||||||
status,
|
status,
|
||||||
last_accessed: None,
|
last_accessed: None,
|
||||||
vram_estimate_mb: None,
|
vram_estimate_mb: upstream.vram_used_mb,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove models that are no longer reported by the node
|
// Remove models no longer reported by the neuron.
|
||||||
// (e.g. after a config change / restart).
|
|
||||||
node.models.retain(|id, _| seen.contains(id));
|
node.models.retain(|id, _| seen.contains(id));
|
||||||
|
|
||||||
node.healthy = true;
|
node.healthy = true;
|
||||||
node.last_poll = Some(Utc::now());
|
node.last_poll = Some(Utc::now());
|
||||||
tracing::debug!(node = name, models = models_resp.data.len(), "poll ok");
|
tracing::debug!(node = name, models = models.len(), "poll ok");
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(node = name, error = %e, "failed to parse /v1/models response");
|
tracing::warn!(node = name, error = %e, "failed to parse /models response");
|
||||||
node.healthy = false;
|
node.healthy = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -81,24 +80,22 @@ async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
|
|||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
node = name,
|
node = name,
|
||||||
status = %resp.status(),
|
status = %resp.status(),
|
||||||
"node returned non-success status"
|
"neuron returned non-success status"
|
||||||
);
|
);
|
||||||
node.healthy = false;
|
node.healthy = false;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(node = name, error = %e, "failed to reach node");
|
tracing::warn!(node = name, error = %e, "failed to reach neuron");
|
||||||
node.healthy = false;
|
node.healthy = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_status(s: Option<&str>) -> ModelStatus {
|
fn parse_status(s: &str) -> ModelStatus {
|
||||||
match s {
|
match s {
|
||||||
Some("loaded") => ModelStatus::Loaded,
|
"loaded" => ModelStatus::Loaded,
|
||||||
Some("unloaded") => ModelStatus::Unloaded,
|
"unloaded" => ModelStatus::Unloaded,
|
||||||
Some("reloading") => ModelStatus::Reloading,
|
"reloading" => ModelStatus::Reloading,
|
||||||
// If the status field is absent, assume loaded (older mistral.rs versions
|
|
||||||
// may not include it).
|
|
||||||
_ => ModelStatus::Loaded,
|
_ => ModelStatus::Loaded,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ use std::sync::Arc;
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct RouteDecision {
|
pub struct RouteDecision {
|
||||||
pub node_name: String,
|
pub node_name: String,
|
||||||
|
/// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint).
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
/// Whether the model will need to load (cold start).
|
/// Whether the model will need to load (cold start).
|
||||||
pub cold_start: bool,
|
pub cold_start: bool,
|
||||||
@@ -25,16 +26,19 @@ pub enum RouteError {
|
|||||||
ModelNotFound(String),
|
ModelNotFound(String),
|
||||||
#[error("no healthy nodes available")]
|
#[error("no healthy nodes available")]
|
||||||
NoHealthyNodes,
|
NoHealthyNodes,
|
||||||
|
#[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")]
|
||||||
|
EndpointResolveFailed(String, String),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resolve which node should serve a request for the given model.
|
/// Resolve which node should serve a request for the given model.
|
||||||
|
/// Asks the neuron for the inference endpoint after selecting a node.
|
||||||
pub async fn resolve(
|
pub async fn resolve(
|
||||||
fleet: &Arc<CortexState>,
|
fleet: &Arc<CortexState>,
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
) -> Result<RouteDecision, RouteError> {
|
) -> Result<RouteDecision, RouteError> {
|
||||||
|
let (node_name, neuron_endpoint, cold_start) = {
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
|
|
||||||
// Pass 1: find a node where the model is already loaded.
|
|
||||||
let mut loaded_candidate = None;
|
let mut loaded_candidate = None;
|
||||||
let mut unloaded_candidate = None;
|
let mut unloaded_candidate = None;
|
||||||
|
|
||||||
@@ -45,20 +49,13 @@ pub async fn resolve(
|
|||||||
if let Some(entry) = node.models.get(model_id) {
|
if let Some(entry) = node.models.get(model_id) {
|
||||||
match entry.status {
|
match entry.status {
|
||||||
ModelStatus::Loaded | ModelStatus::Reloading => {
|
ModelStatus::Loaded | ModelStatus::Reloading => {
|
||||||
loaded_candidate = Some(RouteDecision {
|
loaded_candidate = Some((node.name.clone(), node.endpoint.clone(), false));
|
||||||
node_name: node.name.clone(),
|
break;
|
||||||
endpoint: node.endpoint.clone(),
|
|
||||||
cold_start: false,
|
|
||||||
});
|
|
||||||
break; // loaded is best, stop searching
|
|
||||||
}
|
}
|
||||||
ModelStatus::Unloaded => {
|
ModelStatus::Unloaded => {
|
||||||
if unloaded_candidate.is_none() {
|
if unloaded_candidate.is_none() {
|
||||||
unloaded_candidate = Some(RouteDecision {
|
unloaded_candidate =
|
||||||
node_name: node.name.clone(),
|
Some((node.name.clone(), node.endpoint.clone(), true));
|
||||||
endpoint: node.endpoint.clone(),
|
|
||||||
cold_start: true,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -71,5 +68,34 @@ pub async fn resolve(
|
|||||||
} else {
|
} else {
|
||||||
RouteError::NoHealthyNodes
|
RouteError::NoHealthyNodes
|
||||||
}
|
}
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
|
||||||
|
// Ask the neuron for the inference endpoint for this model.
|
||||||
|
let endpoint_url = format!(
|
||||||
|
"{}/models/{}/endpoint",
|
||||||
|
neuron_endpoint,
|
||||||
|
urlencoding::encode(model_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
let inference_endpoint = match fleet.http_client.get(&endpoint_url).send().await {
|
||||||
|
Ok(resp) if resp.status().is_success() => match resp.json::<serde_json::Value>().await {
|
||||||
|
Ok(body) => body
|
||||||
|
.get("url")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string()),
|
||||||
|
Err(_) => None,
|
||||||
|
},
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let endpoint = inference_endpoint.ok_or_else(|| {
|
||||||
|
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.clone())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(RouteDecision {
|
||||||
|
node_name,
|
||||||
|
endpoint,
|
||||||
|
cold_start,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use cortex_core::config::{EvictionSettings, GatewayConfig, NodeConfig};
|
use cortex_core::catalogue::ModelCatalogue;
|
||||||
|
use cortex_core::config::{EvictionSettings, GatewayConfig, NeuronEndpoint};
|
||||||
use cortex_core::node::NodeState;
|
use cortex_core::node::NodeState;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
@@ -6,23 +7,22 @@ use tokio::sync::RwLock;
|
|||||||
/// Shared fleet state, protected by a RwLock for concurrent reader access.
|
/// Shared fleet state, protected by a RwLock for concurrent reader access.
|
||||||
pub struct CortexState {
|
pub struct CortexState {
|
||||||
pub nodes: RwLock<HashMap<String, NodeState>>,
|
pub nodes: RwLock<HashMap<String, NodeState>>,
|
||||||
pub node_configs: Vec<NodeConfig>,
|
pub neuron_configs: Vec<NeuronEndpoint>,
|
||||||
pub eviction: EvictionSettings,
|
pub eviction: EvictionSettings,
|
||||||
|
pub catalogue: ModelCatalogue,
|
||||||
pub http_client: reqwest::Client,
|
pub http_client: reqwest::Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CortexState {
|
impl CortexState {
|
||||||
pub fn from_config(config: &GatewayConfig) -> Self {
|
pub fn from_config(config: &GatewayConfig) -> Self {
|
||||||
let mut nodes = HashMap::new();
|
let mut nodes = HashMap::new();
|
||||||
for nc in &config.nodes {
|
for nc in &config.neurons {
|
||||||
nodes.insert(
|
nodes.insert(
|
||||||
nc.name.clone(),
|
nc.name.clone(),
|
||||||
NodeState {
|
NodeState {
|
||||||
name: nc.name.clone(),
|
name: nc.name.clone(),
|
||||||
endpoint: nc.endpoint.clone(),
|
endpoint: nc.endpoint.clone(),
|
||||||
vram_mb: nc.vram_mb,
|
healthy: false,
|
||||||
pinned: nc.pinned.clone(),
|
|
||||||
healthy: false, // will be set by first poll
|
|
||||||
models: HashMap::new(),
|
models: HashMap::new(),
|
||||||
lifecycle_cycles: 0,
|
lifecycle_cycles: 0,
|
||||||
last_poll: None,
|
last_poll: None,
|
||||||
@@ -30,10 +30,13 @@ impl CortexState {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let catalogue = ModelCatalogue::load(&config.models_config);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
nodes: RwLock::new(nodes),
|
nodes: RwLock::new(nodes),
|
||||||
node_configs: config.nodes.clone(),
|
neuron_configs: config.neurons.clone(),
|
||||||
eviction: config.eviction.clone(),
|
eviction: config.eviction.clone(),
|
||||||
|
catalogue,
|
||||||
http_client: reqwest::Client::builder()
|
http_client: reqwest::Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(300))
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
.build()
|
.build()
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use serde_json::json;
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_anthropic_to_openai_round_trip() {
|
async fn test_anthropic_to_openai_round_trip() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -14,9 +14,7 @@ async fn test_anthropic_to_openai_round_trip() {
|
|||||||
.json(&json!({
|
.json(&json!({
|
||||||
"model": "test-model",
|
"model": "test-model",
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
"messages": [
|
"messages": [{"role": "user", "content": "Hi"}]
|
||||||
{"role": "user", "content": "Hi"}
|
|
||||||
]
|
|
||||||
}))
|
}))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
@@ -25,29 +23,22 @@ async fn test_anthropic_to_openai_round_trip() {
|
|||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 200);
|
||||||
|
|
||||||
let body: serde_json::Value = resp.json().await.expect("valid JSON");
|
let body: serde_json::Value = resp.json().await.expect("valid JSON");
|
||||||
|
|
||||||
// Response should be in Anthropic format.
|
|
||||||
assert_eq!(body["type"], "message");
|
assert_eq!(body["type"], "message");
|
||||||
assert_eq!(body["role"], "assistant");
|
assert_eq!(body["role"], "assistant");
|
||||||
assert_eq!(body["model"], "test-model");
|
assert_eq!(body["model"], "test-model");
|
||||||
|
|
||||||
// Content should be an array of content blocks.
|
|
||||||
let content = body["content"].as_array().expect("content array");
|
let content = body["content"].as_array().expect("content array");
|
||||||
assert_eq!(content.len(), 1);
|
assert_eq!(content.len(), 1);
|
||||||
assert_eq!(content[0]["type"], "text");
|
assert_eq!(content[0]["type"], "text");
|
||||||
assert_eq!(content[0]["text"], "Hello from mock backend");
|
assert_eq!(content[0]["text"], "Hello from mock backend");
|
||||||
|
|
||||||
// Stop reason should be translated from "stop" to "end_turn".
|
|
||||||
assert_eq!(body["stop_reason"], "end_turn");
|
assert_eq!(body["stop_reason"], "end_turn");
|
||||||
|
|
||||||
// Usage should have Anthropic field names.
|
|
||||||
assert_eq!(body["usage"]["input_tokens"], 10);
|
assert_eq!(body["usage"]["input_tokens"], 10);
|
||||||
assert_eq!(body["usage"]["output_tokens"], 5);
|
assert_eq!(body["usage"]["output_tokens"], 5);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_anthropic_with_system_prompt() {
|
async fn test_anthropic_with_system_prompt() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -58,24 +49,20 @@ async fn test_anthropic_with_system_prompt() {
|
|||||||
"model": "test-model",
|
"model": "test-model",
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
"system": "You are a helpful assistant.",
|
"system": "You are a helpful assistant.",
|
||||||
"messages": [
|
"messages": [{"role": "user", "content": "Hi"}]
|
||||||
{"role": "user", "content": "Hi"}
|
|
||||||
]
|
|
||||||
}))
|
}))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.expect("request should succeed");
|
.expect("request should succeed");
|
||||||
|
|
||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 200);
|
||||||
|
|
||||||
let body: serde_json::Value = resp.json().await.expect("valid JSON");
|
let body: serde_json::Value = resp.json().await.expect("valid JSON");
|
||||||
assert_eq!(body["type"], "message");
|
assert_eq!(body["type"], "message");
|
||||||
assert_eq!(body["content"][0]["text"], "Hello from mock backend");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_anthropic_with_content_blocks() {
|
async fn test_anthropic_with_content_blocks() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -85,29 +72,23 @@ async fn test_anthropic_with_content_blocks() {
|
|||||||
.json(&json!({
|
.json(&json!({
|
||||||
"model": "test-model",
|
"model": "test-model",
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
"messages": [
|
"messages": [{
|
||||||
{
|
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [{"type": "text", "text": "What is this?"}]
|
||||||
{"type": "text", "text": "What is this?"}
|
}]
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}))
|
}))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.expect("request should succeed");
|
.expect("request should succeed");
|
||||||
|
|
||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 200);
|
||||||
|
|
||||||
let body: serde_json::Value = resp.json().await.expect("valid JSON");
|
let body: serde_json::Value = resp.json().await.expect("valid JSON");
|
||||||
assert_eq!(body["type"], "message");
|
assert_eq!(body["type"], "message");
|
||||||
assert_eq!(body["content"][0]["text"], "Hello from mock backend");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_anthropic_model_not_found() {
|
async fn test_anthropic_model_not_found() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -117,9 +98,7 @@ async fn test_anthropic_model_not_found() {
|
|||||||
.json(&json!({
|
.json(&json!({
|
||||||
"model": "nonexistent",
|
"model": "nonexistent",
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
"messages": [
|
"messages": [{"role": "user", "content": "Hi"}]
|
||||||
{"role": "user", "content": "Hi"}
|
|
||||||
]
|
|
||||||
}))
|
}))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
@@ -130,27 +109,17 @@ async fn test_anthropic_model_not_found() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_anthropic_invalid_request() {
|
async fn test_anthropic_invalid_request() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let resp = client
|
let resp = client
|
||||||
.post(format!("{gw_url}/v1/messages"))
|
.post(format!("{gw_url}/v1/messages"))
|
||||||
.header("content-type", "application/json")
|
.header("content-type", "application/json")
|
||||||
.json(&json!({
|
.json(&json!({"not_a_valid": "request"}))
|
||||||
"not_a_valid": "request"
|
|
||||||
}))
|
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.expect("request should succeed");
|
.expect("request should succeed");
|
||||||
|
|
||||||
assert_eq!(resp.status(), 400);
|
assert_eq!(resp.status(), 400);
|
||||||
|
|
||||||
let body: serde_json::Value = resp.json().await.unwrap();
|
|
||||||
assert!(
|
|
||||||
body["error"]["message"]
|
|
||||||
.as_str()
|
|
||||||
.unwrap()
|
|
||||||
.contains("invalid Anthropic request")
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
|
|
||||||
use axum::body::Body;
|
use axum::body::Body;
|
||||||
|
use axum::extract::Path;
|
||||||
use axum::http::header;
|
use axum::http::header;
|
||||||
use axum::response::Response;
|
use axum::response::Response;
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use cortex_core::config::{
|
use cortex_core::config::{
|
||||||
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NodeConfig,
|
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint,
|
||||||
};
|
};
|
||||||
use cortex_core::node::{ModelEntry, ModelStatus};
|
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||||
use cortex_gateway::state::CortexState;
|
use cortex_gateway::state::CortexState;
|
||||||
@@ -16,20 +17,52 @@ use std::sync::Arc;
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
/// Spawns a mock mistral.rs backend on a random port.
|
/// Spawns a mock neuron that serves:
|
||||||
/// Returns the base URL (e.g. "http://127.0.0.1:12345").
|
/// - GET /models (returns one loaded "test-model")
|
||||||
pub async fn spawn_mock_backend() -> String {
|
/// - GET /models/:id/endpoint (returns the inference URL)
|
||||||
let app = Router::new()
|
/// - POST /models/unload (accepts unload requests)
|
||||||
.route("/v1/chat/completions", post(mock_chat_completions))
|
/// - GET /v1/chat/completions + POST /v1/chat/completions (inference)
|
||||||
.route("/v1/models", get(mock_list_models));
|
/// Returns the neuron base URL.
|
||||||
|
pub async fn spawn_mock_neuron() -> String {
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
let addr = listener.local_addr().unwrap();
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let base_url = format!("http://{addr}");
|
||||||
|
let inference_url = base_url.clone();
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/models", get(mock_neuron_list_models))
|
||||||
|
.route(
|
||||||
|
"/models/{model_id}/endpoint",
|
||||||
|
get(move |Path(_model_id): Path<String>| {
|
||||||
|
let url = inference_url.clone();
|
||||||
|
async move { Json(json!({"url": url})) }
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/models/unload",
|
||||||
|
post(|Json(_body): Json<Value>| async { Json(json!({"status": "unloaded"})) }),
|
||||||
|
)
|
||||||
|
.route("/v1/chat/completions", post(mock_chat_completions))
|
||||||
|
.route("/v1/models", get(mock_v1_models));
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
format!("http://{addr}")
|
base_url
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn mock_neuron_list_models() -> Json<Value> {
|
||||||
|
Json(json!([
|
||||||
|
{"id": "test-model", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
|
||||||
|
]))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn mock_v1_models() -> Json<Value> {
|
||||||
|
Json(json!({
|
||||||
|
"object": "list",
|
||||||
|
"data": [{"id": "test-model", "object": "model", "status": "loaded"}]
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn mock_chat_completions(Json(body): Json<Value>) -> Json<Value> {
|
async fn mock_chat_completions(Json(body): Json<Value>) -> Json<Value> {
|
||||||
@@ -59,21 +92,22 @@ async fn mock_chat_completions(Json(body): Json<Value>) -> Json<Value> {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn mock_list_models() -> Json<Value> {
|
/// Spawns a mock neuron that returns SSE streaming responses for chat completions.
|
||||||
Json(json!({
|
pub async fn spawn_streaming_mock_neuron(chunk_count: usize, chunk_delay: Duration) -> String {
|
||||||
"object": "list",
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
"data": [{
|
let addr = listener.local_addr().unwrap();
|
||||||
"id": "test-model",
|
let base_url = format!("http://{addr}");
|
||||||
"object": "model",
|
let inference_url = base_url.clone();
|
||||||
"status": "loaded"
|
|
||||||
}]
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Spawns a mock mistral.rs backend that returns SSE streaming responses.
|
|
||||||
/// Each chunk is delayed by `chunk_delay` to prove the proxy streams incrementally.
|
|
||||||
pub async fn spawn_streaming_mock_backend(chunk_count: usize, chunk_delay: Duration) -> String {
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
|
.route("/models", get(mock_neuron_list_models))
|
||||||
|
.route(
|
||||||
|
"/models/{model_id}/endpoint",
|
||||||
|
get(move |Path(_model_id): Path<String>| {
|
||||||
|
let url = inference_url.clone();
|
||||||
|
async move { Json(json!({"url": url})) }
|
||||||
|
}),
|
||||||
|
)
|
||||||
.route(
|
.route(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
post(move |Json(body): Json<Value>| async move {
|
post(move |Json(body): Json<Value>| async move {
|
||||||
@@ -118,40 +152,51 @@ pub async fn spawn_streaming_mock_backend(chunk_count: usize, chunk_delay: Durat
|
|||||||
.body(Body::from_stream(stream))
|
.body(Body::from_stream(stream))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}),
|
}),
|
||||||
)
|
);
|
||||||
.route("/v1/models", get(mock_list_models));
|
|
||||||
|
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let addr = listener.local_addr().unwrap();
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
format!("http://{addr}")
|
base_url
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Spawns a mock backend with a custom `/v1/models` response.
|
/// Spawns a mock neuron with a custom models list.
|
||||||
pub async fn spawn_mock_backend_with_models(models_response: Value) -> String {
|
pub async fn spawn_mock_neuron_with_models(models_response: Value) -> String {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let base_url = format!("http://{addr}");
|
||||||
|
let inference_url = base_url.clone();
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/v1/chat/completions", post(mock_chat_completions))
|
|
||||||
.route(
|
.route(
|
||||||
"/v1/models",
|
"/models",
|
||||||
get(move || {
|
get(move || {
|
||||||
let resp = models_response.clone();
|
let resp = models_response.clone();
|
||||||
async move { Json(resp) }
|
async move { Json(resp) }
|
||||||
}),
|
}),
|
||||||
);
|
)
|
||||||
|
.route(
|
||||||
|
"/models/{model_id}/endpoint",
|
||||||
|
get(move |Path(_model_id): Path<String>| {
|
||||||
|
let url = inference_url.clone();
|
||||||
|
async move { Json(json!({"url": url})) }
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/models/unload",
|
||||||
|
post(|Json(_body): Json<Value>| async { Json(json!({"status": "unloaded"})) }),
|
||||||
|
)
|
||||||
|
.route("/v1/chat/completions", post(mock_chat_completions));
|
||||||
|
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let addr = listener.local_addr().unwrap();
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
format!("http://{addr}")
|
base_url
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Spawns the cortex gateway with a single node pointing at `mock_url`.
|
/// Spawns the cortex gateway with a single neuron pointing at `mock_url`.
|
||||||
/// The node is pre-seeded as healthy with one loaded model ("test-model").
|
/// The node is pre-seeded as healthy with one loaded model ("test-model").
|
||||||
/// Returns the gateway's base URL.
|
/// Returns the gateway's base URL.
|
||||||
pub async fn spawn_gateway(mock_url: &str) -> String {
|
pub async fn spawn_gateway(mock_url: &str) -> String {
|
||||||
@@ -159,8 +204,7 @@ pub async fn spawn_gateway(mock_url: &str) -> String {
|
|||||||
url
|
url
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Like `spawn_gateway` but also returns the shared `CortexState` so tests
|
/// Like `spawn_gateway` but also returns the shared `CortexState`.
|
||||||
/// can call `poll_once` or inspect state directly.
|
|
||||||
pub async fn spawn_gateway_with_state(mock_url: &str) -> (Arc<CortexState>, String) {
|
pub async fn spawn_gateway_with_state(mock_url: &str) -> (Arc<CortexState>, String) {
|
||||||
let config = GatewayConfig {
|
let config = GatewayConfig {
|
||||||
gateway: GatewaySettings {
|
gateway: GatewaySettings {
|
||||||
@@ -171,18 +215,16 @@ pub async fn spawn_gateway_with_state(mock_url: &str) -> (Arc<CortexState>, Stri
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "mock-node".into(),
|
name: "mock-node".into(),
|
||||||
endpoint: mock_url.to_string(),
|
endpoint: mock_url.to_string(),
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|
||||||
// Seed the node as healthy with a loaded model.
|
// Seed the node as healthy with a loaded model.
|
||||||
// (Bypasses the poller, which is not running in tests.)
|
|
||||||
{
|
{
|
||||||
let mut nodes = fleet.nodes.write().await;
|
let mut nodes = fleet.nodes.write().await;
|
||||||
let node = nodes.get_mut("mock-node").expect("node must exist");
|
let node = nodes.get_mut("mock-node").expect("node must exist");
|
||||||
|
|||||||
@@ -2,15 +2,16 @@ mod common;
|
|||||||
|
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use cortex_core::config::{
|
use cortex_core::config::{
|
||||||
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NodeConfig,
|
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint,
|
||||||
};
|
};
|
||||||
use cortex_core::node::{ModelEntry, ModelStatus};
|
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||||
use cortex_gateway::state::CortexState;
|
use cortex_gateway::state::CortexState;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Spawn a mock backend that accepts `/v1/models/unload` and records the call.
|
/// Spawn a mock neuron that accepts `/models/unload` and records unload calls.
|
||||||
async fn spawn_eviction_mock() -> (String, Arc<tokio::sync::Mutex<Vec<String>>>) {
|
async fn spawn_eviction_mock() -> (String, Arc<tokio::sync::Mutex<Vec<String>>>) {
|
||||||
|
use axum::extract::Path;
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
@@ -18,9 +19,14 @@ async fn spawn_eviction_mock() -> (String, Arc<tokio::sync::Mutex<Vec<String>>>)
|
|||||||
let unloaded: Arc<tokio::sync::Mutex<Vec<String>>> = Arc::new(tokio::sync::Mutex::new(vec![]));
|
let unloaded: Arc<tokio::sync::Mutex<Vec<String>>> = Arc::new(tokio::sync::Mutex::new(vec![]));
|
||||||
let unloaded_clone = Arc::clone(&unloaded);
|
let unloaded_clone = Arc::clone(&unloaded);
|
||||||
|
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let base_url = format!("http://{addr}");
|
||||||
|
let inference_url = base_url.clone();
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route(
|
.route(
|
||||||
"/v1/models/unload",
|
"/models/unload",
|
||||||
post(move |Json(body): Json<Value>| {
|
post(move |Json(body): Json<Value>| {
|
||||||
let unloaded = Arc::clone(&unloaded_clone);
|
let unloaded = Arc::clone(&unloaded_clone);
|
||||||
async move {
|
async move {
|
||||||
@@ -30,30 +36,27 @@ async fn spawn_eviction_mock() -> (String, Arc<tokio::sync::Mutex<Vec<String>>>)
|
|||||||
.unwrap_or("")
|
.unwrap_or("")
|
||||||
.to_string();
|
.to_string();
|
||||||
unloaded.lock().await.push(model_id);
|
unloaded.lock().await.push(model_id);
|
||||||
Json(json!({"status": "ok"}))
|
Json(json!({"status": "unloaded"}))
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
.route("/models", get(|| async { Json(json!([])) }))
|
||||||
.route(
|
.route(
|
||||||
"/v1/models",
|
"/models/{model_id}/endpoint",
|
||||||
get(|| async {
|
get(move |Path(_model_id): Path<String>| {
|
||||||
Json(json!({
|
let url = inference_url.clone();
|
||||||
"object": "list",
|
async move { Json(json!({"url": url})) }
|
||||||
"data": []
|
|
||||||
}))
|
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let addr = listener.local_addr().unwrap();
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
(format!("http://{addr}"), unloaded)
|
(base_url, unloaded)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_fleet(endpoint: &str, pinned: Vec<String>, defrag_after: u32) -> Arc<CortexState> {
|
fn make_fleet(endpoint: &str, defrag_after: u32) -> Arc<CortexState> {
|
||||||
let config = GatewayConfig {
|
let config = GatewayConfig {
|
||||||
gateway: GatewaySettings {
|
gateway: GatewaySettings {
|
||||||
listen: "127.0.0.1:0".into(),
|
listen: "127.0.0.1:0".into(),
|
||||||
@@ -63,12 +66,11 @@ fn make_fleet(endpoint: &str, pinned: Vec<String>, defrag_after: u32) -> Arc<Cor
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: defrag_after,
|
defrag_after_cycles: defrag_after,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "gpu-node".into(),
|
name: "gpu-node".into(),
|
||||||
endpoint: endpoint.to_string(),
|
endpoint: endpoint.to_string(),
|
||||||
vram_mb: 24000,
|
|
||||||
pinned,
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
Arc::new(CortexState::from_config(&config))
|
Arc::new(CortexState::from_config(&config))
|
||||||
}
|
}
|
||||||
@@ -76,9 +78,8 @@ fn make_fleet(endpoint: &str, pinned: Vec<String>, defrag_after: u32) -> Arc<Cor
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_evict_lru_model() {
|
async fn test_evict_lru_model() {
|
||||||
let (mock_url, unloaded) = spawn_eviction_mock().await;
|
let (mock_url, unloaded) = spawn_eviction_mock().await;
|
||||||
let fleet = make_fleet(&mock_url, vec![], 0);
|
let fleet = make_fleet(&mock_url, 0);
|
||||||
|
|
||||||
// Seed two loaded models. "old-model" was accessed earlier than "new-model".
|
|
||||||
{
|
{
|
||||||
let mut nodes = fleet.nodes.write().await;
|
let mut nodes = fleet.nodes.write().await;
|
||||||
let node = nodes.get_mut("gpu-node").unwrap();
|
let node = nodes.get_mut("gpu-node").unwrap();
|
||||||
@@ -107,15 +108,12 @@ async fn test_evict_lru_model() {
|
|||||||
.await
|
.await
|
||||||
.expect("eviction should succeed");
|
.expect("eviction should succeed");
|
||||||
|
|
||||||
// The older model should be evicted.
|
|
||||||
assert_eq!(evicted, Some("old-model".to_string()));
|
assert_eq!(evicted, Some("old-model".to_string()));
|
||||||
|
|
||||||
// Mock received the unload call.
|
|
||||||
let calls = unloaded.lock().await;
|
let calls = unloaded.lock().await;
|
||||||
assert_eq!(calls.len(), 1);
|
assert_eq!(calls.len(), 1);
|
||||||
assert_eq!(calls[0], "old-model");
|
assert_eq!(calls[0], "old-model");
|
||||||
|
|
||||||
// Local state updated.
|
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let node = nodes.get("gpu-node").unwrap();
|
let node = nodes.get("gpu-node").unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -128,67 +126,15 @@ async fn test_evict_lru_model() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_eviction_skips_pinned_models() {
|
|
||||||
let (mock_url, unloaded) = spawn_eviction_mock().await;
|
|
||||||
// Pin "old-model" so it can't be evicted.
|
|
||||||
let fleet = make_fleet(&mock_url, vec!["old-model".into()], 0);
|
|
||||||
|
|
||||||
{
|
|
||||||
let mut nodes = fleet.nodes.write().await;
|
|
||||||
let node = nodes.get_mut("gpu-node").unwrap();
|
|
||||||
node.healthy = true;
|
|
||||||
// old-model is pinned and older — normally it would be evicted.
|
|
||||||
node.models.insert(
|
|
||||||
"old-model".into(),
|
|
||||||
ModelEntry {
|
|
||||||
id: "old-model".into(),
|
|
||||||
status: ModelStatus::Loaded,
|
|
||||||
last_accessed: Some(Utc::now() - chrono::Duration::hours(2)),
|
|
||||||
vram_estimate_mb: Some(8000),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
node.models.insert(
|
|
||||||
"new-model".into(),
|
|
||||||
ModelEntry {
|
|
||||||
id: "new-model".into(),
|
|
||||||
status: ModelStatus::Loaded,
|
|
||||||
last_accessed: Some(Utc::now()),
|
|
||||||
vram_estimate_mb: Some(8000),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let evicted = cortex_gateway::evictor::evict_lru_on_node(&fleet, "gpu-node")
|
|
||||||
.await
|
|
||||||
.expect("eviction should succeed");
|
|
||||||
|
|
||||||
// new-model is evicted instead because old-model is pinned.
|
|
||||||
assert_eq!(evicted, Some("new-model".to_string()));
|
|
||||||
|
|
||||||
let calls = unloaded.lock().await;
|
|
||||||
assert_eq!(calls[0], "new-model");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_eviction_nothing_to_evict() {
|
async fn test_eviction_nothing_to_evict() {
|
||||||
let (mock_url, unloaded) = spawn_eviction_mock().await;
|
let (mock_url, unloaded) = spawn_eviction_mock().await;
|
||||||
// Pin the only model.
|
let fleet = make_fleet(&mock_url, 0);
|
||||||
let fleet = make_fleet(&mock_url, vec!["only-model".into()], 0);
|
|
||||||
|
|
||||||
|
// No models at all.
|
||||||
{
|
{
|
||||||
let mut nodes = fleet.nodes.write().await;
|
let mut nodes = fleet.nodes.write().await;
|
||||||
let node = nodes.get_mut("gpu-node").unwrap();
|
nodes.get_mut("gpu-node").unwrap().healthy = true;
|
||||||
node.healthy = true;
|
|
||||||
node.models.insert(
|
|
||||||
"only-model".into(),
|
|
||||||
ModelEntry {
|
|
||||||
id: "only-model".into(),
|
|
||||||
status: ModelStatus::Loaded,
|
|
||||||
last_accessed: None,
|
|
||||||
vram_estimate_mb: Some(8000),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let evicted = cortex_gateway::evictor::evict_lru_on_node(&fleet, "gpu-node")
|
let evicted = cortex_gateway::evictor::evict_lru_on_node(&fleet, "gpu-node")
|
||||||
@@ -196,8 +142,6 @@ async fn test_eviction_nothing_to_evict() {
|
|||||||
.expect("eviction should succeed");
|
.expect("eviction should succeed");
|
||||||
|
|
||||||
assert_eq!(evicted, None);
|
assert_eq!(evicted, None);
|
||||||
|
|
||||||
// No unload call made.
|
|
||||||
let calls = unloaded.lock().await;
|
let calls = unloaded.lock().await;
|
||||||
assert!(calls.is_empty());
|
assert!(calls.is_empty());
|
||||||
}
|
}
|
||||||
@@ -205,7 +149,7 @@ async fn test_eviction_nothing_to_evict() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_eviction_increments_lifecycle_cycles() {
|
async fn test_eviction_increments_lifecycle_cycles() {
|
||||||
let (mock_url, _) = spawn_eviction_mock().await;
|
let (mock_url, _) = spawn_eviction_mock().await;
|
||||||
let fleet = make_fleet(&mock_url, vec![], 0);
|
let fleet = make_fleet(&mock_url, 0);
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut nodes = fleet.nodes.write().await;
|
let mut nodes = fleet.nodes.write().await;
|
||||||
@@ -233,10 +177,9 @@ async fn test_eviction_increments_lifecycle_cycles() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_last_accessed_updated_on_request() {
|
async fn test_last_accessed_updated_on_request() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let (fleet, gw_url) = common::spawn_gateway_with_state(&mock_url).await;
|
let (fleet, gw_url) = common::spawn_gateway_with_state(&mock_url).await;
|
||||||
|
|
||||||
// Verify last_accessed is None initially.
|
|
||||||
{
|
{
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let node = nodes.get("mock-node").unwrap();
|
let node = nodes.get("mock-node").unwrap();
|
||||||
@@ -249,7 +192,6 @@ async fn test_last_accessed_updated_on_request() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make a request.
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
client
|
client
|
||||||
.post(format!("{gw_url}/v1/chat/completions"))
|
.post(format!("{gw_url}/v1/chat/completions"))
|
||||||
@@ -262,7 +204,6 @@ async fn test_last_accessed_updated_on_request() {
|
|||||||
.await
|
.await
|
||||||
.expect("request should succeed");
|
.expect("request should succeed");
|
||||||
|
|
||||||
// Verify last_accessed is now set.
|
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let node = nodes.get("mock-node").unwrap();
|
let node = nodes.get("mock-node").unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
|
|||||||
@@ -4,21 +4,17 @@ use serde_json::json;
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_metrics_emitted_after_proxy() {
|
async fn test_metrics_emitted_after_proxy() {
|
||||||
// Install a test recorder (no HTTP listener, renders to string).
|
|
||||||
// This sets the global recorder, so only one test can do this.
|
|
||||||
let handle = cortex_gateway::metrics::install_test_recorder().expect("recorder should install");
|
let handle = cortex_gateway::metrics::install_test_recorder().expect("recorder should install");
|
||||||
|
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
// Verify no request metrics yet.
|
|
||||||
let before = handle.render();
|
let before = handle.render();
|
||||||
assert!(
|
assert!(
|
||||||
!before.contains("cortex_requests_total"),
|
!before.contains("cortex_requests_total"),
|
||||||
"no request metrics before any requests"
|
"no request metrics before any requests"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Make a successful request.
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let resp = client
|
let resp = client
|
||||||
.post(format!("{gw_url}/v1/chat/completions"))
|
.post(format!("{gw_url}/v1/chat/completions"))
|
||||||
@@ -31,10 +27,8 @@ async fn test_metrics_emitted_after_proxy() {
|
|||||||
.await
|
.await
|
||||||
.expect("request should succeed");
|
.expect("request should succeed");
|
||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 200);
|
||||||
// Consume the response body to ensure the proxy completes.
|
|
||||||
let _body: serde_json::Value = resp.json().await.unwrap();
|
let _body: serde_json::Value = resp.json().await.unwrap();
|
||||||
|
|
||||||
// Check metrics were emitted.
|
|
||||||
let after = handle.render();
|
let after = handle.render();
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
@@ -45,7 +39,6 @@ async fn test_metrics_emitted_after_proxy() {
|
|||||||
after.contains("cortex_request_duration_seconds"),
|
after.contains("cortex_request_duration_seconds"),
|
||||||
"cortex_request_duration_seconds should be present.\nMetrics:\n{after}"
|
"cortex_request_duration_seconds should be present.\nMetrics:\n{after}"
|
||||||
);
|
);
|
||||||
// Should NOT have error or cold start counters for this request.
|
|
||||||
assert!(
|
assert!(
|
||||||
!after.contains("cortex_request_errors_total"),
|
!after.contains("cortex_request_errors_total"),
|
||||||
"no errors expected for a successful request"
|
"no errors expected for a successful request"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
use cortex_core::config::{
|
use cortex_core::config::{
|
||||||
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NodeConfig,
|
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint,
|
||||||
};
|
};
|
||||||
use cortex_core::node::ModelStatus;
|
use cortex_core::node::ModelStatus;
|
||||||
use cortex_gateway::state::CortexState;
|
use cortex_gateway::state::CortexState;
|
||||||
@@ -10,14 +10,11 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_poller_discovers_models() {
|
async fn test_poller_discovers_models() {
|
||||||
// Mock backend reports 2 models: one loaded, one unloaded.
|
// Mock neuron reports 2 models via /models endpoint (neuron format).
|
||||||
let mock_url = common::spawn_mock_backend_with_models(json!({
|
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
"object": "list",
|
{"id": "model-a", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
|
||||||
"data": [
|
{"id": "model-b", "harness": "mistralrs", "status": "unloaded", "devices": [], "vram_used_mb": null}
|
||||||
{ "id": "model-a", "object": "model", "status": "loaded" },
|
]))
|
||||||
{ "id": "model-b", "object": "model", "status": "unloaded" }
|
|
||||||
]
|
|
||||||
}))
|
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let config = GatewayConfig {
|
let config = GatewayConfig {
|
||||||
@@ -29,17 +26,15 @@ async fn test_poller_discovers_models() {
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "test-node".into(),
|
name: "test-node".into(),
|
||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|
||||||
// Before polling: node is unhealthy, no models.
|
|
||||||
{
|
{
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let node = nodes.get("test-node").unwrap();
|
let node = nodes.get("test-node").unwrap();
|
||||||
@@ -47,10 +42,8 @@ async fn test_poller_discovers_models() {
|
|||||||
assert!(node.models.is_empty());
|
assert!(node.models.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Poll once.
|
|
||||||
cortex_gateway::poller::poll_once(&fleet).await;
|
cortex_gateway::poller::poll_once(&fleet).await;
|
||||||
|
|
||||||
// After polling: node is healthy, both models discovered with correct status.
|
|
||||||
{
|
{
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let node = nodes.get("test-node").unwrap();
|
let node = nodes.get("test-node").unwrap();
|
||||||
@@ -69,14 +62,10 @@ async fn test_poller_discovers_models() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_poller_updates_gateway_models_endpoint() {
|
async fn test_poller_updates_gateway_models_endpoint() {
|
||||||
// Mock backend with 2 models.
|
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
let mock_url = common::spawn_mock_backend_with_models(json!({
|
{"id": "model-x", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||||
"object": "list",
|
{"id": "model-y", "harness": "mistralrs", "status": "loaded", "devices": [1], "vram_used_mb": null}
|
||||||
"data": [
|
]))
|
||||||
{ "id": "model-x", "object": "model", "status": "loaded" },
|
|
||||||
{ "id": "model-y", "object": "model", "status": "loaded" }
|
|
||||||
]
|
|
||||||
}))
|
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let config = GatewayConfig {
|
let config = GatewayConfig {
|
||||||
@@ -88,20 +77,16 @@ async fn test_poller_updates_gateway_models_endpoint() {
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "poll-node".into(),
|
name: "poll-node".into(),
|
||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|
||||||
// Poll to discover models and mark node healthy.
|
|
||||||
cortex_gateway::poller::poll_once(&fleet).await;
|
cortex_gateway::poller::poll_once(&fleet).await;
|
||||||
|
|
||||||
// Start gateway with the polled state.
|
|
||||||
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
let addr = listener.local_addr().unwrap();
|
let addr = listener.local_addr().unwrap();
|
||||||
@@ -109,7 +94,6 @@ async fn test_poller_updates_gateway_models_endpoint() {
|
|||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
// Query /v1/models on the gateway.
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let resp = client
|
let resp = client
|
||||||
.get(format!("http://{addr}/v1/models"))
|
.get(format!("http://{addr}/v1/models"))
|
||||||
@@ -127,7 +111,6 @@ async fn test_poller_updates_gateway_models_endpoint() {
|
|||||||
assert!(ids.contains(&"model-x"));
|
assert!(ids.contains(&"model-x"));
|
||||||
assert!(ids.contains(&"model-y"));
|
assert!(ids.contains(&"model-y"));
|
||||||
|
|
||||||
// Verify node attribution in locations.
|
|
||||||
for model in data {
|
for model in data {
|
||||||
let locations = model["locations"].as_array().expect("locations array");
|
let locations = model["locations"].as_array().expect("locations array");
|
||||||
assert_eq!(locations.len(), 1);
|
assert_eq!(locations.len(), 1);
|
||||||
@@ -146,17 +129,15 @@ async fn test_poller_marks_unreachable_node_unhealthy() {
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "dead-node".into(),
|
name: "dead-node".into(),
|
||||||
endpoint: "http://127.0.0.1:1".into(), // unreachable
|
endpoint: "http://127.0.0.1:1".into(),
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|
||||||
// Manually mark healthy to verify poller flips it.
|
|
||||||
{
|
{
|
||||||
let mut nodes = fleet.nodes.write().await;
|
let mut nodes = fleet.nodes.write().await;
|
||||||
nodes.get_mut("dead-node").unwrap().healthy = true;
|
nodes.get_mut("dead-node").unwrap().healthy = true;
|
||||||
@@ -170,14 +151,10 @@ async fn test_poller_marks_unreachable_node_unhealthy() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_poller_removes_stale_models() {
|
async fn test_poller_removes_stale_models() {
|
||||||
// Start with a mock that reports 2 models.
|
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
let mock_url = common::spawn_mock_backend_with_models(json!({
|
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||||
"object": "list",
|
{"id": "drop-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||||
"data": [
|
]))
|
||||||
{ "id": "keep-me", "object": "model", "status": "loaded" },
|
|
||||||
{ "id": "drop-me", "object": "model", "status": "loaded" }
|
|
||||||
]
|
|
||||||
}))
|
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let config = GatewayConfig {
|
let config = GatewayConfig {
|
||||||
@@ -189,35 +166,27 @@ async fn test_poller_removes_stale_models() {
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "test-node".into(),
|
name: "test-node".into(),
|
||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
cortex_gateway::poller::poll_once(&fleet).await;
|
cortex_gateway::poller::poll_once(&fleet).await;
|
||||||
|
|
||||||
// Verify both models exist.
|
|
||||||
{
|
{
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
assert_eq!(nodes.get("test-node").unwrap().models.len(), 2);
|
assert_eq!(nodes.get("test-node").unwrap().models.len(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now spin up a new mock that only reports one model, and re-point the node.
|
// New mock with only one model.
|
||||||
let new_mock_url = common::spawn_mock_backend_with_models(json!({
|
let new_mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
"object": "list",
|
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||||
"data": [
|
]))
|
||||||
{ "id": "keep-me", "object": "model", "status": "loaded" }
|
|
||||||
]
|
|
||||||
}))
|
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Update the node endpoint to point at the new mock.
|
|
||||||
// We can't change node_configs (they're immutable), so instead we'll
|
|
||||||
// create a new fleet with the updated endpoint and poll that.
|
|
||||||
let config2 = GatewayConfig {
|
let config2 = GatewayConfig {
|
||||||
gateway: GatewaySettings {
|
gateway: GatewaySettings {
|
||||||
listen: "127.0.0.1:0".into(),
|
listen: "127.0.0.1:0".into(),
|
||||||
@@ -227,17 +196,16 @@ async fn test_poller_removes_stale_models() {
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "test-node".into(),
|
name: "test-node".into(),
|
||||||
endpoint: new_mock_url,
|
endpoint: new_mock_url,
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet2 = Arc::new(CortexState::from_config(&config2));
|
let fleet2 = Arc::new(CortexState::from_config(&config2));
|
||||||
|
|
||||||
// Seed the stale model so we can verify it gets removed.
|
// Seed stale model.
|
||||||
{
|
{
|
||||||
let mut nodes = fleet2.nodes.write().await;
|
let mut nodes = fleet2.nodes.write().await;
|
||||||
let node = nodes.get_mut("test-node").unwrap();
|
let node = nodes.get_mut("test-node").unwrap();
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use serde_json::json;
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_chat_completion_proxy() {
|
async fn test_chat_completion_proxy() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -33,7 +33,7 @@ async fn test_chat_completion_proxy() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_health_endpoint() {
|
async fn test_health_endpoint() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -53,7 +53,7 @@ async fn test_health_endpoint() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_list_models() {
|
async fn test_list_models() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -75,7 +75,7 @@ async fn test_list_models() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_model_not_found() {
|
async fn test_model_not_found() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -112,12 +112,11 @@ async fn test_no_healthy_nodes() {
|
|||||||
strategy: cortex_core::config::EvictionStrategy::Lru,
|
strategy: cortex_core::config::EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![cortex_core::config::NodeConfig {
|
neurons: vec![cortex_core::config::NeuronEndpoint {
|
||||||
name: "dead-node".into(),
|
name: "dead-node".into(),
|
||||||
endpoint: "http://127.0.0.1:1".into(),
|
endpoint: "http://127.0.0.1:1".into(),
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
let fleet = std::sync::Arc::new(cortex_gateway::state::CortexState::from_config(&config));
|
let fleet = std::sync::Arc::new(cortex_gateway::state::CortexState::from_config(&config));
|
||||||
|
|
||||||
@@ -153,7 +152,7 @@ async fn test_no_healthy_nodes() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_missing_model_field() {
|
async fn test_missing_model_field() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ use std::time::{Duration, Instant};
|
|||||||
async fn test_streaming_sse_passthrough() {
|
async fn test_streaming_sse_passthrough() {
|
||||||
let chunk_count = 5;
|
let chunk_count = 5;
|
||||||
let chunk_delay = Duration::from_millis(50);
|
let chunk_delay = Duration::from_millis(50);
|
||||||
let mock_url = common::spawn_streaming_mock_backend(chunk_count, chunk_delay).await;
|
let mock_url = common::spawn_streaming_mock_neuron(chunk_count, chunk_delay).await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -33,7 +33,6 @@ async fn test_streaming_sse_passthrough() {
|
|||||||
"text/event-stream"
|
"text/event-stream"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Collect SSE chunks as they arrive, recording arrival times.
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let mut chunk_times = Vec::new();
|
let mut chunk_times = Vec::new();
|
||||||
let mut chunks = Vec::new();
|
let mut chunks = Vec::new();
|
||||||
@@ -51,7 +50,6 @@ async fn test_streaming_sse_passthrough() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify we got all content chunks plus [DONE].
|
|
||||||
assert!(
|
assert!(
|
||||||
chunks.len() >= chunk_count + 1,
|
chunks.len() >= chunk_count + 1,
|
||||||
"expected at least {} chunks (got {}): {:?}",
|
"expected at least {} chunks (got {}): {:?}",
|
||||||
@@ -60,10 +58,8 @@ async fn test_streaming_sse_passthrough() {
|
|||||||
chunks,
|
chunks,
|
||||||
);
|
);
|
||||||
|
|
||||||
// The last chunk should be [DONE].
|
|
||||||
assert_eq!(chunks.last().unwrap(), "[DONE]");
|
assert_eq!(chunks.last().unwrap(), "[DONE]");
|
||||||
|
|
||||||
// Verify the content chunks contain expected tokens.
|
|
||||||
for i in 0..chunk_count {
|
for i in 0..chunk_count {
|
||||||
let chunk_json: serde_json::Value =
|
let chunk_json: serde_json::Value =
|
||||||
serde_json::from_str(&chunks[i]).expect("chunk should be valid JSON");
|
serde_json::from_str(&chunks[i]).expect("chunk should be valid JSON");
|
||||||
@@ -73,10 +69,6 @@ async fn test_streaming_sse_passthrough() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify streaming behavior: total time should reflect incremental delivery,
|
|
||||||
// not a single batch. With 5 chunks at 50ms each + [DONE], we expect ~300ms total.
|
|
||||||
// If buffered, all chunks would arrive at once after ~300ms with no spread.
|
|
||||||
// We verify that the last chunk arrived noticeably after the first.
|
|
||||||
let first = chunk_times.first().unwrap();
|
let first = chunk_times.first().unwrap();
|
||||||
let last = chunk_times.last().unwrap();
|
let last = chunk_times.last().unwrap();
|
||||||
let spread = *last - *first;
|
let spread = *last - *first;
|
||||||
@@ -88,7 +80,7 @@ async fn test_streaming_sse_passthrough() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_streaming_done_terminator() {
|
async fn test_streaming_done_terminator() {
|
||||||
let mock_url = common::spawn_streaming_mock_backend(2, Duration::from_millis(10)).await;
|
let mock_url = common::spawn_streaming_mock_neuron(2, Duration::from_millis(10)).await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|||||||
Reference in New Issue
Block a user