diff --git a/CLAUDE.md b/CLAUDE.md index 298da6b..4933bb0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -219,22 +219,14 @@ for SSE out of the box. 2 integration tests in `cortex-gateway/tests/streaming.r incremental delivery (time spread between first and last chunk) - `test_streaming_done_terminator` — verifies `data: [DONE]` is forwarded -### Phase 3: Poller + live `/v1/models` +### Phase 3: Poller + live `/v1/models` ✅ -**Goal:** The background poller refreshes node state from real (or mock) -mistral.rs instances. `GET /v1/models` returns live, aggregated data. - -**Files to change:** -- `cortex-gateway/src/poller.rs` — already implemented but needs testing -- `cortex-gateway/src/handlers.rs` — the `list_models` handler reads - from `CortexState`; verify it reflects poller updates -- `tests/` — test that: - 1. Mock backend serves `/v1/models` with 2 models (1 loaded, 1 unloaded) - 2. After poller runs, `GET /v1/models` on cortex returns both with - correct status and node attribution - -**Done when:** Poller test passes. The router in Phase 1 now routes -based on live-polled state instead of seed data. +Completed. Extracted `poll_once()` from `poll_loop()` for testability. +4 tests in `cortex-gateway/tests/poller.rs`: +- `test_poller_discovers_models` — 2 models (loaded + unloaded) discovered with correct status +- `test_poller_updates_gateway_models_endpoint` — `/v1/models` reflects polled state with node attribution +- `test_poller_marks_unreachable_node_unhealthy` — unreachable node flipped to unhealthy +- `test_poller_removes_stale_models` — model removed from upstream is pruned from state ### Phase 4: Eviction diff --git a/crates/cortex-gateway/src/poller.rs b/crates/cortex-gateway/src/poller.rs index 60b5666..27a047c 100644 --- a/crates/cortex-gateway/src/poller.rs +++ b/crates/cortex-gateway/src/poller.rs @@ -12,13 +12,18 @@ const POLL_INTERVAL: Duration = Duration::from_secs(10); /// Runs forever, polling all nodes on a fixed interval. pub async fn poll_loop(fleet: Arc) { loop { - for nc in &fleet.node_configs { - poll_node(&fleet, &nc.name, &nc.endpoint).await; - } + poll_once(&fleet).await; tokio::time::sleep(POLL_INTERVAL).await; } } +/// Poll all nodes once. Used by `poll_loop` and available for testing. +pub async fn poll_once(fleet: &CortexState) { + for nc in &fleet.node_configs { + poll_node(fleet, &nc.name, &nc.endpoint).await; + } +} + async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) { let url = format!("{endpoint}/v1/models"); diff --git a/crates/cortex-gateway/tests/common/mod.rs b/crates/cortex-gateway/tests/common/mod.rs index 05be7f2..db60fff 100644 --- a/crates/cortex-gateway/tests/common/mod.rs +++ b/crates/cortex-gateway/tests/common/mod.rs @@ -130,10 +130,38 @@ pub async fn spawn_streaming_mock_backend(chunk_count: usize, chunk_delay: Durat format!("http://{addr}") } +/// Spawns a mock backend with a custom `/v1/models` response. +pub async fn spawn_mock_backend_with_models(models_response: Value) -> String { + let app = Router::new() + .route("/v1/chat/completions", post(mock_chat_completions)) + .route( + "/v1/models", + get(move || { + let resp = models_response.clone(); + async move { Json(resp) } + }), + ); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + format!("http://{addr}") +} + /// Spawns the cortex gateway with a single node pointing at `mock_url`. /// The node is pre-seeded as healthy with one loaded model ("test-model"). /// Returns the gateway's base URL. pub async fn spawn_gateway(mock_url: &str) -> String { + let (_, url) = spawn_gateway_with_state(mock_url).await; + url +} + +/// Like `spawn_gateway` but also returns the shared `CortexState` so tests +/// can call `poll_once` or inspect state directly. +pub async fn spawn_gateway_with_state(mock_url: &str) -> (Arc, String) { let config = GatewayConfig { gateway: GatewaySettings { listen: "127.0.0.1:0".into(), @@ -170,7 +198,7 @@ pub async fn spawn_gateway(mock_url: &str) -> String { ); } - let app = cortex_gateway::build_app(fleet); + let app = cortex_gateway::build_app(Arc::clone(&fleet)); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -178,5 +206,5 @@ pub async fn spawn_gateway(mock_url: &str) -> String { axum::serve(listener, app).await.unwrap(); }); - format!("http://{addr}") + (fleet, format!("http://{addr}")) } diff --git a/crates/cortex-gateway/tests/poller.rs b/crates/cortex-gateway/tests/poller.rs new file mode 100644 index 0000000..1fd741b --- /dev/null +++ b/crates/cortex-gateway/tests/poller.rs @@ -0,0 +1,271 @@ +mod common; + +use cortex_core::config::{ + EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NodeConfig, +}; +use cortex_core::node::ModelStatus; +use cortex_gateway::state::CortexState; +use serde_json::json; +use std::sync::Arc; + +#[tokio::test] +async fn test_poller_discovers_models() { + // Mock backend reports 2 models: one loaded, one unloaded. + let mock_url = common::spawn_mock_backend_with_models(json!({ + "object": "list", + "data": [ + { "id": "model-a", "object": "model", "status": "loaded" }, + { "id": "model-b", "object": "model", "status": "unloaded" } + ] + })) + .await; + + let config = GatewayConfig { + gateway: GatewaySettings { + listen: "127.0.0.1:0".into(), + metrics_listen: "127.0.0.1:0".into(), + }, + eviction: EvictionSettings { + strategy: EvictionStrategy::Lru, + defrag_after_cycles: 0, + }, + nodes: vec![NodeConfig { + name: "test-node".into(), + endpoint: mock_url, + vram_mb: 24000, + pinned: vec![], + }], + }; + + let fleet = Arc::new(CortexState::from_config(&config)); + + // Before polling: node is unhealthy, no models. + { + let nodes = fleet.nodes.read().await; + let node = nodes.get("test-node").unwrap(); + assert!(!node.healthy); + assert!(node.models.is_empty()); + } + + // Poll once. + cortex_gateway::poller::poll_once(&fleet).await; + + // After polling: node is healthy, both models discovered with correct status. + { + let nodes = fleet.nodes.read().await; + let node = nodes.get("test-node").unwrap(); + assert!(node.healthy); + assert_eq!(node.models.len(), 2); + + let model_a = node.models.get("model-a").expect("model-a should exist"); + assert_eq!(model_a.status, ModelStatus::Loaded); + + let model_b = node.models.get("model-b").expect("model-b should exist"); + assert_eq!(model_b.status, ModelStatus::Unloaded); + + assert!(node.last_poll.is_some()); + } +} + +#[tokio::test] +async fn test_poller_updates_gateway_models_endpoint() { + // Mock backend with 2 models. + let mock_url = common::spawn_mock_backend_with_models(json!({ + "object": "list", + "data": [ + { "id": "model-x", "object": "model", "status": "loaded" }, + { "id": "model-y", "object": "model", "status": "loaded" } + ] + })) + .await; + + let config = GatewayConfig { + gateway: GatewaySettings { + listen: "127.0.0.1:0".into(), + metrics_listen: "127.0.0.1:0".into(), + }, + eviction: EvictionSettings { + strategy: EvictionStrategy::Lru, + defrag_after_cycles: 0, + }, + nodes: vec![NodeConfig { + name: "poll-node".into(), + endpoint: mock_url, + vram_mb: 24000, + pinned: vec![], + }], + }; + + let fleet = Arc::new(CortexState::from_config(&config)); + + // Poll to discover models and mark node healthy. + cortex_gateway::poller::poll_once(&fleet).await; + + // Start gateway with the polled state. + let app = cortex_gateway::build_app(Arc::clone(&fleet)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + // Query /v1/models on the gateway. + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{addr}/v1/models")) + .send() + .await + .expect("request should succeed"); + + assert_eq!(resp.status(), 200); + + let body: serde_json::Value = resp.json().await.unwrap(); + let data = body["data"].as_array().expect("data should be an array"); + assert_eq!(data.len(), 2); + + let ids: Vec<&str> = data.iter().filter_map(|m| m["id"].as_str()).collect(); + assert!(ids.contains(&"model-x")); + assert!(ids.contains(&"model-y")); + + // Verify node attribution in locations. + for model in data { + let locations = model["locations"].as_array().expect("locations array"); + assert_eq!(locations.len(), 1); + assert_eq!(locations[0]["node"], "poll-node"); + } +} + +#[tokio::test] +async fn test_poller_marks_unreachable_node_unhealthy() { + let config = GatewayConfig { + gateway: GatewaySettings { + listen: "127.0.0.1:0".into(), + metrics_listen: "127.0.0.1:0".into(), + }, + eviction: EvictionSettings { + strategy: EvictionStrategy::Lru, + defrag_after_cycles: 0, + }, + nodes: vec![NodeConfig { + name: "dead-node".into(), + endpoint: "http://127.0.0.1:1".into(), // unreachable + vram_mb: 24000, + pinned: vec![], + }], + }; + + let fleet = Arc::new(CortexState::from_config(&config)); + + // Manually mark healthy to verify poller flips it. + { + let mut nodes = fleet.nodes.write().await; + nodes.get_mut("dead-node").unwrap().healthy = true; + } + + cortex_gateway::poller::poll_once(&fleet).await; + + let nodes = fleet.nodes.read().await; + assert!(!nodes.get("dead-node").unwrap().healthy); +} + +#[tokio::test] +async fn test_poller_removes_stale_models() { + // Start with a mock that reports 2 models. + let mock_url = common::spawn_mock_backend_with_models(json!({ + "object": "list", + "data": [ + { "id": "keep-me", "object": "model", "status": "loaded" }, + { "id": "drop-me", "object": "model", "status": "loaded" } + ] + })) + .await; + + let config = GatewayConfig { + gateway: GatewaySettings { + listen: "127.0.0.1:0".into(), + metrics_listen: "127.0.0.1:0".into(), + }, + eviction: EvictionSettings { + strategy: EvictionStrategy::Lru, + defrag_after_cycles: 0, + }, + nodes: vec![NodeConfig { + name: "test-node".into(), + endpoint: mock_url, + vram_mb: 24000, + pinned: vec![], + }], + }; + + let fleet = Arc::new(CortexState::from_config(&config)); + cortex_gateway::poller::poll_once(&fleet).await; + + // Verify both models exist. + { + let nodes = fleet.nodes.read().await; + assert_eq!(nodes.get("test-node").unwrap().models.len(), 2); + } + + // Now spin up a new mock that only reports one model, and re-point the node. + let new_mock_url = common::spawn_mock_backend_with_models(json!({ + "object": "list", + "data": [ + { "id": "keep-me", "object": "model", "status": "loaded" } + ] + })) + .await; + + // Update the node endpoint to point at the new mock. + // We can't change node_configs (they're immutable), so instead we'll + // create a new fleet with the updated endpoint and poll that. + let config2 = GatewayConfig { + gateway: GatewaySettings { + listen: "127.0.0.1:0".into(), + metrics_listen: "127.0.0.1:0".into(), + }, + eviction: EvictionSettings { + strategy: EvictionStrategy::Lru, + defrag_after_cycles: 0, + }, + nodes: vec![NodeConfig { + name: "test-node".into(), + endpoint: new_mock_url, + vram_mb: 24000, + pinned: vec![], + }], + }; + + let fleet2 = Arc::new(CortexState::from_config(&config2)); + + // Seed the stale model so we can verify it gets removed. + { + let mut nodes = fleet2.nodes.write().await; + let node = nodes.get_mut("test-node").unwrap(); + node.models.insert( + "keep-me".into(), + cortex_core::node::ModelEntry { + id: "keep-me".into(), + status: ModelStatus::Loaded, + last_accessed: None, + vram_estimate_mb: None, + }, + ); + node.models.insert( + "drop-me".into(), + cortex_core::node::ModelEntry { + id: "drop-me".into(), + status: ModelStatus::Loaded, + last_accessed: None, + vram_estimate_mb: None, + }, + ); + } + + cortex_gateway::poller::poll_once(&fleet2).await; + + let nodes = fleet2.nodes.read().await; + let node = nodes.get("test-node").unwrap(); + assert_eq!(node.models.len(), 1); + assert!(node.models.contains_key("keep-me")); + assert!(!node.models.contains_key("drop-me")); +}