//! Tensor-parallel inference plumbing. //! //! The leader process (the neuron daemon proper) drives one //! subprocess per non-zero NCCL rank — `tokio::process::Command` on //! `/proc/self/exe --worker --rank N --tp-size N --cuda-device N` — //! and talks to each over a newline-delimited JSON RPC channel on //! the worker's stdin/stdout (see `rpc.rs`). //! //! Sub-staging: //! //! - **7a-i (this commit):** process lifecycle. `WorkerPool::spawn` //! forks N workers; `ping` round-trips every worker to confirm //! they're alive; `shutdown` cleanly drains and reaps. `Init` / //! `NcclSanityCheck` are stubbed. //! - **7a-ii:** real NCCL `Comm` setup via `Init`, sanity check via //! `NcclSanityCheck`. CUDA-gated. //! - **7b:** TP-aware Qwen3 inference dispatched through the pool. //! - **7c:** crash detection, streaming SSE, graceful unload. pub mod all_reduce; pub mod nccl_state; pub mod rpc; pub mod tp_linear; pub mod tp_qwen3; pub mod tp_qwen3_5; pub mod worker; use anyhow::{Context, Result}; use std::path::{Path, PathBuf}; use std::process::Stdio; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines}; use tokio::process::{Child, ChildStdin, ChildStdout, Command}; use rpc::{WorkerRequest, WorkerResponse}; /// Leader-side handle for any TP-loaded model. The pool's /// `load_dense_shard` dispatches on `config.json#/model_type` to build /// the right variant; downstream callers (the harness's /// `chat_completion_tp` path, `generate_step`, `clear_kv_cache`, /// `unload_model`) all hold this enum and let the variant dispatch /// determine the concrete forward. /// /// Variants gated on `cuda` because the underlying TP models hold /// `Arc` references — irrelevant on CPU builds. #[cfg(feature = "cuda")] pub enum TpLeaderModel { Qwen3(tp_qwen3::TpQwen3ForCausalLM), Qwen3_5(tp_qwen3_5::TpQwen3_5ForCausalLM), } #[cfg(feature = "cuda")] impl TpLeaderModel { pub fn forward( &mut self, input: &candle_core::Tensor, offset: usize, ) -> candle_core::Result { match self { TpLeaderModel::Qwen3(m) => m.forward(input, offset), TpLeaderModel::Qwen3_5(m) => m.forward(input, offset), } } pub fn clear_kv_cache(&mut self) { match self { TpLeaderModel::Qwen3(m) => m.clear_kv_cache(), TpLeaderModel::Qwen3_5(m) => m.clear_kv_cache(), } } pub fn device(&self) -> &candle_core::Device { match self { TpLeaderModel::Qwen3(m) => m.device(), TpLeaderModel::Qwen3_5(m) => m.device(), } } } /// One worker subprocess plus its bidirectional stdio handles. struct Worker { rank: u32, /// Captured so the leader can log "spawned rank N on device M" and /// future stages can re-issue Init after a CUDA reset. Unused in /// the Stage 7a-i RPC paths themselves. #[allow(dead_code)] cuda_device: u32, child: Child, stdin: ChildStdin, stdout: Lines>, } impl Worker { /// Send a request and wait for the response. Used for sequenced /// ops like `Ping` / `Shutdown` where the caller doesn't need to /// overlap the worker's execution with the leader's. async fn request(&mut self, req: &WorkerRequest) -> Result { self.send_only(req).await?; self.recv_only().await } /// Write a request without awaiting its response. Pair with /// `recv_only` from the caller when leader and worker need to do /// work concurrently — e.g. during `Init`, where the leader /// itself calls `Comm::from_rank` on rank 0 in parallel with the /// workers, then collects `InitOk` after NCCL completes. async fn send_only(&mut self, req: &WorkerRequest) -> Result<()> { let mut line = serde_json::to_string(req).context("serialise WorkerRequest")?; line.push('\n'); self.stdin .write_all(line.as_bytes()) .await .with_context(|| format!("write request to rank {}", self.rank))?; self.stdin .flush() .await .with_context(|| format!("flush stdin to rank {}", self.rank))?; Ok(()) } async fn recv_only(&mut self) -> Result { let reply = self .stdout .next_line() .await .with_context(|| format!("read reply from rank {}", self.rank))? .ok_or_else(|| anyhow::anyhow!("rank {} stdout closed before reply", self.rank))?; serde_json::from_str(&reply) .with_context(|| format!("parse reply from rank {}: {reply:?}", self.rank)) } } /// Drain one response from every worker, classifying each via the /// supplied checker. Always reads from every worker — even if some /// fail — so the next call's recv doesn't pick up stale responses /// from this one (pipe-poisoning was the cause of the /// "ClearKvCache: expected KvCacheCleared, got GenerateStepOk" class /// of bugs). /// /// Returns a vector of `rank N: detail` strings for any worker that /// errored, expected-mismatched, or failed to respond. Caller decides /// how to combine these with the leader's outcome. async fn drain_workers( workers: &mut [Worker], mut check: impl FnMut(WorkerResponse) -> std::result::Result<(), String>, ) -> Vec { let mut errs = Vec::new(); for w in workers { match w.recv_only().await { Ok(resp) => { if let Err(detail) = check(resp) { errs.push(format!("rank {} {detail}", w.rank)); } } Err(e) => errs.push(format!("rank {} recv: {e:#}", w.rank)), } } errs } /// Combine a leader's `Result>` (the typical /// `spawn_blocking → JoinHandle>` shape) with the worker /// drain results into a single `Result`. Leader failures take /// precedence in the error message but worker errors get appended so /// the operator sees both halves. #[cfg(feature = "cuda")] fn combine_leader_workers( leader: Result>, worker_errors: Vec, op: &str, ) -> Result { match leader { Ok(Ok(value)) => { if worker_errors.is_empty() { Ok(value) } else { anyhow::bail!( "{op}: leader succeeded but workers failed: {}", worker_errors.join("; ") ) } } Ok(Err(e)) => { if worker_errors.is_empty() { Err(e.context(format!("{op}: leader forward failed"))) } else { Err(e.context(format!( "{op}: leader forward failed and workers also failed: {}", worker_errors.join("; ") ))) } } Err(panic_err) => { if worker_errors.is_empty() { Err(panic_err) } else { Err(panic_err.context(format!( "{op}: leader task panicked and workers failed: {}", worker_errors.join("; ") ))) } } } } /// A live pool of worker subprocesses. Owns the `Child` handles so /// dropping the pool kills the children; explicit `shutdown()` is /// the graceful path. pub struct WorkerPool { world_size: u32, workers: Vec, /// Path to the neuron binary used to launch workers. #[allow(dead_code)] exe: PathBuf, /// Leader's own NCCL rank-0 state. Defaults to empty; populated by /// `init_nccl()`. Held here so the leader can participate in /// collectives (rank 0) without spawning a fourth subprocess. leader_nccl: nccl_state::NcclState, } impl WorkerPool { /// Spawn `world_size - 1` worker subprocesses. Rank 0 is the /// leader (in-process) and is *not* spawned here — the leader /// holds rank 0's NCCL Comm and shard in its own address space. /// /// `binary` is the path to the neuron executable to run for each /// worker (production passes `/proc/self/exe`; tests pass the /// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`). /// `cuda_devices` is one entry per rank including rank 0. Worker /// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`. pub async fn spawn(binary: &Path, world_size: u32, cuda_devices: &[u32]) -> Result { if world_size < 2 { anyhow::bail!( "WorkerPool::spawn called with world_size={world_size}; \ use the single-process path for world_size < 2" ); } if cuda_devices.len() as u32 != world_size { anyhow::bail!( "expected {world_size} cuda_devices entries, got {}", cuda_devices.len() ); } let exe = binary.to_path_buf(); let mut workers = Vec::with_capacity(world_size as usize - 1); // Rank 0 stays in-process. Spawn ranks 1..world_size. for rank in 1..world_size { let cuda_device = cuda_devices[rank as usize]; let mut cmd = Command::new(&exe); cmd.arg("--worker") .arg("--rank") .arg(rank.to_string()) .arg("--tp-size") .arg(world_size.to_string()) .arg("--cuda-device") .arg(cuda_device.to_string()) .stdin(Stdio::piped()) .stdout(Stdio::piped()) // Inherit stderr so worker tracing surfaces alongside // the leader's journalctl stream. .stderr(Stdio::inherit()) .kill_on_drop(true); let mut child = cmd .spawn() .with_context(|| format!("spawn worker rank {rank}"))?; let stdin = child .stdin .take() .ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdin handle"))?; let stdout = child .stdout .take() .ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdout handle"))?; let stdout = BufReader::new(stdout).lines(); workers.push(Worker { rank, cuda_device, child, stdin, stdout, }); tracing::info!(rank, cuda_device, "spawned tp worker"); } Ok(Self { world_size, workers, exe, leader_nccl: nccl_state::NcclState::new(), }) } /// Establish the NCCL communicator across the leader (rank 0) and /// every worker subprocess. Rendezvous is via a freshly-generated /// `Id` broadcast over the RPC stream; the actual handshake blocks /// inside `Comm::from_rank` until all `world_size` ranks check in. /// /// `leader_cuda_device` is the CUDA device the leader binds rank 0 /// to — typically the first entry of the `cuda_devices` slice /// originally passed to `spawn()`. /// /// On the non-cuda build this immediately fails because the leader /// can't generate an `Id` without libnccl. The same call works in /// the worker path (returning a no-cuda error response) so the /// failure surface is uniform. pub async fn init_nccl(&mut self, leader_cuda_device: u32) -> Result<()> { let comm_id = nccl_state::generate_comm_id_hex() .map_err(|m| anyhow::anyhow!("generate NCCL id: {m}"))?; // 1. Write Init to every worker's stdin without awaiting the // response. Workers will parse and call Comm::from_rank // concurrently with the leader below. for w in &mut self.workers { let req = WorkerRequest::Init { comm_id: comm_id.clone(), }; w.send_only(&req).await?; } // 2. Leader rank 0 calls Comm::from_rank on its own device. // Runs on spawn_blocking because NCCL's init blocks until // every rank has called in — that's exactly the workers // above. The leader's NcclState is moved through the // blocking task and returned to the pool. let leader_cfg = worker::WorkerConfig { rank: 0, world_size: self.world_size, cuda_device: leader_cuda_device, }; let comm_id_for_leader = comm_id.clone(); // Swap out the leader's NcclState into a fresh empty one so we // can move it into spawn_blocking; restore after the task // returns. (NcclState isn't Clone — it owns a real NCCL Comm.) let mut leader_state = std::mem::take(&mut self.leader_nccl); let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || { let resp = leader_state.init(leader_cfg, &comm_id_for_leader); (leader_state, resp) }) .await .context("leader NCCL init task panicked")?; self.leader_nccl = returned_state; match leader_resp { rpc::WorkerResponse::InitOk => {} rpc::WorkerResponse::Error { kind, message } => { anyhow::bail!("leader rank 0 init failed [{kind}]: {message}"); } other => anyhow::bail!("leader rank 0 init: unexpected {other:?}"), } // 3. Read InitOk from each worker. By now every worker has // completed its Comm::from_rank call (NCCL released them // when the leader joined the handshake) and is writing its // response. for w in &mut self.workers { let resp = w.recv_only().await?; match &resp { rpc::WorkerResponse::InitOk => {} rpc::WorkerResponse::Error { kind, message } => { anyhow::bail!("worker rank {} init failed [{kind}]: {message}", w.rank); } other => anyhow::bail!( "worker rank {} init: expected InitOk, got {other:?}", w.rank ), } } tracing::info!( world_size = self.world_size, "NCCL communicator established across all ranks" ); Ok(()) } /// Validate the NCCL communicator: every rank `all_reduce`s a /// sentinel `1u32` with `ReduceOp::Sum`; the expected total is /// `world_size`. Confirms the handshake is live, not just /// configured. /// /// Must be called after `init_nccl()`; before that the leader has /// no Comm and the workers reply with `nccl_not_initialised`. pub async fn nccl_sanity_check(&mut self) -> Result<()> { // 1. Trigger the all_reduce on every worker (write-only). for w in &mut self.workers { w.send_only(&WorkerRequest::NcclSanityCheck).await?; } // 2. Leader's own all_reduce, in spawn_blocking. NCCL operations // block until every rank participates. let mut leader_state = std::mem::take(&mut self.leader_nccl); let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || { let resp = leader_state.sanity_check(); (leader_state, resp) }) .await .context("leader NCCL sanity task panicked")?; self.leader_nccl = returned_state; let expected = self.world_size; let leader_sum = match leader_resp { rpc::WorkerResponse::NcclSanityResult { observed_sum } => observed_sum, rpc::WorkerResponse::Error { kind, message } => { anyhow::bail!("leader rank 0 sanity failed [{kind}]: {message}"); } other => anyhow::bail!("leader rank 0 sanity: unexpected {other:?}"), }; if leader_sum != expected { anyhow::bail!("leader observed_sum={leader_sum}, expected {expected}"); } // 3. Read sanity result from each worker. All must match // world_size — anything else means the collective didn't // complete consistently across ranks. for w in &mut self.workers { let resp = w.recv_only().await?; match resp { rpc::WorkerResponse::NcclSanityResult { observed_sum } if observed_sum == expected => {} rpc::WorkerResponse::NcclSanityResult { observed_sum } => { anyhow::bail!( "worker rank {} observed_sum={observed_sum}, expected {expected}", w.rank ); } rpc::WorkerResponse::Error { kind, message } => { anyhow::bail!("worker rank {} sanity failed [{kind}]: {message}", w.rank); } other => anyhow::bail!("worker rank {} sanity: unexpected {other:?}", w.rank), } } tracing::info!( world_size = expected, "NCCL sanity check OK across all ranks" ); Ok(()) } /// Ping every worker and return their Pong payloads in rank order. /// Useful right after `spawn` to confirm the lifecycle plumbing is /// intact before kicking off any heavier work. pub async fn ping_all(&mut self) -> Result> { let mut out = Vec::with_capacity(self.workers.len()); for w in &mut self.workers { let resp = w.request(&WorkerRequest::Ping).await?; match &resp { WorkerResponse::Pong { rank, .. } if *rank == w.rank => {} WorkerResponse::Pong { rank, .. } => { anyhow::bail!("rank mismatch: expected {}, got {rank}", w.rank); } other => anyhow::bail!("expected Pong from rank {}, got {other:?}", w.rank), } out.push(resp); } Ok(out) } /// Load this rank's shard of a dense Qwen3 model on every rank. /// /// The leader builds rank 0's `TpQwen3ForCausalLM` directly into /// the returned `Arc>` — workers build their rank-local /// shards in their own address spaces and confirm via /// `LoadDenseShardOk`. All ranks see the same `safetensors_paths`; /// `ShardedVarBuilder` slices each tensor by rank at materialisation /// time, so the per-rank VRAM footprint is roughly `1/world_size` /// of the full model (plus the replicated embedding/norm/lm_head). /// /// `leader_device` is the candle `Device` the leader's shard lives /// on — typically `Device::new_cuda(leader_cuda_device)` matching /// the same index passed to `init_nccl`. `dtype` is the on-device /// element type; bf16 is the canonical Qwen3 distribution dtype. /// /// `init_nccl` must have completed first. Bails if the leader's /// NCCL comm isn't set up yet. #[cfg(feature = "cuda")] pub async fn load_dense_shard( &mut self, model_id: &str, config_json: &str, safetensors_paths: &[std::path::PathBuf], leader_device: &candle_core::Device, dtype: candle_core::DType, ) -> Result>> { use candle_nn::var_builder::ShardedSafeTensors; use std::sync::Arc; use tokio::sync::Mutex; // Wrap the comm in SendComm immediately so it stays Send across // the await points in this method — bare Arc would // poison the async fn's Send bound (Comm's raw NCCL pointer is // !Send). The wrapper's safety contract is satisfied by the // pool's outer Mutex serialising callers + the spawn_blocking // thread being the only place ops are issued. let leader_comm = nccl_state::SendComm(self.leader_nccl.comm().ok_or_else(|| { anyhow::anyhow!("leader NCCL not initialised; call init_nccl first") })?); let world_size = self.world_size; let safetensors_str: Vec = safetensors_paths .iter() .map(|p| p.to_string_lossy().into_owned()) .collect(); // 1. Fan out the LoadDenseShard request to every worker without // awaiting their replies — they'll build their shards in // parallel with the leader below. for w in &mut self.workers { w.send_only(&WorkerRequest::LoadDenseShard { model_id: model_id.to_string(), config_json: config_json.to_string(), safetensors_paths: safetensors_str.clone(), }) .await?; } // 2. Build rank 0's shard on the leader. Dispatch on model_type // — for `qwen3` we build a `TpQwen3ForCausalLM`, for // `qwen3_5` (Qwen3-Next, Qwen3.6's architecture) we build // `TpQwen3_5ForCausalLM`. Both end up wrapped in the // `TpLeaderModel` enum so downstream callers don't care. let model_type = serde_json::from_str::(config_json) .ok() .as_ref() .and_then(|v| v.get("model_type")) .and_then(|v| v.as_str()) .unwrap_or("") .to_string(); let paths_for_leader: Vec = safetensors_paths.to_vec(); let device_for_leader = leader_device.clone(); let comm_for_leader = leader_comm; let model_id_for_log = model_id.to_string(); let config_json_for_leader = config_json.to_string(); let leader_model = tokio::task::spawn_blocking(move || -> Result { // SAFETY: same invariant as the single-GPU dense path — // the HF cache files are treated as immutable while the // mmap is held. let vb = unsafe { ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader) .context("build ShardedVarBuilder over safetensors")? }; let comm = comm_for_leader.into_inner(); let loaded = match model_type.as_str() { "qwen3" => { let cfg: super::tp::tp_qwen3::Config = serde_json::from_str(&config_json_for_leader) .context("parse Qwen3 Config JSON for leader load")?; TpLeaderModel::Qwen3(super::tp::tp_qwen3::TpQwen3ForCausalLM::load( &cfg, &vb, 0, world_size, comm, )?) } "qwen3_5" => { let cfg: super::tp::tp_qwen3_5::Config = serde_json::from_str(&config_json_for_leader) .context("parse Qwen3-Next Config JSON for leader load")?; TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load( cfg, &vb, 0, world_size, comm, )?) } other => anyhow::bail!( "TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)" ), }; tracing::info!(rank = 0, model = %model_id_for_log, model_type = %model_type, "loaded TP shard (leader)"); Ok(loaded) }) .await .context("leader load task panicked")??; // 3. Collect worker confirmations. Anything other than // LoadDenseShardOk aborts the whole load — the leader's // already-loaded shard drops when this fn returns Err. for w in &mut self.workers { let resp = w.recv_only().await?; match resp { WorkerResponse::LoadDenseShardOk => {} WorkerResponse::Error { kind, message } => { anyhow::bail!("worker rank {} LoadDenseShard [{kind}]: {message}", w.rank) } other => anyhow::bail!( "worker rank {} LoadDenseShard: expected LoadDenseShardOk, got {other:?}", w.rank ), } } Ok(Arc::new(Mutex::new(leader_model))) } /// Run one forward step across every rank. The leader's forward /// returns the last-position logits as a candle Tensor on the /// leader's device; the caller does sampling out-of-band. Workers /// run their own forwards (the AllReduce inside row-parallel layers /// is what lets the leader's collective complete) and reply with /// `GenerateStepOk` — they do not ship logits over the wire. /// /// `tokens` is the input for this step (prompt for prefill, the /// previously-sampled token for decode). `offset` is the KV-cache /// position before this step. #[cfg(feature = "cuda")] pub async fn generate_step( &mut self, model_id: &str, leader_model: std::sync::Arc>, tokens: Vec, offset: usize, ) -> Result { let step_start = std::time::Instant::now(); let tokens_len = tokens.len(); tracing::debug!( model = %model_id, tokens = tokens_len, offset, "WorkerPool::generate_step: fan-out" ); // 1. Fan-out to workers. for w in &mut self.workers { w.send_only(&WorkerRequest::GenerateStep { model_id: model_id.to_string(), tokens: tokens.clone(), offset, }) .await?; } // 2. Leader's forward in spawn_blocking. The AllReduce CustomOps // inside the row-parallel layers block until every worker's // forward issues the matching collective. let leader_start = std::time::Instant::now(); let leader_result = tokio::task::spawn_blocking(move || -> Result { let mut model = leader_model.blocking_lock(); let device = model.device().clone(); let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; // ForCausalLM::forward returns [B, 1, V] — squeeze both // leading dims to the rank-1 vocab logits the sampler wants. let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?; Ok(logits) }) .await .context("leader forward task panicked"); let leader_ok = matches!(leader_result, Ok(Ok(_))); tracing::debug!( model = %model_id, tokens = tokens_len, leader_ms = leader_start.elapsed().as_millis(), leader_ok, "WorkerPool::generate_step: leader forward returned" ); // 3. ALWAYS drain worker responses, regardless of whether the // leader succeeded. Skipping this on the leader's error // path leaves stale GenerateStepOk replies in the worker // pipes that poison the NEXT request's recv (was seeing // "ClearKvCache: expected KvCacheCleared, got // GenerateStepOk" the call after any forward-time failure). let drain_start = std::time::Instant::now(); let worker_errors = drain_workers(&mut self.workers, |r| match r { WorkerResponse::GenerateStepOk => Ok(()), WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")), other => Err(format!("expected GenerateStepOk, got {other:?}")), }) .await; tracing::debug!( model = %model_id, drain_ms = drain_start.elapsed().as_millis(), errors = worker_errors.len(), total_ms = step_start.elapsed().as_millis(), "WorkerPool::generate_step: workers drained" ); combine_leader_workers(leader_result, worker_errors, "GenerateStep") } /// Reset the KV cache for `model_id` on every rank. Called at the /// start of every inference so a fresh request doesn't attend over /// the previous one's tokens. pub async fn clear_kv_cache( &mut self, model_id: &str, #[cfg(feature = "cuda")] leader_model: std::sync::Arc>, ) -> Result<()> { let start = std::time::Instant::now(); tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out"); for w in &mut self.workers { w.send_only(&WorkerRequest::ClearKvCache { model_id: model_id.to_string(), }) .await?; } #[cfg(feature = "cuda")] { let mut m = leader_model.lock().await; m.clear_kv_cache(); } // Drain workers — same rationale as `generate_step`. The // leader's clear_kv_cache is in-process and infallible, but we // still always drain so an error on one worker doesn't leave // pending responses for the others. let worker_errors = drain_workers(&mut self.workers, |r| match r { WorkerResponse::KvCacheCleared => Ok(()), WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")), other => Err(format!("expected KvCacheCleared, got {other:?}")), }) .await; tracing::debug!( model = %model_id, elapsed_ms = start.elapsed().as_millis(), errors = worker_errors.len(), "WorkerPool::clear_kv_cache: workers drained" ); if !worker_errors.is_empty() { anyhow::bail!("ClearKvCache: {}", worker_errors.join("; ")); } Ok(()) } /// Drop this model's shards on every rank. The leader's shard is /// expected to have been dropped by the caller (its `Arc` was held /// in the TpLoadedModel and goes away when that's removed). pub async fn unload_model(&mut self, model_id: &str) -> Result<()> { for w in &mut self.workers { w.send_only(&WorkerRequest::UnloadModel { model_id: model_id.to_string(), }) .await?; } for w in &mut self.workers { let resp = w.recv_only().await?; match resp { WorkerResponse::Unloaded => {} WorkerResponse::Error { kind, message } => { anyhow::bail!("worker rank {} UnloadModel [{kind}]: {message}", w.rank) } other => anyhow::bail!( "worker rank {} UnloadModel: expected Unloaded, got {other:?}", w.rank ), } } Ok(()) } /// Send `Shutdown` to every worker, await each `Bye`, and reap the /// children. Best-effort — individual worker failures are logged /// but don't abort the rest of the sweep. pub async fn shutdown(mut self) -> Result<()> { for w in &mut self.workers { match w.request(&WorkerRequest::Shutdown).await { Ok(WorkerResponse::Bye) => {} Ok(other) => tracing::warn!( rank = w.rank, response = ?other, "expected Bye on shutdown" ), Err(e) => tracing::warn!(rank = w.rank, error = %e, "shutdown request failed"), } } for w in &mut self.workers { match w.child.wait().await { Ok(status) => tracing::info!(rank = w.rank, %status, "worker exited"), Err(e) => tracing::warn!(rank = w.rank, error = %e, "wait on worker failed"), } } Ok(()) } pub fn world_size(&self) -> u32 { self.world_size } pub fn binary_path(&self) -> &PathBuf { &self.exe } }