feat: implement mistral.rs harness and neuron model API
All checks were successful
CI / Format, lint, build, test (push) Successful in 2m30s
CI / Build SRPM (push) Has been skipped
CI / Publish to COPR (push) Has been skipped

- MistralRsHarness: Harness trait impl wrapping mistral.rs HTTP API
  (list/load/unload models, health check, start/stop via systemd)
- HarnessRegistry: maps harness name -> Box<dyn Harness>, built from
  neuron.toml config
- Neuron API endpoints: GET /models, POST /models/load,
  POST /models/unload, GET /models/:id/endpoint
- NeuronConfig: figment-based config loading from neuron.toml
- Integration test: full model lifecycle through mock mistral.rs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-15 14:29:42 +03:00
parent 6dc717ebcd
commit 26e5e7ead8
10 changed files with 562 additions and 99 deletions

View File

@@ -1,17 +1,23 @@
//! HTTP API handlers for the neuron daemon.
use crate::harness::HarnessRegistry;
use crate::health::HealthCache;
use axum::Router;
use axum::extract::State;
use axum::response::Json;
use axum::routing::get;
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json};
use axum::routing::{get, post};
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
use cortex_core::harness::ModelSpec;
use serde_json::{Value, json};
use std::sync::Arc;
use tokio::sync::RwLock;
/// Shared state for the neuron HTTP server.
pub struct NeuronState {
pub discovery: DiscoveryResponse,
pub health_cache: Arc<HealthCache>,
pub registry: RwLock<HarnessRegistry>,
}
/// Build the neuron API router.
@@ -19,6 +25,10 @@ pub fn neuron_routes() -> Router<Arc<NeuronState>> {
Router::new()
.route("/discovery", get(discovery_handler))
.route("/health", get(health_handler))
.route("/models", get(list_models))
.route("/models/load", post(load_model))
.route("/models/unload", post(unload_model))
.route("/models/{model_id}/endpoint", get(model_endpoint))
}
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
@@ -28,3 +38,67 @@ async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<Discov
async fn health_handler(State(state): State<Arc<NeuronState>>) -> Json<HealthResponse> {
Json(state.health_cache.snapshot().await)
}
async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse {
let registry = state.registry.read().await;
match registry.list_all_models().await {
Ok(models) => Json(json!(models)).into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
}
async fn load_model(
State(state): State<Arc<NeuronState>>,
Json(spec): Json<ModelSpec>,
) -> impl IntoResponse {
let registry = state.registry.read().await;
match registry.load_model(&spec).await {
Ok(()) => Json(json!({"status": "loaded"})).into_response(),
Err(e) => (
StatusCode::BAD_REQUEST,
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
}
async fn unload_model(
State(state): State<Arc<NeuronState>>,
Json(body): Json<Value>,
) -> impl IntoResponse {
let model_id = match body.get("model_id").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "missing model_id"})),
)
.into_response();
}
};
let registry = state.registry.read().await;
match registry.unload_model(&model_id).await {
Ok(()) => Json(json!({"status": "unloaded"})).into_response(),
Err(e) => (StatusCode::NOT_FOUND, Json(json!({"error": e.to_string()}))).into_response(),
}
}
async fn model_endpoint(
State(state): State<Arc<NeuronState>>,
Path(model_id): Path<String>,
) -> impl IntoResponse {
let registry = state.registry.read().await;
match registry.inference_endpoint(&model_id).await {
Some(url) => Json(json!({"url": url})).into_response(),
None => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("model '{}' not loaded", model_id)})),
)
.into_response(),
}
}

View File

@@ -0,0 +1,40 @@
//! Neuron configuration loaded from neuron.toml.
use cortex_core::harness::HarnessConfig;
use figment::{
Figment,
providers::{Env, Format, Toml},
};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NeuronConfig {
#[serde(default = "default_port")]
pub port: u16,
#[serde(default)]
pub harnesses: Vec<HarnessConfig>,
}
fn default_port() -> u16 {
9090
}
impl NeuronConfig {
pub fn load(path: impl AsRef<Path>) -> Result<Self, Box<figment::Error>> {
Figment::new()
.merge(Toml::file(path))
.merge(Env::prefixed("NEURON_").split("__"))
.extract()
.map_err(Box::new)
}
}
impl Default for NeuronConfig {
fn default() -> Self {
Self {
port: 9090,
harnesses: vec![],
}
}
}

View File

@@ -1 +1,163 @@
// mistral.rs harness implementation — Phase 8.
//! mistral.rs harness implementation.
//!
//! Wraps the mistral.rs HTTP API for model lifecycle management
//! and optionally manages the process via systemd.
use anyhow::Result;
use async_trait::async_trait;
use cortex_core::harness::{Harness, HarnessConfig, HarnessHealth, ModelInfo, ModelSpec};
use reqwest::Client;
use serde::Deserialize;
pub struct MistralRsHarness {
endpoint: String,
systemd_unit: Option<String>,
client: Client,
}
impl MistralRsHarness {
pub fn new(endpoint: String, systemd_unit: Option<String>) -> Self {
Self {
endpoint,
systemd_unit,
client: Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("failed to build HTTP client"),
}
}
}
/// Response from mistral.rs `GET /v1/models`.
#[derive(Debug, Deserialize)]
struct ModelsResponse {
data: Vec<ModelEntry>,
}
#[derive(Debug, Deserialize)]
struct ModelEntry {
id: String,
#[serde(default)]
status: Option<String>,
}
#[async_trait]
impl Harness for MistralRsHarness {
fn name(&self) -> &str {
"mistralrs"
}
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
let Some(unit) = &self.systemd_unit else {
anyhow::bail!("no systemd unit configured for mistralrs harness");
};
let output = tokio::process::Command::new("systemctl")
.args(["start", unit])
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("systemctl start {unit} failed: {stderr}");
}
// Wait for the health endpoint to respond (up to 30s).
let url = format!("{}/health", self.endpoint);
for _ in 0..30 {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
if self.client.get(&url).send().await.is_ok() {
tracing::info!(unit, "mistralrs started and healthy");
return Ok(());
}
}
anyhow::bail!("mistralrs started but health endpoint did not respond within 30s");
}
async fn stop(&self) -> Result<()> {
let Some(unit) = &self.systemd_unit else {
anyhow::bail!("no systemd unit configured for mistralrs harness");
};
let output = tokio::process::Command::new("systemctl")
.args(["stop", unit])
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("systemctl stop {unit} failed: {stderr}");
}
Ok(())
}
async fn health(&self) -> HarnessHealth {
let url = format!("{}/health", self.endpoint);
let running = self.client.get(&url).send().await.is_ok();
HarnessHealth {
name: "mistralrs".into(),
running,
uptime_secs: None,
}
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
let url = format!("{}/v1/models", self.endpoint);
let resp = self.client.get(&url).send().await?;
if !resp.status().is_success() {
anyhow::bail!("GET /v1/models returned {}", resp.status());
}
let models_resp: ModelsResponse = resp.json().await?;
Ok(models_resp
.data
.into_iter()
.map(|m| ModelInfo {
id: m.id,
harness: "mistralrs".into(),
status: m.status.unwrap_or_else(|| "loaded".into()),
devices: vec![],
vram_used_mb: None,
})
.collect())
}
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
let url = format!("{}/v1/models/reload", self.endpoint);
let resp = self
.client
.post(&url)
.json(&serde_json::json!({ "model_id": spec.model_id }))
.send()
.await?;
if !resp.status().is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("POST /v1/models/reload failed: {body}");
}
Ok(())
}
async fn unload_model(&self, model_id: &str) -> Result<()> {
let url = format!("{}/v1/models/unload", self.endpoint);
let resp = self
.client
.post(&url)
.json(&serde_json::json!({ "model_id": model_id }))
.send()
.await?;
if !resp.status().is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("POST /v1/models/unload failed: {body}");
}
Ok(())
}
async fn inference_endpoint(&self, _model_id: &str) -> Option<String> {
// mistral.rs routes internally by model name in the request body,
// so the inference endpoint is always the base URL.
Some(self.endpoint.clone())
}
}

View File

@@ -1,4 +1,105 @@
// Harness registry. Implementations added in Phase 8+.
//! Harness registry — maps harness names to trait implementations.
pub mod llamacpp;
pub mod mistralrs;
use anyhow::Result;
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
use std::collections::HashMap;
/// Registry of available harness implementations.
pub struct HarnessRegistry {
harnesses: HashMap<String, Box<dyn Harness>>,
}
impl Default for HarnessRegistry {
fn default() -> Self {
Self::new()
}
}
impl HarnessRegistry {
pub fn new() -> Self {
Self {
harnesses: HashMap::new(),
}
}
pub fn register(&mut self, harness: Box<dyn Harness>) {
self.harnesses.insert(harness.name().to_string(), harness);
}
/// List all registered harness names.
pub fn names(&self) -> Vec<String> {
self.harnesses.keys().cloned().collect()
}
/// List models from all registered harnesses.
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
let mut all = Vec::new();
for harness in self.harnesses.values() {
match harness.list_models().await {
Ok(models) => all.extend(models),
Err(e) => {
tracing::warn!(harness = harness.name(), error = %e, "failed to list models");
}
}
}
Ok(all)
}
/// Load a model on the specified harness.
pub async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
let harness = self
.harnesses
.get(&spec.harness)
.ok_or_else(|| anyhow::anyhow!("unknown harness: {}", spec.harness))?;
harness.load_model(spec).await
}
/// Unload a model. Tries each harness until one claims it.
pub async fn unload_model(&self, model_id: &str) -> Result<()> {
for harness in self.harnesses.values() {
match harness.list_models().await {
Ok(models) if models.iter().any(|m| m.id == model_id) => {
return harness.unload_model(model_id).await;
}
_ => continue,
}
}
anyhow::bail!("model '{model_id}' not found on any harness")
}
/// Get the inference endpoint for a model.
pub async fn inference_endpoint(&self, model_id: &str) -> Option<String> {
for harness in self.harnesses.values() {
if let Some(url) = harness.inference_endpoint(model_id).await {
return Some(url);
}
}
None
}
/// Build a registry from harness configs.
pub fn from_configs(configs: &[HarnessConfig]) -> Self {
let mut registry = Self::new();
for config in configs {
match config.name.as_str() {
"mistralrs" => {
if let Some(endpoint) = &config.endpoint {
registry.register(Box::new(mistralrs::MistralRsHarness::new(
endpoint.clone(),
config.systemd_unit.clone(),
)));
} else {
tracing::warn!("mistralrs harness missing endpoint, skipping");
}
}
other => {
tracing::warn!(harness = other, "unknown harness type, skipping");
}
}
}
registry
}
}

View File

@@ -1,4 +1,5 @@
pub mod api;
pub mod config;
pub mod discovery;
pub mod harness;
pub mod health;

View File

@@ -1,8 +1,9 @@
use anyhow::Result;
use clap::Parser;
use cortex_neuron::{api, discovery, health};
use cortex_neuron::{api, config::NeuronConfig, discovery, harness::HarnessRegistry, health};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tracing_subscriber::EnvFilter;
#[derive(Parser)]
@@ -10,9 +11,13 @@ use tracing_subscriber::EnvFilter;
#[command(about = "Per-node daemon for cortex inference clusters")]
#[command(version)]
struct Args {
/// Port to listen on.
#[arg(short, long, default_value = "9090")]
port: u16,
/// Port to listen on (overrides config file).
#[arg(short, long)]
port: Option<u16>,
/// Path to the neuron config file.
#[arg(short, long, default_value = "neuron.toml")]
config: String,
}
#[tokio::main]
@@ -25,16 +30,27 @@ async fn main() -> Result<()> {
.init();
let args = Args::parse();
let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| {
tracing::warn!(path = %args.config, error = %e, "config not found, using defaults");
NeuronConfig::default()
});
let port = args.port.unwrap_or(cfg.port);
let start_time = Instant::now();
tracing::info!("running hardware discovery");
let discovery_result = discovery::discover_system().await?;
let mut discovery_result = discovery::discover_system().await?;
tracing::info!(
hostname = %discovery_result.hostname,
devices = discovery_result.devices.len(),
"discovery complete"
);
// Build harness registry from config.
let registry = HarnessRegistry::from_configs(&cfg.harnesses);
discovery_result.harnesses = registry.names();
let health_cache = Arc::new(health::HealthCache::new());
health_cache
.set_has_gpus(!discovery_result.devices.is_empty())
@@ -48,10 +64,11 @@ async fn main() -> Result<()> {
let state = Arc::new(api::NeuronState {
discovery: discovery_result,
health_cache,
registry: RwLock::new(registry),
});
let app = api::neuron_routes().with_state(state);
let addr: std::net::SocketAddr = format!("0.0.0.0:{}", args.port).parse()?;
let addr: std::net::SocketAddr = format!("0.0.0.0:{port}").parse()?;
tracing::info!("cortex-neuron listening on {addr}");
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;