Stage 7a-ii: real NCCL handshake behind the worker pool
Some checks failed
CI / Format (push) Failing after 38s
build-prerelease / Resolve version stamps (push) Successful in 42s
CI / Clippy (push) Successful in 2m18s
build-prerelease / Build neuron-blackwell (push) Failing after 3m33s
CI / Test (push) Successful in 4m27s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ampere (push) Failing after 4m19s
build-prerelease / Build neuron-ada (push) Failing after 4m56s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Some checks failed
CI / Format (push) Failing after 38s
build-prerelease / Resolve version stamps (push) Successful in 42s
CI / Clippy (push) Successful in 2m18s
build-prerelease / Build neuron-blackwell (push) Failing after 3m33s
CI / Test (push) Successful in 4m27s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ampere (push) Failing after 4m19s
build-prerelease / Build neuron-ada (push) Failing after 4m56s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Wires cudarc::nccl into the TP worker lifecycle introduced in 7a-i.
With --features cuda the leader and its workers now establish a live
NCCL communicator end-to-end; without the feature the same code paths
return Error{kind="cuda_feature_not_enabled"} so a misconfigured
build is obvious instead of silently no-op.
NCCL state machine (harness/tp/nccl_state.rs) is shared between the
worker process and the leader's pool:
- generate_comm_id_hex() mints an Id::new() on the leader.
- NcclState::init parses 256 hex chars → [c_char; 128] → Id::uninit,
opens a CudaContext on the configured device, calls Comm::from_rank
with the supplied (rank, world_size, id). NCCL blocks until every
rank has joined.
- NcclState::sanity_check runs one all_reduce(1u32, Sum); the leader
asserts every rank reports observed_sum == world_size.
- NCCL handles serialised under Mutex; unsafe impl Send/Sync gates
the Comm across spawn_blocking boundaries (NCCL is move-safe; only
concurrent op issuance is unsafe).
WorkerPool::init_nccl orchestrates the rendezvous:
1. Write Init { comm_id } to every worker's stdin (no await yet).
2. Leader rank 0 calls its own Comm::from_rank in spawn_blocking,
concurrently with workers.
3. NCCL handshake completes for all ranks simultaneously.
4. Leader collects InitOk responses.
WorkerPool::nccl_sanity_check follows the same pattern over
all_reduce, validating world_size == observed_sum on every rank.
Worker.send_only / Worker.recv_only split out from the previous
monolithic Worker.request so the leader can interleave its own NCCL
work with the worker calls — required because NCCL blocks during
init.
Tests:
- 4 hex roundtrip unit tests for the wire encoding.
- The 7a-i "not implemented" expectation now reads
"cuda_feature_not_enabled" on the local dev box (no CUDA), or
accepts InitOk on a cuda-built test binary.
- New cuda-integration test in tp_worker_lifecycle_cuda.rs covers
the real init + sanity round-trip; gated on the cuda-integration
feature so default CI doesn't try to NCCL.
Verifiable on beast (2× RTX 5090):
cargo test -p neuron --features cuda-integration \
--test tp_worker_lifecycle_cuda
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
|
||||
//! - **7c:** crash detection, streaming SSE, graceful unload.
|
||||
|
||||
pub mod nccl_state;
|
||||
pub mod rpc;
|
||||
pub mod worker;
|
||||
|
||||
@@ -42,7 +43,20 @@ struct Worker {
|
||||
}
|
||||
|
||||
impl Worker {
|
||||
/// Send a request and wait for the response. Used for sequenced
|
||||
/// ops like `Ping` / `Shutdown` where the caller doesn't need to
|
||||
/// overlap the worker's execution with the leader's.
|
||||
async fn request(&mut self, req: &WorkerRequest) -> Result<WorkerResponse> {
|
||||
self.send_only(req).await?;
|
||||
self.recv_only().await
|
||||
}
|
||||
|
||||
/// Write a request without awaiting its response. Pair with
|
||||
/// `recv_only` from the caller when leader and worker need to do
|
||||
/// work concurrently — e.g. during `Init`, where the leader
|
||||
/// itself calls `Comm::from_rank` on rank 0 in parallel with the
|
||||
/// workers, then collects `InitOk` after NCCL completes.
|
||||
async fn send_only(&mut self, req: &WorkerRequest) -> Result<()> {
|
||||
let mut line = serde_json::to_string(req).context("serialise WorkerRequest")?;
|
||||
line.push('\n');
|
||||
self.stdin
|
||||
@@ -53,7 +67,10 @@ impl Worker {
|
||||
.flush()
|
||||
.await
|
||||
.with_context(|| format!("flush stdin to rank {}", self.rank))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv_only(&mut self) -> Result<WorkerResponse> {
|
||||
let reply = self
|
||||
.stdout
|
||||
.next_line()
|
||||
@@ -71,10 +88,13 @@ impl Worker {
|
||||
pub struct WorkerPool {
|
||||
world_size: u32,
|
||||
workers: Vec<Worker>,
|
||||
/// Path to the neuron binary used to launch workers — captured at
|
||||
/// `spawn()` time via `/proc/self/exe` so the workers run the same
|
||||
/// binary the leader is running.
|
||||
/// Path to the neuron binary used to launch workers.
|
||||
#[allow(dead_code)]
|
||||
exe: PathBuf,
|
||||
/// Leader's own NCCL rank-0 state. Defaults to empty; populated by
|
||||
/// `init_nccl()`. Held here so the leader can participate in
|
||||
/// collectives (rank 0) without spawning a fourth subprocess.
|
||||
leader_nccl: nccl_state::NcclState,
|
||||
}
|
||||
|
||||
impl WorkerPool {
|
||||
@@ -148,9 +168,156 @@ impl WorkerPool {
|
||||
world_size,
|
||||
workers,
|
||||
exe,
|
||||
leader_nccl: nccl_state::NcclState::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Establish the NCCL communicator across the leader (rank 0) and
|
||||
/// every worker subprocess. Rendezvous is via a freshly-generated
|
||||
/// `Id` broadcast over the RPC stream; the actual handshake blocks
|
||||
/// inside `Comm::from_rank` until all `world_size` ranks check in.
|
||||
///
|
||||
/// `leader_cuda_device` is the CUDA device the leader binds rank 0
|
||||
/// to — typically the first entry of the `cuda_devices` slice
|
||||
/// originally passed to `spawn()`.
|
||||
///
|
||||
/// On the non-cuda build this immediately fails because the leader
|
||||
/// can't generate an `Id` without libnccl. The same call works in
|
||||
/// the worker path (returning a no-cuda error response) so the
|
||||
/// failure surface is uniform.
|
||||
pub async fn init_nccl(&mut self, leader_cuda_device: u32) -> Result<()> {
|
||||
let comm_id = nccl_state::generate_comm_id_hex()
|
||||
.map_err(|m| anyhow::anyhow!("generate NCCL id: {m}"))?;
|
||||
|
||||
// 1. Write Init to every worker's stdin without awaiting the
|
||||
// response. Workers will parse and call Comm::from_rank
|
||||
// concurrently with the leader below.
|
||||
for w in &mut self.workers {
|
||||
let req = WorkerRequest::Init {
|
||||
comm_id: comm_id.clone(),
|
||||
};
|
||||
w.send_only(&req).await?;
|
||||
}
|
||||
|
||||
// 2. Leader rank 0 calls Comm::from_rank on its own device.
|
||||
// Runs on spawn_blocking because NCCL's init blocks until
|
||||
// every rank has called in — that's exactly the workers
|
||||
// above. The leader's NcclState is moved through the
|
||||
// blocking task and returned to the pool.
|
||||
let leader_cfg = worker::WorkerConfig {
|
||||
rank: 0,
|
||||
world_size: self.world_size,
|
||||
cuda_device: leader_cuda_device,
|
||||
};
|
||||
let comm_id_for_leader = comm_id.clone();
|
||||
// Swap out the leader's NcclState into a fresh empty one so we
|
||||
// can move it into spawn_blocking; restore after the task
|
||||
// returns. (NcclState isn't Clone — it owns a real NCCL Comm.)
|
||||
let mut leader_state =
|
||||
std::mem::take(&mut self.leader_nccl);
|
||||
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
|
||||
let resp = leader_state.init(leader_cfg, &comm_id_for_leader);
|
||||
(leader_state, resp)
|
||||
})
|
||||
.await
|
||||
.context("leader NCCL init task panicked")?;
|
||||
self.leader_nccl = returned_state;
|
||||
match leader_resp {
|
||||
rpc::WorkerResponse::InitOk => {}
|
||||
rpc::WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("leader rank 0 init failed [{kind}]: {message}");
|
||||
}
|
||||
other => anyhow::bail!("leader rank 0 init: unexpected {other:?}"),
|
||||
}
|
||||
|
||||
// 3. Read InitOk from each worker. By now every worker has
|
||||
// completed its Comm::from_rank call (NCCL released them
|
||||
// when the leader joined the handshake) and is writing its
|
||||
// response.
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match &resp {
|
||||
rpc::WorkerResponse::InitOk => {}
|
||||
rpc::WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} init failed [{kind}]: {message}", w.rank);
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"worker rank {} init: expected InitOk, got {other:?}",
|
||||
w.rank
|
||||
),
|
||||
}
|
||||
}
|
||||
tracing::info!(
|
||||
world_size = self.world_size,
|
||||
"NCCL communicator established across all ranks"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate the NCCL communicator: every rank `all_reduce`s a
|
||||
/// sentinel `1u32` with `ReduceOp::Sum`; the expected total is
|
||||
/// `world_size`. Confirms the handshake is live, not just
|
||||
/// configured.
|
||||
///
|
||||
/// Must be called after `init_nccl()`; before that the leader has
|
||||
/// no Comm and the workers reply with `nccl_not_initialised`.
|
||||
pub async fn nccl_sanity_check(&mut self) -> Result<()> {
|
||||
// 1. Trigger the all_reduce on every worker (write-only).
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::NcclSanityCheck).await?;
|
||||
}
|
||||
|
||||
// 2. Leader's own all_reduce, in spawn_blocking. NCCL operations
|
||||
// block until every rank participates.
|
||||
let mut leader_state =
|
||||
std::mem::take(&mut self.leader_nccl);
|
||||
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
|
||||
let resp = leader_state.sanity_check();
|
||||
(leader_state, resp)
|
||||
})
|
||||
.await
|
||||
.context("leader NCCL sanity task panicked")?;
|
||||
self.leader_nccl = returned_state;
|
||||
|
||||
let expected = self.world_size;
|
||||
let leader_sum = match leader_resp {
|
||||
rpc::WorkerResponse::NcclSanityResult { observed_sum } => observed_sum,
|
||||
rpc::WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("leader rank 0 sanity failed [{kind}]: {message}");
|
||||
}
|
||||
other => anyhow::bail!("leader rank 0 sanity: unexpected {other:?}"),
|
||||
};
|
||||
if leader_sum != expected {
|
||||
anyhow::bail!("leader observed_sum={leader_sum}, expected {expected}");
|
||||
}
|
||||
|
||||
// 3. Read sanity result from each worker. All must match
|
||||
// world_size — anything else means the collective didn't
|
||||
// complete consistently across ranks.
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match resp {
|
||||
rpc::WorkerResponse::NcclSanityResult { observed_sum }
|
||||
if observed_sum == expected => {}
|
||||
rpc::WorkerResponse::NcclSanityResult { observed_sum } => {
|
||||
anyhow::bail!(
|
||||
"worker rank {} observed_sum={observed_sum}, expected {expected}",
|
||||
w.rank
|
||||
);
|
||||
}
|
||||
rpc::WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} sanity failed [{kind}]: {message}", w.rank);
|
||||
}
|
||||
other => anyhow::bail!("worker rank {} sanity: unexpected {other:?}", w.rank),
|
||||
}
|
||||
}
|
||||
tracing::info!(
|
||||
world_size = expected,
|
||||
"NCCL sanity check OK across all ranks"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Ping every worker and return their Pong payloads in rank order.
|
||||
/// Useful right after `spawn` to confirm the lifecycle plumbing is
|
||||
/// intact before kicking off any heavier work.
|
||||
|
||||
Reference in New Issue
Block a user