diff --git a/crates/neuron/src/harness/tp/mod.rs b/crates/neuron/src/harness/tp/mod.rs index c05522f..dae6fb9 100644 --- a/crates/neuron/src/harness/tp/mod.rs +++ b/crates/neuron/src/harness/tp/mod.rs @@ -129,6 +129,79 @@ impl Worker { } } +/// Drain one response from every worker, classifying each via the +/// supplied checker. Always reads from every worker — even if some +/// fail — so the next call's recv doesn't pick up stale responses +/// from this one (pipe-poisoning was the cause of the +/// "ClearKvCache: expected KvCacheCleared, got GenerateStepOk" class +/// of bugs). +/// +/// Returns a vector of `rank N: detail` strings for any worker that +/// errored, expected-mismatched, or failed to respond. Caller decides +/// how to combine these with the leader's outcome. +async fn drain_workers( + workers: &mut [Worker], + mut check: impl FnMut(WorkerResponse) -> std::result::Result<(), String>, +) -> Vec { + let mut errs = Vec::new(); + for w in workers { + match w.recv_only().await { + Ok(resp) => { + if let Err(detail) = check(resp) { + errs.push(format!("rank {} {detail}", w.rank)); + } + } + Err(e) => errs.push(format!("rank {} recv: {e:#}", w.rank)), + } + } + errs +} + +/// Combine a leader's `Result>` (the typical +/// `spawn_blocking → JoinHandle>` shape) with the worker +/// drain results into a single `Result`. Leader failures take +/// precedence in the error message but worker errors get appended so +/// the operator sees both halves. +#[cfg(feature = "cuda")] +fn combine_leader_workers( + leader: Result>, + worker_errors: Vec, + op: &str, +) -> Result { + match leader { + Ok(Ok(value)) => { + if worker_errors.is_empty() { + Ok(value) + } else { + anyhow::bail!( + "{op}: leader succeeded but workers failed: {}", + worker_errors.join("; ") + ) + } + } + Ok(Err(e)) => { + if worker_errors.is_empty() { + Err(e.context(format!("{op}: leader forward failed"))) + } else { + Err(e.context(format!( + "{op}: leader forward failed and workers also failed: {}", + worker_errors.join("; ") + ))) + } + } + Err(panic_err) => { + if worker_errors.is_empty() { + Err(panic_err) + } else { + Err(panic_err.context(format!( + "{op}: leader task panicked and workers failed: {}", + worker_errors.join("; ") + ))) + } + } + } +} + /// A live pool of worker subprocesses. Owns the `Child` handles so /// dropping the pool kills the children; explicit `shutdown()` is /// the graceful path. @@ -544,34 +617,32 @@ impl WorkerPool { // 2. Leader's forward in spawn_blocking. The AllReduce CustomOps // inside the row-parallel layers block until every worker's // forward issues the matching collective. - let logits = tokio::task::spawn_blocking(move || -> Result { + let leader_result = tokio::task::spawn_blocking(move || -> Result { let mut model = leader_model.blocking_lock(); let device = model.device().clone(); let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; - // TpQwen3ForCausalLM::forward returns [B, 1, V] (it slices - // to the last position internally). Squeeze both leading - // dims to get the rank-1 vocab logits LogitsProcessor wants. + // ForCausalLM::forward returns [B, 1, V] — squeeze both + // leading dims to the rank-1 vocab logits the sampler wants. let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?; Ok(logits) }) .await - .context("leader forward task panicked")??; + .context("leader forward task panicked"); - // 3. Collect worker confirmations. - for w in &mut self.workers { - let resp = w.recv_only().await?; - match resp { - WorkerResponse::GenerateStepOk => {} - WorkerResponse::Error { kind, message } => { - anyhow::bail!("worker rank {} GenerateStep [{kind}]: {message}", w.rank) - } - other => anyhow::bail!( - "worker rank {} GenerateStep: expected GenerateStepOk, got {other:?}", - w.rank - ), - } - } - Ok(logits) + // 3. ALWAYS drain worker responses, regardless of whether the + // leader succeeded. Skipping this on the leader's error + // path leaves stale GenerateStepOk replies in the worker + // pipes that poison the NEXT request's recv (was seeing + // "ClearKvCache: expected KvCacheCleared, got + // GenerateStepOk" the call after any forward-time failure). + let worker_errors = drain_workers(&mut self.workers, |r| match r { + WorkerResponse::GenerateStepOk => Ok(()), + WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")), + other => Err(format!("expected GenerateStepOk, got {other:?}")), + }) + .await; + + combine_leader_workers(leader_result, worker_errors, "GenerateStep") } /// Reset the KV cache for `model_id` on every rank. Called at the @@ -593,18 +664,18 @@ impl WorkerPool { let mut m = leader_model.lock().await; m.clear_kv_cache(); } - for w in &mut self.workers { - let resp = w.recv_only().await?; - match resp { - WorkerResponse::KvCacheCleared => {} - WorkerResponse::Error { kind, message } => { - anyhow::bail!("worker rank {} ClearKvCache [{kind}]: {message}", w.rank) - } - other => anyhow::bail!( - "worker rank {} ClearKvCache: expected KvCacheCleared, got {other:?}", - w.rank - ), - } + // Drain workers — same rationale as `generate_step`. The + // leader's clear_kv_cache is in-process and infallible, but we + // still always drain so an error on one worker doesn't leave + // pending responses for the others. + let worker_errors = drain_workers(&mut self.workers, |r| match r { + WorkerResponse::KvCacheCleared => Ok(()), + WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")), + other => Err(format!("expected KvCacheCleared, got {other:?}")), + }) + .await; + if !worker_errors.is_empty() { + anyhow::bail!("ClearKvCache: {}", worker_errors.join("; ")); } Ok(()) }