From d1a4aad91d444d20dc2ca84af709fe61cd8debb1 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Thu, 21 May 2026 07:39:36 +0300 Subject: [PATCH] fix(tp): always drain worker responses on leader failure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The TP-2 inference probe against Qwen3.6-27B surfaced: worker rank 1 ClearKvCache: expected KvCacheCleared, got GenerateStepOk Caused by pipe poisoning. The previous shape of `generate_step`: for w in workers { w.send_only(GenerateStep) } // 1. fan-out let logits = spawn_blocking(leader.forward)??; // 2. early return on err for w in workers { w.recv_only() } // 3. drain (skipped on 2's err) When step 2 returned `Err` (e.g. a dtype mismatch we hadn't seen before, an OOM, a downstream squeeze that didn't match the shape), the function bailed before step 3 — but workers had already written `GenerateStepOk` to their stdout pipes, since their forwards (and the NCCL collectives inside) completed independently of the leader's post-collective Rust-side work. The next call (typically `ClearKvCache` at the start of the *next* inference request) would then send a fresh request and read those stale replies as if they were the new operation's. Once a pipe is poisoned, every subsequent call surfaces the same shape of error even though nothing's actually broken. Fix: introduce two helpers in `tp/mod.rs`: - `drain_workers(workers, check)` — reads exactly one response from every worker regardless of individual outcomes. Returns `Vec` of `rank N: detail` strings for any non-OK reply. - `combine_leader_workers(leader, worker_errs, op)` — folds the leader's `Result>` (the spawn_blocking shape) with the worker drain into a single `Result`. Leader failure takes precedence but worker errors get appended so both halves surface. `generate_step` and `clear_kv_cache` now use this pattern. Worst case: both halves fail and the operator sees a combined error message; either way the pipes are always drained so the next call's recv matches the request it sent. Note: the model is still poisoned in the current state — the operator needs to either `POST /models/unload` + reload, or `systemctl restart neuron`, to recover. The fix prevents *future* desync; it doesn't repair existing stale pipe state. Stage 7c-ii crash detection was tracked as the canonical solution to this class of issue; this is the minimum-viable subset. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/neuron/src/harness/tp/mod.rs | 135 +++++++++++++++++++++------- 1 file changed, 103 insertions(+), 32 deletions(-) diff --git a/crates/neuron/src/harness/tp/mod.rs b/crates/neuron/src/harness/tp/mod.rs index c05522f..dae6fb9 100644 --- a/crates/neuron/src/harness/tp/mod.rs +++ b/crates/neuron/src/harness/tp/mod.rs @@ -129,6 +129,79 @@ impl Worker { } } +/// Drain one response from every worker, classifying each via the +/// supplied checker. Always reads from every worker — even if some +/// fail — so the next call's recv doesn't pick up stale responses +/// from this one (pipe-poisoning was the cause of the +/// "ClearKvCache: expected KvCacheCleared, got GenerateStepOk" class +/// of bugs). +/// +/// Returns a vector of `rank N: detail` strings for any worker that +/// errored, expected-mismatched, or failed to respond. Caller decides +/// how to combine these with the leader's outcome. +async fn drain_workers( + workers: &mut [Worker], + mut check: impl FnMut(WorkerResponse) -> std::result::Result<(), String>, +) -> Vec { + let mut errs = Vec::new(); + for w in workers { + match w.recv_only().await { + Ok(resp) => { + if let Err(detail) = check(resp) { + errs.push(format!("rank {} {detail}", w.rank)); + } + } + Err(e) => errs.push(format!("rank {} recv: {e:#}", w.rank)), + } + } + errs +} + +/// Combine a leader's `Result>` (the typical +/// `spawn_blocking → JoinHandle>` shape) with the worker +/// drain results into a single `Result`. Leader failures take +/// precedence in the error message but worker errors get appended so +/// the operator sees both halves. +#[cfg(feature = "cuda")] +fn combine_leader_workers( + leader: Result>, + worker_errors: Vec, + op: &str, +) -> Result { + match leader { + Ok(Ok(value)) => { + if worker_errors.is_empty() { + Ok(value) + } else { + anyhow::bail!( + "{op}: leader succeeded but workers failed: {}", + worker_errors.join("; ") + ) + } + } + Ok(Err(e)) => { + if worker_errors.is_empty() { + Err(e.context(format!("{op}: leader forward failed"))) + } else { + Err(e.context(format!( + "{op}: leader forward failed and workers also failed: {}", + worker_errors.join("; ") + ))) + } + } + Err(panic_err) => { + if worker_errors.is_empty() { + Err(panic_err) + } else { + Err(panic_err.context(format!( + "{op}: leader task panicked and workers failed: {}", + worker_errors.join("; ") + ))) + } + } + } +} + /// A live pool of worker subprocesses. Owns the `Child` handles so /// dropping the pool kills the children; explicit `shutdown()` is /// the graceful path. @@ -544,34 +617,32 @@ impl WorkerPool { // 2. Leader's forward in spawn_blocking. The AllReduce CustomOps // inside the row-parallel layers block until every worker's // forward issues the matching collective. - let logits = tokio::task::spawn_blocking(move || -> Result { + let leader_result = tokio::task::spawn_blocking(move || -> Result { let mut model = leader_model.blocking_lock(); let device = model.device().clone(); let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; - // TpQwen3ForCausalLM::forward returns [B, 1, V] (it slices - // to the last position internally). Squeeze both leading - // dims to get the rank-1 vocab logits LogitsProcessor wants. + // ForCausalLM::forward returns [B, 1, V] — squeeze both + // leading dims to the rank-1 vocab logits the sampler wants. let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?; Ok(logits) }) .await - .context("leader forward task panicked")??; + .context("leader forward task panicked"); - // 3. Collect worker confirmations. - for w in &mut self.workers { - let resp = w.recv_only().await?; - match resp { - WorkerResponse::GenerateStepOk => {} - WorkerResponse::Error { kind, message } => { - anyhow::bail!("worker rank {} GenerateStep [{kind}]: {message}", w.rank) - } - other => anyhow::bail!( - "worker rank {} GenerateStep: expected GenerateStepOk, got {other:?}", - w.rank - ), - } - } - Ok(logits) + // 3. ALWAYS drain worker responses, regardless of whether the + // leader succeeded. Skipping this on the leader's error + // path leaves stale GenerateStepOk replies in the worker + // pipes that poison the NEXT request's recv (was seeing + // "ClearKvCache: expected KvCacheCleared, got + // GenerateStepOk" the call after any forward-time failure). + let worker_errors = drain_workers(&mut self.workers, |r| match r { + WorkerResponse::GenerateStepOk => Ok(()), + WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")), + other => Err(format!("expected GenerateStepOk, got {other:?}")), + }) + .await; + + combine_leader_workers(leader_result, worker_errors, "GenerateStep") } /// Reset the KV cache for `model_id` on every rank. Called at the @@ -593,18 +664,18 @@ impl WorkerPool { let mut m = leader_model.lock().await; m.clear_kv_cache(); } - for w in &mut self.workers { - let resp = w.recv_only().await?; - match resp { - WorkerResponse::KvCacheCleared => {} - WorkerResponse::Error { kind, message } => { - anyhow::bail!("worker rank {} ClearKvCache [{kind}]: {message}", w.rank) - } - other => anyhow::bail!( - "worker rank {} ClearKvCache: expected KvCacheCleared, got {other:?}", - w.rank - ), - } + // Drain workers — same rationale as `generate_step`. The + // leader's clear_kv_cache is in-process and infallible, but we + // still always drain so an error on one worker doesn't leave + // pending responses for the others. + let worker_errors = drain_workers(&mut self.workers, |r| match r { + WorkerResponse::KvCacheCleared => Ok(()), + WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")), + other => Err(format!("expected KvCacheCleared, got {other:?}")), + }) + .await; + if !worker_errors.is_empty() { + anyhow::bail!("ClearKvCache: {}", worker_errors.join("; ")); } Ok(()) }