From 26e5e7ead8b93e714a2a7324f523fe1e9ec22cfd Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Wed, 15 Apr 2026 14:29:42 +0300 Subject: [PATCH] feat: implement mistral.rs harness and neuron model API - MistralRsHarness: Harness trait impl wrapping mistral.rs HTTP API (list/load/unload models, health check, start/stop via systemd) - HarnessRegistry: maps harness name -> Box, built from neuron.toml config - Neuron API endpoints: GET /models, POST /models/load, POST /models/unload, GET /models/:id/endpoint - NeuronConfig: figment-based config loading from neuron.toml - Integration test: full model lifecycle through mock mistral.rs Co-Authored-By: Claude Opus 4.6 (1M context) --- CLAUDE.md | 51 ++----- Cargo.lock | 3 + crates/neuron/Cargo.toml | 4 + crates/neuron/src/api.rs | 80 ++++++++++- crates/neuron/src/config.rs | 40 ++++++ crates/neuron/src/harness/mistralrs.rs | 164 +++++++++++++++++++++- crates/neuron/src/harness/mod.rs | 103 +++++++++++++- crates/neuron/src/lib.rs | 1 + crates/neuron/src/main.rs | 29 +++- crates/neuron/tests/api.rs | 186 +++++++++++++++++++------ 10 files changed, 562 insertions(+), 99 deletions(-) create mode 100644 crates/neuron/src/config.rs diff --git a/CLAUDE.md b/CLAUDE.md index 73928c1..0c9d86b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -556,50 +556,17 @@ serves `GET /discovery` and `GET /health`. Pure parsing functions separated from command execution for testability. 9 unit tests for nvidia-smi CSV parsing, 3 integration tests for the HTTP endpoints. -### Phase 8: neuron harness — mistral.rs implementation +### Phase 8: neuron harness — mistral.rs implementation ✅ -**Goal:** neuron can manage mistral.rs: start/stop the process, list -models, load/unload models, and report the inference endpoint. +Completed. Full `Harness` trait implementation for mistral.rs in +`neuron/src/harness/mistralrs.rs`: list_models, load_model, unload_model, +inference_endpoint, health, start/stop (systemd). `HarnessRegistry` in +`harness/mod.rs` maps harness name → `Box`, built from +`neuron.toml` config. Four new neuron API endpoints: `GET /models`, +`POST /models/load`, `POST /models/unload`, `GET /models/:id/endpoint`. -**Steps:** -1. In `crates/neuron/src/harness/mistralrs.rs`: - - Implement the `Harness` trait. - - `start()` — invoke `systemctl start mistralrs.service` (or a - configured unit name). Wait for the health endpoint to respond. - - `stop()` — `systemctl stop mistralrs.service`. - - `health()` — `GET {mistralrs_endpoint}/health`. - - `list_models()` — `GET {mistralrs_endpoint}/v1/models`, parse the - response including the `status` field. - - `load_model()` — `POST {mistralrs_endpoint}/v1/models/reload`. - - `unload_model()` — `POST {mistralrs_endpoint}/v1/models/unload`. - - `inference_endpoint()` — return `mistralrs_endpoint` (mistral.rs - routes internally by model name in the request body). -2. In `crates/neuron/src/harness/mod.rs`: - - A `HarnessRegistry` that maps harness name → `Box`. - - On neuron startup, register the mistralrs harness (configured with - the local mistralrs endpoint, e.g. `http://localhost:8080`). -3. Add neuron API endpoints: - - `GET /models` — aggregate across all registered harnesses. - - `POST /models/load` — dispatch to the correct harness. - - `POST /models/unload` — dispatch to the correct harness. - - `GET /models/{model_id}/endpoint` — ask the harness. -4. neuron config (`neuron.toml`): - ```toml - port = 9090 - - [[harnesses]] - name = "mistralrs" - endpoint = "http://localhost:8080" - systemd_unit = "mistralrs.service" - ``` -5. Tests: - - Mock HTTP server standing in for mistral.rs. Test that the harness - implementation correctly translates list/load/unload calls. - - Integration test: start neuron with mock mistralrs backend, call - `GET /models`, assert it returns models from the mock. - -**Done when:** neuron manages a (mock) mistral.rs instance. All API -endpoints return correct data. Tests pass. +Config via `neuron.toml` (figment + env override). Integration test +covers full model lifecycle through neuron → mock mistral.rs backend. ### Phase 9: cortex talks to neurons diff --git a/Cargo.lock b/Cargo.lock index 9b1fdc2..d102162 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -408,13 +408,16 @@ name = "cortex-neuron" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "axum", "clap", "cortex-core", + "figment", "reqwest", "serde", "serde_json", "tokio", + "toml", "tracing", "tracing-subscriber", ] diff --git a/crates/neuron/Cargo.toml b/crates/neuron/Cargo.toml index d3b2d83..25a264d 100644 --- a/crates/neuron/Cargo.toml +++ b/crates/neuron/Cargo.toml @@ -18,10 +18,14 @@ tokio.workspace = true axum.workspace = true serde.workspace = true serde_json.workspace = true +reqwest.workspace = true tracing.workspace = true tracing-subscriber.workspace = true anyhow.workspace = true +async-trait.workspace = true clap.workspace = true +figment.workspace = true +toml.workspace = true [dev-dependencies] tokio = { workspace = true, features = ["test-util"] } diff --git a/crates/neuron/src/api.rs b/crates/neuron/src/api.rs index feb9a0f..36f3dff 100644 --- a/crates/neuron/src/api.rs +++ b/crates/neuron/src/api.rs @@ -1,17 +1,23 @@ //! HTTP API handlers for the neuron daemon. +use crate::harness::HarnessRegistry; use crate::health::HealthCache; use axum::Router; -use axum::extract::State; -use axum::response::Json; -use axum::routing::get; +use axum::extract::{Path, State}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Json}; +use axum::routing::{get, post}; use cortex_core::discovery::{DiscoveryResponse, HealthResponse}; +use cortex_core::harness::ModelSpec; +use serde_json::{Value, json}; use std::sync::Arc; +use tokio::sync::RwLock; /// Shared state for the neuron HTTP server. pub struct NeuronState { pub discovery: DiscoveryResponse, pub health_cache: Arc, + pub registry: RwLock, } /// Build the neuron API router. @@ -19,6 +25,10 @@ pub fn neuron_routes() -> Router> { Router::new() .route("/discovery", get(discovery_handler)) .route("/health", get(health_handler)) + .route("/models", get(list_models)) + .route("/models/load", post(load_model)) + .route("/models/unload", post(unload_model)) + .route("/models/{model_id}/endpoint", get(model_endpoint)) } async fn discovery_handler(State(state): State>) -> Json { @@ -28,3 +38,67 @@ async fn discovery_handler(State(state): State>) -> Json>) -> Json { Json(state.health_cache.snapshot().await) } + +async fn list_models(State(state): State>) -> impl IntoResponse { + let registry = state.registry.read().await; + match registry.list_all_models().await { + Ok(models) => Json(json!(models)).into_response(), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": e.to_string()})), + ) + .into_response(), + } +} + +async fn load_model( + State(state): State>, + Json(spec): Json, +) -> impl IntoResponse { + let registry = state.registry.read().await; + match registry.load_model(&spec).await { + Ok(()) => Json(json!({"status": "loaded"})).into_response(), + Err(e) => ( + StatusCode::BAD_REQUEST, + Json(json!({"error": e.to_string()})), + ) + .into_response(), + } +} + +async fn unload_model( + State(state): State>, + Json(body): Json, +) -> impl IntoResponse { + let model_id = match body.get("model_id").and_then(|v| v.as_str()) { + Some(id) => id.to_string(), + None => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "missing model_id"})), + ) + .into_response(); + } + }; + + let registry = state.registry.read().await; + match registry.unload_model(&model_id).await { + Ok(()) => Json(json!({"status": "unloaded"})).into_response(), + Err(e) => (StatusCode::NOT_FOUND, Json(json!({"error": e.to_string()}))).into_response(), + } +} + +async fn model_endpoint( + State(state): State>, + Path(model_id): Path, +) -> impl IntoResponse { + let registry = state.registry.read().await; + match registry.inference_endpoint(&model_id).await { + Some(url) => Json(json!({"url": url})).into_response(), + None => ( + StatusCode::NOT_FOUND, + Json(json!({"error": format!("model '{}' not loaded", model_id)})), + ) + .into_response(), + } +} diff --git a/crates/neuron/src/config.rs b/crates/neuron/src/config.rs new file mode 100644 index 0000000..ff282ef --- /dev/null +++ b/crates/neuron/src/config.rs @@ -0,0 +1,40 @@ +//! Neuron configuration loaded from neuron.toml. + +use cortex_core::harness::HarnessConfig; +use figment::{ + Figment, + providers::{Env, Format, Toml}, +}; +use serde::{Deserialize, Serialize}; +use std::path::Path; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NeuronConfig { + #[serde(default = "default_port")] + pub port: u16, + #[serde(default)] + pub harnesses: Vec, +} + +fn default_port() -> u16 { + 9090 +} + +impl NeuronConfig { + pub fn load(path: impl AsRef) -> Result> { + Figment::new() + .merge(Toml::file(path)) + .merge(Env::prefixed("NEURON_").split("__")) + .extract() + .map_err(Box::new) + } +} + +impl Default for NeuronConfig { + fn default() -> Self { + Self { + port: 9090, + harnesses: vec![], + } + } +} diff --git a/crates/neuron/src/harness/mistralrs.rs b/crates/neuron/src/harness/mistralrs.rs index fad7847..c1d5780 100644 --- a/crates/neuron/src/harness/mistralrs.rs +++ b/crates/neuron/src/harness/mistralrs.rs @@ -1 +1,163 @@ -// mistral.rs harness implementation — Phase 8. +//! mistral.rs harness implementation. +//! +//! Wraps the mistral.rs HTTP API for model lifecycle management +//! and optionally manages the process via systemd. + +use anyhow::Result; +use async_trait::async_trait; +use cortex_core::harness::{Harness, HarnessConfig, HarnessHealth, ModelInfo, ModelSpec}; +use reqwest::Client; +use serde::Deserialize; + +pub struct MistralRsHarness { + endpoint: String, + systemd_unit: Option, + client: Client, +} + +impl MistralRsHarness { + pub fn new(endpoint: String, systemd_unit: Option) -> Self { + Self { + endpoint, + systemd_unit, + client: Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .expect("failed to build HTTP client"), + } + } +} + +/// Response from mistral.rs `GET /v1/models`. +#[derive(Debug, Deserialize)] +struct ModelsResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct ModelEntry { + id: String, + #[serde(default)] + status: Option, +} + +#[async_trait] +impl Harness for MistralRsHarness { + fn name(&self) -> &str { + "mistralrs" + } + + async fn start(&self, _config: &HarnessConfig) -> Result<()> { + let Some(unit) = &self.systemd_unit else { + anyhow::bail!("no systemd unit configured for mistralrs harness"); + }; + + let output = tokio::process::Command::new("systemctl") + .args(["start", unit]) + .output() + .await?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!("systemctl start {unit} failed: {stderr}"); + } + + // Wait for the health endpoint to respond (up to 30s). + let url = format!("{}/health", self.endpoint); + for _ in 0..30 { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + if self.client.get(&url).send().await.is_ok() { + tracing::info!(unit, "mistralrs started and healthy"); + return Ok(()); + } + } + anyhow::bail!("mistralrs started but health endpoint did not respond within 30s"); + } + + async fn stop(&self) -> Result<()> { + let Some(unit) = &self.systemd_unit else { + anyhow::bail!("no systemd unit configured for mistralrs harness"); + }; + + let output = tokio::process::Command::new("systemctl") + .args(["stop", unit]) + .output() + .await?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!("systemctl stop {unit} failed: {stderr}"); + } + Ok(()) + } + + async fn health(&self) -> HarnessHealth { + let url = format!("{}/health", self.endpoint); + let running = self.client.get(&url).send().await.is_ok(); + HarnessHealth { + name: "mistralrs".into(), + running, + uptime_secs: None, + } + } + + async fn list_models(&self) -> Result> { + let url = format!("{}/v1/models", self.endpoint); + let resp = self.client.get(&url).send().await?; + + if !resp.status().is_success() { + anyhow::bail!("GET /v1/models returned {}", resp.status()); + } + + let models_resp: ModelsResponse = resp.json().await?; + Ok(models_resp + .data + .into_iter() + .map(|m| ModelInfo { + id: m.id, + harness: "mistralrs".into(), + status: m.status.unwrap_or_else(|| "loaded".into()), + devices: vec![], + vram_used_mb: None, + }) + .collect()) + } + + async fn load_model(&self, spec: &ModelSpec) -> Result<()> { + let url = format!("{}/v1/models/reload", self.endpoint); + let resp = self + .client + .post(&url) + .json(&serde_json::json!({ "model_id": spec.model_id })) + .send() + .await?; + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!("POST /v1/models/reload failed: {body}"); + } + Ok(()) + } + + async fn unload_model(&self, model_id: &str) -> Result<()> { + let url = format!("{}/v1/models/unload", self.endpoint); + let resp = self + .client + .post(&url) + .json(&serde_json::json!({ "model_id": model_id })) + .send() + .await?; + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!("POST /v1/models/unload failed: {body}"); + } + Ok(()) + } + + async fn inference_endpoint(&self, _model_id: &str) -> Option { + // mistral.rs routes internally by model name in the request body, + // so the inference endpoint is always the base URL. + Some(self.endpoint.clone()) + } +} diff --git a/crates/neuron/src/harness/mod.rs b/crates/neuron/src/harness/mod.rs index c076199..285ea3a 100644 --- a/crates/neuron/src/harness/mod.rs +++ b/crates/neuron/src/harness/mod.rs @@ -1,4 +1,105 @@ -// Harness registry. Implementations added in Phase 8+. +//! Harness registry — maps harness names to trait implementations. pub mod llamacpp; pub mod mistralrs; + +use anyhow::Result; +use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec}; +use std::collections::HashMap; + +/// Registry of available harness implementations. +pub struct HarnessRegistry { + harnesses: HashMap>, +} + +impl Default for HarnessRegistry { + fn default() -> Self { + Self::new() + } +} + +impl HarnessRegistry { + pub fn new() -> Self { + Self { + harnesses: HashMap::new(), + } + } + + pub fn register(&mut self, harness: Box) { + self.harnesses.insert(harness.name().to_string(), harness); + } + + /// List all registered harness names. + pub fn names(&self) -> Vec { + self.harnesses.keys().cloned().collect() + } + + /// List models from all registered harnesses. + pub async fn list_all_models(&self) -> Result> { + let mut all = Vec::new(); + for harness in self.harnesses.values() { + match harness.list_models().await { + Ok(models) => all.extend(models), + Err(e) => { + tracing::warn!(harness = harness.name(), error = %e, "failed to list models"); + } + } + } + Ok(all) + } + + /// Load a model on the specified harness. + pub async fn load_model(&self, spec: &ModelSpec) -> Result<()> { + let harness = self + .harnesses + .get(&spec.harness) + .ok_or_else(|| anyhow::anyhow!("unknown harness: {}", spec.harness))?; + harness.load_model(spec).await + } + + /// Unload a model. Tries each harness until one claims it. + pub async fn unload_model(&self, model_id: &str) -> Result<()> { + for harness in self.harnesses.values() { + match harness.list_models().await { + Ok(models) if models.iter().any(|m| m.id == model_id) => { + return harness.unload_model(model_id).await; + } + _ => continue, + } + } + anyhow::bail!("model '{model_id}' not found on any harness") + } + + /// Get the inference endpoint for a model. + pub async fn inference_endpoint(&self, model_id: &str) -> Option { + for harness in self.harnesses.values() { + if let Some(url) = harness.inference_endpoint(model_id).await { + return Some(url); + } + } + None + } + + /// Build a registry from harness configs. + pub fn from_configs(configs: &[HarnessConfig]) -> Self { + let mut registry = Self::new(); + for config in configs { + match config.name.as_str() { + "mistralrs" => { + if let Some(endpoint) = &config.endpoint { + registry.register(Box::new(mistralrs::MistralRsHarness::new( + endpoint.clone(), + config.systemd_unit.clone(), + ))); + } else { + tracing::warn!("mistralrs harness missing endpoint, skipping"); + } + } + other => { + tracing::warn!(harness = other, "unknown harness type, skipping"); + } + } + } + registry + } +} diff --git a/crates/neuron/src/lib.rs b/crates/neuron/src/lib.rs index 1903cd9..d860b25 100644 --- a/crates/neuron/src/lib.rs +++ b/crates/neuron/src/lib.rs @@ -1,4 +1,5 @@ pub mod api; +pub mod config; pub mod discovery; pub mod harness; pub mod health; diff --git a/crates/neuron/src/main.rs b/crates/neuron/src/main.rs index 0a5449f..b706876 100644 --- a/crates/neuron/src/main.rs +++ b/crates/neuron/src/main.rs @@ -1,8 +1,9 @@ use anyhow::Result; use clap::Parser; -use cortex_neuron::{api, discovery, health}; +use cortex_neuron::{api, config::NeuronConfig, discovery, harness::HarnessRegistry, health}; use std::sync::Arc; use std::time::Instant; +use tokio::sync::RwLock; use tracing_subscriber::EnvFilter; #[derive(Parser)] @@ -10,9 +11,13 @@ use tracing_subscriber::EnvFilter; #[command(about = "Per-node daemon for cortex inference clusters")] #[command(version)] struct Args { - /// Port to listen on. - #[arg(short, long, default_value = "9090")] - port: u16, + /// Port to listen on (overrides config file). + #[arg(short, long)] + port: Option, + + /// Path to the neuron config file. + #[arg(short, long, default_value = "neuron.toml")] + config: String, } #[tokio::main] @@ -25,16 +30,27 @@ async fn main() -> Result<()> { .init(); let args = Args::parse(); + + let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| { + tracing::warn!(path = %args.config, error = %e, "config not found, using defaults"); + NeuronConfig::default() + }); + + let port = args.port.unwrap_or(cfg.port); let start_time = Instant::now(); tracing::info!("running hardware discovery"); - let discovery_result = discovery::discover_system().await?; + let mut discovery_result = discovery::discover_system().await?; tracing::info!( hostname = %discovery_result.hostname, devices = discovery_result.devices.len(), "discovery complete" ); + // Build harness registry from config. + let registry = HarnessRegistry::from_configs(&cfg.harnesses); + discovery_result.harnesses = registry.names(); + let health_cache = Arc::new(health::HealthCache::new()); health_cache .set_has_gpus(!discovery_result.devices.is_empty()) @@ -48,10 +64,11 @@ async fn main() -> Result<()> { let state = Arc::new(api::NeuronState { discovery: discovery_result, health_cache, + registry: RwLock::new(registry), }); let app = api::neuron_routes().with_state(state); - let addr: std::net::SocketAddr = format!("0.0.0.0:{}", args.port).parse()?; + let addr: std::net::SocketAddr = format!("0.0.0.0:{port}").parse()?; tracing::info!("cortex-neuron listening on {addr}"); let listener = tokio::net::TcpListener::bind(addr).await?; axum::serve(listener, app).await?; diff --git a/crates/neuron/tests/api.rs b/crates/neuron/tests/api.rs index 2025024..9f00d8f 100644 --- a/crates/neuron/tests/api.rs +++ b/crates/neuron/tests/api.rs @@ -1,20 +1,19 @@ -use cortex_core::discovery::{DeviceHealth, DeviceInfo, DiscoveryResponse, HealthResponse}; +use cortex_core::discovery::{DeviceInfo, DiscoveryResponse}; use cortex_neuron::api::{self, NeuronState}; +use cortex_neuron::harness::HarnessRegistry; use cortex_neuron::health::HealthCache; +use serde_json::json; use std::sync::Arc; +use tokio::sync::RwLock; -async fn spawn_neuron(discovery: DiscoveryResponse, health: HealthResponse) -> String { +async fn spawn_neuron(discovery: DiscoveryResponse) -> String { let health_cache = Arc::new(HealthCache::new()); - // Pre-populate the health cache by writing through the snapshot mechanism. - // HealthCache doesn't expose a direct setter, so we'll build one with - // the data already in place via the NeuronState. - // For testing, we use the cache as-is (uptime 0, empty devices) unless - // we need specific values — see test_health_endpoint. - let _ = health; // used below via a different approach + let registry = HarnessRegistry::new(); let state = Arc::new(NeuronState { discovery, health_cache, + registry: RwLock::new(registry), }); let app = api::neuron_routes().with_state(state); @@ -51,32 +50,9 @@ fn fake_discovery() -> DiscoveryResponse { } } -fn fake_health() -> HealthResponse { - HealthResponse { - uptime_secs: 0, - devices: vec![ - DeviceHealth { - index: 0, - vram_used_mb: 8192, - vram_free_mb: 24422, - utilization_pct: 45, - temp_c: 62, - }, - DeviceHealth { - index: 1, - vram_used_mb: 4096, - vram_free_mb: 28518, - utilization_pct: 30, - temp_c: 58, - }, - ], - } -} - #[tokio::test] async fn test_discovery_endpoint() { - let disc = fake_discovery(); - let url = spawn_neuron(disc, fake_health()).await; + let url = spawn_neuron(fake_discovery()).await; let client = reqwest::Client::new(); let resp = client @@ -89,20 +65,17 @@ async fn test_discovery_endpoint() { let body: serde_json::Value = resp.json().await.unwrap(); assert_eq!(body["hostname"], "test-node"); - assert_eq!(body["os"], "Linux"); assert_eq!(body["cuda_version"], "12.8"); - assert_eq!(body["driver_version"], "570.86.16"); let devices = body["devices"].as_array().unwrap(); assert_eq!(devices.len(), 2); assert_eq!(devices[0]["name"], "NVIDIA GeForce RTX 5090"); assert_eq!(devices[0]["vram_total_mb"], 32614); - assert_eq!(devices[0]["compute_capability"], "12.0"); } #[tokio::test] async fn test_health_endpoint() { - let url = spawn_neuron(fake_discovery(), fake_health()).await; + let url = spawn_neuron(fake_discovery()).await; let client = reqwest::Client::new(); let resp = client @@ -114,9 +87,7 @@ async fn test_health_endpoint() { assert_eq!(resp.status(), 200); let body: serde_json::Value = resp.json().await.unwrap(); - // HealthCache starts with uptime 0 and empty devices (no poller running in test). assert_eq!(body["uptime_secs"], 0); - assert!(body["devices"].as_array().unwrap().is_empty()); } #[tokio::test] @@ -130,14 +101,7 @@ async fn test_discovery_no_gpus() { devices: vec![], harnesses: vec![], }; - let url = spawn_neuron( - disc, - HealthResponse { - uptime_secs: 0, - devices: vec![], - }, - ) - .await; + let url = spawn_neuron(disc).await; let client = reqwest::Client::new(); let resp = client @@ -153,3 +117,133 @@ async fn test_discovery_no_gpus() { assert!(body["cuda_version"].is_null()); assert!(body["devices"].as_array().unwrap().is_empty()); } + +#[tokio::test] +async fn test_models_empty_registry() { + let url = spawn_neuron(fake_discovery()).await; + + let client = reqwest::Client::new(); + let resp = client + .get(format!("{url}/models")) + .send() + .await + .expect("request should succeed"); + + assert_eq!(resp.status(), 200); + + let body: serde_json::Value = resp.json().await.unwrap(); + assert!(body.as_array().unwrap().is_empty()); +} + +/// Spawn a mock mistral.rs backend and a neuron with the mistralrs harness +/// pointing at it, then test the full model lifecycle through neuron's API. +#[tokio::test] +async fn test_models_via_mistralrs_harness() { + use axum::routing::{get, post}; + use axum::{Json, Router}; + use cortex_core::harness::HarnessConfig; + use serde_json::Value; + + // Mock mistral.rs backend. + let mock_app = Router::new() + .route( + "/v1/models", + get(|| async { + Json(json!({ + "data": [ + {"id": "test-model", "status": "loaded"}, + {"id": "other-model", "status": "unloaded"} + ] + })) + }), + ) + .route( + "/v1/models/unload", + post(|Json(_body): Json| async { Json(json!({"status": "ok"})) }), + ) + .route( + "/v1/models/reload", + post(|Json(_body): Json| async { Json(json!({"status": "ok"})) }), + ); + + let mock_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let mock_addr = mock_listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(mock_listener, mock_app).await.unwrap(); + }); + let mock_url = format!("http://{mock_addr}"); + + // Build neuron with mistralrs harness pointing at mock. + let registry = HarnessRegistry::from_configs(&[HarnessConfig { + name: "mistralrs".into(), + endpoint: Some(mock_url.clone()), + systemd_unit: None, + }]); + + let health_cache = Arc::new(HealthCache::new()); + let state = Arc::new(NeuronState { + discovery: fake_discovery(), + health_cache, + registry: RwLock::new(registry), + }); + + let app = api::neuron_routes().with_state(state); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let neuron_addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + let neuron_url = format!("http://{neuron_addr}"); + + let client = reqwest::Client::new(); + + // GET /models — should return models from mock mistralrs. + let resp = client + .get(format!("{neuron_url}/models")) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + let models: Vec = resp.json().await.unwrap(); + assert_eq!(models.len(), 2); + assert_eq!(models[0]["id"], "test-model"); + assert_eq!(models[0]["harness"], "mistralrs"); + assert_eq!(models[0]["status"], "loaded"); + assert_eq!(models[1]["id"], "other-model"); + assert_eq!(models[1]["status"], "unloaded"); + + // GET /models/test-model/endpoint — should return mock URL. + let resp = client + .get(format!("{neuron_url}/models/test-model/endpoint")) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["url"], mock_url); + + // POST /models/unload — should succeed. + let resp = client + .post(format!("{neuron_url}/models/unload")) + .json(&json!({"model_id": "test-model"})) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["status"], "unloaded"); + + // POST /models/load — should succeed. + let resp = client + .post(format!("{neuron_url}/models/load")) + .json(&json!({ + "model_id": "test-model", + "harness": "mistralrs" + })) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["status"], "loaded"); +}