diff --git a/Cargo.lock b/Cargo.lock index 2c34803..0635e1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2114,6 +2114,7 @@ dependencies = [ "clap", "cortex-core", "figment", + "futures", "hf-hub", "reqwest", "serde", @@ -2121,6 +2122,7 @@ dependencies = [ "thiserror 2.0.18", "tokenizers", "tokio", + "tokio-stream", "toml", "tracing", "tracing-subscriber", diff --git a/crates/neuron/Cargo.toml b/crates/neuron/Cargo.toml index 862139c..94ef7f0 100644 --- a/crates/neuron/Cargo.toml +++ b/crates/neuron/Cargo.toml @@ -49,6 +49,8 @@ anyhow.workspace = true async-trait.workspace = true clap.workspace = true thiserror.workspace = true +futures.workspace = true +tokio-stream.workspace = true figment.workspace = true toml.workspace = true diff --git a/crates/neuron/src/api.rs b/crates/neuron/src/api.rs index 3ef267d..30399d2 100644 --- a/crates/neuron/src/api.rs +++ b/crates/neuron/src/api.rs @@ -6,14 +6,18 @@ use crate::health::HealthCache; use axum::Router; use axum::extract::{Path, State}; use axum::http::StatusCode; +use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Json}; use axum::routing::{get, post}; use cortex_core::discovery::{DiscoveryResponse, HealthResponse}; use cortex_core::harness::ModelSpec; use cortex_core::openai::ChatCompletionRequest; +use futures::stream::{self, StreamExt}; use serde_json::{Value, json}; +use std::convert::Infallible; use std::sync::Arc; use tokio::sync::RwLock; +use tokio_stream::wrappers::ReceiverStream; /// Shared state for the neuron HTTP server. pub struct NeuronState { @@ -110,8 +114,9 @@ async fn model_endpoint( } } -/// OpenAI-compatible chat completions. Non-streaming for Stage 3; the -/// streaming path is added in Stage 4. +/// OpenAI-compatible chat completions. Dispatches to streaming SSE when +/// `stream: true` is set on the request; otherwise returns a single +/// `ChatCompletionResponse`. async fn chat_completions( State(state): State>, Json(req): Json, @@ -125,24 +130,44 @@ async fn chat_completions( }; if req.stream.unwrap_or(false) { - return ( - StatusCode::NOT_IMPLEMENTED, - Json(json!({"error": "streaming responses arrive in Stage 4"})), - ) - .into_response(); - } - - match candle.chat_completion(req).await { - Ok(resp) => Json(resp).into_response(), - Err(InferenceError::ModelNotLoaded(id)) => ( - StatusCode::NOT_FOUND, - Json(json!({"error": format!("model '{id}' not loaded on this neuron")})), - ) - .into_response(), - Err(InferenceError::Other(e)) => ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": e.to_string()})), - ) - .into_response(), + match candle.chat_completion_stream(req).await { + Ok(rx) => { + // Each chunk → one SSE `data: {json}` line. After the + // channel closes, append the OpenAI [DONE] terminator. + let body_stream = ReceiverStream::new(rx).map(|chunk| { + let body = serde_json::to_string(&chunk).unwrap_or_default(); + Ok::<_, Infallible>(Event::default().data(body)) + }); + let done_stream = + stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) }); + Sse::new(body_stream.chain(done_stream)) + .keep_alive(KeepAlive::default()) + .into_response() + } + Err(InferenceError::ModelNotLoaded(id)) => ( + StatusCode::NOT_FOUND, + Json(json!({"error": format!("model '{id}' not loaded on this neuron")})), + ) + .into_response(), + Err(InferenceError::Other(e)) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": e.to_string()})), + ) + .into_response(), + } + } else { + match candle.chat_completion(req).await { + Ok(resp) => Json(resp).into_response(), + Err(InferenceError::ModelNotLoaded(id)) => ( + StatusCode::NOT_FOUND, + Json(json!({"error": format!("model '{id}' not loaded on this neuron")})), + ) + .into_response(), + Err(InferenceError::Other(e)) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": e.to_string()})), + ) + .into_response(), + } } } diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index b3ee3cb..6482f4e 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -16,15 +16,16 @@ use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights; use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec}; use cortex_core::openai::{ - ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, - MessageContent, Usage, + ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, + ChatMessage, ChunkChoice, MessageContent, Usage, }; +use serde_json::json; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use tokenizers::Tokenizer; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::{Mutex, RwLock, mpsc}; /// In-process candle harness. Owns the loaded model registry. pub struct CandleHarness { @@ -212,6 +213,104 @@ impl CandleHarness { extra: serde_json::Value::Object(Default::default()), }) } + + /// Run a streaming chat completion against a loaded model. + /// + /// Returns an `mpsc::Receiver` that yields `ChatCompletionChunk`s in + /// OpenAI SSE format. The first chunk carries the assistant role; + /// subsequent chunks carry incremental `content` deltas; the final + /// chunk carries `finish_reason`. The handler is responsible for + /// wrapping these into an SSE response and appending the `[DONE]` + /// terminator. + /// + /// Token-by-token decoding tracks the cumulative decoded prefix so + /// BPE byte-fallback boundaries don't split a UTF-8 char across + /// chunks. + pub async fn chat_completion_stream( + &self, + request: ChatCompletionRequest, + ) -> Result, InferenceError> { + let loaded = { + let models = self.models.read().await; + models.get(&request.model).cloned() + }; + let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?; + + let prompt = format_qwen3_prompt(&request.messages); + let encoding = loaded + .tokenizer + .encode(prompt.as_str(), true) + .map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?; + let prompt_tokens: Vec = encoding.get_ids().to_vec(); + + let temperature = request.temperature.unwrap_or(0.7); + let top_p = request.top_p; + let max_new = request.max_tokens.unwrap_or(512) as usize; + let seed = unix_subsec_nanos(); + + let eos_id = loaded + .tokenizer + .token_to_id("<|im_end|>") + .or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>")); + + let arch_arc = Arc::clone(&loaded.arch); + let device = loaded.device.clone(); + let tokenizer = loaded.tokenizer.clone(); + let model_id = request.model.clone(); + let id = format!("chatcmpl-{:x}", unix_subsec_nanos()); + let created = unix_now_secs(); + + // Bounded channel so the producer (blocking inference) is back- + // pressured by the consumer (SSE writer). 32 is generous — + // tokens arrive one at a time and the SSE writer is async. + let (tx, rx) = mpsc::channel::(32); + + // Lead chunk: announce the assistant role per OpenAI streaming + // conventions. Tools that auto-detect a streaming reply expect + // this before any content delta. + let role_chunk = ChatCompletionChunk { + id: id.clone(), + object: "chat.completion.chunk".into(), + created, + model: model_id.clone(), + choices: vec![ChunkChoice { + index: 0, + delta: json!({"role": "assistant"}), + finish_reason: None, + extra: serde_json::Value::Object(Default::default()), + }], + usage: None, + extra: serde_json::Value::Object(Default::default()), + }; + // If sending the role chunk fails the receiver is already gone; + // bail before kicking off the heavy blocking work. + tx.send(role_chunk) + .await + .map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?; + + tokio::task::spawn_blocking(move || { + let mut guard = arch_arc.blocking_lock(); + if let Err(e) = run_inference_streaming( + &mut guard, + &device, + &tokenizer, + &prompt_tokens, + max_new, + temperature, + top_p, + seed, + eos_id, + &id, + created, + &model_id, + &tx, + ) { + tracing::warn!(model = %model_id, error = %e, "streaming inference failed"); + } + }); + + Ok(rx) + } } #[async_trait] @@ -426,6 +525,130 @@ fn run_inference( Ok((generated, "length".into())) } +/// Streaming counterpart to `run_inference`. Emits chunks via `tx` as +/// tokens are generated and exits on EOS, max_new, or receiver drop. +/// +/// Detokenization tracks the cumulative decoded prefix so each chunk's +/// `content` delta is the substring appended since the last chunk — +/// safe across BPE byte-fallback boundaries. +#[allow(clippy::too_many_arguments)] +fn run_inference_streaming( + arch: &mut ModelArch, + device: &Device, + tokenizer: &Tokenizer, + prompt_tokens: &[u32], + max_new: usize, + temperature: f64, + top_p: Option, + seed: u64, + eos_id: Option, + id: &str, + created: u64, + model_id: &str, + tx: &mpsc::Sender, +) -> Result<()> { + let mut logits_processor = { + let sampling = if temperature <= 0.0 { + Sampling::ArgMax + } else { + match top_p { + Some(p) => Sampling::TopP { p, temperature }, + None => Sampling::All { temperature }, + } + }; + LogitsProcessor::from_sampling(seed, sampling) + }; + + let mut all_tokens: Vec = Vec::new(); + let mut decoded_prefix = String::new(); + let mut finish_reason = "length".to_string(); + + let mut next_token = match arch { + ModelArch::Qwen3Quantized(model) => { + model.clear_kv_cache(); + let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; + let logits = logits.squeeze(0)?; + logits_processor.sample(&logits)? + } + }; + + let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result { + let full = tokenizer + .decode(all_tokens, true) + .map_err(|e| anyhow::anyhow!("decode: {e}"))?; + if full.len() > decoded_prefix.len() { + let delta = full[decoded_prefix.len()..].to_string(); + *decoded_prefix = full; + let chunk = ChatCompletionChunk { + id: id.into(), + object: "chat.completion.chunk".into(), + created, + model: model_id.into(), + choices: vec![ChunkChoice { + index: 0, + delta: json!({ "content": delta }), + finish_reason: None, + extra: serde_json::Value::Object(Default::default()), + }], + usage: None, + extra: serde_json::Value::Object(Default::default()), + }; + // blocking_send returns Err if the consumer hung up — signal + // the caller to stop generating. + if tx.blocking_send(chunk).is_err() { + return Ok(false); + } + } + Ok(true) + }; + + if Some(next_token) == eos_id { + finish_reason = "stop".into(); + } else { + all_tokens.push(next_token); + if !emit_token(&all_tokens, &mut decoded_prefix)? { + return Ok(()); + } + + for index in 0..max_new.saturating_sub(1) { + next_token = match arch { + ModelArch::Qwen3Quantized(model) => { + let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?; + let logits = model.forward(&input, prompt_tokens.len() + index)?; + let logits = logits.squeeze(0)?; + logits_processor.sample(&logits)? + } + }; + if Some(next_token) == eos_id { + finish_reason = "stop".into(); + break; + } + all_tokens.push(next_token); + if !emit_token(&all_tokens, &mut decoded_prefix)? { + return Ok(()); + } + } + } + + let final_chunk = ChatCompletionChunk { + id: id.into(), + object: "chat.completion.chunk".into(), + created, + model: model_id.into(), + choices: vec![ChunkChoice { + index: 0, + delta: serde_json::Value::Object(Default::default()), + finish_reason: Some(finish_reason), + extra: serde_json::Value::Object(Default::default()), + }], + usage: None, + extra: serde_json::Value::Object(Default::default()), + }; + let _ = tx.blocking_send(final_chunk); + Ok(()) +} + fn unix_now_secs() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) diff --git a/crates/neuron/tests/api.rs b/crates/neuron/tests/api.rs index 9e8f121..61a9a03 100644 --- a/crates/neuron/tests/api.rs +++ b/crates/neuron/tests/api.rs @@ -273,10 +273,11 @@ async fn test_chat_completions_model_not_loaded() { assert_eq!(resp.status(), 404); } -/// `/v1/chat/completions` with `stream: true` returns 501 until Stage 4 -/// wires up SSE. +/// `/v1/chat/completions` with `stream: true` returns 404 when the +/// model isn't loaded — same surface as the non-streaming path. The +/// streaming code only kicks in once the model lookup succeeds. #[tokio::test] -async fn test_chat_completions_streaming_not_yet_implemented() { +async fn test_chat_completions_streaming_model_not_loaded() { use cortex_core::harness::HarnessConfig; use neuron::config::HarnessSettings; @@ -306,12 +307,12 @@ async fn test_chat_completions_streaming_not_yet_implemented() { let resp = reqwest::Client::new() .post(format!("{url}/v1/chat/completions")) .json(&json!({ - "model": "anything", + "model": "definitely/not-loaded", "messages": [{"role": "user", "content": "hi"}], "stream": true })) .send() .await .unwrap(); - assert_eq!(resp.status(), 501); + assert_eq!(resp.status(), 404); }