feat(tp): Stage 7c-i — streaming SSE through TP
Some checks failed
build-prerelease / Package cortex RPM (push) Blocked by required conditions
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 37s
CI / Clippy (push) Successful in 2m12s
CI / Test (push) Successful in 5m3s
build-prerelease / Build neuron-blackwell (push) Successful in 3m39s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 5m7s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Some checks failed
build-prerelease / Package cortex RPM (push) Blocked by required conditions
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 37s
CI / Clippy (push) Successful in 2m12s
CI / Test (push) Successful in 5m3s
build-prerelease / Build neuron-blackwell (push) Successful in 3m39s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 5m7s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
`chat_completion_stream` no longer returns an error for TP loads. The new `chat_completion_tp_stream` mirrors the non-streaming TP path (clear_kv_cache, prefill, sample, decode loop) but emits one `ChatCompletionChunk` per generated token over an mpsc channel so the handler can write a streaming SSE response. Unlike the single-GPU streaming path (which runs candle's forward inside `spawn_blocking` and uses `blocking_send`), the TP loop is itself async — every `pool.generate_step` already awaits the leader's own spawn_blocking forward plus every worker's recv_only. So the orchestration runs as a plain `tokio::spawn` task using `Sender::send`. The shared `emit_chunk` helper tracks the cumulative decoded prefix and emits the delta — same UTF-8-safe BPE boundary handling as the single-GPU streaming path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -526,15 +526,8 @@ impl CandleHarness {
|
|||||||
let loaded = match handle {
|
let loaded = match handle {
|
||||||
LoadedHandle::Single(m) => m,
|
LoadedHandle::Single(m) => m,
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
LoadedHandle::Tp(_) => {
|
LoadedHandle::Tp(m) => {
|
||||||
// Streaming through TP is Stage 7c work — the
|
return self.chat_completion_tp_stream(m, request).await;
|
||||||
// non-streaming path drives the same forwards through
|
|
||||||
// the pool but doesn't have to interleave SSE writes
|
|
||||||
// with spawn_blocking forwards.
|
|
||||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
|
||||||
"streaming chat completions through TP are not yet supported; \
|
|
||||||
retry with stream=false"
|
|
||||||
)));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -961,6 +954,258 @@ impl CandleHarness {
|
|||||||
extra: serde_json::Value::Object(Default::default()),
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Streaming counterpart to `chat_completion_tp`. Same per-step
|
||||||
|
/// orchestration (clear cache, prefill, sample, decode loop) but
|
||||||
|
/// emits one `ChatCompletionChunk` per token over an mpsc channel
|
||||||
|
/// so the handler can write an SSE stream.
|
||||||
|
///
|
||||||
|
/// Unlike the single-GPU streaming path (which runs the candle
|
||||||
|
/// forward inside `spawn_blocking` and uses `blocking_send`), the
|
||||||
|
/// TP loop is itself async — every `pool.generate_step` awaits the
|
||||||
|
/// leader's spawn_blocking forward plus every worker's recv_only.
|
||||||
|
/// So we `tokio::spawn` the orchestration task and use plain
|
||||||
|
/// `Sender::send`.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
async fn chat_completion_tp_stream(
|
||||||
|
&self,
|
||||||
|
tp: Arc<TpLoadedModel>,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
||||||
|
let prompt = format_qwen3_prompt(&request.messages);
|
||||||
|
let encoding = tp
|
||||||
|
.tokenizer
|
||||||
|
.encode(prompt.as_str(), true)
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||||||
|
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||||||
|
let prompt_len = prompt_tokens.len();
|
||||||
|
|
||||||
|
let temperature = request.temperature.unwrap_or(0.7);
|
||||||
|
let top_p = request.top_p;
|
||||||
|
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
||||||
|
let seed = unix_subsec_nanos();
|
||||||
|
|
||||||
|
let eos_id = tp
|
||||||
|
.tokenizer
|
||||||
|
.token_to_id("<|im_end|>")
|
||||||
|
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
|
||||||
|
|
||||||
|
let model_id = request.model.clone();
|
||||||
|
let id = format!("chatcmpl-{:x}", unix_subsec_nanos());
|
||||||
|
let created = unix_now_secs();
|
||||||
|
let tokenizer = tp.tokenizer.clone();
|
||||||
|
|
||||||
|
// Bounded channel — back-pressures the producer when the SSE
|
||||||
|
// writer is slow.
|
||||||
|
let (tx, rx) = mpsc::channel::<ChatCompletionChunk>(32);
|
||||||
|
|
||||||
|
// Role chunk first, before kicking off the heavy work — if the
|
||||||
|
// receiver is gone by now there's no point starting inference.
|
||||||
|
let role_chunk = ChatCompletionChunk {
|
||||||
|
id: id.clone(),
|
||||||
|
object: "chat.completion.chunk".into(),
|
||||||
|
created,
|
||||||
|
model: model_id.clone(),
|
||||||
|
choices: vec![ChunkChoice {
|
||||||
|
index: 0,
|
||||||
|
delta: json!({"role": "assistant"}),
|
||||||
|
finish_reason: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
};
|
||||||
|
tx.send(role_chunk)
|
||||||
|
.await
|
||||||
|
.map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?;
|
||||||
|
|
||||||
|
// The orchestration task. Holds the pool lock for the lifetime
|
||||||
|
// of this inference; concurrent requests against the same TP
|
||||||
|
// model serialise behind it.
|
||||||
|
let tp_for_task = Arc::clone(&tp);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut pool = tp_for_task.pool.lock().await;
|
||||||
|
let leader_arc = tp_for_task.leader_model.clone();
|
||||||
|
|
||||||
|
if let Err(e) = pool.clear_kv_cache(&model_id, leader_arc.clone()).await {
|
||||||
|
tracing::warn!(model = %model_id, error = %e, "TP stream: clear_kv_cache failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut logits_processor = {
|
||||||
|
let sampling = if temperature <= 0.0 {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match top_p {
|
||||||
|
Some(p) => Sampling::TopP { p, temperature },
|
||||||
|
None => Sampling::All { temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut all_tokens: Vec<u32> = Vec::new();
|
||||||
|
let mut decoded_prefix = String::new();
|
||||||
|
let mut finish_reason = "length".to_string();
|
||||||
|
|
||||||
|
// Prefill — every rank embeds the prompt, offset = 0.
|
||||||
|
let logits = match pool
|
||||||
|
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(l) => l,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(model = %model_id, error = %e, "TP stream: prefill failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut next_token = match sample_with_penalty(
|
||||||
|
&logits,
|
||||||
|
&all_tokens,
|
||||||
|
&mut logits_processor,
|
||||||
|
) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(model = %model_id, error = %e, "TP stream: prefill sample failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
finish_reason = "stop".into();
|
||||||
|
} else {
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if !emit_chunk(
|
||||||
|
&all_tokens,
|
||||||
|
&mut decoded_prefix,
|
||||||
|
&tokenizer,
|
||||||
|
&tx,
|
||||||
|
&id,
|
||||||
|
created,
|
||||||
|
&model_id,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
|
let logits = match pool
|
||||||
|
.generate_step(
|
||||||
|
&model_id,
|
||||||
|
leader_arc.clone(),
|
||||||
|
vec![next_token],
|
||||||
|
prompt_len + index,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(l) => l,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"TP stream: decode step failed"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
next_token =
|
||||||
|
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"TP stream: decode sample failed"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
finish_reason = "stop".into();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if !emit_chunk(
|
||||||
|
&all_tokens,
|
||||||
|
&mut decoded_prefix,
|
||||||
|
&tokenizer,
|
||||||
|
&tx,
|
||||||
|
&id,
|
||||||
|
created,
|
||||||
|
&model_id,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final chunk carrying finish_reason.
|
||||||
|
let final_chunk = ChatCompletionChunk {
|
||||||
|
id: id.clone(),
|
||||||
|
object: "chat.completion.chunk".into(),
|
||||||
|
created,
|
||||||
|
model: model_id.clone(),
|
||||||
|
choices: vec![ChunkChoice {
|
||||||
|
index: 0,
|
||||||
|
delta: serde_json::Value::Object(Default::default()),
|
||||||
|
finish_reason: Some(finish_reason),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
};
|
||||||
|
let _ = tx.send(final_chunk).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(rx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode the cumulative token list, emit the delta (substring appended
|
||||||
|
/// since the last chunk) as a `chat.completion.chunk`. Returns `false`
|
||||||
|
/// if the receiver has hung up — the caller should bail.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
async fn emit_chunk(
|
||||||
|
all_tokens: &[u32],
|
||||||
|
decoded_prefix: &mut String,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
tx: &mpsc::Sender<ChatCompletionChunk>,
|
||||||
|
id: &str,
|
||||||
|
created: u64,
|
||||||
|
model_id: &str,
|
||||||
|
) -> bool {
|
||||||
|
let full = match tokenizer.decode(all_tokens, true) {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "TP stream: decode failed");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if full.len() > decoded_prefix.len() {
|
||||||
|
let delta = full[decoded_prefix.len()..].to_string();
|
||||||
|
*decoded_prefix = full;
|
||||||
|
let chunk = ChatCompletionChunk {
|
||||||
|
id: id.into(),
|
||||||
|
object: "chat.completion.chunk".into(),
|
||||||
|
created,
|
||||||
|
model: model_id.into(),
|
||||||
|
choices: vec![ChunkChoice {
|
||||||
|
index: 0,
|
||||||
|
delta: json!({ "content": delta }),
|
||||||
|
finish_reason: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
};
|
||||||
|
if tx.send(chunk).await.is_err() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Errors returned by `CandleHarness::chat_completion`. The
|
/// Errors returned by `CandleHarness::chat_completion`. The
|
||||||
|
|||||||
Reference in New Issue
Block a user