feat(neuron): OpenAI-compatible non-streaming chat completion
Stage 3 of the candle-native pivot. neuron now serves POST /v1/chat/completions backed by candle's quantized_qwen3 forward pass on a per-model serialised generation loop, returning the standard OpenAI ChatCompletionResponse envelope. Pipeline per request: - Look up the LoadedModel by request.model (404 if absent). - Apply the Qwen3 chat template across all messages. - Tokenize, then spawn_blocking onto tokio's blocking pool to acquire the per-model arch lock and run prefill + greedy/temperature/top-p sampling via LogitsProcessor. - Stop on <|im_end|>/<|endoftext|> EOS or max_tokens (finish_reason "stop" vs "length"). - Decode with skip_special_tokens=true, build OpenAI response with prompt/completion/total usage counts. Supporting changes: - HarnessRegistry now stores Arc<dyn Harness> and caches a typed Arc<CandleHarness> so inference routes bypass dyn-Trait dispatch. - LoadedModel.arch becomes Arc<Mutex<ModelArch>> so the lock guard can be moved into spawn_blocking. - NeuronState gains an Option<Arc<CandleHarness>> field for the new inference route. - Typed InferenceError lets the handler map ModelNotLoaded → 404 and other failures → 500 without string-matching anyhow messages. - stream=true returns 501 until Stage 4 wires up SSE. - Two leftover mistral.rs string references in proxy.rs and cortex-cli (missed during the Stage 1 sweep) are corrected here. Three new default-feature tests cover the no-candle 503, model-not- loaded 404, and stream=true 501 paths. The cuda-integration test from Stage 2 still covers real load/unload; a streaming-feature gated test exercising actual generation will arrive with Stage 4. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2105,6 +2105,7 @@ dependencies = [
|
|||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"thiserror 2.0.18",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"toml",
|
"toml",
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ use tracing_subscriber::EnvFilter;
|
|||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "cortex")]
|
#[command(name = "cortex")]
|
||||||
#[command(about = "Unified inference gateway for multi-node mistral.rs clusters")]
|
#[command(about = "Unified inference gateway for multi-node GPU clusters")]
|
||||||
#[command(version)]
|
#[command(version)]
|
||||||
struct Cli {
|
struct Cli {
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//! Streaming HTTP reverse proxy to mistral.rs backends.
|
//! Streaming HTTP reverse proxy to neuron backends.
|
||||||
//!
|
//!
|
||||||
//! For streaming requests, SSE chunks are forwarded as they arrive.
|
//! For streaming requests, SSE chunks are forwarded as they arrive.
|
||||||
//! The proxy captures timing information for metrics but does not
|
//! The proxy captures timing information for metrics but does not
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ tracing-subscriber.workspace = true
|
|||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
async-trait.workspace = true
|
async-trait.workspace = true
|
||||||
clap.workspace = true
|
clap.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
figment.workspace = true
|
figment.workspace = true
|
||||||
toml.workspace = true
|
toml.workspace = true
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
//! HTTP API handlers for the neuron daemon.
|
//! HTTP API handlers for the neuron daemon.
|
||||||
|
|
||||||
use crate::harness::HarnessRegistry;
|
use crate::harness::HarnessRegistry;
|
||||||
|
use crate::harness::candle::{CandleHarness, InferenceError};
|
||||||
use crate::health::HealthCache;
|
use crate::health::HealthCache;
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
use axum::extract::{Path, State};
|
use axum::extract::{Path, State};
|
||||||
@@ -9,6 +10,7 @@ use axum::response::{IntoResponse, Json};
|
|||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
||||||
use cortex_core::harness::ModelSpec;
|
use cortex_core::harness::ModelSpec;
|
||||||
|
use cortex_core::openai::ChatCompletionRequest;
|
||||||
use serde_json::{Value, json};
|
use serde_json::{Value, json};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
@@ -18,6 +20,10 @@ pub struct NeuronState {
|
|||||||
pub discovery: DiscoveryResponse,
|
pub discovery: DiscoveryResponse,
|
||||||
pub health_cache: Arc<HealthCache>,
|
pub health_cache: Arc<HealthCache>,
|
||||||
pub registry: RwLock<HarnessRegistry>,
|
pub registry: RwLock<HarnessRegistry>,
|
||||||
|
/// Typed handle to the candle harness for inference routes. Cached at
|
||||||
|
/// startup so `/v1/chat/completions` doesn't have to hold the registry
|
||||||
|
/// read lock or perform dyn-Trait dispatch per request.
|
||||||
|
pub candle: Option<Arc<CandleHarness>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build the neuron API router.
|
/// Build the neuron API router.
|
||||||
@@ -29,6 +35,7 @@ pub fn neuron_routes() -> Router<Arc<NeuronState>> {
|
|||||||
.route("/models/load", post(load_model))
|
.route("/models/load", post(load_model))
|
||||||
.route("/models/unload", post(unload_model))
|
.route("/models/unload", post(unload_model))
|
||||||
.route("/models/{model_id}/endpoint", get(model_endpoint))
|
.route("/models/{model_id}/endpoint", get(model_endpoint))
|
||||||
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
|
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
|
||||||
@@ -102,3 +109,40 @@ async fn model_endpoint(
|
|||||||
.into_response(),
|
.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// OpenAI-compatible chat completions. Non-streaming for Stage 3; the
|
||||||
|
/// streaming path is added in Stage 4.
|
||||||
|
async fn chat_completions(
|
||||||
|
State(state): State<Arc<NeuronState>>,
|
||||||
|
Json(req): Json<ChatCompletionRequest>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
|
||||||
|
return (
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
Json(json!({"error": "candle harness not enabled on this neuron"})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
};
|
||||||
|
|
||||||
|
if req.stream.unwrap_or(false) {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
Json(json!({"error": "streaming responses arrive in Stage 4"})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
match candle.chat_completion(req).await {
|
||||||
|
Ok(resp) => Json(resp).into_response(),
|
||||||
|
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(InferenceError::Other(e)) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({"error": e.to_string()})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,20 +1,28 @@
|
|||||||
//! Candle harness — in-process inference using huggingface/candle.
|
//! Candle harness — in-process inference using huggingface/candle.
|
||||||
//!
|
//!
|
||||||
//! This is the sole `Harness` implementation. Inference runs inside
|
//! This is the sole `Harness` implementation. Inference runs inside
|
||||||
//! the neuron process; there is no external subprocess. Stage 2 wires
|
//! the neuron process; there is no external subprocess.
|
||||||
//! up GGUF (currently Qwen3 only) model load/unload via
|
//!
|
||||||
//! `candle-transformers::models::quantized_qwen3`. Stage 3 adds the
|
//! - Stage 2 wired GGUF (Qwen3 only) load/unload via `quantized_qwen3`.
|
||||||
//! inference endpoint.
|
//! - Stage 3 (this) adds `chat_completion` — a non-streaming OpenAI
|
||||||
|
//! compatible chat completion routed to the loaded model's forward
|
||||||
|
//! pass on a per-model serialised generation loop.
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use candle_core::Device;
|
|
||||||
use candle_core::quantized::gguf_file;
|
use candle_core::quantized::gguf_file;
|
||||||
|
use candle_core::{Device, Tensor};
|
||||||
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
|
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
|
||||||
use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec};
|
use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec};
|
||||||
|
use cortex_core::openai::{
|
||||||
|
ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage,
|
||||||
|
MessageContent, Usage,
|
||||||
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::{Mutex, RwLock};
|
use tokio::sync::{Mutex, RwLock};
|
||||||
|
|
||||||
@@ -26,19 +34,20 @@ pub struct CandleHarness {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A loaded model with its tokenizer, device placement, and architecture-
|
/// A loaded model with its tokenizer, device placement, and architecture-
|
||||||
/// specific weights. The `arch` field is mutexed because future inference
|
/// specific weights. The `arch` is `Arc<Mutex<>>` so the lock guard can be
|
||||||
/// calls take `&mut self` on the underlying ModelWeights (KV cache state).
|
/// moved into `spawn_blocking` for synchronous candle forward passes.
|
||||||
pub struct LoadedModel {
|
pub struct LoadedModel {
|
||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
pub arch: Mutex<ModelArch>,
|
pub arch: Arc<Mutex<ModelArch>>,
|
||||||
pub tokenizer: Tokenizer,
|
pub tokenizer: Tokenizer,
|
||||||
pub device: Device,
|
pub device: Device,
|
||||||
pub quant: Option<String>,
|
pub quant: Option<String>,
|
||||||
pub devices: Vec<u32>,
|
pub devices: Vec<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Architecture-specific weights. Stage 2 supports only Qwen3 quantized;
|
/// Architecture-specific weights. Stage 3 still supports only Qwen3
|
||||||
/// Stage 8 broadens this to additional families and non-quantized variants.
|
/// quantized; Stage 8 broadens this to additional families and
|
||||||
|
/// non-quantized variants.
|
||||||
pub enum ModelArch {
|
pub enum ModelArch {
|
||||||
Qwen3Quantized(QuantizedQwen3Weights),
|
Qwen3Quantized(QuantizedQwen3Weights),
|
||||||
}
|
}
|
||||||
@@ -117,6 +126,92 @@ impl CandleHarness {
|
|||||||
.context("fetch tokenizer.json")?;
|
.context("fetch tokenizer.json")?;
|
||||||
Ok((gguf_path, tokenizer_path))
|
Ok((gguf_path, tokenizer_path))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Run a non-streaming chat completion against a loaded model.
|
||||||
|
///
|
||||||
|
/// Returns a typed `InferenceError` when the model isn't loaded so the
|
||||||
|
/// handler can map to an appropriate HTTP status without string-matching.
|
||||||
|
pub async fn chat_completion(
|
||||||
|
&self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Result<ChatCompletionResponse, InferenceError> {
|
||||||
|
let loaded = {
|
||||||
|
let models = self.models.read().await;
|
||||||
|
models.get(&request.model).cloned()
|
||||||
|
};
|
||||||
|
let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
||||||
|
|
||||||
|
let prompt = format_qwen3_prompt(&request.messages);
|
||||||
|
|
||||||
|
let encoding = loaded
|
||||||
|
.tokenizer
|
||||||
|
.encode(prompt.as_str(), true)
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||||||
|
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||||||
|
let prompt_len = prompt_tokens.len();
|
||||||
|
|
||||||
|
let temperature = request.temperature.unwrap_or(0.7);
|
||||||
|
let top_p = request.top_p;
|
||||||
|
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
||||||
|
let seed = unix_subsec_nanos();
|
||||||
|
|
||||||
|
let eos_id = loaded
|
||||||
|
.tokenizer
|
||||||
|
.token_to_id("<|im_end|>")
|
||||||
|
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
|
||||||
|
|
||||||
|
let arch_arc = Arc::clone(&loaded.arch);
|
||||||
|
let device = loaded.device.clone();
|
||||||
|
let model_id = request.model.clone();
|
||||||
|
|
||||||
|
let (generated_ids, finish_reason) =
|
||||||
|
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
||||||
|
let mut guard = arch_arc.blocking_lock();
|
||||||
|
run_inference(
|
||||||
|
&mut guard,
|
||||||
|
&device,
|
||||||
|
&prompt_tokens,
|
||||||
|
max_new,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
seed,
|
||||||
|
eos_id,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("inference task panicked: {e}")))?
|
||||||
|
.map_err(InferenceError::Other)?;
|
||||||
|
|
||||||
|
let completion_text = loaded
|
||||||
|
.tokenizer
|
||||||
|
.decode(&generated_ids, true)
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
|
||||||
|
|
||||||
|
let usage = Usage {
|
||||||
|
prompt_tokens: prompt_len as u64,
|
||||||
|
completion_tokens: generated_ids.len() as u64,
|
||||||
|
total_tokens: (prompt_len + generated_ids.len()) as u64,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(ChatCompletionResponse {
|
||||||
|
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
|
||||||
|
object: "chat.completion".into(),
|
||||||
|
created: unix_now_secs(),
|
||||||
|
model: model_id,
|
||||||
|
choices: vec![ChatCompletionChoice {
|
||||||
|
index: 0,
|
||||||
|
message: ChatMessage {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: MessageContent::Text(completion_text),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
},
|
||||||
|
finish_reason: Some(finish_reason),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}],
|
||||||
|
usage: Some(usage),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@@ -193,7 +288,7 @@ impl Harness for CandleHarness {
|
|||||||
Ok(ModelArch::Qwen3Quantized(weights))
|
Ok(ModelArch::Qwen3Quantized(weights))
|
||||||
}
|
}
|
||||||
other => anyhow::bail!(
|
other => anyhow::bail!(
|
||||||
"unsupported GGUF architecture '{other}'; Stage 2 only supports qwen3"
|
"unsupported GGUF architecture '{other}'; Stage 3 only supports qwen3"
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -202,7 +297,7 @@ impl Harness for CandleHarness {
|
|||||||
|
|
||||||
let loaded = Arc::new(LoadedModel {
|
let loaded = Arc::new(LoadedModel {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: spec.model_id.clone(),
|
||||||
arch: Mutex::new(arch),
|
arch: Arc::new(Mutex::new(arch)),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
device,
|
device,
|
||||||
quant: spec.quant.clone(),
|
quant: spec.quant.clone(),
|
||||||
@@ -229,3 +324,118 @@ impl Harness for CandleHarness {
|
|||||||
models.contains_key(model_id).then(|| self.bind_url.clone())
|
models.contains_key(model_id).then(|| self.bind_url.clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Errors returned by `CandleHarness::chat_completion`. The
|
||||||
|
/// `ModelNotLoaded` variant lets the HTTP handler map cleanly to 404
|
||||||
|
/// without string-matching on anyhow messages.
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum InferenceError {
|
||||||
|
#[error("model '{0}' not loaded on this neuron")]
|
||||||
|
ModelNotLoaded(String),
|
||||||
|
#[error(transparent)]
|
||||||
|
Other(#[from] anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply the Qwen3 chat template:
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// <|im_start|>{role}\n{content}<|im_end|>\n
|
||||||
|
/// ...
|
||||||
|
/// <|im_start|>assistant\n
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// The trailing `<|im_start|>assistant\n` cues the model to begin a turn.
|
||||||
|
/// Non-text content parts (vision blocks) are joined as text only; full
|
||||||
|
/// multimodal handling is out of scope for Stage 3.
|
||||||
|
fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
|
||||||
|
let mut prompt = String::new();
|
||||||
|
for msg in messages {
|
||||||
|
let content = match &msg.content {
|
||||||
|
MessageContent::Text(s) => s.clone(),
|
||||||
|
MessageContent::Parts(parts) => parts
|
||||||
|
.iter()
|
||||||
|
.filter_map(|p| p.get("text").and_then(|v| v.as_str()))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(""),
|
||||||
|
};
|
||||||
|
prompt.push_str("<|im_start|>");
|
||||||
|
prompt.push_str(&msg.role);
|
||||||
|
prompt.push('\n');
|
||||||
|
prompt.push_str(&content);
|
||||||
|
prompt.push_str("<|im_end|>\n");
|
||||||
|
}
|
||||||
|
prompt.push_str("<|im_start|>assistant\n");
|
||||||
|
prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn run_inference(
|
||||||
|
arch: &mut ModelArch,
|
||||||
|
device: &Device,
|
||||||
|
prompt_tokens: &[u32],
|
||||||
|
max_new: usize,
|
||||||
|
temperature: f64,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
seed: u64,
|
||||||
|
eos_id: Option<u32>,
|
||||||
|
) -> Result<(Vec<u32>, String)> {
|
||||||
|
let mut logits_processor = {
|
||||||
|
let sampling = if temperature <= 0.0 {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match top_p {
|
||||||
|
Some(p) => Sampling::TopP { p, temperature },
|
||||||
|
None => Sampling::All { temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut generated: Vec<u32> = Vec::new();
|
||||||
|
|
||||||
|
let mut next_token = match arch {
|
||||||
|
ModelArch::Qwen3Quantized(model) => {
|
||||||
|
model.clear_kv_cache();
|
||||||
|
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, 0)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
logits_processor.sample(&logits)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
return Ok((generated, "stop".into()));
|
||||||
|
}
|
||||||
|
generated.push(next_token);
|
||||||
|
|
||||||
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
|
next_token = match arch {
|
||||||
|
ModelArch::Qwen3Quantized(model) => {
|
||||||
|
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
logits_processor.sample(&logits)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
return Ok((generated, "stop".into()));
|
||||||
|
}
|
||||||
|
generated.push(next_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((generated, "length".into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unix_now_secs() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.map(|d| d.as_secs())
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unix_subsec_nanos() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.map(|d| d.as_nanos() as u64)
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,10 +5,18 @@ pub mod candle;
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
|
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Registry of available harness implementations.
|
/// Registry of available harness implementations.
|
||||||
|
///
|
||||||
|
/// Holds an `Arc<dyn Harness>` per harness for generic lifecycle dispatch
|
||||||
|
/// (load/unload/list_models). When a candle harness is registered, a typed
|
||||||
|
/// `Arc<CandleHarness>` is also cached so inference routes can bypass the
|
||||||
|
/// dyn-Trait dispatch and reach harness-specific methods (chat completion,
|
||||||
|
/// streaming, etc.).
|
||||||
pub struct HarnessRegistry {
|
pub struct HarnessRegistry {
|
||||||
harnesses: HashMap<String, Box<dyn Harness>>,
|
harnesses: HashMap<String, Arc<dyn Harness>>,
|
||||||
|
candle: Option<Arc<candle::CandleHarness>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for HarnessRegistry {
|
impl Default for HarnessRegistry {
|
||||||
@@ -21,10 +29,11 @@ impl HarnessRegistry {
|
|||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
harnesses: HashMap::new(),
|
harnesses: HashMap::new(),
|
||||||
|
candle: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register(&mut self, harness: Box<dyn Harness>) {
|
pub fn register(&mut self, harness: Arc<dyn Harness>) {
|
||||||
self.harnesses.insert(harness.name().to_string(), harness);
|
self.harnesses.insert(harness.name().to_string(), harness);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,6 +42,12 @@ impl HarnessRegistry {
|
|||||||
self.harnesses.keys().cloned().collect()
|
self.harnesses.keys().cloned().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Typed handle to the candle harness, if registered. Used by inference
|
||||||
|
/// routes that need methods beyond the `Harness` trait surface.
|
||||||
|
pub fn candle(&self) -> Option<Arc<candle::CandleHarness>> {
|
||||||
|
self.candle.clone()
|
||||||
|
}
|
||||||
|
|
||||||
/// List models from all registered harnesses.
|
/// List models from all registered harnesses.
|
||||||
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
||||||
let mut all = Vec::new();
|
let mut all = Vec::new();
|
||||||
@@ -93,10 +108,12 @@ impl HarnessRegistry {
|
|||||||
for config in configs {
|
for config in configs {
|
||||||
match config.name.as_str() {
|
match config.name.as_str() {
|
||||||
"candle" => {
|
"candle" => {
|
||||||
registry.register(Box::new(candle::CandleHarness::new(
|
let harness = Arc::new(candle::CandleHarness::new(
|
||||||
bind_url.to_string(),
|
bind_url.to_string(),
|
||||||
settings.candle.hf_cache.clone(),
|
settings.candle.hf_cache.clone(),
|
||||||
)));
|
));
|
||||||
|
registry.candle = Some(Arc::clone(&harness));
|
||||||
|
registry.harnesses.insert("candle".into(), harness);
|
||||||
}
|
}
|
||||||
other => {
|
other => {
|
||||||
tracing::warn!(harness = other, "unknown harness type, skipping");
|
tracing::warn!(harness = other, "unknown harness type, skipping");
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ async fn main() -> Result<()> {
|
|||||||
// inference_endpoint.
|
// inference_endpoint.
|
||||||
let registry = HarnessRegistry::from_configs(&cfg.harnesses, &bind_url, &cfg.harness);
|
let registry = HarnessRegistry::from_configs(&cfg.harnesses, &bind_url, &cfg.harness);
|
||||||
discovery_result.harnesses = registry.names();
|
discovery_result.harnesses = registry.names();
|
||||||
|
let candle = registry.candle();
|
||||||
|
|
||||||
let health_cache = Arc::new(health::HealthCache::new());
|
let health_cache = Arc::new(health::HealthCache::new());
|
||||||
health_cache
|
health_cache
|
||||||
@@ -68,6 +69,7 @@ async fn main() -> Result<()> {
|
|||||||
discovery: discovery_result,
|
discovery: discovery_result,
|
||||||
health_cache,
|
health_cache,
|
||||||
registry: RwLock::new(registry),
|
registry: RwLock::new(registry),
|
||||||
|
candle,
|
||||||
});
|
});
|
||||||
|
|
||||||
let app = api::neuron_routes().with_state(state);
|
let app = api::neuron_routes().with_state(state);
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ async fn spawn_neuron(discovery: DiscoveryResponse) -> String {
|
|||||||
discovery,
|
discovery,
|
||||||
health_cache,
|
health_cache,
|
||||||
registry: RwLock::new(registry),
|
registry: RwLock::new(registry),
|
||||||
|
candle: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
let app = api::neuron_routes().with_state(state);
|
let app = api::neuron_routes().with_state(state);
|
||||||
@@ -152,11 +153,13 @@ async fn test_candle_harness_registers_and_rejects_bogus_model() {
|
|||||||
&HarnessSettings::default(),
|
&HarnessSettings::default(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let candle = registry.candle();
|
||||||
let health_cache = Arc::new(HealthCache::new());
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
let state = Arc::new(NeuronState {
|
let state = Arc::new(NeuronState {
|
||||||
discovery: fake_discovery(),
|
discovery: fake_discovery(),
|
||||||
health_cache,
|
health_cache,
|
||||||
registry: RwLock::new(registry),
|
registry: RwLock::new(registry),
|
||||||
|
candle,
|
||||||
});
|
});
|
||||||
|
|
||||||
let app = api::neuron_routes().with_state(state);
|
let app = api::neuron_routes().with_state(state);
|
||||||
@@ -197,3 +200,118 @@ async fn test_candle_harness_registers_and_rejects_bogus_model() {
|
|||||||
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
|
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||||
assert!(models.is_empty());
|
assert!(models.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// `/v1/chat/completions` returns 503 when no candle harness is registered.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chat_completions_no_candle_harness() {
|
||||||
|
let registry = HarnessRegistry::new();
|
||||||
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
|
let state = Arc::new(NeuronState {
|
||||||
|
discovery: fake_discovery(),
|
||||||
|
health_cache,
|
||||||
|
registry: RwLock::new(registry),
|
||||||
|
candle: None,
|
||||||
|
});
|
||||||
|
let app = api::neuron_routes().with_state(state);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
let url = format!("http://{addr}");
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{url}/v1/chat/completions"))
|
||||||
|
.json(&json!({
|
||||||
|
"model": "anything",
|
||||||
|
"messages": [{"role": "user", "content": "hi"}]
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 503);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `/v1/chat/completions` returns 404 when the requested model isn't loaded.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chat_completions_model_not_loaded() {
|
||||||
|
use cortex_core::harness::HarnessConfig;
|
||||||
|
use neuron::config::HarnessSettings;
|
||||||
|
|
||||||
|
let registry = HarnessRegistry::from_configs(
|
||||||
|
&[HarnessConfig {
|
||||||
|
name: "candle".into(),
|
||||||
|
}],
|
||||||
|
"http://localhost:0",
|
||||||
|
&HarnessSettings::default(),
|
||||||
|
);
|
||||||
|
let candle = registry.candle();
|
||||||
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
|
let state = Arc::new(NeuronState {
|
||||||
|
discovery: fake_discovery(),
|
||||||
|
health_cache,
|
||||||
|
registry: RwLock::new(registry),
|
||||||
|
candle,
|
||||||
|
});
|
||||||
|
let app = api::neuron_routes().with_state(state);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
let url = format!("http://{addr}");
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{url}/v1/chat/completions"))
|
||||||
|
.json(&json!({
|
||||||
|
"model": "definitely/not-loaded",
|
||||||
|
"messages": [{"role": "user", "content": "hi"}]
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 404);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `/v1/chat/completions` with `stream: true` returns 501 until Stage 4
|
||||||
|
/// wires up SSE.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chat_completions_streaming_not_yet_implemented() {
|
||||||
|
use cortex_core::harness::HarnessConfig;
|
||||||
|
use neuron::config::HarnessSettings;
|
||||||
|
|
||||||
|
let registry = HarnessRegistry::from_configs(
|
||||||
|
&[HarnessConfig {
|
||||||
|
name: "candle".into(),
|
||||||
|
}],
|
||||||
|
"http://localhost:0",
|
||||||
|
&HarnessSettings::default(),
|
||||||
|
);
|
||||||
|
let candle = registry.candle();
|
||||||
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
|
let state = Arc::new(NeuronState {
|
||||||
|
discovery: fake_discovery(),
|
||||||
|
health_cache,
|
||||||
|
registry: RwLock::new(registry),
|
||||||
|
candle,
|
||||||
|
});
|
||||||
|
let app = api::neuron_routes().with_state(state);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
let url = format!("http://{addr}");
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{url}/v1/chat/completions"))
|
||||||
|
.json(&json!({
|
||||||
|
"model": "anything",
|
||||||
|
"messages": [{"role": "user", "content": "hi"}],
|
||||||
|
"stream": true
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 501);
|
||||||
|
}
|
||||||
|
|||||||
@@ -60,10 +60,7 @@ async fn test_candle_qwen3_load_unload_lifecycle() {
|
|||||||
.await
|
.await
|
||||||
.expect("load_model should succeed");
|
.expect("load_model should succeed");
|
||||||
|
|
||||||
let models = registry
|
let models = registry.list_all_models().await.expect("list_all_models");
|
||||||
.list_all_models()
|
|
||||||
.await
|
|
||||||
.expect("list_all_models");
|
|
||||||
assert_eq!(models.len(), 1, "expected exactly one loaded model");
|
assert_eq!(models.len(), 1, "expected exactly one loaded model");
|
||||||
assert_eq!(models[0].id, model_id);
|
assert_eq!(models[0].id, model_id);
|
||||||
assert_eq!(models[0].harness, "candle");
|
assert_eq!(models[0].harness, "candle");
|
||||||
|
|||||||
Reference in New Issue
Block a user