Files
cortex/crates/neuron/src/api.rs
rob thijssen 729317d1ef feat(neuron): OpenAI-compatible non-streaming chat completion
Stage 3 of the candle-native pivot. neuron now serves
POST /v1/chat/completions backed by candle's quantized_qwen3 forward
pass on a per-model serialised generation loop, returning the standard
OpenAI ChatCompletionResponse envelope.

Pipeline per request:
- Look up the LoadedModel by request.model (404 if absent).
- Apply the Qwen3 chat template across all messages.
- Tokenize, then spawn_blocking onto tokio's blocking pool to acquire
  the per-model arch lock and run prefill + greedy/temperature/top-p
  sampling via LogitsProcessor.
- Stop on <|im_end|>/<|endoftext|> EOS or max_tokens (finish_reason
  "stop" vs "length").
- Decode with skip_special_tokens=true, build OpenAI response with
  prompt/completion/total usage counts.

Supporting changes:
- HarnessRegistry now stores Arc<dyn Harness> and caches a typed
  Arc<CandleHarness> so inference routes bypass dyn-Trait dispatch.
- LoadedModel.arch becomes Arc<Mutex<ModelArch>> so the lock guard
  can be moved into spawn_blocking.
- NeuronState gains an Option<Arc<CandleHarness>> field for the new
  inference route.
- Typed InferenceError lets the handler map ModelNotLoaded → 404 and
  other failures → 500 without string-matching anyhow messages.
- stream=true returns 501 until Stage 4 wires up SSE.
- Two leftover mistral.rs string references in proxy.rs and cortex-cli
  (missed during the Stage 1 sweep) are corrected here.

Three new default-feature tests cover the no-candle 503, model-not-
loaded 404, and stream=true 501 paths. The cuda-integration test from
Stage 2 still covers real load/unload; a streaming-feature gated test
exercising actual generation will arrive with Stage 4.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 16:47:58 +03:00

149 lines
4.9 KiB
Rust

//! HTTP API handlers for the neuron daemon.
use crate::harness::HarnessRegistry;
use crate::harness::candle::{CandleHarness, InferenceError};
use crate::health::HealthCache;
use axum::Router;
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json};
use axum::routing::{get, post};
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
use cortex_core::harness::ModelSpec;
use cortex_core::openai::ChatCompletionRequest;
use serde_json::{Value, json};
use std::sync::Arc;
use tokio::sync::RwLock;
/// Shared state for the neuron HTTP server.
pub struct NeuronState {
pub discovery: DiscoveryResponse,
pub health_cache: Arc<HealthCache>,
pub registry: RwLock<HarnessRegistry>,
/// Typed handle to the candle harness for inference routes. Cached at
/// startup so `/v1/chat/completions` doesn't have to hold the registry
/// read lock or perform dyn-Trait dispatch per request.
pub candle: Option<Arc<CandleHarness>>,
}
/// Build the neuron API router.
pub fn neuron_routes() -> Router<Arc<NeuronState>> {
Router::new()
.route("/discovery", get(discovery_handler))
.route("/health", get(health_handler))
.route("/models", get(list_models))
.route("/models/load", post(load_model))
.route("/models/unload", post(unload_model))
.route("/models/{model_id}/endpoint", get(model_endpoint))
.route("/v1/chat/completions", post(chat_completions))
}
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
Json(state.discovery.clone())
}
async fn health_handler(State(state): State<Arc<NeuronState>>) -> Json<HealthResponse> {
Json(state.health_cache.snapshot().await)
}
async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse {
let registry = state.registry.read().await;
match registry.list_all_models().await {
Ok(models) => Json(json!(models)).into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
}
async fn load_model(
State(state): State<Arc<NeuronState>>,
Json(spec): Json<ModelSpec>,
) -> impl IntoResponse {
let registry = state.registry.read().await;
match registry.load_model(&spec).await {
Ok(()) => Json(json!({"status": "loaded"})).into_response(),
Err(e) => (
StatusCode::BAD_REQUEST,
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
}
async fn unload_model(
State(state): State<Arc<NeuronState>>,
Json(body): Json<Value>,
) -> impl IntoResponse {
let model_id = match body.get("model_id").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "missing model_id"})),
)
.into_response();
}
};
let registry = state.registry.read().await;
match registry.unload_model(&model_id).await {
Ok(()) => Json(json!({"status": "unloaded"})).into_response(),
Err(e) => (StatusCode::NOT_FOUND, Json(json!({"error": e.to_string()}))).into_response(),
}
}
async fn model_endpoint(
State(state): State<Arc<NeuronState>>,
Path(model_id): Path<String>,
) -> impl IntoResponse {
let registry = state.registry.read().await;
match registry.inference_endpoint(&model_id).await {
Some(url) => Json(json!({"url": url})).into_response(),
None => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("model '{}' not loaded", model_id)})),
)
.into_response(),
}
}
/// OpenAI-compatible chat completions. Non-streaming for Stage 3; the
/// streaming path is added in Stage 4.
async fn chat_completions(
State(state): State<Arc<NeuronState>>,
Json(req): Json<ChatCompletionRequest>,
) -> impl IntoResponse {
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({"error": "candle harness not enabled on this neuron"})),
)
.into_response();
};
if req.stream.unwrap_or(false) {
return (
StatusCode::NOT_IMPLEMENTED,
Json(json!({"error": "streaming responses arrive in Stage 4"})),
)
.into_response();
}
match candle.chat_completion(req).await {
Ok(resp) => Json(resp).into_response(),
Err(InferenceError::ModelNotLoaded(id)) => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
)
.into_response(),
Err(InferenceError::Other(e)) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
}