Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / Format (push) Successful in 31s
CI / Clippy (push) Failing after 49s
build-prerelease / Build neuron-blackwell (push) Failing after 3m29s
build-prerelease / Build cortex binary (push) Successful in 4m41s
CI / Test (push) Successful in 5m6s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m20s
build-prerelease / Build neuron-ampere (push) Failing after 5m1s
build-prerelease / Build neuron-ada (push) Failing after 4m53s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Adds harness/tp/sharded_linear.rs with ShardedLinear — a Megatron-LM
style sharded wrapper over candle_nn::Linear. Two constructors:
- load_column: splits the output dimension. Each rank holds rows
[r*out/N .. (r+1)*out/N] of the weight, plus its slice of the bias.
Forward = local matmul; output is naturally sharded; downstream
consumer either accepts the shard (next layer is column-parallel)
or merges via all-gather later.
- load_row: splits the input dimension. Each rank holds cols
[r*in/N .. (r+1)*in/N] of the weight; bias lives only on rank 0
so the post-all_reduce sum carries it exactly once. Forward
produces a partial output that the caller reduces via NCCL.
Both constructors bail with a clear error when divisibility doesn't
hold — the precondition mistral.rs's first qwen3-next-tp commit
made explicit. The path included in the error is the VarBuilder
prefix, so the operator sees exactly which projection failed
("column-parallel 'model.layers.0.self_attn.q_proj': out_features=...").
5 unit tests on CPU verify the math against an unsharded reference:
- column shard produces the expected slice of the full matmul
- row partials sum to the unsharded result
- row bias appears only on rank 0
- divisibility violations bail (column + row)
forward_with_comm() is stubbed for row-parallel (CUDA-only) — wiring
the actual cudarc::nccl all_reduce against candle's Tensor lands in
7b-iii alongside the model assembly, where the model holds the Comm
in scope. ColumnParallel's forward_with_comm just delegates to the
local matmul (no collective needed).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
371 lines
15 KiB
Rust
371 lines
15 KiB
Rust
//! Tensor-parallel inference plumbing.
|
|
//!
|
|
//! The leader process (the neuron daemon proper) drives one
|
|
//! subprocess per non-zero NCCL rank — `tokio::process::Command` on
|
|
//! `/proc/self/exe --worker --rank N --tp-size N --cuda-device N` —
|
|
//! and talks to each over a newline-delimited JSON RPC channel on
|
|
//! the worker's stdin/stdout (see `rpc.rs`).
|
|
//!
|
|
//! Sub-staging:
|
|
//!
|
|
//! - **7a-i (this commit):** process lifecycle. `WorkerPool::spawn`
|
|
//! forks N workers; `ping` round-trips every worker to confirm
|
|
//! they're alive; `shutdown` cleanly drains and reaps. `Init` /
|
|
//! `NcclSanityCheck` are stubbed.
|
|
//! - **7a-ii:** real NCCL `Comm` setup via `Init`, sanity check via
|
|
//! `NcclSanityCheck`. CUDA-gated.
|
|
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
|
|
//! - **7c:** crash detection, streaming SSE, graceful unload.
|
|
|
|
pub mod nccl_state;
|
|
pub mod rpc;
|
|
pub mod sharded_linear;
|
|
pub mod worker;
|
|
|
|
use anyhow::{Context, Result};
|
|
use std::path::{Path, PathBuf};
|
|
use std::process::Stdio;
|
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines};
|
|
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
|
|
|
use rpc::{WorkerRequest, WorkerResponse};
|
|
|
|
/// One worker subprocess plus its bidirectional stdio handles.
|
|
struct Worker {
|
|
rank: u32,
|
|
/// Captured so the leader can log "spawned rank N on device M" and
|
|
/// future stages can re-issue Init after a CUDA reset. Unused in
|
|
/// the Stage 7a-i RPC paths themselves.
|
|
#[allow(dead_code)]
|
|
cuda_device: u32,
|
|
child: Child,
|
|
stdin: ChildStdin,
|
|
stdout: Lines<BufReader<ChildStdout>>,
|
|
}
|
|
|
|
impl Worker {
|
|
/// Send a request and wait for the response. Used for sequenced
|
|
/// ops like `Ping` / `Shutdown` where the caller doesn't need to
|
|
/// overlap the worker's execution with the leader's.
|
|
async fn request(&mut self, req: &WorkerRequest) -> Result<WorkerResponse> {
|
|
self.send_only(req).await?;
|
|
self.recv_only().await
|
|
}
|
|
|
|
/// Write a request without awaiting its response. Pair with
|
|
/// `recv_only` from the caller when leader and worker need to do
|
|
/// work concurrently — e.g. during `Init`, where the leader
|
|
/// itself calls `Comm::from_rank` on rank 0 in parallel with the
|
|
/// workers, then collects `InitOk` after NCCL completes.
|
|
async fn send_only(&mut self, req: &WorkerRequest) -> Result<()> {
|
|
let mut line = serde_json::to_string(req).context("serialise WorkerRequest")?;
|
|
line.push('\n');
|
|
self.stdin
|
|
.write_all(line.as_bytes())
|
|
.await
|
|
.with_context(|| format!("write request to rank {}", self.rank))?;
|
|
self.stdin
|
|
.flush()
|
|
.await
|
|
.with_context(|| format!("flush stdin to rank {}", self.rank))?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn recv_only(&mut self) -> Result<WorkerResponse> {
|
|
let reply = self
|
|
.stdout
|
|
.next_line()
|
|
.await
|
|
.with_context(|| format!("read reply from rank {}", self.rank))?
|
|
.ok_or_else(|| anyhow::anyhow!("rank {} stdout closed before reply", self.rank))?;
|
|
serde_json::from_str(&reply)
|
|
.with_context(|| format!("parse reply from rank {}: {reply:?}", self.rank))
|
|
}
|
|
}
|
|
|
|
/// A live pool of worker subprocesses. Owns the `Child` handles so
|
|
/// dropping the pool kills the children; explicit `shutdown()` is
|
|
/// the graceful path.
|
|
pub struct WorkerPool {
|
|
world_size: u32,
|
|
workers: Vec<Worker>,
|
|
/// Path to the neuron binary used to launch workers.
|
|
#[allow(dead_code)]
|
|
exe: PathBuf,
|
|
/// Leader's own NCCL rank-0 state. Defaults to empty; populated by
|
|
/// `init_nccl()`. Held here so the leader can participate in
|
|
/// collectives (rank 0) without spawning a fourth subprocess.
|
|
leader_nccl: nccl_state::NcclState,
|
|
}
|
|
|
|
impl WorkerPool {
|
|
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
|
|
/// leader (in-process) and is *not* spawned here — the leader
|
|
/// holds rank 0's NCCL Comm and shard in its own address space.
|
|
///
|
|
/// `binary` is the path to the neuron executable to run for each
|
|
/// worker (production passes `/proc/self/exe`; tests pass the
|
|
/// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`).
|
|
/// `cuda_devices` is one entry per rank including rank 0. Worker
|
|
/// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`.
|
|
pub async fn spawn(binary: &Path, world_size: u32, cuda_devices: &[u32]) -> Result<Self> {
|
|
if world_size < 2 {
|
|
anyhow::bail!(
|
|
"WorkerPool::spawn called with world_size={world_size}; \
|
|
use the single-process path for world_size < 2"
|
|
);
|
|
}
|
|
if cuda_devices.len() as u32 != world_size {
|
|
anyhow::bail!(
|
|
"expected {world_size} cuda_devices entries, got {}",
|
|
cuda_devices.len()
|
|
);
|
|
}
|
|
let exe = binary.to_path_buf();
|
|
|
|
let mut workers = Vec::with_capacity(world_size as usize - 1);
|
|
// Rank 0 stays in-process. Spawn ranks 1..world_size.
|
|
for rank in 1..world_size {
|
|
let cuda_device = cuda_devices[rank as usize];
|
|
let mut cmd = Command::new(&exe);
|
|
cmd.arg("--worker")
|
|
.arg("--rank")
|
|
.arg(rank.to_string())
|
|
.arg("--tp-size")
|
|
.arg(world_size.to_string())
|
|
.arg("--cuda-device")
|
|
.arg(cuda_device.to_string())
|
|
.stdin(Stdio::piped())
|
|
.stdout(Stdio::piped())
|
|
// Inherit stderr so worker tracing surfaces alongside
|
|
// the leader's journalctl stream.
|
|
.stderr(Stdio::inherit())
|
|
.kill_on_drop(true);
|
|
|
|
let mut child = cmd
|
|
.spawn()
|
|
.with_context(|| format!("spawn worker rank {rank}"))?;
|
|
let stdin = child
|
|
.stdin
|
|
.take()
|
|
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdin handle"))?;
|
|
let stdout = child
|
|
.stdout
|
|
.take()
|
|
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdout handle"))?;
|
|
let stdout = BufReader::new(stdout).lines();
|
|
|
|
workers.push(Worker {
|
|
rank,
|
|
cuda_device,
|
|
child,
|
|
stdin,
|
|
stdout,
|
|
});
|
|
tracing::info!(rank, cuda_device, "spawned tp worker");
|
|
}
|
|
|
|
Ok(Self {
|
|
world_size,
|
|
workers,
|
|
exe,
|
|
leader_nccl: nccl_state::NcclState::new(),
|
|
})
|
|
}
|
|
|
|
/// Establish the NCCL communicator across the leader (rank 0) and
|
|
/// every worker subprocess. Rendezvous is via a freshly-generated
|
|
/// `Id` broadcast over the RPC stream; the actual handshake blocks
|
|
/// inside `Comm::from_rank` until all `world_size` ranks check in.
|
|
///
|
|
/// `leader_cuda_device` is the CUDA device the leader binds rank 0
|
|
/// to — typically the first entry of the `cuda_devices` slice
|
|
/// originally passed to `spawn()`.
|
|
///
|
|
/// On the non-cuda build this immediately fails because the leader
|
|
/// can't generate an `Id` without libnccl. The same call works in
|
|
/// the worker path (returning a no-cuda error response) so the
|
|
/// failure surface is uniform.
|
|
pub async fn init_nccl(&mut self, leader_cuda_device: u32) -> Result<()> {
|
|
let comm_id = nccl_state::generate_comm_id_hex()
|
|
.map_err(|m| anyhow::anyhow!("generate NCCL id: {m}"))?;
|
|
|
|
// 1. Write Init to every worker's stdin without awaiting the
|
|
// response. Workers will parse and call Comm::from_rank
|
|
// concurrently with the leader below.
|
|
for w in &mut self.workers {
|
|
let req = WorkerRequest::Init {
|
|
comm_id: comm_id.clone(),
|
|
};
|
|
w.send_only(&req).await?;
|
|
}
|
|
|
|
// 2. Leader rank 0 calls Comm::from_rank on its own device.
|
|
// Runs on spawn_blocking because NCCL's init blocks until
|
|
// every rank has called in — that's exactly the workers
|
|
// above. The leader's NcclState is moved through the
|
|
// blocking task and returned to the pool.
|
|
let leader_cfg = worker::WorkerConfig {
|
|
rank: 0,
|
|
world_size: self.world_size,
|
|
cuda_device: leader_cuda_device,
|
|
};
|
|
let comm_id_for_leader = comm_id.clone();
|
|
// Swap out the leader's NcclState into a fresh empty one so we
|
|
// can move it into spawn_blocking; restore after the task
|
|
// returns. (NcclState isn't Clone — it owns a real NCCL Comm.)
|
|
let mut leader_state = std::mem::take(&mut self.leader_nccl);
|
|
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
|
|
let resp = leader_state.init(leader_cfg, &comm_id_for_leader);
|
|
(leader_state, resp)
|
|
})
|
|
.await
|
|
.context("leader NCCL init task panicked")?;
|
|
self.leader_nccl = returned_state;
|
|
match leader_resp {
|
|
rpc::WorkerResponse::InitOk => {}
|
|
rpc::WorkerResponse::Error { kind, message } => {
|
|
anyhow::bail!("leader rank 0 init failed [{kind}]: {message}");
|
|
}
|
|
other => anyhow::bail!("leader rank 0 init: unexpected {other:?}"),
|
|
}
|
|
|
|
// 3. Read InitOk from each worker. By now every worker has
|
|
// completed its Comm::from_rank call (NCCL released them
|
|
// when the leader joined the handshake) and is writing its
|
|
// response.
|
|
for w in &mut self.workers {
|
|
let resp = w.recv_only().await?;
|
|
match &resp {
|
|
rpc::WorkerResponse::InitOk => {}
|
|
rpc::WorkerResponse::Error { kind, message } => {
|
|
anyhow::bail!("worker rank {} init failed [{kind}]: {message}", w.rank);
|
|
}
|
|
other => anyhow::bail!(
|
|
"worker rank {} init: expected InitOk, got {other:?}",
|
|
w.rank
|
|
),
|
|
}
|
|
}
|
|
tracing::info!(
|
|
world_size = self.world_size,
|
|
"NCCL communicator established across all ranks"
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
/// Validate the NCCL communicator: every rank `all_reduce`s a
|
|
/// sentinel `1u32` with `ReduceOp::Sum`; the expected total is
|
|
/// `world_size`. Confirms the handshake is live, not just
|
|
/// configured.
|
|
///
|
|
/// Must be called after `init_nccl()`; before that the leader has
|
|
/// no Comm and the workers reply with `nccl_not_initialised`.
|
|
pub async fn nccl_sanity_check(&mut self) -> Result<()> {
|
|
// 1. Trigger the all_reduce on every worker (write-only).
|
|
for w in &mut self.workers {
|
|
w.send_only(&WorkerRequest::NcclSanityCheck).await?;
|
|
}
|
|
|
|
// 2. Leader's own all_reduce, in spawn_blocking. NCCL operations
|
|
// block until every rank participates.
|
|
let mut leader_state = std::mem::take(&mut self.leader_nccl);
|
|
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
|
|
let resp = leader_state.sanity_check();
|
|
(leader_state, resp)
|
|
})
|
|
.await
|
|
.context("leader NCCL sanity task panicked")?;
|
|
self.leader_nccl = returned_state;
|
|
|
|
let expected = self.world_size;
|
|
let leader_sum = match leader_resp {
|
|
rpc::WorkerResponse::NcclSanityResult { observed_sum } => observed_sum,
|
|
rpc::WorkerResponse::Error { kind, message } => {
|
|
anyhow::bail!("leader rank 0 sanity failed [{kind}]: {message}");
|
|
}
|
|
other => anyhow::bail!("leader rank 0 sanity: unexpected {other:?}"),
|
|
};
|
|
if leader_sum != expected {
|
|
anyhow::bail!("leader observed_sum={leader_sum}, expected {expected}");
|
|
}
|
|
|
|
// 3. Read sanity result from each worker. All must match
|
|
// world_size — anything else means the collective didn't
|
|
// complete consistently across ranks.
|
|
for w in &mut self.workers {
|
|
let resp = w.recv_only().await?;
|
|
match resp {
|
|
rpc::WorkerResponse::NcclSanityResult { observed_sum }
|
|
if observed_sum == expected => {}
|
|
rpc::WorkerResponse::NcclSanityResult { observed_sum } => {
|
|
anyhow::bail!(
|
|
"worker rank {} observed_sum={observed_sum}, expected {expected}",
|
|
w.rank
|
|
);
|
|
}
|
|
rpc::WorkerResponse::Error { kind, message } => {
|
|
anyhow::bail!("worker rank {} sanity failed [{kind}]: {message}", w.rank);
|
|
}
|
|
other => anyhow::bail!("worker rank {} sanity: unexpected {other:?}", w.rank),
|
|
}
|
|
}
|
|
tracing::info!(
|
|
world_size = expected,
|
|
"NCCL sanity check OK across all ranks"
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
/// Ping every worker and return their Pong payloads in rank order.
|
|
/// Useful right after `spawn` to confirm the lifecycle plumbing is
|
|
/// intact before kicking off any heavier work.
|
|
pub async fn ping_all(&mut self) -> Result<Vec<WorkerResponse>> {
|
|
let mut out = Vec::with_capacity(self.workers.len());
|
|
for w in &mut self.workers {
|
|
let resp = w.request(&WorkerRequest::Ping).await?;
|
|
match &resp {
|
|
WorkerResponse::Pong { rank, .. } if *rank == w.rank => {}
|
|
WorkerResponse::Pong { rank, .. } => {
|
|
anyhow::bail!("rank mismatch: expected {}, got {rank}", w.rank);
|
|
}
|
|
other => anyhow::bail!("expected Pong from rank {}, got {other:?}", w.rank),
|
|
}
|
|
out.push(resp);
|
|
}
|
|
Ok(out)
|
|
}
|
|
|
|
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
|
|
/// children. Best-effort — individual worker failures are logged
|
|
/// but don't abort the rest of the sweep.
|
|
pub async fn shutdown(mut self) -> Result<()> {
|
|
for w in &mut self.workers {
|
|
match w.request(&WorkerRequest::Shutdown).await {
|
|
Ok(WorkerResponse::Bye) => {}
|
|
Ok(other) => tracing::warn!(
|
|
rank = w.rank,
|
|
response = ?other,
|
|
"expected Bye on shutdown"
|
|
),
|
|
Err(e) => tracing::warn!(rank = w.rank, error = %e, "shutdown request failed"),
|
|
}
|
|
}
|
|
for w in &mut self.workers {
|
|
match w.child.wait().await {
|
|
Ok(status) => tracing::info!(rank = w.rank, %status, "worker exited"),
|
|
Err(e) => tracing::warn!(rank = w.rank, error = %e, "wait on worker failed"),
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub fn world_size(&self) -> u32 {
|
|
self.world_size
|
|
}
|
|
|
|
pub fn binary_path(&self) -> &PathBuf {
|
|
&self.exe
|
|
}
|
|
}
|