From 76ab24d98c3518ca924e755120ff32fdba2dc472 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Wed, 27 May 2026 10:16:02 +0300 Subject: [PATCH] =?UTF-8?q?refactor(neuron):=20phase=203=20=E2=80=94=20TP?= =?UTF-8?q?=20forward=20+=20NCCL=20state=20move=20onto=20device=20worker?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Third slice of the per-device CUDA context-ownership refactor planned at ~/.claude/plans/plan-the-per-device-worker-abstract-micali.md. The leader's `NcclState`, every `Comm::all_reduce` issued by the TP layers, the leader-side KV cache reset, and the TP forward step itself now all run on the per-device worker thread — the same OS thread that bound the leader's `CudaContext` at startup. What this phase changes: - `Job` gains `NcclInit`, `NcclSanity`, `CloneLeaderComm` (Phase 3 bridge — Phase 4 removes), `TransferInTp`, `DropTp`, `TpClearKv`, `TpForwardLogits`. Plus a new `TpHandle(u64)` opaque key. - `DeviceWorkerState` gains `nccl: NcclState` and `tp_models: HashMap>` (+ counter). - `WorkerPool` loses its `leader_nccl` field; gains a `leader_worker: Arc` passed at construction. `init_nccl`, `nccl_sanity_check`, `load_dense_shard`, `generate_step`, `clear_kv_cache` all route their leader-side ops through `Job::Nccl*` / `Job::Tp*` instead of spawn_blocking against a Mutex-wrapped state. `generate_step` returns `Vec` instead of a device-resident `Tensor` — the worker copies logits to CPU before reply so the async caller can sample on a CPU candle tensor with zero device-context touch. - `TpLoadedModel.leader_model: Arc>` → opaque `leader_handle: TpHandle`. The boxed `TpLeaderModel` lives in the worker thread's slab; both the model's CUDA tensors and the embedded `Arc` clones release on the same thread that allocated them (the Drop semantics constraint cudarc forces). - `Job::CloneLeaderComm` is a Phase 3 bridge: the TP shard load still runs in spawn_blocking and needs the leader's `Arc` to build the row-parallel layers' AllReduce ops. The Job clones the Comm out of the worker's NcclState and ships it back as `SendComm`. Phase 4 deletes this bridge when the load itself moves onto the worker. - `Job::NcclInit` and `Job::NcclSanity` are ungated by `cuda` so the no-cuda `NcclState` stubs (which reply with `cuda_feature_not_enabled`) still flow through the same channel uniformly; the cuda-only TP variants (CloneLeaderComm, Transfer/Drop/Clear/Forward Tp) remain gated. What this phase doesn't touch (yet): - TP shard load itself — still spawn_blocking, bridged via `CloneLeaderComm`. Phase 4 moves it to `Job::TpLoadShard` and reads `state.nccl.comm()` directly inside the worker. - Single-GPU model loads — still spawn_blocking, transferred via `Job::TransferIn`. Phase 4 moves them. - `device_vram_mb` / `cuda_mem_mb` / `log_construction_complete` helpers — still present, used inside spawn_blocking load closures. Phase 4 cleanup folds them into `dispatch.rs`. `tp/mod.rs::WorkerPool::spawn` gained a required `leader_worker: Arc` argument. Three external callers were updated: `CandleHarness::load_tp` (passes the cached device worker), `main.rs::tp_smoke` (spawns a fresh worker), and the two `tp_worker_lifecycle*.rs` integration tests. Public API unchanged. fmt + clippy clean; 37 lib tests + all integration tests pass. CUDA-only TP integration smoke deferred to the next deploy on beast. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/neuron/src/harness/candle.rs | 136 +++++++---- .../src/harness/device_worker/dispatch.rs | 165 +++++++++++++ .../neuron/src/harness/device_worker/jobs.rs | 77 ++++++ .../neuron/src/harness/device_worker/mod.rs | 188 ++++++++++++++ crates/neuron/src/harness/tp/mod.rs | 229 +++++++++++------- crates/neuron/src/main.rs | 8 +- crates/neuron/tests/tp_worker_lifecycle.rs | 7 +- .../neuron/tests/tp_worker_lifecycle_cuda.rs | 4 +- 8 files changed, 676 insertions(+), 138 deletions(-) diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 23d5a71..9beae87 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -163,16 +163,22 @@ pub struct TpLoadedModel { pub model_id: String, pub tokenizer: Tokenizer, pub devices: Vec, - /// One end-to-end gate: the pool's RPC stream isn't safe to use - /// concurrently and the leader shard's KV cache mutates with every - /// step. The same Mutex covers both for the simplest correctness - /// story. + /// One end-to-end gate: the pool's RPC stream to the subprocess + /// workers isn't safe to use concurrently. After Phase 3 the + /// leader's `TpLeaderModel` lives in the worker thread's slab, + /// so this Mutex no longer covers the leader's KV cache; it just + /// serialises subprocess RPC traffic on the pool's + /// `Vec` channels. pub pool: tokio::sync::Mutex, - pub leader_model: Arc>, - /// Candle device for rank 0. Mirrors what `leader_model.device()` - /// would return, but stored separately so the request path can - /// query VRAM without locking the leader (which would contend with - /// the in-flight forward). + /// Handle into the leader device worker's TP slab. The boxed + /// `TpLeaderModel` (with its embedded `Arc` clones and + /// per-rank CUDA tensors) lives on the worker thread; we hold an + /// opaque index. Forward / clear_kv / unload all route through + /// `Job::Tp*` against this handle. + pub leader_handle: super::device_worker::TpHandle, + /// Candle device for rank 0. Mirrors what + /// `TpLeaderModel::device()` would return, kept on the struct so + /// the request path can name the device without an RPC. pub leader_device: Device, /// Same poisoning gate as [`LoadedModel::poisoned`]. A TP forward /// failure (CUDA OOM on any rank, NCCL desync, illegal address) is @@ -180,9 +186,8 @@ pub struct TpLoadedModel { /// reliably reset without restarting the worker subprocesses. pub poisoned: AtomicBool, /// Worker thread for the leader's CUDA device. Owns the leader's - /// `CudaContext` for the daemon's lifetime. VRAM queries route - /// through it; in later refactor phases the forward, kv-cache - /// clear, and shard unload route through it too. + /// `CudaContext`, `NcclState`, and the boxed `TpLeaderModel` + /// referenced by `leader_handle`. pub worker: Arc, } @@ -1642,6 +1647,16 @@ impl Harness for CandleHarness { anyhow::bail!("cannot unload '{model_id}': inference still in flight"); } }; + // Drop the leader's TpLeaderModel on the device worker + // thread (CUDA tensors and Arc clones release on + // the same OS thread that allocated them). + if let Err(e) = tp.worker.drop_tp(tp.leader_handle).await { + tracing::warn!( + model = %model_id, + error = %e, + "TP unload: DropTp RPC failed (leader model may leak in worker slab)" + ); + } let mut pool = tp.pool.into_inner(); if let Err(e) = pool.unload_model(model_id).await { tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed"); @@ -1715,9 +1730,14 @@ impl CandleHarness { // 2. Spawn the worker pool. Rank 0 stays in-process; ranks // 1..tp_size are subprocesses, one per device after the - // leader's own. + // leader's own. The leader's device worker thread is + // spawned (or reused) here and passed into the pool so + // `init_nccl`, the load, every TP forward, and KV-cache + // clears all dispatch from the same OS thread. let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?; - let mut pool = super::tp::WorkerPool::spawn(&exe, tp_size, &devices).await?; + let leader_worker = self.ensure_device_worker(devices[0]).await?; + let mut pool = + super::tp::WorkerPool::spawn(&exe, tp_size, &devices, leader_worker.clone()).await?; // 3. NCCL handshake across all ranks. let leader_device_idx = devices[0]; @@ -1727,8 +1747,11 @@ impl CandleHarness { let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize) .context("Device::new_cuda for TP leader")?; - // 5. Load this rank's shard on every rank. - let leader_model = pool + // 5. Load this rank's shard on every rank. After Phase 3 + // `load_dense_shard` transfers the freshly-built + // `TpLeaderModel` into the device worker's TP slab and + // returns the resulting handle. + let leader_handle = pool .load_dense_shard( &spec.model_id, &config_json, @@ -1743,21 +1766,18 @@ impl CandleHarness { let tokenizer = Tokenizer::from_file(&tokenizer_path) .map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?; - // 7. Worker thread for the leader's CUDA device. TP always - // runs on CUDA — the harness rejects TP without the cuda - // feature earlier in this function — so we always have a - // device to own. - let worker = self.ensure_device_worker(devices[0]).await?; - let tp_loaded = StdArc::new(TpLoadedModel { model_id: spec.model_id.clone(), tokenizer, devices: devices.clone(), pool: TMutex::new(pool), - leader_model, + leader_handle, leader_device: leader_device.clone(), poisoned: AtomicBool::new(false), - worker, + // Same `leader_worker` we passed into the pool above — + // single `Arc` shared between WorkerPool and + // TpLoadedModel so they reference the same thread. + worker: leader_worker, }); let mut models = self.models.write().await; @@ -1932,14 +1952,14 @@ impl CandleHarness { async move { let mut failure: Option = None; let mut pool = acquire_pool_lock(&tp_for_task.pool, &model_id).await; - let leader_arc = tp_for_task.leader_model.clone(); + let leader_handle = tp_for_task.leader_handle; let mut all_tokens: Vec = Vec::new(); let mut decoded_prefix = String::new(); let mut finish_reason = "length".to_string(); 'work: { - if let Err(e) = pool.clear_kv_cache(&model_id, leader_arc.clone()).await { + if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await { failure = Some(format!("clear_kv_cache: {e:#}")); break 'work; } @@ -1957,8 +1977,8 @@ impl CandleHarness { }; // Prefill — every rank embeds the prompt, offset = 0. - let logits = match pool - .generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0) + let logits_vec = match pool + .generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0) .await { Ok(l) => l, @@ -1974,11 +1994,18 @@ impl CandleHarness { vram_free_mb = post_prefill_vram_free_mb, "TP chat_completion (stream): prefill complete" ); + let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) { + Ok(t) => t, + Err(e) => { + failure = Some(format!("prefill build cpu logits: {e:#}")); + break 'work; + } + }; let mut next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) { Ok(t) => t, Err(e) => { - let health = logits_health(&logits); + let health = logits_health_slice(&logits_vec); tracing::warn!( model = %model_id, ?health, @@ -2010,10 +2037,10 @@ impl CandleHarness { } for index in 0..max_new.saturating_sub(1) { - let logits = match pool + let logits_vec = match pool .generate_step( &model_id, - leader_arc.clone(), + leader_handle, vec![next_token], prompt_len + index, ) @@ -2025,6 +2052,14 @@ impl CandleHarness { break 'work; } }; + let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) { + Ok(t) => t, + Err(e) => { + failure = + Some(format!("decode build cpu logits {index}: {e:#}")); + break 'work; + } + }; next_token = match sample_with_penalty( &logits, &all_tokens, @@ -2032,7 +2067,7 @@ impl CandleHarness { ) { Ok(t) => t, Err(e) => { - let health = logits_health(&logits); + let health = logits_health_slice(&logits_vec); tracing::warn!( model = %model_id, step = index, @@ -2180,20 +2215,19 @@ async fn chat_completion_tp_inner( "TP chat_completion: starting" ); - // Acquire the pool lock for the duration of the request. The - // leader_model's own Mutex is acquired step-by-step inside - // pool.generate_step (so spawn_blocking can grab it without - // holding the pool lock across the blocking_lock call). - // `acquire_pool_lock` warns periodically while we wait so a - // stuck holder doesn't make the queueing requests look like - // silence in the journal. + // Acquire the pool lock for the duration of the request. After + // Phase 3 the leader's TpLeaderModel lives in the device worker + // thread, so the pool lock now serialises only subprocess RPC + // traffic — but holding it for the whole request still keeps + // concurrent chat_completions against the same TP model from + // interleaving prefill/decode jobs. let mut pool = acquire_pool_lock(&tp.pool, &model_id).await; - let leader_arc = tp.leader_model.clone(); + let leader_handle = tp.leader_handle; // Reset every rank's KV cache so this request doesn't attend // over the previous request's tokens. let clear_start = std::time::Instant::now(); - pool.clear_kv_cache(&model_id, leader_arc.clone()) + pool.clear_kv_cache(&model_id, leader_handle) .await .map_err(InferenceError::Other)?; tracing::debug!( @@ -2219,8 +2253,8 @@ async fn chat_completion_tp_inner( // Prefill: every rank embeds the whole prompt, offset = 0. let prefill_start = std::time::Instant::now(); - let logits = pool - .generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0) + let logits_vec = pool + .generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0) .await .map_err(InferenceError::Other)?; let (post_prefill_vram_free_mb, _) = tp.query_vram().await; @@ -2231,6 +2265,11 @@ async fn chat_completion_tp_inner( vram_free_mb = post_prefill_vram_free_mb, "TP chat_completion: prefill complete" ); + // Wrap the CPU-side logits in a CPU candle Tensor for sampling. + // No device touch on the async caller's thread — sampling reads + // from CPU memory only. + let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu) + .map_err(|e| InferenceError::Other(anyhow::anyhow!("build cpu logits: {e}")))?; let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) { Ok(t) => t, Err(e) => { @@ -2239,7 +2278,7 @@ async fn chat_completion_tp_inner( // this WARN sits just above that and carries the actual // numerical state so an operator can tell at a glance // whether it was a NaN cascade, an Inf, or something else. - let health = logits_health(&logits); + let health = logits_health_slice(&logits_vec); tracing::warn!( model = %model_id, ?health, @@ -2256,19 +2295,22 @@ async fn chat_completion_tp_inner( let decode_start = std::time::Instant::now(); for index in 0..max_new.saturating_sub(1) { let step_start = std::time::Instant::now(); - let logits = pool + let logits_vec = pool .generate_step( &model_id, - leader_arc.clone(), + leader_handle, vec![next_token], prompt_len + index, ) .await .map_err(InferenceError::Other)?; + let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu).map_err(|e| { + InferenceError::Other(anyhow::anyhow!("build cpu logits step {index}: {e}")) + })?; next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) { Ok(t) => t, Err(e) => { - let health = logits_health(&logits); + let health = logits_health_slice(&logits_vec); tracing::warn!( model = %model_id, step = index, diff --git a/crates/neuron/src/harness/device_worker/dispatch.rs b/crates/neuron/src/harness/device_worker/dispatch.rs index a6e2e3a..89196bf 100644 --- a/crates/neuron/src/harness/device_worker/dispatch.rs +++ b/crates/neuron/src/harness/device_worker/dispatch.rs @@ -14,7 +14,12 @@ //! `tp_models: HashMap>`. use crate::harness::candle::ModelArch; +#[cfg(feature = "cuda")] +use crate::harness::device_worker::jobs::TpHandle; use crate::harness::device_worker::jobs::{ArchHandle, Job}; +#[cfg(feature = "cuda")] +use crate::harness::tp::TpLeaderModel; +use crate::harness::tp::nccl_state::NcclState; use std::collections::HashMap; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; @@ -40,6 +45,20 @@ struct DeviceWorkerState { /// increments and returns the new value. Wraps at u64::MAX after /// ~10^19 model loads — not a practical concern. next_handle: u64, + /// Leader's NCCL state. Populated by `Job::NcclInit`; the + /// underlying `Comm`'s libnccl handle lives bound to this thread + /// for its entire lifetime. Subprocess workers maintain their own + /// `NcclState` in their own processes — that's not visible from + /// here. + #[allow(dead_code)] // Read only via methods on NcclState + nccl: NcclState, + /// TP leader model slab. Same lifecycle as `models`; separate + /// namespace so `ArchHandle` and `TpHandle` can't collide. + #[cfg(feature = "cuda")] + tp_models: HashMap>, + /// Counter for minting fresh `TpHandle`s. + #[cfg(feature = "cuda")] + next_tp_handle: u64, #[cfg(feature = "cuda")] #[allow(dead_code)] /// `None` only if `CudaContext::new()` failed — in that case the @@ -128,15 +147,93 @@ pub(crate) fn run(device_index: u32, rx: Receiver, poisoned: Arc { + let resp = state.nccl.init(cfg, &comm_id_hex); + let _ = reply.send(resp); + } + Job::NcclSanity { reply } => { + let resp = state.nccl.sanity_check(); + let _ = reply.send(resp); + } + #[cfg(feature = "cuda")] + Job::CloneLeaderComm { reply } => { + let result = match state.nccl.comm() { + Some(comm) => Ok(crate::harness::tp::nccl_state::SendComm(comm)), + None => Err(anyhow::anyhow!( + "CloneLeaderComm: NcclState has no Comm; call NcclInit first" + )), + }; + let _ = reply.send(result); + } + #[cfg(feature = "cuda")] + Job::TransferInTp { model, reply } => { + let handle = TpHandle(state.next_tp_handle); + state.next_tp_handle = state.next_tp_handle.wrapping_add(1); + state.tp_models.insert(handle, model); + tracing::debug!( + device_index, + tp_handle = handle.0, + slab_size = state.tp_models.len(), + "device worker: TP model transferred in" + ); + let _ = reply.send(Ok(handle)); + } + #[cfg(feature = "cuda")] + Job::DropTp { handle, reply } => { + let removed = state.tp_models.remove(&handle); + let was_present = removed.is_some(); + drop(removed); + tracing::debug!( + device_index, + tp_handle = handle.0, + was_present, + slab_size = state.tp_models.len(), + "device worker: TP model dropped" + ); + let _ = reply.send(()); + } + #[cfg(feature = "cuda")] + Job::TpClearKv { handle, reply } => { + let result = match state.tp_models.get_mut(&handle) { + Some(model) => { + model.clear_kv_cache(); + Ok(()) + } + None => Err(anyhow::anyhow!( + "TpClearKv: no TP model for handle {}", + handle.0 + )), + }; + let _ = reply.send(result); + } + #[cfg(feature = "cuda")] + Job::TpForwardLogits { + handle, + tokens, + offset, + reply, + } => { + let result = tp_forward_logits(&mut state, handle, &tokens, offset); + let _ = reply.send(result); + } // Handled by the matches!() check above; reaching here // means a Shutdown slipped past which is a bug. Job::Shutdown => unreachable!("Shutdown should break above"), } } + #[cfg(feature = "cuda")] + let tp_slab_size = state.tp_models.len(); + #[cfg(not(feature = "cuda"))] + let tp_slab_size = 0_usize; tracing::info!( device_index, slab_size = state.models.len(), + tp_slab_size, "device worker exiting; dropping remaining models" ); // Drops every model in the slab on this thread before the function @@ -193,6 +290,9 @@ fn init_state(device_index: u32) -> DeviceWorkerState { device, models: HashMap::new(), next_handle: 1, + nccl: NcclState::new(), + tp_models: HashMap::new(), + next_tp_handle: 1, ctx, } } @@ -203,6 +303,7 @@ fn init_state(device_index: u32) -> DeviceWorkerState { device: candle_core::Device::Cpu, models: HashMap::new(), next_handle: 1, + nccl: NcclState::new(), } } } @@ -231,6 +332,38 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> { Ok((0, 0)) } +/// TP-equivalent of [`forward_logits`]: looks up the leader's +/// [`TpLeaderModel`] in the slab, runs its forward, copies the +/// `[vocab]` logits to a CPU `Vec`. The leader's `Arc` +/// clones embedded in the TP layers' AllReduce ops fire from this +/// thread — same thread that bound the CUDA context and that holds +/// the `Comm` in `state.nccl`. +#[cfg(feature = "cuda")] +fn tp_forward_logits( + state: &mut DeviceWorkerState, + handle: TpHandle, + tokens: &[u32], + offset: usize, +) -> anyhow::Result> { + use candle_core::{DType, Tensor}; + + let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?; + + let model = state + .tp_models + .get_mut(&handle) + .ok_or_else(|| anyhow::anyhow!("TpForwardLogits: no model for handle {}", handle.0))?; + + let logits = model.forward(&input, offset)?; + // ForCausalLM forward returns [B, 1, V] after the trailing + // .i((.., l - 1.., ..))?.apply(lm_head); squeeze both leading + // singleton dims to a rank-1 [V] tensor for sampling. + let logits = logits.squeeze(0)?.squeeze(0)?; + let logits = logits.to_dtype(DType::F32)?.flatten_all()?; + let values = logits.to_vec1::()?; + Ok(values) +} + /// Forward step + copy the `[vocab]` logits to a CPU `Vec` ready /// for sampling on the async caller. The model's `device()` (CUDA or /// CPU) determines where the kernel runs; this fn doesn't care. @@ -297,6 +430,38 @@ fn drain_poisoned(job: Job, device_index: u32) { Job::ForwardLogits { reply, .. } => { let _ = reply.send(Err(err())); } + Job::NcclInit { reply, .. } => { + let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error { + kind: "device_worker_poisoned".into(), + message: format!("device worker {device_index} poisoned"), + }); + } + Job::NcclSanity { reply } => { + let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error { + kind: "device_worker_poisoned".into(), + message: format!("device worker {device_index} poisoned"), + }); + } + #[cfg(feature = "cuda")] + Job::CloneLeaderComm { reply } => { + let _ = reply.send(Err(err())); + } + #[cfg(feature = "cuda")] + Job::TransferInTp { reply, .. } => { + let _ = reply.send(Err(err())); + } + #[cfg(feature = "cuda")] + Job::DropTp { reply, .. } => { + let _ = reply.send(()); + } + #[cfg(feature = "cuda")] + Job::TpClearKv { reply, .. } => { + let _ = reply.send(Err(err())); + } + #[cfg(feature = "cuda")] + Job::TpForwardLogits { reply, .. } => { + let _ = reply.send(Err(err())); + } Job::Shutdown => { // Filtered by the matches!() guard in run(); reaching // here would be a logic error. diff --git a/crates/neuron/src/harness/device_worker/jobs.rs b/crates/neuron/src/harness/device_worker/jobs.rs index 2cfbb70..240323e 100644 --- a/crates/neuron/src/harness/device_worker/jobs.rs +++ b/crates/neuron/src/harness/device_worker/jobs.rs @@ -19,6 +19,15 @@ use tokio::sync::oneshot; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct ArchHandle(pub u64); +/// Opaque handle to a `TpLeaderModel` stored in the worker thread's +/// state slab. Same shape as [`ArchHandle`] but in a separate +/// namespace so the two slabs can coexist without ambiguity. Phase 3 +/// introduces it; Phase 4 may unify the two slabs after the TP forward +/// path proves out. +#[cfg(feature = "cuda")] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TpHandle(pub u64); + /// One unit of work for the device worker. /// /// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the @@ -73,6 +82,74 @@ pub enum Job { offset: usize, reply: oneshot::Sender>>, }, + /// Initialize the leader's NCCL communicator. The worker's + /// `NcclState` mints the `Comm` here so its underlying + /// `ncclComm_t` and `CudaContext` live on the same thread as + /// every later `Comm::all_reduce` call. Reply is the worker + /// response shape used by the subprocess workers (`InitOk` on + /// success, `Error` on failure) so the calling + /// `WorkerPool::init_nccl` orchestration stays uniform. + /// + /// Available on both cuda and no-cuda builds — the dispatch + /// handler calls `NcclState::init` which has a no-cuda stub that + /// replies with `cuda_feature_not_enabled`. Keeping the Job + /// variant ungated lets `WorkerPool::init_nccl` stay uniform. + NcclInit { + cfg: crate::harness::tp::worker::WorkerConfig, + comm_id_hex: String, + reply: oneshot::Sender, + }, + /// Run NCCL's all_reduce sanity check on the leader's rank 0. + /// Same response shape as `NcclInit`; also available on both + /// builds via the no-cuda `NcclState::sanity_check` stub. + NcclSanity { + reply: oneshot::Sender, + }, + /// Clone the leader's `Arc` out of the worker's `NcclState` + /// so a spawn_blocking-based load (Phase 3 bridge) can hand it to + /// the row-parallel layers. Wrapped in `SendComm` because + /// `Arc` is `!Send` at the type level (the NCCL contract + /// requires serialised access, which we provide structurally). + /// Phase 4 eliminates this when `TpLoadShard` becomes a Job and + /// the load runs entirely on the worker thread. + #[cfg(feature = "cuda")] + CloneLeaderComm { + reply: oneshot::Sender>, + }, + /// Move a freshly-built `TpLeaderModel` into the worker's tp slab. + /// Returns a `TpHandle` the caller stores on `TpLoadedModel`. + #[cfg(feature = "cuda")] + TransferInTp { + model: Box, + reply: oneshot::Sender>, + }, + /// Drop the TP leader model on the worker thread. CUDA tensors + /// and `Arc` clones held inside the model release on the + /// thread that allocated them. + #[cfg(feature = "cuda")] + DropTp { + handle: TpHandle, + reply: oneshot::Sender<()>, + }, + /// Reset the leader's KV cache for a TP model. Mirrors `ClearKv` + /// for single-GPU. + #[cfg(feature = "cuda")] + TpClearKv { + handle: TpHandle, + reply: oneshot::Sender>, + }, + /// Run one TP forward step on the leader's shard. Returns CPU- + /// side logits as a `Vec` so the async caller can sample + /// without holding a device tensor. The caller is also + /// responsible for fan-out to subprocess ranks and drain — only + /// the leader's forward moves into the worker thread. + #[cfg(feature = "cuda")] + TpForwardLogits { + handle: TpHandle, + tokens: Vec, + offset: usize, + reply: oneshot::Sender>>, + }, /// Tell the worker to break its dispatch loop and exit. Any jobs /// queued after this in the channel reply `Err` to their oneshot /// senders (the senders are dropped on the worker's exit, which diff --git a/crates/neuron/src/harness/device_worker/mod.rs b/crates/neuron/src/harness/device_worker/mod.rs index 497eeb3..5717457 100644 --- a/crates/neuron/src/harness/device_worker/mod.rs +++ b/crates/neuron/src/harness/device_worker/mod.rs @@ -49,6 +49,8 @@ use std::sync::mpsc::{self, Sender}; use std::thread::JoinHandle; use tokio::sync::oneshot; +#[cfg(feature = "cuda")] +pub use jobs::TpHandle; pub use jobs::{ArchHandle, Job}; /// Errors returned by `DeviceWorkerHandle` submit methods. @@ -277,6 +279,192 @@ impl DeviceWorkerHandle { } } + /// Initialise the leader's NCCL communicator. The reply uses + /// `WorkerResponse` (same shape subprocess workers use over stdio + /// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader + + /// subprocess responses uniformly. Available on no-cuda builds + /// too — the dispatch handler calls the no-cuda `NcclState::init` + /// stub which replies `cuda_feature_not_enabled`. + pub async fn nccl_init( + &self, + cfg: crate::harness::tp::worker::WorkerConfig, + comm_id_hex: String, + ) -> Result { + if self.poisoned.load(Ordering::Acquire) { + return Err(WorkerError::Poisoned { + device_index: self.device_index, + }); + } + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::NcclInit { + cfg, + comm_id_hex, + reply: reply_tx, + }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + reply_rx.await.map_err(|_| WorkerError::Gone { + device_index: self.device_index, + }) + } + + /// Run an NCCL sanity all_reduce on the leader's rank 0. + /// Available on no-cuda builds; replies with an error response. + pub async fn nccl_sanity( + &self, + ) -> Result { + if self.poisoned.load(Ordering::Acquire) { + return Err(WorkerError::Poisoned { + device_index: self.device_index, + }); + } + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::NcclSanity { reply: reply_tx }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + reply_rx.await.map_err(|_| WorkerError::Gone { + device_index: self.device_index, + }) + } + + /// Clone the leader's `Arc` so a spawn_blocking-based load + /// (Phase 3 bridge) can pass it to the row-parallel layers. + /// Phase 4 eliminates this once the TP load runs on this thread. + #[cfg(feature = "cuda")] + pub async fn clone_leader_comm( + &self, + ) -> Result { + if self.poisoned.load(Ordering::Acquire) { + return Err(WorkerError::Poisoned { + device_index: self.device_index, + }); + } + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::CloneLeaderComm { reply: reply_tx }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + match reply_rx.await { + Ok(result) => result.map_err(WorkerError::from), + Err(_) => Err(WorkerError::Gone { + device_index: self.device_index, + }), + } + } + + /// Move a freshly-built `TpLeaderModel` into the worker's TP slab. + #[cfg(feature = "cuda")] + pub async fn transfer_in_tp( + &self, + model: Box, + ) -> Result { + if self.poisoned.load(Ordering::Acquire) { + return Err(WorkerError::Poisoned { + device_index: self.device_index, + }); + } + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::TransferInTp { + model, + reply: reply_tx, + }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + match reply_rx.await { + Ok(result) => result.map_err(WorkerError::from), + Err(_) => Err(WorkerError::Gone { + device_index: self.device_index, + }), + } + } + + /// Drop the TP model at `handle` on the worker thread. + #[cfg(feature = "cuda")] + pub async fn drop_tp(&self, handle: TpHandle) -> Result<(), WorkerError> { + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::DropTp { + handle, + reply: reply_tx, + }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + match reply_rx.await { + Ok(()) => Ok(()), + Err(_) => Err(WorkerError::Gone { + device_index: self.device_index, + }), + } + } + + /// Reset the leader's KV cache for a TP model. + #[cfg(feature = "cuda")] + pub async fn tp_clear_kv(&self, handle: TpHandle) -> Result<(), WorkerError> { + if self.poisoned.load(Ordering::Acquire) { + return Err(WorkerError::Poisoned { + device_index: self.device_index, + }); + } + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::TpClearKv { + handle, + reply: reply_tx, + }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + match reply_rx.await { + Ok(result) => result.map_err(WorkerError::from), + Err(_) => Err(WorkerError::Gone { + device_index: self.device_index, + }), + } + } + + /// Run one TP forward step on the leader's shard. Returns CPU-side + /// logits as `Vec` ready for sampling. The caller is + /// responsible for fan-out / drain of the subprocess workers + /// concurrently with this call. + #[cfg(feature = "cuda")] + pub async fn tp_forward_logits( + &self, + handle: TpHandle, + tokens: Vec, + offset: usize, + ) -> Result, WorkerError> { + if self.poisoned.load(Ordering::Acquire) { + return Err(WorkerError::Poisoned { + device_index: self.device_index, + }); + } + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::TpForwardLogits { + handle, + tokens, + offset, + reply: reply_tx, + }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + match reply_rx.await { + Ok(result) => result.map_err(WorkerError::from), + Err(_) => Err(WorkerError::Gone { + device_index: self.device_index, + }), + } + } + /// Send `Job::Shutdown` and join the thread. Idempotent — calling /// twice is a no-op the second time. pub fn shutdown(&self) -> anyhow::Result<()> { diff --git a/crates/neuron/src/harness/tp/mod.rs b/crates/neuron/src/harness/tp/mod.rs index 1f65a23..6c67e4c 100644 --- a/crates/neuron/src/harness/tp/mod.rs +++ b/crates/neuron/src/harness/tp/mod.rs @@ -212,10 +212,15 @@ pub struct WorkerPool { /// Path to the neuron binary used to launch workers. #[allow(dead_code)] exe: PathBuf, - /// Leader's own NCCL rank-0 state. Defaults to empty; populated by - /// `init_nccl()`. Held here so the leader can participate in - /// collectives (rank 0) without spawning a fourth subprocess. - leader_nccl: nccl_state::NcclState, + /// The leader's per-device CUDA worker thread. Phase 3 moved the + /// leader's `NcclState` (rank-0 NCCL Comm) into this thread, so + /// every NCCL op (init, sanity, all_reduce inside forward) issues + /// from one OS thread for the daemon's lifetime. The handle is + /// also used by `load_dense_shard` to clone the leader's + /// `Arc` for the row-parallel layers' AllReduce ops; in + /// Phase 4 the load itself moves onto the worker and that bridge + /// goes away. + pub(crate) leader_worker: std::sync::Arc, } impl WorkerPool { @@ -228,7 +233,12 @@ impl WorkerPool { /// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`). /// `cuda_devices` is one entry per rank including rank 0. Worker /// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`. - pub async fn spawn(binary: &Path, world_size: u32, cuda_devices: &[u32]) -> Result { + pub async fn spawn( + binary: &Path, + world_size: u32, + cuda_devices: &[u32], + leader_worker: std::sync::Arc, + ) -> Result { if world_size < 2 { anyhow::bail!( "WorkerPool::spawn called with world_size={world_size}; \ @@ -289,7 +299,7 @@ impl WorkerPool { world_size, workers, exe, - leader_nccl: nccl_state::NcclState::new(), + leader_worker, }) } @@ -321,27 +331,26 @@ impl WorkerPool { } // 2. Leader rank 0 calls Comm::from_rank on its own device. - // Runs on spawn_blocking because NCCL's init blocks until - // every rank has called in — that's exactly the workers - // above. The leader's NcclState is moved through the - // blocking task and returned to the pool. + // Phase 3 moved this from spawn_blocking onto the leader's + // device worker thread (`Job::NcclInit`); the underlying + // `Comm` now lives on the same OS thread for its entire + // lifetime, including every later `Comm::all_reduce` issued + // by the row-parallel layers during forward. + // + // NCCL's init blocks until every rank has called in — the + // subprocess workers above and the leader's device worker + // here. The Job's reply unblocks when the leader's + // Comm::from_rank returns. let leader_cfg = worker::WorkerConfig { rank: 0, world_size: self.world_size, cuda_device: leader_cuda_device, }; - let comm_id_for_leader = comm_id.clone(); - // Swap out the leader's NcclState into a fresh empty one so we - // can move it into spawn_blocking; restore after the task - // returns. (NcclState isn't Clone — it owns a real NCCL Comm.) - let mut leader_state = std::mem::take(&mut self.leader_nccl); - let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || { - let resp = leader_state.init(leader_cfg, &comm_id_for_leader); - (leader_state, resp) - }) - .await - .context("leader NCCL init task panicked")?; - self.leader_nccl = returned_state; + let leader_resp = self + .leader_worker + .nccl_init(leader_cfg, comm_id.clone()) + .await + .map_err(|e| anyhow::anyhow!("leader NCCL init via device worker: {e}"))?; match leader_resp { rpc::WorkerResponse::InitOk => {} rpc::WorkerResponse::Error { kind, message } => { @@ -387,16 +396,16 @@ impl WorkerPool { w.send_only(&WorkerRequest::NcclSanityCheck).await?; } - // 2. Leader's own all_reduce, in spawn_blocking. NCCL operations - // block until every rank participates. - let mut leader_state = std::mem::take(&mut self.leader_nccl); - let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || { - let resp = leader_state.sanity_check(); - (leader_state, resp) - }) - .await - .context("leader NCCL sanity task panicked")?; - self.leader_nccl = returned_state; + // 2. Leader's own all_reduce, on its device worker thread. + // NCCL operations block until every rank participates; + // Job::NcclSanity returns once the leader's side completes + // (which happens when every subprocess worker reaches its + // all_reduce call too). + let leader_resp = self + .leader_worker + .nccl_sanity() + .await + .map_err(|e| anyhow::anyhow!("leader NCCL sanity via device worker: {e}"))?; let expected = self.world_size; let leader_sum = match leader_resp { @@ -483,21 +492,24 @@ impl WorkerPool { leader_device: &candle_core::Device, dtype: candle_core::DType, quant: Option, - ) -> Result>> { + ) -> Result { use candle_nn::var_builder::ShardedSafeTensors; - use std::sync::Arc; - use tokio::sync::Mutex; - // Wrap the comm in SendComm immediately so it stays Send across - // the await points in this method — bare Arc would - // poison the async fn's Send bound (Comm's raw NCCL pointer is - // !Send). The wrapper's safety contract is satisfied by the - // pool's outer Mutex serialising callers + the spawn_blocking - // thread being the only place ops are issued. - let leader_comm = - nccl_state::SendComm(self.leader_nccl.comm().ok_or_else(|| { - anyhow::anyhow!("leader NCCL not initialised; call init_nccl first") - })?); + // Ask the leader's device worker for an `Arc` clone. + // Phase 3 moved `NcclState` ownership onto the worker thread, + // so the spawn_blocking load below can no longer reach the + // Comm directly. The reply is wrapped in `SendComm` because + // the underlying `Arc` is `!Send` at the type level; + // the safety contract (only one thread issues NCCL ops at a + // time) is preserved because the load runs on a single + // spawn_blocking thread and AllReduce ops fire only from the + // device worker thread later. Phase 4 eliminates this bridge + // when the load itself moves onto the worker. + let leader_comm = self + .leader_worker + .clone_leader_comm() + .await + .map_err(|e| anyhow::anyhow!("clone leader Comm via device worker: {e}"))?; let world_size = self.world_size; let safetensors_str: Vec = safetensors_paths .iter() @@ -601,15 +613,32 @@ impl WorkerPool { } } - Ok(Arc::new(Mutex::new(leader_model))) + // Phase 3: move the leader's freshly-built `TpLeaderModel` + // into the device worker's TP slab. The model holds + // `Arc` clones (in its AllReduce ops) plus CUDA + // tensors — both need to live on the device worker thread so + // every `Comm::all_reduce` and tensor op during forward + // dispatches from the same OS thread that bound the CUDA + // context. + let handle = self + .leader_worker + .transfer_in_tp(Box::new(leader_model)) + .await + .map_err(|e| anyhow::anyhow!("transfer TP leader model into device worker: {e}"))?; + Ok(handle) } /// Run one forward step across every rank. The leader's forward - /// returns the last-position logits as a candle Tensor on the - /// leader's device; the caller does sampling out-of-band. Workers - /// run their own forwards (the AllReduce inside row-parallel layers - /// is what lets the leader's collective complete) and reply with - /// `GenerateStepOk` — they do not ship logits over the wire. + /// runs on the device worker thread via `Job::TpForwardLogits` and + /// returns CPU-side `[vocab]` logits as `Vec`; the async + /// caller wraps them in a CPU tensor for `apply_repeat_penalty` + + /// sampling without holding a device-resident tensor on a tokio + /// thread. + /// + /// Subprocess workers run their own forwards in parallel (the + /// AllReduce CustomOps inside row-parallel layers are what let + /// the leader's collective complete) and reply with + /// `GenerateStepOk` over the RPC stream — they do not ship logits. /// /// `tokens` is the input for this step (prompt for prefill, the /// previously-sampled token for decode). `offset` is the KV-cache @@ -618,10 +647,10 @@ impl WorkerPool { pub async fn generate_step( &mut self, model_id: &str, - leader_model: std::sync::Arc>, + leader_handle: super::device_worker::TpHandle, tokens: Vec, offset: usize, - ) -> Result { + ) -> Result> { let step_start = std::time::Instant::now(); let tokens_len = tokens.len(); tracing::debug!( @@ -630,7 +659,7 @@ impl WorkerPool { offset, "WorkerPool::generate_step: fan-out" ); - // 1. Fan-out to workers. + // 1. Fan-out to subprocess workers. for w in &mut self.workers { w.send_only(&WorkerRequest::GenerateStep { model_id: model_id.to_string(), @@ -640,35 +669,30 @@ impl WorkerPool { .await?; } - // 2. Leader's forward in spawn_blocking. The AllReduce CustomOps - // inside the row-parallel layers block until every worker's - // forward issues the matching collective. + // 2. Leader's forward on its device worker thread. The + // AllReduce CustomOps inside the row-parallel layers block + // until every subprocess worker's forward issues the + // matching collective. Returning CPU-side `Vec` keeps + // the device tensor from escaping the worker thread — + // that's the invariant the whole refactor exists to + // preserve. let leader_start = std::time::Instant::now(); - let leader_result = tokio::task::spawn_blocking(move || -> Result { - let mut model = leader_model.blocking_lock(); - let device = model.device().clone(); - let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; - // ForCausalLM::forward returns [B, 1, V] — squeeze both - // leading dims to the rank-1 vocab logits the sampler wants. - let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?; - Ok(logits) - }) - .await - .context("leader forward task panicked"); - let leader_ok = matches!(leader_result, Ok(Ok(_))); + let leader_result = self + .leader_worker + .tp_forward_logits(leader_handle, tokens, offset) + .await; + let leader_ok = leader_result.is_ok(); let leader_ms = leader_start.elapsed().as_millis(); - // Surface the leader's own error at WARN. Previously this was - // silently coerced to `leader_ok=false` while only worker - // ranks' errors got logged — when both the leader and a worker - // fail together (the typical "CUDA context is now poisoned" - // pattern after an OOM), the operator could see only the - // worker side and had to guess what hit rank 0. + // Surface the leader's own error at WARN before draining + // workers so the operator can correlate it with whatever the + // subprocess workers logged. Previously this was silently + // coerced to a bool. if !leader_ok { - let detail = match &leader_result { - Ok(Err(e)) => format!("{e:#}"), - Err(e) => format!("task: {e:#}"), - Ok(Ok(_)) => unreachable!("leader_ok=false implies an error path"), - }; + let detail = leader_result + .as_ref() + .err() + .map(|e| format!("{e:#}")) + .unwrap_or_default(); tracing::warn!( model = %model_id, tokens = tokens_len, @@ -707,7 +731,33 @@ impl WorkerPool { "WorkerPool::generate_step: workers drained" ); - combine_leader_workers(leader_result, worker_errors, "GenerateStep") + // Combine the leader's Result + the workers' string-error + // list. Phase 3 inlines this because the upstream + // `combine_leader_workers` expects the spawn_blocking-shaped + // `Result>`; the new device-worker path produces a + // single `Result` instead. + match leader_result { + Ok(values) => { + if worker_errors.is_empty() { + Ok(values) + } else { + anyhow::bail!( + "GenerateStep: leader succeeded but workers failed: {}", + worker_errors.join("; ") + ) + } + } + Err(e) => { + if worker_errors.is_empty() { + Err(anyhow::Error::new(e).context("GenerateStep: leader forward failed")) + } else { + Err(anyhow::Error::new(e).context(format!( + "GenerateStep: leader forward failed and workers also failed: {}", + worker_errors.join("; ") + ))) + } + } + } } /// Reset the KV cache for `model_id` on every rank. Called at the @@ -716,7 +766,7 @@ impl WorkerPool { pub async fn clear_kv_cache( &mut self, model_id: &str, - #[cfg(feature = "cuda")] leader_model: std::sync::Arc>, + #[cfg(feature = "cuda")] leader_handle: super::device_worker::TpHandle, ) -> Result<()> { let start = std::time::Instant::now(); tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out"); @@ -728,13 +778,18 @@ impl WorkerPool { } #[cfg(feature = "cuda")] { - let mut m = leader_model.lock().await; - m.clear_kv_cache(); + // Leader-side clear on the device worker thread — + // `TpLeaderModel::clear_kv_cache` is infallible but still + // routes through Job::TpClearKv so the cache reset runs + // on the same thread that owns the model's CUDA tensors. + if let Err(e) = self.leader_worker.tp_clear_kv(leader_handle).await { + anyhow::bail!("leader TP clear_kv_cache via device worker: {e}"); + } } // Drain workers — same rationale as `generate_step`. The - // leader's clear_kv_cache is in-process and infallible, but we - // still always drain so an error on one worker doesn't leave - // pending responses for the others. + // leader's clear_kv_cache is now async-via-channel but still + // returns before the drain so the workers' KvCacheCleared + // replies are processed in order. let worker_errors = drain_workers(&mut self.workers, |r| match r { WorkerResponse::KvCacheCleared => Ok(()), WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")), diff --git a/crates/neuron/src/main.rs b/crates/neuron/src/main.rs index 25f9e09..8d7729d 100644 --- a/crates/neuron/src/main.rs +++ b/crates/neuron/src/main.rs @@ -118,7 +118,13 @@ async fn tp_smoke(tp_size: u32, cuda_devices: Vec) -> Result<()> { binary = %exe.display(), "tp-smoke: spawning worker pool" ); - let mut pool = tp::WorkerPool::spawn(&exe, tp_size, &cuda_devices).await?; + // tp_smoke is a diagnostic tool; spawn the leader's device worker + // directly. (In the daemon path, CandleHarness::ensure_device_worker + // caches one per device.) + let leader_worker = neuron::harness::device_worker::DeviceWorkerHandle::spawn(leader_device) + .context("spawn leader device worker for tp-smoke")?; + let mut pool = + tp::WorkerPool::spawn(&exe, tp_size, &cuda_devices, leader_worker.clone()).await?; tracing::info!("tp-smoke: pinging every worker"); let pongs = pool.ping_all().await?; diff --git a/crates/neuron/tests/tp_worker_lifecycle.rs b/crates/neuron/tests/tp_worker_lifecycle.rs index f5e8cbd..7c41c05 100644 --- a/crates/neuron/tests/tp_worker_lifecycle.rs +++ b/crates/neuron/tests/tp_worker_lifecycle.rs @@ -5,6 +5,7 @@ //! `Init` and `NcclSanityCheck` are stubbed in 7a-i, so this test //! runs on any host the workspace builds on. +use neuron::harness::device_worker::DeviceWorkerHandle; use neuron::harness::tp::{WorkerPool, rpc::WorkerResponse}; /// Path to the neuron binary built by cargo for this test process. @@ -19,7 +20,8 @@ const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron"); async fn test_spawn_ping_shutdown() { // cuda_devices: rank 0 → device 0 (leader, unused here), // rank 1 → device 1 (worker; not actually opened in 7a-i). - let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1]) + let leader_worker = DeviceWorkerHandle::spawn(0).expect("spawn device worker"); + let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1], leader_worker) .await .expect("spawn worker pool"); @@ -44,7 +46,8 @@ async fn test_spawn_ping_shutdown() { /// Three workers — exercise the loop in `ping_all` / `shutdown`. #[tokio::test] async fn test_spawn_three_workers() { - let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 3, &[0, 1, 2]) + let leader_worker = DeviceWorkerHandle::spawn(0).expect("spawn device worker"); + let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 3, &[0, 1, 2], leader_worker) .await .expect("spawn worker pool"); diff --git a/crates/neuron/tests/tp_worker_lifecycle_cuda.rs b/crates/neuron/tests/tp_worker_lifecycle_cuda.rs index 262f5c8..9219393 100644 --- a/crates/neuron/tests/tp_worker_lifecycle_cuda.rs +++ b/crates/neuron/tests/tp_worker_lifecycle_cuda.rs @@ -25,7 +25,9 @@ async fn test_init_and_sanity_check_two_ranks() { .try_init(); // 2 ranks: leader = rank 0 on device 0, worker = rank 1 on device 1. - let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1]) + let leader_worker = neuron::harness::device_worker::DeviceWorkerHandle::spawn(0) + .expect("spawn leader device worker"); + let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1], leader_worker) .await .expect("spawn worker pool");