From f72dee094f404dfb98de99d0082f256f515fa305 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Wed, 20 May 2026 07:32:46 +0300 Subject: [PATCH] =?UTF-8?q?feat(tp):=20Stage=207c-i=20=E2=80=94=20streamin?= =?UTF-8?q?g=20SSE=20through=20TP?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `chat_completion_stream` no longer returns an error for TP loads. The new `chat_completion_tp_stream` mirrors the non-streaming TP path (clear_kv_cache, prefill, sample, decode loop) but emits one `ChatCompletionChunk` per generated token over an mpsc channel so the handler can write a streaming SSE response. Unlike the single-GPU streaming path (which runs candle's forward inside `spawn_blocking` and uses `blocking_send`), the TP loop is itself async — every `pool.generate_step` already awaits the leader's own spawn_blocking forward plus every worker's recv_only. So the orchestration runs as a plain `tokio::spawn` task using `Sender::send`. The shared `emit_chunk` helper tracks the cumulative decoded prefix and emits the delta — same UTF-8-safe BPE boundary handling as the single-GPU streaming path. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/neuron/src/harness/candle.rs | 263 +++++++++++++++++++++++++++- 1 file changed, 254 insertions(+), 9 deletions(-) diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index cc1ffd1..f87d74b 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -526,15 +526,8 @@ impl CandleHarness { let loaded = match handle { LoadedHandle::Single(m) => m, #[cfg(feature = "cuda")] - LoadedHandle::Tp(_) => { - // Streaming through TP is Stage 7c work — the - // non-streaming path drives the same forwards through - // the pool but doesn't have to interleave SSE writes - // with spawn_blocking forwards. - return Err(InferenceError::Other(anyhow::anyhow!( - "streaming chat completions through TP are not yet supported; \ - retry with stream=false" - ))); + LoadedHandle::Tp(m) => { + return self.chat_completion_tp_stream(m, request).await; } }; @@ -961,6 +954,258 @@ impl CandleHarness { extra: serde_json::Value::Object(Default::default()), }) } + + /// Streaming counterpart to `chat_completion_tp`. Same per-step + /// orchestration (clear cache, prefill, sample, decode loop) but + /// emits one `ChatCompletionChunk` per token over an mpsc channel + /// so the handler can write an SSE stream. + /// + /// Unlike the single-GPU streaming path (which runs the candle + /// forward inside `spawn_blocking` and uses `blocking_send`), the + /// TP loop is itself async — every `pool.generate_step` awaits the + /// leader's spawn_blocking forward plus every worker's recv_only. + /// So we `tokio::spawn` the orchestration task and use plain + /// `Sender::send`. + #[cfg(feature = "cuda")] + async fn chat_completion_tp_stream( + &self, + tp: Arc, + request: ChatCompletionRequest, + ) -> Result, InferenceError> { + let prompt = format_qwen3_prompt(&request.messages); + let encoding = tp + .tokenizer + .encode(prompt.as_str(), true) + .map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?; + let prompt_tokens: Vec = encoding.get_ids().to_vec(); + let prompt_len = prompt_tokens.len(); + + let temperature = request.temperature.unwrap_or(0.7); + let top_p = request.top_p; + let max_new = request.max_tokens.unwrap_or(512) as usize; + let seed = unix_subsec_nanos(); + + let eos_id = tp + .tokenizer + .token_to_id("<|im_end|>") + .or_else(|| tp.tokenizer.token_to_id("<|endoftext|>")); + + let model_id = request.model.clone(); + let id = format!("chatcmpl-{:x}", unix_subsec_nanos()); + let created = unix_now_secs(); + let tokenizer = tp.tokenizer.clone(); + + // Bounded channel — back-pressures the producer when the SSE + // writer is slow. + let (tx, rx) = mpsc::channel::(32); + + // Role chunk first, before kicking off the heavy work — if the + // receiver is gone by now there's no point starting inference. + let role_chunk = ChatCompletionChunk { + id: id.clone(), + object: "chat.completion.chunk".into(), + created, + model: model_id.clone(), + choices: vec![ChunkChoice { + index: 0, + delta: json!({"role": "assistant"}), + finish_reason: None, + extra: serde_json::Value::Object(Default::default()), + }], + usage: None, + extra: serde_json::Value::Object(Default::default()), + }; + tx.send(role_chunk) + .await + .map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?; + + // The orchestration task. Holds the pool lock for the lifetime + // of this inference; concurrent requests against the same TP + // model serialise behind it. + let tp_for_task = Arc::clone(&tp); + tokio::spawn(async move { + let mut pool = tp_for_task.pool.lock().await; + let leader_arc = tp_for_task.leader_model.clone(); + + if let Err(e) = pool.clear_kv_cache(&model_id, leader_arc.clone()).await { + tracing::warn!(model = %model_id, error = %e, "TP stream: clear_kv_cache failed"); + return; + } + + let mut logits_processor = { + let sampling = if temperature <= 0.0 { + Sampling::ArgMax + } else { + match top_p { + Some(p) => Sampling::TopP { p, temperature }, + None => Sampling::All { temperature }, + } + }; + LogitsProcessor::from_sampling(seed, sampling) + }; + + let mut all_tokens: Vec = Vec::new(); + let mut decoded_prefix = String::new(); + let mut finish_reason = "length".to_string(); + + // Prefill — every rank embeds the prompt, offset = 0. + let logits = match pool + .generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0) + .await + { + Ok(l) => l, + Err(e) => { + tracing::warn!(model = %model_id, error = %e, "TP stream: prefill failed"); + return; + } + }; + let mut next_token = match sample_with_penalty( + &logits, + &all_tokens, + &mut logits_processor, + ) { + Ok(t) => t, + Err(e) => { + tracing::warn!(model = %model_id, error = %e, "TP stream: prefill sample failed"); + return; + } + }; + + if Some(next_token) == eos_id { + finish_reason = "stop".into(); + } else { + all_tokens.push(next_token); + if !emit_chunk( + &all_tokens, + &mut decoded_prefix, + &tokenizer, + &tx, + &id, + created, + &model_id, + ) + .await + { + return; + } + + for index in 0..max_new.saturating_sub(1) { + let logits = match pool + .generate_step( + &model_id, + leader_arc.clone(), + vec![next_token], + prompt_len + index, + ) + .await + { + Ok(l) => l, + Err(e) => { + tracing::warn!( + model = %model_id, + error = %e, + "TP stream: decode step failed" + ); + return; + } + }; + next_token = + match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) { + Ok(t) => t, + Err(e) => { + tracing::warn!( + model = %model_id, + error = %e, + "TP stream: decode sample failed" + ); + return; + } + }; + if Some(next_token) == eos_id { + finish_reason = "stop".into(); + break; + } + all_tokens.push(next_token); + if !emit_chunk( + &all_tokens, + &mut decoded_prefix, + &tokenizer, + &tx, + &id, + created, + &model_id, + ) + .await + { + return; + } + } + } + + // Final chunk carrying finish_reason. + let final_chunk = ChatCompletionChunk { + id: id.clone(), + object: "chat.completion.chunk".into(), + created, + model: model_id.clone(), + choices: vec![ChunkChoice { + index: 0, + delta: serde_json::Value::Object(Default::default()), + finish_reason: Some(finish_reason), + extra: serde_json::Value::Object(Default::default()), + }], + usage: None, + extra: serde_json::Value::Object(Default::default()), + }; + let _ = tx.send(final_chunk).await; + }); + + Ok(rx) + } +} + +/// Decode the cumulative token list, emit the delta (substring appended +/// since the last chunk) as a `chat.completion.chunk`. Returns `false` +/// if the receiver has hung up — the caller should bail. +#[cfg(feature = "cuda")] +async fn emit_chunk( + all_tokens: &[u32], + decoded_prefix: &mut String, + tokenizer: &Tokenizer, + tx: &mpsc::Sender, + id: &str, + created: u64, + model_id: &str, +) -> bool { + let full = match tokenizer.decode(all_tokens, true) { + Ok(s) => s, + Err(e) => { + tracing::warn!(error = %e, "TP stream: decode failed"); + return false; + } + }; + if full.len() > decoded_prefix.len() { + let delta = full[decoded_prefix.len()..].to_string(); + *decoded_prefix = full; + let chunk = ChatCompletionChunk { + id: id.into(), + object: "chat.completion.chunk".into(), + created, + model: model_id.into(), + choices: vec![ChunkChoice { + index: 0, + delta: json!({ "content": delta }), + finish_reason: None, + extra: serde_json::Value::Object(Default::default()), + }], + usage: None, + extra: serde_json::Value::Object(Default::default()), + }; + if tx.send(chunk).await.is_err() { + return false; + } + } + true } /// Errors returned by `CandleHarness::chat_completion`. The