From 99920dd3228de2c9de507e7a4ae0cc86c98275f1 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Mon, 8 Jun 2026 14:15:29 +0300 Subject: [PATCH] feat(neuron): TP step watchdog aborts wedged collectives (#17 Stage 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Make a hung NCCL collective recoverable instead of a permanent brick. Today a wedged collective hangs the in-process leader thread forever, and even Stage 1's recovery can't help — its unload's DropTp queues behind the stuck thread and hangs too. - Cache the leader's NCCL Comm handle async-side at init (new cuda-gated Job::GetLeaderComm → DeviceWorkerHandle::get_leader_comm → stored on WorkerPool.leader_comm). Fetched while the thread is responsive — a wedged thread can't service the fetch, which is why it's cached up front. - Wrap the leader forward in both generate_step and generate_step_with_images in tokio::time::timeout (default 120s, NEURON_TP_STEP_TIMEOUT_S). On expiry the watchdog calls Comm::abort() (ncclCommAbort) on the cached handle from the async thread — the one NCCL op sanctioned concurrently with an in-flight collective — which unblocks the leader thread, then fails the step WITHOUT draining (workers are wedged too; recovery's unload kills them). The error is a device fault → poison → Stage 1 auto-recovery, which now completes because the leader thread is responsive again. - Bumps the cudarc patch to dbc425a (adds the Drop-must-not-panic fix so the post-abort comm teardown during recovery doesn't double-abort-panic). Logs the whole sequence at ERROR with greppable `tp watchdog:` / `ncclCommAbort` markers so a real-world hang leaves a forensic trail — verification is by inspecting journals after real hangs, not a synthetic harness. cuda-gated → validated by the blackwell build. Co-Authored-By: Claude Opus 4.8 (1M context) --- Cargo.lock | 2 +- Cargo.toml | 2 +- .../src/harness/device_worker/dispatch.rs | 14 ++ .../neuron/src/harness/device_worker/jobs.rs | 11 ++ .../neuron/src/harness/device_worker/mod.rs | 21 +++ crates/neuron/src/harness/tp/mod.rs | 134 ++++++++++++++++-- 6 files changed, 168 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4690de1..7d07668 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -905,7 +905,7 @@ dependencies = [ [[package]] name = "cudarc" version = "0.19.7" -source = "git+https://github.com/grenade/cudarc?rev=4dff0be72d8a685d6691a6a53d4c95e1fe932277#4dff0be72d8a685d6691a6a53d4c95e1fe932277" +source = "git+https://github.com/grenade/cudarc?rev=dbc425aa865c178f38a3ec838f1f7a4da3146358#dbc425aa865c178f38a3ec838f1f7a4da3146358" dependencies = [ "float8", "half", diff --git a/Cargo.toml b/Cargo.toml index b69d9d4..df3741d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,4 +69,4 @@ cortex-gateway = { path = "crates/cortex-gateway" } # rebuild the comm). Pinned to a fork revision pending upstream review # (grenade/cudarc @ nccl-comm-abort). [patch.crates-io] -cudarc = { git = "https://github.com/grenade/cudarc", rev = "4dff0be72d8a685d6691a6a53d4c95e1fe932277" } +cudarc = { git = "https://github.com/grenade/cudarc", rev = "dbc425aa865c178f38a3ec838f1f7a4da3146358" } diff --git a/crates/neuron/src/harness/device_worker/dispatch.rs b/crates/neuron/src/harness/device_worker/dispatch.rs index 6df69ef..397636e 100644 --- a/crates/neuron/src/harness/device_worker/dispatch.rs +++ b/crates/neuron/src/harness/device_worker/dispatch.rs @@ -201,6 +201,16 @@ pub(crate) fn run(device_index: u32, rx: Receiver, poisoned: Arc { + // Clone the leader's Arc out for the async-side + // watchdog. `None` before NcclInit. (#17 Stage 2) + let comm = state + .nccl + .comm() + .map(crate::harness::tp::nccl_state::SendComm); + let _ = reply.send(comm); + } + #[cfg(feature = "cuda")] Job::TpLoadShard { model_id, config_json, @@ -1004,6 +1014,10 @@ fn drain_poisoned(job: Job, device_index: u32) { message: format!("device worker {device_index} poisoned"), }); } + #[cfg(feature = "cuda")] + Job::GetLeaderComm { reply } => { + let _ = reply.send(None); + } Job::NcclSanity { reply } => { let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error { kind: "device_worker_poisoned".into(), diff --git a/crates/neuron/src/harness/device_worker/jobs.rs b/crates/neuron/src/harness/device_worker/jobs.rs index d53826e..273bbb1 100644 --- a/crates/neuron/src/harness/device_worker/jobs.rs +++ b/crates/neuron/src/harness/device_worker/jobs.rs @@ -192,6 +192,17 @@ pub enum Job { NcclSanity { reply: oneshot::Sender, }, + /// Hand a clonable handle to the leader's NCCL `Comm` back to the + /// async side, so the TP step watchdog can call `ncclCommAbort` on + /// it from a *different* thread to unblock a wedged collective + /// (#17 Stage 2). Fetched once at init while the worker thread is + /// still responsive — a thread already wedged in a collective can't + /// service this job, which is exactly why the handle is cached + /// up front. Replies `None` before `NcclInit` has run. + #[cfg(feature = "cuda")] + GetLeaderComm { + reply: oneshot::Sender>, + }, /// Load the leader's TP shard on the worker thread. The dispatch /// handler reads `state.nccl.comm()` directly (no cross-thread /// `Arc` transfer, no `SendComm` wrapper) and builds the diff --git a/crates/neuron/src/harness/device_worker/mod.rs b/crates/neuron/src/harness/device_worker/mod.rs index 7305787..2c617db 100644 --- a/crates/neuron/src/harness/device_worker/mod.rs +++ b/crates/neuron/src/harness/device_worker/mod.rs @@ -161,6 +161,27 @@ impl DeviceWorkerHandle { } } + /// Fetch a clonable handle to the leader's NCCL `Comm` (#17 Stage 2). + /// The TP step watchdog caches this at init so it can call + /// `ncclCommAbort` from the async thread to unblock a wedged + /// collective. Returns `None` if uninitialised, poisoned, or gone — + /// the caller treats a missing handle as "can't abort" and logs it. + #[cfg(feature = "cuda")] + pub async fn get_leader_comm(&self) -> Option { + if self.poisoned.load(Ordering::Acquire) { + return None; + } + let (reply_tx, reply_rx) = oneshot::channel(); + if self + .tx + .send(Job::GetLeaderComm { reply: reply_tx }) + .is_err() + { + return None; + } + reply_rx.await.ok().flatten() + } + /// Load a GGUF (pre-quantized) single-GPU model on the worker /// thread. The hf-hub resolution happens on the async caller; the /// resolved local `gguf_path` plus the spec's model_id are sent diff --git a/crates/neuron/src/harness/tp/mod.rs b/crates/neuron/src/harness/tp/mod.rs index be83208..63e1383 100644 --- a/crates/neuron/src/harness/tp/mod.rs +++ b/crates/neuron/src/harness/tp/mod.rs @@ -245,9 +245,67 @@ pub struct WorkerPool { /// Phase 4 the load itself moves onto the worker and that bridge /// goes away. pub(crate) leader_worker: std::sync::Arc, + /// Cached handle to the leader's NCCL `Comm`, fetched at `init_nccl` + /// while the worker thread is responsive. The TP step watchdog uses + /// it to `ncclCommAbort` a wedged collective from the async thread — + /// the one NCCL op allowed concurrently with an in-flight collective, + /// and the only way to unblock the in-process leader thread so + /// recovery's `unload` doesn't itself hang (#17 Stage 2). `None` if + /// init couldn't cache it; the watchdog then logs that it can't abort. + #[cfg(feature = "cuda")] + leader_comm: Option, +} + +/// Per-step deadline for a TP forward (#17 Stage 2). A healthy decode +/// step or chunked prefill completes in well under a second; a wedged +/// NCCL collective never returns. Generous default so no legitimate step +/// trips it; overridable via `NEURON_TP_STEP_TIMEOUT_S` (seconds). +#[cfg(feature = "cuda")] +fn tp_step_timeout() -> std::time::Duration { + let secs = std::env::var("NEURON_TP_STEP_TIMEOUT_S") + .ok() + .and_then(|v| v.trim().parse::().ok()) + .filter(|&s| s > 0) + .unwrap_or(120); + std::time::Duration::from_secs(secs) } impl WorkerPool { + /// Abort the leader's NCCL comm to unblock a collective the watchdog + /// found wedged (#17 Stage 2). Logs the whole sequence loudly so a + /// real-world hang leaves a greppable forensic trail + /// (`tp watchdog:` / `ncclCommAbort`). Calling abort from this async + /// thread while the worker thread is blocked inside the collective is + /// the one concurrent NCCL op the library sanctions — it is how a + /// stuck/failed collective is unblocked. + #[cfg(feature = "cuda")] + fn watchdog_abort_leader_comm(&self, model_id: &str, secs: u64) { + tracing::error!( + model = %model_id, + timeout_s = secs, + "tp watchdog: leader forward exceeded deadline — NCCL collective wedged; \ + aborting comm to unblock the leader thread for auto-recovery" + ); + match &self.leader_comm { + Some(c) => match c.0.abort() { + Ok(()) => tracing::error!( + model = %model_id, + "tp watchdog: ncclCommAbort succeeded — wedged collective unblocked; \ + failing the step so the model auto-recovers (unload+reload)" + ), + Err(e) => tracing::error!( + model = %model_id, error = ?e, + "tp watchdog: ncclCommAbort failed — recovery may stall until a process restart" + ), + }, + None => tracing::error!( + model = %model_id, + "tp watchdog: no cached leader comm handle — cannot abort; recovery will rely \ + on a process restart" + ), + } + } + /// Spawn `world_size - 1` worker subprocesses. Rank 0 is the /// leader (in-process) and is *not* spawned here — the leader /// holds rank 0's NCCL Comm and shard in its own address space. @@ -324,6 +382,8 @@ impl WorkerPool { workers, exe, leader_worker, + #[cfg(feature = "cuda")] + leader_comm: None, }) } @@ -404,6 +464,23 @@ impl WorkerPool { world_size = self.world_size, "NCCL communicator established across all ranks" ); + + // Cache the leader's Comm handle now, while the worker thread is + // responsive, so the TP step watchdog can abort a wedged + // collective later (it can't fetch it then — the thread is stuck). + // (#17 Stage 2.) + #[cfg(feature = "cuda")] + { + self.leader_comm = self.leader_worker.get_leader_comm().await; + if self.leader_comm.is_some() { + tracing::debug!("cached leader NCCL comm handle for the TP step watchdog"); + } else { + tracing::warn!( + "could not cache leader NCCL comm handle; the TP step watchdog will be \ + unable to abort a wedged collective (a hang would need a process restart)" + ); + } + } Ok(()) } @@ -628,10 +705,27 @@ impl WorkerPool { // that's the invariant the whole refactor exists to // preserve. let leader_start = std::time::Instant::now(); - let leader_result = self + let timeout = tp_step_timeout(); + let leader_fut = self .leader_worker - .tp_forward_logits(leader_handle, tokens, offset) - .await; + .tp_forward_logits(leader_handle, tokens, offset); + let leader_result = match tokio::time::timeout(timeout, leader_fut).await { + Ok(r) => r, + Err(_elapsed) => { + // Watchdog (#17 Stage 2): the NCCL collective is wedged. + // Abort the leader comm to unblock its thread, then fail + // the step WITHOUT draining (the subprocess workers are + // wedged too; recovery's unload kills them). The error + // poisons the model → auto-recovery, which no longer hangs + // because the leader thread is now responsive. + self.watchdog_abort_leader_comm(model_id, timeout.as_secs()); + anyhow::bail!( + "tp watchdog: leader forward exceeded {}s deadline; aborted wedged NCCL \ + comm — model will auto-recover", + timeout.as_secs() + ); + } + }; let leader_ok = leader_result.is_ok(); let leader_ms = leader_start.elapsed().as_millis(); // Surface the leader's own error at WARN before draining @@ -767,17 +861,29 @@ impl WorkerPool { // matching collective; CPU-side logits keep the device tensor // from escaping the worker thread. let leader_start = std::time::Instant::now(); - let leader_result = self - .leader_worker - .tp_forward_logits_with_images( - leader_handle, - tokens, - offset, - image_token_id, - image_data_uris, - chunk_size, - ) - .await; + let timeout = tp_step_timeout(); + let leader_fut = self.leader_worker.tp_forward_logits_with_images( + leader_handle, + tokens, + offset, + image_token_id, + image_data_uris, + chunk_size, + ); + let leader_result = match tokio::time::timeout(timeout, leader_fut).await { + Ok(r) => r, + Err(_elapsed) => { + // Watchdog (#17 Stage 2) — see generate_step. Vision + // prefill is still well under the deadline on healthy + // hardware; a timeout means a wedged collective. + self.watchdog_abort_leader_comm(model_id, timeout.as_secs()); + anyhow::bail!( + "tp watchdog: leader image forward exceeded {}s deadline; aborted wedged \ + NCCL comm — model will auto-recover", + timeout.as_secs() + ); + } + }; let leader_ok = leader_result.is_ok(); let leader_ms = leader_start.elapsed().as_millis(); if !leader_ok {