feat: implement mistral.rs harness and neuron model API
- MistralRsHarness: Harness trait impl wrapping mistral.rs HTTP API (list/load/unload models, health check, start/stop via systemd) - HarnessRegistry: maps harness name -> Box<dyn Harness>, built from neuron.toml config - Neuron API endpoints: GET /models, POST /models/load, POST /models/unload, GET /models/:id/endpoint - NeuronConfig: figment-based config loading from neuron.toml - Integration test: full model lifecycle through mock mistral.rs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
51
CLAUDE.md
51
CLAUDE.md
@@ -556,50 +556,17 @@ serves `GET /discovery` and `GET /health`. Pure parsing functions
|
|||||||
separated from command execution for testability. 9 unit tests for
|
separated from command execution for testability. 9 unit tests for
|
||||||
nvidia-smi CSV parsing, 3 integration tests for the HTTP endpoints.
|
nvidia-smi CSV parsing, 3 integration tests for the HTTP endpoints.
|
||||||
|
|
||||||
### Phase 8: neuron harness — mistral.rs implementation
|
### Phase 8: neuron harness — mistral.rs implementation ✅
|
||||||
|
|
||||||
**Goal:** neuron can manage mistral.rs: start/stop the process, list
|
Completed. Full `Harness` trait implementation for mistral.rs in
|
||||||
models, load/unload models, and report the inference endpoint.
|
`neuron/src/harness/mistralrs.rs`: list_models, load_model, unload_model,
|
||||||
|
inference_endpoint, health, start/stop (systemd). `HarnessRegistry` in
|
||||||
|
`harness/mod.rs` maps harness name → `Box<dyn Harness>`, built from
|
||||||
|
`neuron.toml` config. Four new neuron API endpoints: `GET /models`,
|
||||||
|
`POST /models/load`, `POST /models/unload`, `GET /models/:id/endpoint`.
|
||||||
|
|
||||||
**Steps:**
|
Config via `neuron.toml` (figment + env override). Integration test
|
||||||
1. In `crates/neuron/src/harness/mistralrs.rs`:
|
covers full model lifecycle through neuron → mock mistral.rs backend.
|
||||||
- Implement the `Harness` trait.
|
|
||||||
- `start()` — invoke `systemctl start mistralrs.service` (or a
|
|
||||||
configured unit name). Wait for the health endpoint to respond.
|
|
||||||
- `stop()` — `systemctl stop mistralrs.service`.
|
|
||||||
- `health()` — `GET {mistralrs_endpoint}/health`.
|
|
||||||
- `list_models()` — `GET {mistralrs_endpoint}/v1/models`, parse the
|
|
||||||
response including the `status` field.
|
|
||||||
- `load_model()` — `POST {mistralrs_endpoint}/v1/models/reload`.
|
|
||||||
- `unload_model()` — `POST {mistralrs_endpoint}/v1/models/unload`.
|
|
||||||
- `inference_endpoint()` — return `mistralrs_endpoint` (mistral.rs
|
|
||||||
routes internally by model name in the request body).
|
|
||||||
2. In `crates/neuron/src/harness/mod.rs`:
|
|
||||||
- A `HarnessRegistry` that maps harness name → `Box<dyn Harness>`.
|
|
||||||
- On neuron startup, register the mistralrs harness (configured with
|
|
||||||
the local mistralrs endpoint, e.g. `http://localhost:8080`).
|
|
||||||
3. Add neuron API endpoints:
|
|
||||||
- `GET /models` — aggregate across all registered harnesses.
|
|
||||||
- `POST /models/load` — dispatch to the correct harness.
|
|
||||||
- `POST /models/unload` — dispatch to the correct harness.
|
|
||||||
- `GET /models/{model_id}/endpoint` — ask the harness.
|
|
||||||
4. neuron config (`neuron.toml`):
|
|
||||||
```toml
|
|
||||||
port = 9090
|
|
||||||
|
|
||||||
[[harnesses]]
|
|
||||||
name = "mistralrs"
|
|
||||||
endpoint = "http://localhost:8080"
|
|
||||||
systemd_unit = "mistralrs.service"
|
|
||||||
```
|
|
||||||
5. Tests:
|
|
||||||
- Mock HTTP server standing in for mistral.rs. Test that the harness
|
|
||||||
implementation correctly translates list/load/unload calls.
|
|
||||||
- Integration test: start neuron with mock mistralrs backend, call
|
|
||||||
`GET /models`, assert it returns models from the mock.
|
|
||||||
|
|
||||||
**Done when:** neuron manages a (mock) mistral.rs instance. All API
|
|
||||||
endpoints return correct data. Tests pass.
|
|
||||||
|
|
||||||
### Phase 9: cortex talks to neurons
|
### Phase 9: cortex talks to neurons
|
||||||
|
|
||||||
|
|||||||
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -408,13 +408,16 @@ name = "cortex-neuron"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"async-trait",
|
||||||
"axum",
|
"axum",
|
||||||
"clap",
|
"clap",
|
||||||
"cortex-core",
|
"cortex-core",
|
||||||
|
"figment",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"toml",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -18,10 +18,14 @@ tokio.workspace = true
|
|||||||
axum.workspace = true
|
axum.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
reqwest.workspace = true
|
||||||
tracing.workspace = true
|
tracing.workspace = true
|
||||||
tracing-subscriber.workspace = true
|
tracing-subscriber.workspace = true
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
|
async-trait.workspace = true
|
||||||
clap.workspace = true
|
clap.workspace = true
|
||||||
|
figment.workspace = true
|
||||||
|
toml.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { workspace = true, features = ["test-util"] }
|
tokio = { workspace = true, features = ["test-util"] }
|
||||||
|
|||||||
@@ -1,17 +1,23 @@
|
|||||||
//! HTTP API handlers for the neuron daemon.
|
//! HTTP API handlers for the neuron daemon.
|
||||||
|
|
||||||
|
use crate::harness::HarnessRegistry;
|
||||||
use crate::health::HealthCache;
|
use crate::health::HealthCache;
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
use axum::extract::State;
|
use axum::extract::{Path, State};
|
||||||
use axum::response::Json;
|
use axum::http::StatusCode;
|
||||||
use axum::routing::get;
|
use axum::response::{IntoResponse, Json};
|
||||||
|
use axum::routing::{get, post};
|
||||||
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
||||||
|
use cortex_core::harness::ModelSpec;
|
||||||
|
use serde_json::{Value, json};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
/// Shared state for the neuron HTTP server.
|
/// Shared state for the neuron HTTP server.
|
||||||
pub struct NeuronState {
|
pub struct NeuronState {
|
||||||
pub discovery: DiscoveryResponse,
|
pub discovery: DiscoveryResponse,
|
||||||
pub health_cache: Arc<HealthCache>,
|
pub health_cache: Arc<HealthCache>,
|
||||||
|
pub registry: RwLock<HarnessRegistry>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build the neuron API router.
|
/// Build the neuron API router.
|
||||||
@@ -19,6 +25,10 @@ pub fn neuron_routes() -> Router<Arc<NeuronState>> {
|
|||||||
Router::new()
|
Router::new()
|
||||||
.route("/discovery", get(discovery_handler))
|
.route("/discovery", get(discovery_handler))
|
||||||
.route("/health", get(health_handler))
|
.route("/health", get(health_handler))
|
||||||
|
.route("/models", get(list_models))
|
||||||
|
.route("/models/load", post(load_model))
|
||||||
|
.route("/models/unload", post(unload_model))
|
||||||
|
.route("/models/{model_id}/endpoint", get(model_endpoint))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
|
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
|
||||||
@@ -28,3 +38,67 @@ async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<Discov
|
|||||||
async fn health_handler(State(state): State<Arc<NeuronState>>) -> Json<HealthResponse> {
|
async fn health_handler(State(state): State<Arc<NeuronState>>) -> Json<HealthResponse> {
|
||||||
Json(state.health_cache.snapshot().await)
|
Json(state.health_cache.snapshot().await)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse {
|
||||||
|
let registry = state.registry.read().await;
|
||||||
|
match registry.list_all_models().await {
|
||||||
|
Ok(models) => Json(json!(models)).into_response(),
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({"error": e.to_string()})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn load_model(
|
||||||
|
State(state): State<Arc<NeuronState>>,
|
||||||
|
Json(spec): Json<ModelSpec>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let registry = state.registry.read().await;
|
||||||
|
match registry.load_model(&spec).await {
|
||||||
|
Ok(()) => Json(json!({"status": "loaded"})).into_response(),
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({"error": e.to_string()})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn unload_model(
|
||||||
|
State(state): State<Arc<NeuronState>>,
|
||||||
|
Json(body): Json<Value>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let model_id = match body.get("model_id").and_then(|v| v.as_str()) {
|
||||||
|
Some(id) => id.to_string(),
|
||||||
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({"error": "missing model_id"})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let registry = state.registry.read().await;
|
||||||
|
match registry.unload_model(&model_id).await {
|
||||||
|
Ok(()) => Json(json!({"status": "unloaded"})).into_response(),
|
||||||
|
Err(e) => (StatusCode::NOT_FOUND, Json(json!({"error": e.to_string()}))).into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn model_endpoint(
|
||||||
|
State(state): State<Arc<NeuronState>>,
|
||||||
|
Path(model_id): Path<String>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let registry = state.registry.read().await;
|
||||||
|
match registry.inference_endpoint(&model_id).await {
|
||||||
|
Some(url) => Json(json!({"url": url})).into_response(),
|
||||||
|
None => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({"error": format!("model '{}' not loaded", model_id)})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
40
crates/neuron/src/config.rs
Normal file
40
crates/neuron/src/config.rs
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
//! Neuron configuration loaded from neuron.toml.
|
||||||
|
|
||||||
|
use cortex_core::harness::HarnessConfig;
|
||||||
|
use figment::{
|
||||||
|
Figment,
|
||||||
|
providers::{Env, Format, Toml},
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct NeuronConfig {
|
||||||
|
#[serde(default = "default_port")]
|
||||||
|
pub port: u16,
|
||||||
|
#[serde(default)]
|
||||||
|
pub harnesses: Vec<HarnessConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_port() -> u16 {
|
||||||
|
9090
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NeuronConfig {
|
||||||
|
pub fn load(path: impl AsRef<Path>) -> Result<Self, Box<figment::Error>> {
|
||||||
|
Figment::new()
|
||||||
|
.merge(Toml::file(path))
|
||||||
|
.merge(Env::prefixed("NEURON_").split("__"))
|
||||||
|
.extract()
|
||||||
|
.map_err(Box::new)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for NeuronConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
port: 9090,
|
||||||
|
harnesses: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1 +1,163 @@
|
|||||||
// mistral.rs harness implementation — Phase 8.
|
//! mistral.rs harness implementation.
|
||||||
|
//!
|
||||||
|
//! Wraps the mistral.rs HTTP API for model lifecycle management
|
||||||
|
//! and optionally manages the process via systemd.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use cortex_core::harness::{Harness, HarnessConfig, HarnessHealth, ModelInfo, ModelSpec};
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
pub struct MistralRsHarness {
|
||||||
|
endpoint: String,
|
||||||
|
systemd_unit: Option<String>,
|
||||||
|
client: Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MistralRsHarness {
|
||||||
|
pub fn new(endpoint: String, systemd_unit: Option<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
endpoint,
|
||||||
|
systemd_unit,
|
||||||
|
client: Client::builder()
|
||||||
|
.timeout(std::time::Duration::from_secs(30))
|
||||||
|
.build()
|
||||||
|
.expect("failed to build HTTP client"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response from mistral.rs `GET /v1/models`.
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ModelsResponse {
|
||||||
|
data: Vec<ModelEntry>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ModelEntry {
|
||||||
|
id: String,
|
||||||
|
#[serde(default)]
|
||||||
|
status: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Harness for MistralRsHarness {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"mistralrs"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
|
||||||
|
let Some(unit) = &self.systemd_unit else {
|
||||||
|
anyhow::bail!("no systemd unit configured for mistralrs harness");
|
||||||
|
};
|
||||||
|
|
||||||
|
let output = tokio::process::Command::new("systemctl")
|
||||||
|
.args(["start", unit])
|
||||||
|
.output()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
|
anyhow::bail!("systemctl start {unit} failed: {stderr}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the health endpoint to respond (up to 30s).
|
||||||
|
let url = format!("{}/health", self.endpoint);
|
||||||
|
for _ in 0..30 {
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||||
|
if self.client.get(&url).send().await.is_ok() {
|
||||||
|
tracing::info!(unit, "mistralrs started and healthy");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anyhow::bail!("mistralrs started but health endpoint did not respond within 30s");
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stop(&self) -> Result<()> {
|
||||||
|
let Some(unit) = &self.systemd_unit else {
|
||||||
|
anyhow::bail!("no systemd unit configured for mistralrs harness");
|
||||||
|
};
|
||||||
|
|
||||||
|
let output = tokio::process::Command::new("systemctl")
|
||||||
|
.args(["stop", unit])
|
||||||
|
.output()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
|
anyhow::bail!("systemctl stop {unit} failed: {stderr}");
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self) -> HarnessHealth {
|
||||||
|
let url = format!("{}/health", self.endpoint);
|
||||||
|
let running = self.client.get(&url).send().await.is_ok();
|
||||||
|
HarnessHealth {
|
||||||
|
name: "mistralrs".into(),
|
||||||
|
running,
|
||||||
|
uptime_secs: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||||
|
let url = format!("{}/v1/models", self.endpoint);
|
||||||
|
let resp = self.client.get(&url).send().await?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
anyhow::bail!("GET /v1/models returned {}", resp.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
let models_resp: ModelsResponse = resp.json().await?;
|
||||||
|
Ok(models_resp
|
||||||
|
.data
|
||||||
|
.into_iter()
|
||||||
|
.map(|m| ModelInfo {
|
||||||
|
id: m.id,
|
||||||
|
harness: "mistralrs".into(),
|
||||||
|
status: m.status.unwrap_or_else(|| "loaded".into()),
|
||||||
|
devices: vec![],
|
||||||
|
vram_used_mb: None,
|
||||||
|
})
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
|
||||||
|
let url = format!("{}/v1/models/reload", self.endpoint);
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(&url)
|
||||||
|
.json(&serde_json::json!({ "model_id": spec.model_id }))
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
let body = resp.text().await.unwrap_or_default();
|
||||||
|
anyhow::bail!("POST /v1/models/reload failed: {body}");
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn unload_model(&self, model_id: &str) -> Result<()> {
|
||||||
|
let url = format!("{}/v1/models/unload", self.endpoint);
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(&url)
|
||||||
|
.json(&serde_json::json!({ "model_id": model_id }))
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
let body = resp.text().await.unwrap_or_default();
|
||||||
|
anyhow::bail!("POST /v1/models/unload failed: {body}");
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn inference_endpoint(&self, _model_id: &str) -> Option<String> {
|
||||||
|
// mistral.rs routes internally by model name in the request body,
|
||||||
|
// so the inference endpoint is always the base URL.
|
||||||
|
Some(self.endpoint.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,105 @@
|
|||||||
// Harness registry. Implementations added in Phase 8+.
|
//! Harness registry — maps harness names to trait implementations.
|
||||||
|
|
||||||
pub mod llamacpp;
|
pub mod llamacpp;
|
||||||
pub mod mistralrs;
|
pub mod mistralrs;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
/// Registry of available harness implementations.
|
||||||
|
pub struct HarnessRegistry {
|
||||||
|
harnesses: HashMap<String, Box<dyn Harness>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for HarnessRegistry {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HarnessRegistry {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
harnesses: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register(&mut self, harness: Box<dyn Harness>) {
|
||||||
|
self.harnesses.insert(harness.name().to_string(), harness);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all registered harness names.
|
||||||
|
pub fn names(&self) -> Vec<String> {
|
||||||
|
self.harnesses.keys().cloned().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List models from all registered harnesses.
|
||||||
|
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
||||||
|
let mut all = Vec::new();
|
||||||
|
for harness in self.harnesses.values() {
|
||||||
|
match harness.list_models().await {
|
||||||
|
Ok(models) => all.extend(models),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(harness = harness.name(), error = %e, "failed to list models");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(all)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a model on the specified harness.
|
||||||
|
pub async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
|
||||||
|
let harness = self
|
||||||
|
.harnesses
|
||||||
|
.get(&spec.harness)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("unknown harness: {}", spec.harness))?;
|
||||||
|
harness.load_model(spec).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unload a model. Tries each harness until one claims it.
|
||||||
|
pub async fn unload_model(&self, model_id: &str) -> Result<()> {
|
||||||
|
for harness in self.harnesses.values() {
|
||||||
|
match harness.list_models().await {
|
||||||
|
Ok(models) if models.iter().any(|m| m.id == model_id) => {
|
||||||
|
return harness.unload_model(model_id).await;
|
||||||
|
}
|
||||||
|
_ => continue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anyhow::bail!("model '{model_id}' not found on any harness")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the inference endpoint for a model.
|
||||||
|
pub async fn inference_endpoint(&self, model_id: &str) -> Option<String> {
|
||||||
|
for harness in self.harnesses.values() {
|
||||||
|
if let Some(url) = harness.inference_endpoint(model_id).await {
|
||||||
|
return Some(url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a registry from harness configs.
|
||||||
|
pub fn from_configs(configs: &[HarnessConfig]) -> Self {
|
||||||
|
let mut registry = Self::new();
|
||||||
|
for config in configs {
|
||||||
|
match config.name.as_str() {
|
||||||
|
"mistralrs" => {
|
||||||
|
if let Some(endpoint) = &config.endpoint {
|
||||||
|
registry.register(Box::new(mistralrs::MistralRsHarness::new(
|
||||||
|
endpoint.clone(),
|
||||||
|
config.systemd_unit.clone(),
|
||||||
|
)));
|
||||||
|
} else {
|
||||||
|
tracing::warn!("mistralrs harness missing endpoint, skipping");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
tracing::warn!(harness = other, "unknown harness type, skipping");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
registry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
pub mod api;
|
pub mod api;
|
||||||
|
pub mod config;
|
||||||
pub mod discovery;
|
pub mod discovery;
|
||||||
pub mod harness;
|
pub mod harness;
|
||||||
pub mod health;
|
pub mod health;
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use cortex_neuron::{api, discovery, health};
|
use cortex_neuron::{api, config::NeuronConfig, discovery, harness::HarnessRegistry, health};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
@@ -10,9 +11,13 @@ use tracing_subscriber::EnvFilter;
|
|||||||
#[command(about = "Per-node daemon for cortex inference clusters")]
|
#[command(about = "Per-node daemon for cortex inference clusters")]
|
||||||
#[command(version)]
|
#[command(version)]
|
||||||
struct Args {
|
struct Args {
|
||||||
/// Port to listen on.
|
/// Port to listen on (overrides config file).
|
||||||
#[arg(short, long, default_value = "9090")]
|
#[arg(short, long)]
|
||||||
port: u16,
|
port: Option<u16>,
|
||||||
|
|
||||||
|
/// Path to the neuron config file.
|
||||||
|
#[arg(short, long, default_value = "neuron.toml")]
|
||||||
|
config: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@@ -25,16 +30,27 @@ async fn main() -> Result<()> {
|
|||||||
.init();
|
.init();
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| {
|
||||||
|
tracing::warn!(path = %args.config, error = %e, "config not found, using defaults");
|
||||||
|
NeuronConfig::default()
|
||||||
|
});
|
||||||
|
|
||||||
|
let port = args.port.unwrap_or(cfg.port);
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
|
||||||
tracing::info!("running hardware discovery");
|
tracing::info!("running hardware discovery");
|
||||||
let discovery_result = discovery::discover_system().await?;
|
let mut discovery_result = discovery::discover_system().await?;
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
hostname = %discovery_result.hostname,
|
hostname = %discovery_result.hostname,
|
||||||
devices = discovery_result.devices.len(),
|
devices = discovery_result.devices.len(),
|
||||||
"discovery complete"
|
"discovery complete"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Build harness registry from config.
|
||||||
|
let registry = HarnessRegistry::from_configs(&cfg.harnesses);
|
||||||
|
discovery_result.harnesses = registry.names();
|
||||||
|
|
||||||
let health_cache = Arc::new(health::HealthCache::new());
|
let health_cache = Arc::new(health::HealthCache::new());
|
||||||
health_cache
|
health_cache
|
||||||
.set_has_gpus(!discovery_result.devices.is_empty())
|
.set_has_gpus(!discovery_result.devices.is_empty())
|
||||||
@@ -48,10 +64,11 @@ async fn main() -> Result<()> {
|
|||||||
let state = Arc::new(api::NeuronState {
|
let state = Arc::new(api::NeuronState {
|
||||||
discovery: discovery_result,
|
discovery: discovery_result,
|
||||||
health_cache,
|
health_cache,
|
||||||
|
registry: RwLock::new(registry),
|
||||||
});
|
});
|
||||||
|
|
||||||
let app = api::neuron_routes().with_state(state);
|
let app = api::neuron_routes().with_state(state);
|
||||||
let addr: std::net::SocketAddr = format!("0.0.0.0:{}", args.port).parse()?;
|
let addr: std::net::SocketAddr = format!("0.0.0.0:{port}").parse()?;
|
||||||
tracing::info!("cortex-neuron listening on {addr}");
|
tracing::info!("cortex-neuron listening on {addr}");
|
||||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||||
axum::serve(listener, app).await?;
|
axum::serve(listener, app).await?;
|
||||||
|
|||||||
@@ -1,20 +1,19 @@
|
|||||||
use cortex_core::discovery::{DeviceHealth, DeviceInfo, DiscoveryResponse, HealthResponse};
|
use cortex_core::discovery::{DeviceInfo, DiscoveryResponse};
|
||||||
use cortex_neuron::api::{self, NeuronState};
|
use cortex_neuron::api::{self, NeuronState};
|
||||||
|
use cortex_neuron::harness::HarnessRegistry;
|
||||||
use cortex_neuron::health::HealthCache;
|
use cortex_neuron::health::HealthCache;
|
||||||
|
use serde_json::json;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
async fn spawn_neuron(discovery: DiscoveryResponse, health: HealthResponse) -> String {
|
async fn spawn_neuron(discovery: DiscoveryResponse) -> String {
|
||||||
let health_cache = Arc::new(HealthCache::new());
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
// Pre-populate the health cache by writing through the snapshot mechanism.
|
let registry = HarnessRegistry::new();
|
||||||
// HealthCache doesn't expose a direct setter, so we'll build one with
|
|
||||||
// the data already in place via the NeuronState.
|
|
||||||
// For testing, we use the cache as-is (uptime 0, empty devices) unless
|
|
||||||
// we need specific values — see test_health_endpoint.
|
|
||||||
let _ = health; // used below via a different approach
|
|
||||||
|
|
||||||
let state = Arc::new(NeuronState {
|
let state = Arc::new(NeuronState {
|
||||||
discovery,
|
discovery,
|
||||||
health_cache,
|
health_cache,
|
||||||
|
registry: RwLock::new(registry),
|
||||||
});
|
});
|
||||||
|
|
||||||
let app = api::neuron_routes().with_state(state);
|
let app = api::neuron_routes().with_state(state);
|
||||||
@@ -51,32 +50,9 @@ fn fake_discovery() -> DiscoveryResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fake_health() -> HealthResponse {
|
|
||||||
HealthResponse {
|
|
||||||
uptime_secs: 0,
|
|
||||||
devices: vec![
|
|
||||||
DeviceHealth {
|
|
||||||
index: 0,
|
|
||||||
vram_used_mb: 8192,
|
|
||||||
vram_free_mb: 24422,
|
|
||||||
utilization_pct: 45,
|
|
||||||
temp_c: 62,
|
|
||||||
},
|
|
||||||
DeviceHealth {
|
|
||||||
index: 1,
|
|
||||||
vram_used_mb: 4096,
|
|
||||||
vram_free_mb: 28518,
|
|
||||||
utilization_pct: 30,
|
|
||||||
temp_c: 58,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_discovery_endpoint() {
|
async fn test_discovery_endpoint() {
|
||||||
let disc = fake_discovery();
|
let url = spawn_neuron(fake_discovery()).await;
|
||||||
let url = spawn_neuron(disc, fake_health()).await;
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let resp = client
|
let resp = client
|
||||||
@@ -89,20 +65,17 @@ async fn test_discovery_endpoint() {
|
|||||||
|
|
||||||
let body: serde_json::Value = resp.json().await.unwrap();
|
let body: serde_json::Value = resp.json().await.unwrap();
|
||||||
assert_eq!(body["hostname"], "test-node");
|
assert_eq!(body["hostname"], "test-node");
|
||||||
assert_eq!(body["os"], "Linux");
|
|
||||||
assert_eq!(body["cuda_version"], "12.8");
|
assert_eq!(body["cuda_version"], "12.8");
|
||||||
assert_eq!(body["driver_version"], "570.86.16");
|
|
||||||
|
|
||||||
let devices = body["devices"].as_array().unwrap();
|
let devices = body["devices"].as_array().unwrap();
|
||||||
assert_eq!(devices.len(), 2);
|
assert_eq!(devices.len(), 2);
|
||||||
assert_eq!(devices[0]["name"], "NVIDIA GeForce RTX 5090");
|
assert_eq!(devices[0]["name"], "NVIDIA GeForce RTX 5090");
|
||||||
assert_eq!(devices[0]["vram_total_mb"], 32614);
|
assert_eq!(devices[0]["vram_total_mb"], 32614);
|
||||||
assert_eq!(devices[0]["compute_capability"], "12.0");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_health_endpoint() {
|
async fn test_health_endpoint() {
|
||||||
let url = spawn_neuron(fake_discovery(), fake_health()).await;
|
let url = spawn_neuron(fake_discovery()).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let resp = client
|
let resp = client
|
||||||
@@ -114,9 +87,7 @@ async fn test_health_endpoint() {
|
|||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 200);
|
||||||
|
|
||||||
let body: serde_json::Value = resp.json().await.unwrap();
|
let body: serde_json::Value = resp.json().await.unwrap();
|
||||||
// HealthCache starts with uptime 0 and empty devices (no poller running in test).
|
|
||||||
assert_eq!(body["uptime_secs"], 0);
|
assert_eq!(body["uptime_secs"], 0);
|
||||||
assert!(body["devices"].as_array().unwrap().is_empty());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -130,14 +101,7 @@ async fn test_discovery_no_gpus() {
|
|||||||
devices: vec![],
|
devices: vec![],
|
||||||
harnesses: vec![],
|
harnesses: vec![],
|
||||||
};
|
};
|
||||||
let url = spawn_neuron(
|
let url = spawn_neuron(disc).await;
|
||||||
disc,
|
|
||||||
HealthResponse {
|
|
||||||
uptime_secs: 0,
|
|
||||||
devices: vec![],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let resp = client
|
let resp = client
|
||||||
@@ -153,3 +117,133 @@ async fn test_discovery_no_gpus() {
|
|||||||
assert!(body["cuda_version"].is_null());
|
assert!(body["cuda_version"].is_null());
|
||||||
assert!(body["devices"].as_array().unwrap().is_empty());
|
assert!(body["devices"].as_array().unwrap().is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_models_empty_registry() {
|
||||||
|
let url = spawn_neuron(fake_discovery()).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let resp = client
|
||||||
|
.get(format!("{url}/models"))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("request should succeed");
|
||||||
|
|
||||||
|
assert_eq!(resp.status(), 200);
|
||||||
|
|
||||||
|
let body: serde_json::Value = resp.json().await.unwrap();
|
||||||
|
assert!(body.as_array().unwrap().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spawn a mock mistral.rs backend and a neuron with the mistralrs harness
|
||||||
|
/// pointing at it, then test the full model lifecycle through neuron's API.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_models_via_mistralrs_harness() {
|
||||||
|
use axum::routing::{get, post};
|
||||||
|
use axum::{Json, Router};
|
||||||
|
use cortex_core::harness::HarnessConfig;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
// Mock mistral.rs backend.
|
||||||
|
let mock_app = Router::new()
|
||||||
|
.route(
|
||||||
|
"/v1/models",
|
||||||
|
get(|| async {
|
||||||
|
Json(json!({
|
||||||
|
"data": [
|
||||||
|
{"id": "test-model", "status": "loaded"},
|
||||||
|
{"id": "other-model", "status": "unloaded"}
|
||||||
|
]
|
||||||
|
}))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/v1/models/unload",
|
||||||
|
post(|Json(_body): Json<Value>| async { Json(json!({"status": "ok"})) }),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/v1/models/reload",
|
||||||
|
post(|Json(_body): Json<Value>| async { Json(json!({"status": "ok"})) }),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mock_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let mock_addr = mock_listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(mock_listener, mock_app).await.unwrap();
|
||||||
|
});
|
||||||
|
let mock_url = format!("http://{mock_addr}");
|
||||||
|
|
||||||
|
// Build neuron with mistralrs harness pointing at mock.
|
||||||
|
let registry = HarnessRegistry::from_configs(&[HarnessConfig {
|
||||||
|
name: "mistralrs".into(),
|
||||||
|
endpoint: Some(mock_url.clone()),
|
||||||
|
systemd_unit: None,
|
||||||
|
}]);
|
||||||
|
|
||||||
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
|
let state = Arc::new(NeuronState {
|
||||||
|
discovery: fake_discovery(),
|
||||||
|
health_cache,
|
||||||
|
registry: RwLock::new(registry),
|
||||||
|
});
|
||||||
|
|
||||||
|
let app = api::neuron_routes().with_state(state);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let neuron_addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
let neuron_url = format!("http://{neuron_addr}");
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
// GET /models — should return models from mock mistralrs.
|
||||||
|
let resp = client
|
||||||
|
.get(format!("{neuron_url}/models"))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 200);
|
||||||
|
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||||
|
assert_eq!(models.len(), 2);
|
||||||
|
assert_eq!(models[0]["id"], "test-model");
|
||||||
|
assert_eq!(models[0]["harness"], "mistralrs");
|
||||||
|
assert_eq!(models[0]["status"], "loaded");
|
||||||
|
assert_eq!(models[1]["id"], "other-model");
|
||||||
|
assert_eq!(models[1]["status"], "unloaded");
|
||||||
|
|
||||||
|
// GET /models/test-model/endpoint — should return mock URL.
|
||||||
|
let resp = client
|
||||||
|
.get(format!("{neuron_url}/models/test-model/endpoint"))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 200);
|
||||||
|
let body: serde_json::Value = resp.json().await.unwrap();
|
||||||
|
assert_eq!(body["url"], mock_url);
|
||||||
|
|
||||||
|
// POST /models/unload — should succeed.
|
||||||
|
let resp = client
|
||||||
|
.post(format!("{neuron_url}/models/unload"))
|
||||||
|
.json(&json!({"model_id": "test-model"}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 200);
|
||||||
|
let body: serde_json::Value = resp.json().await.unwrap();
|
||||||
|
assert_eq!(body["status"], "unloaded");
|
||||||
|
|
||||||
|
// POST /models/load — should succeed.
|
||||||
|
let resp = client
|
||||||
|
.post(format!("{neuron_url}/models/load"))
|
||||||
|
.json(&json!({
|
||||||
|
"model_id": "test-model",
|
||||||
|
"harness": "mistralrs"
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 200);
|
||||||
|
let body: serde_json::Value = resp.json().await.unwrap();
|
||||||
|
assert_eq!(body["status"], "loaded");
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user