feat(neuron): OpenAI-compatible SSE streaming chat completions

Stage 4 of the candle-native pivot. /v1/chat/completions now switches
to text/event-stream when the request sets stream: true, emitting one
chat.completion.chunk per generated token followed by the OpenAI
[DONE] terminator.

Pipeline:
- chat_completion_stream creates a bounded mpsc::channel<ChatCompletionChunk>(32),
  sends the leading role chunk, then spawns a blocking task that
  acquires the per-model arch lock and runs the streaming generation
  loop.
- run_inference_streaming tracks a cumulative decoded prefix so each
  chunk's delta.content is the substring added since the last chunk —
  safe across BPE byte-fallback boundaries that would otherwise split
  multi-byte UTF-8 chars.
- The blocking task aborts cleanly if blocking_send fails (client
  disconnected), so generation stops when the SSE consumer hangs up.
- Final chunk carries finish_reason ("stop" on EOS, "length" on
  max_tokens). The handler appends data: [DONE] after the channel
  closes.

The Stage 3 streaming 501 placeholder test is repurposed: with the
streaming path live, an unloaded model now hits the same 404 surface
as the non-streaming path (the model lookup happens first).

cortex-gateway's existing proxy is unchanged — it already forwards
SSE bytes verbatim from Phase 2 work, so the candle SSE format passes
through unmodified.

Neuron Cargo.toml gains futures + tokio-stream (both already in
workspace deps) for ReceiverStream and stream combinators.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-18 17:53:14 +03:00
parent 249c9442e8
commit 84f5662df1
5 changed files with 282 additions and 29 deletions

View File

@@ -16,15 +16,16 @@ use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec};
use cortex_core::openai::{
ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage,
MessageContent, Usage,
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse,
ChatMessage, ChunkChoice, MessageContent, Usage,
};
use serde_json::json;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokenizers::Tokenizer;
use tokio::sync::{Mutex, RwLock};
use tokio::sync::{Mutex, RwLock, mpsc};
/// In-process candle harness. Owns the loaded model registry.
pub struct CandleHarness {
@@ -212,6 +213,104 @@ impl CandleHarness {
extra: serde_json::Value::Object(Default::default()),
})
}
/// Run a streaming chat completion against a loaded model.
///
/// Returns an `mpsc::Receiver` that yields `ChatCompletionChunk`s in
/// OpenAI SSE format. The first chunk carries the assistant role;
/// subsequent chunks carry incremental `content` deltas; the final
/// chunk carries `finish_reason`. The handler is responsible for
/// wrapping these into an SSE response and appending the `[DONE]`
/// terminator.
///
/// Token-by-token decoding tracks the cumulative decoded prefix so
/// BPE byte-fallback boundaries don't split a UTF-8 char across
/// chunks.
pub async fn chat_completion_stream(
&self,
request: ChatCompletionRequest,
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
let loaded = {
let models = self.models.read().await;
models.get(&request.model).cloned()
};
let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
let prompt = format_qwen3_prompt(&request.messages);
let encoding = loaded
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(512) as usize;
let seed = unix_subsec_nanos();
let eos_id = loaded
.tokenizer
.token_to_id("<|im_end|>")
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
let arch_arc = Arc::clone(&loaded.arch);
let device = loaded.device.clone();
let tokenizer = loaded.tokenizer.clone();
let model_id = request.model.clone();
let id = format!("chatcmpl-{:x}", unix_subsec_nanos());
let created = unix_now_secs();
// Bounded channel so the producer (blocking inference) is back-
// pressured by the consumer (SSE writer). 32 is generous —
// tokens arrive one at a time and the SSE writer is async.
let (tx, rx) = mpsc::channel::<ChatCompletionChunk>(32);
// Lead chunk: announce the assistant role per OpenAI streaming
// conventions. Tools that auto-detect a streaming reply expect
// this before any content delta.
let role_chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".into(),
created,
model: model_id.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: json!({"role": "assistant"}),
finish_reason: None,
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
// If sending the role chunk fails the receiver is already gone;
// bail before kicking off the heavy blocking work.
tx.send(role_chunk)
.await
.map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?;
tokio::task::spawn_blocking(move || {
let mut guard = arch_arc.blocking_lock();
if let Err(e) = run_inference_streaming(
&mut guard,
&device,
&tokenizer,
&prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
&id,
created,
&model_id,
&tx,
) {
tracing::warn!(model = %model_id, error = %e, "streaming inference failed");
}
});
Ok(rx)
}
}
#[async_trait]
@@ -426,6 +525,130 @@ fn run_inference(
Ok((generated, "length".into()))
}
/// Streaming counterpart to `run_inference`. Emits chunks via `tx` as
/// tokens are generated and exits on EOS, max_new, or receiver drop.
///
/// Detokenization tracks the cumulative decoded prefix so each chunk's
/// `content` delta is the substring appended since the last chunk —
/// safe across BPE byte-fallback boundaries.
#[allow(clippy::too_many_arguments)]
fn run_inference_streaming(
arch: &mut ModelArch,
device: &Device,
tokenizer: &Tokenizer,
prompt_tokens: &[u32],
max_new: usize,
temperature: f64,
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
id: &str,
created: u64,
model_id: &str,
tx: &mpsc::Sender<ChatCompletionChunk>,
) -> Result<()> {
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut all_tokens: Vec<u32> = Vec::new();
let mut decoded_prefix = String::new();
let mut finish_reason = "length".to_string();
let mut next_token = match arch {
ModelArch::Qwen3Quantized(model) => {
model.clear_kv_cache();
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
}
};
let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result<bool> {
let full = tokenizer
.decode(all_tokens, true)
.map_err(|e| anyhow::anyhow!("decode: {e}"))?;
if full.len() > decoded_prefix.len() {
let delta = full[decoded_prefix.len()..].to_string();
*decoded_prefix = full;
let chunk = ChatCompletionChunk {
id: id.into(),
object: "chat.completion.chunk".into(),
created,
model: model_id.into(),
choices: vec![ChunkChoice {
index: 0,
delta: json!({ "content": delta }),
finish_reason: None,
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
// blocking_send returns Err if the consumer hung up — signal
// the caller to stop generating.
if tx.blocking_send(chunk).is_err() {
return Ok(false);
}
}
Ok(true)
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
} else {
all_tokens.push(next_token);
if !emit_token(&all_tokens, &mut decoded_prefix)? {
return Ok(());
}
for index in 0..max_new.saturating_sub(1) {
next_token = match arch {
ModelArch::Qwen3Quantized(model) => {
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
}
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
}
all_tokens.push(next_token);
if !emit_token(&all_tokens, &mut decoded_prefix)? {
return Ok(());
}
}
}
let final_chunk = ChatCompletionChunk {
id: id.into(),
object: "chat.completion.chunk".into(),
created,
model: model_id.into(),
choices: vec![ChunkChoice {
index: 0,
delta: serde_json::Value::Object(Default::default()),
finish_reason: Some(finish_reason),
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
let _ = tx.blocking_send(final_chunk);
Ok(())
}
fn unix_now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)