diff --git a/Cargo.lock b/Cargo.lock index 7d07668..9db1775 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -905,7 +905,7 @@ dependencies = [ [[package]] name = "cudarc" version = "0.19.7" -source = "git+https://github.com/grenade/cudarc?rev=dbc425aa865c178f38a3ec838f1f7a4da3146358#dbc425aa865c178f38a3ec838f1f7a4da3146358" +source = "git+https://github.com/grenade/cudarc?rev=63327a256059f8252641ae46c6bb9eefe707f382#63327a256059f8252641ae46c6bb9eefe707f382" dependencies = [ "float8", "half", diff --git a/Cargo.toml b/Cargo.toml index df3741d..f0ab917 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,4 +69,4 @@ cortex-gateway = { path = "crates/cortex-gateway" } # rebuild the comm). Pinned to a fork revision pending upstream review # (grenade/cudarc @ nccl-comm-abort). [patch.crates-io] -cudarc = { git = "https://github.com/grenade/cudarc", rev = "dbc425aa865c178f38a3ec838f1f7a4da3146358" } +cudarc = { git = "https://github.com/grenade/cudarc", rev = "63327a256059f8252641ae46c6bb9eefe707f382" } diff --git a/crates/neuron/src/harness/tp/nccl_state.rs b/crates/neuron/src/harness/tp/nccl_state.rs index 402e218..e47e04b 100644 --- a/crates/neuron/src/harness/tp/nccl_state.rs +++ b/crates/neuron/src/harness/tp/nccl_state.rs @@ -119,40 +119,25 @@ mod cuda_impl { } } - /// `Arc` doesn't impl `Send` because `Comm` wraps a raw - /// `ncclComm_t` pointer. The NCCL contract is "operations against a - /// given comm must be serialised", not "the handle must stay on the - /// thread that created it" — so it's safe to move an `Arc` - /// across threads as long as no concurrent ops are issued. The - /// pool's outer Mutex serialises us into `spawn_blocking`, so this - /// wrapper at the move boundary is the only thing missing. + /// Thin newtype over `Arc`, kept for call-site clarity — it marks + /// the points where a comm handle is intentionally moved across threads + /// (e.g. cached async-side for the TP step watchdog's `ncclCommAbort`). /// - /// `Sync` is also marked safe because the `Arc` clones held - /// by the row-parallel layers are only used from the - /// `spawn_blocking` thread driving the forward pass; concurrent - /// access from another thread would still be a bug. + /// `Send`/`Sync` are provided upstream by `cudarc`'s `Comm` (which + /// asserts the NCCL thread-safety invariant, including aborting from a + /// different thread than one inside a collective), so this type derives + /// them automatically — no manual `unsafe impl` here. pub struct SendComm(pub Arc); - // SAFETY: see the doc-comment above; the invariant is enforced at - // the call site (pool Mutex + single spawn_blocking thread), not at - // the type level. - unsafe impl Send for SendComm {} - unsafe impl Sync for SendComm {} - impl SendComm { pub fn into_inner(self) -> Arc { self.0 } } - // SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer - // (libnccl-allocated state). NCCL requires that operations against - // one Comm be issued one at a time; we serialise access by storing - // NcclState behind a Mutex in `WorkerPool`. The Comm itself is - // move-safe — NCCL doesn't track the calling OS thread, only the - // stream the operations are dispatched against. - unsafe impl Send for NcclState {} - unsafe impl Sync for NcclState {} + // `NcclState`'s `Send`/`Sync` are auto-derived: its `Arc` and + // `Arc` fields are now `Send`/`Sync` (cudarc asserts the + // comm thread-safety invariant), so no manual `unsafe impl` is needed. /// Generate a fresh NCCL `Id` and return it hex-encoded. Used by /// the leader to mint the shared communicator id which is then