Compare commits
3 Commits
feat/neuro
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
60f5598542
|
|||
|
7945240646
|
|||
|
0c74d89d15
|
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -905,7 +905,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cudarc"
|
||||
version = "0.19.7"
|
||||
source = "git+https://github.com/grenade/cudarc?rev=dbc425aa865c178f38a3ec838f1f7a4da3146358#dbc425aa865c178f38a3ec838f1f7a4da3146358"
|
||||
source = "git+https://github.com/grenade/cudarc?rev=63327a256059f8252641ae46c6bb9eefe707f382#63327a256059f8252641ae46c6bb9eefe707f382"
|
||||
dependencies = [
|
||||
"float8",
|
||||
"half",
|
||||
|
||||
@@ -69,4 +69,4 @@ cortex-gateway = { path = "crates/cortex-gateway" }
|
||||
# rebuild the comm). Pinned to a fork revision pending upstream review
|
||||
# (grenade/cudarc @ nccl-comm-abort).
|
||||
[patch.crates-io]
|
||||
cudarc = { git = "https://github.com/grenade/cudarc", rev = "dbc425aa865c178f38a3ec838f1f7a4da3146358" }
|
||||
cudarc = { git = "https://github.com/grenade/cudarc", rev = "63327a256059f8252641ae46c6bb9eefe707f382" }
|
||||
|
||||
@@ -119,40 +119,25 @@ mod cuda_impl {
|
||||
}
|
||||
}
|
||||
|
||||
/// `Arc<Comm>` doesn't impl `Send` because `Comm` wraps a raw
|
||||
/// `ncclComm_t` pointer. The NCCL contract is "operations against a
|
||||
/// given comm must be serialised", not "the handle must stay on the
|
||||
/// thread that created it" — so it's safe to move an `Arc<Comm>`
|
||||
/// across threads as long as no concurrent ops are issued. The
|
||||
/// pool's outer Mutex serialises us into `spawn_blocking`, so this
|
||||
/// wrapper at the move boundary is the only thing missing.
|
||||
/// Thin newtype over `Arc<Comm>`, kept for call-site clarity — it marks
|
||||
/// the points where a comm handle is intentionally moved across threads
|
||||
/// (e.g. cached async-side for the TP step watchdog's `ncclCommAbort`).
|
||||
///
|
||||
/// `Sync` is also marked safe because the `Arc<Comm>` clones held
|
||||
/// by the row-parallel layers are only used from the
|
||||
/// `spawn_blocking` thread driving the forward pass; concurrent
|
||||
/// access from another thread would still be a bug.
|
||||
/// `Send`/`Sync` are provided upstream by `cudarc`'s `Comm` (which
|
||||
/// asserts the NCCL thread-safety invariant, including aborting from a
|
||||
/// different thread than one inside a collective), so this type derives
|
||||
/// them automatically — no manual `unsafe impl` here.
|
||||
pub struct SendComm(pub Arc<Comm>);
|
||||
|
||||
// SAFETY: see the doc-comment above; the invariant is enforced at
|
||||
// the call site (pool Mutex + single spawn_blocking thread), not at
|
||||
// the type level.
|
||||
unsafe impl Send for SendComm {}
|
||||
unsafe impl Sync for SendComm {}
|
||||
|
||||
impl SendComm {
|
||||
pub fn into_inner(self) -> Arc<Comm> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer
|
||||
// (libnccl-allocated state). NCCL requires that operations against
|
||||
// one Comm be issued one at a time; we serialise access by storing
|
||||
// NcclState behind a Mutex in `WorkerPool`. The Comm itself is
|
||||
// move-safe — NCCL doesn't track the calling OS thread, only the
|
||||
// stream the operations are dispatched against.
|
||||
unsafe impl Send for NcclState {}
|
||||
unsafe impl Sync for NcclState {}
|
||||
// `NcclState`'s `Send`/`Sync` are auto-derived: its `Arc<Comm>` and
|
||||
// `Arc<CudaContext>` fields are now `Send`/`Sync` (cudarc asserts the
|
||||
// comm thread-safety invariant), so no manual `unsafe impl` is needed.
|
||||
|
||||
/// Generate a fresh NCCL `Id` and return it hex-encoded. Used by
|
||||
/// the leader to mint the shared communicator id which is then
|
||||
|
||||
Reference in New Issue
Block a user