refactor(neuron): phase 3 — TP forward + NCCL state move onto device worker
Some checks failed
CI / Format (push) Successful in 29s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Test (push) Failing after 58s
CI / Clippy (push) Successful in 2m31s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m13s
build-prerelease / Build neuron-blackwell (push) Successful in 4m1s
build-prerelease / Package cortex RPM (push) Successful in 1m30s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled

Third slice of the per-device CUDA context-ownership refactor planned at
~/.claude/plans/plan-the-per-device-worker-abstract-micali.md. The
leader's `NcclState`, every `Comm::all_reduce` issued by the TP layers,
the leader-side KV cache reset, and the TP forward step itself now all
run on the per-device worker thread — the same OS thread that bound
the leader's `CudaContext` at startup.

What this phase changes:

- `Job` gains `NcclInit`, `NcclSanity`, `CloneLeaderComm` (Phase 3
  bridge — Phase 4 removes), `TransferInTp`, `DropTp`, `TpClearKv`,
  `TpForwardLogits`. Plus a new `TpHandle(u64)` opaque key.
- `DeviceWorkerState` gains `nccl: NcclState` and
  `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>` (+ counter).
- `WorkerPool` loses its `leader_nccl` field; gains a
  `leader_worker: Arc<DeviceWorkerHandle>` passed at construction.
  `init_nccl`, `nccl_sanity_check`, `load_dense_shard`,
  `generate_step`, `clear_kv_cache` all route their leader-side ops
  through `Job::Nccl*` / `Job::Tp*` instead of spawn_blocking against
  a Mutex-wrapped state. `generate_step` returns `Vec<f32>` instead
  of a device-resident `Tensor` — the worker copies logits to CPU
  before reply so the async caller can sample on a CPU candle
  tensor with zero device-context touch.
- `TpLoadedModel.leader_model: Arc<Mutex<TpLeaderModel>>` → opaque
  `leader_handle: TpHandle`. The boxed `TpLeaderModel` lives in the
  worker thread's slab; both the model's CUDA tensors and the
  embedded `Arc<Comm>` clones release on the same thread that
  allocated them (the Drop semantics constraint cudarc forces).
- `Job::CloneLeaderComm` is a Phase 3 bridge: the TP shard load still
  runs in spawn_blocking and needs the leader's `Arc<Comm>` to build
  the row-parallel layers' AllReduce ops. The Job clones the Comm
  out of the worker's NcclState and ships it back as `SendComm`.
  Phase 4 deletes this bridge when the load itself moves onto the
  worker.
- `Job::NcclInit` and `Job::NcclSanity` are ungated by `cuda` so the
  no-cuda `NcclState` stubs (which reply with `cuda_feature_not_enabled`)
  still flow through the same channel uniformly; the cuda-only
  TP variants (CloneLeaderComm, Transfer/Drop/Clear/Forward Tp)
  remain gated.

What this phase doesn't touch (yet):

- TP shard load itself — still spawn_blocking, bridged via
  `CloneLeaderComm`. Phase 4 moves it to `Job::TpLoadShard` and
  reads `state.nccl.comm()` directly inside the worker.
- Single-GPU model loads — still spawn_blocking, transferred via
  `Job::TransferIn`. Phase 4 moves them.
- `device_vram_mb` / `cuda_mem_mb` / `log_construction_complete`
  helpers — still present, used inside spawn_blocking load closures.
  Phase 4 cleanup folds them into `dispatch.rs`.

`tp/mod.rs::WorkerPool::spawn` gained a required
`leader_worker: Arc<DeviceWorkerHandle>` argument. Three external
callers were updated: `CandleHarness::load_tp` (passes the cached
device worker), `main.rs::tp_smoke` (spawns a fresh worker), and
the two `tp_worker_lifecycle*.rs` integration tests.

Public API unchanged. fmt + clippy clean; 37 lib tests + all
integration tests pass. CUDA-only TP integration smoke deferred to
the next deploy on beast.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-27 10:16:02 +03:00
parent b179204fd3
commit 76ab24d98c
8 changed files with 676 additions and 138 deletions

View File

@@ -163,16 +163,22 @@ pub struct TpLoadedModel {
pub model_id: String, pub model_id: String,
pub tokenizer: Tokenizer, pub tokenizer: Tokenizer,
pub devices: Vec<u32>, pub devices: Vec<u32>,
/// One end-to-end gate: the pool's RPC stream isn't safe to use /// One end-to-end gate: the pool's RPC stream to the subprocess
/// concurrently and the leader shard's KV cache mutates with every /// workers isn't safe to use concurrently. After Phase 3 the
/// step. The same Mutex covers both for the simplest correctness /// leader's `TpLeaderModel` lives in the worker thread's slab,
/// story. /// so this Mutex no longer covers the leader's KV cache; it just
/// serialises subprocess RPC traffic on the pool's
/// `Vec<Worker>` channels.
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>, pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
pub leader_model: Arc<tokio::sync::Mutex<super::tp::TpLeaderModel>>, /// Handle into the leader device worker's TP slab. The boxed
/// Candle device for rank 0. Mirrors what `leader_model.device()` /// `TpLeaderModel` (with its embedded `Arc<Comm>` clones and
/// would return, but stored separately so the request path can /// per-rank CUDA tensors) lives on the worker thread; we hold an
/// query VRAM without locking the leader (which would contend with /// opaque index. Forward / clear_kv / unload all route through
/// the in-flight forward). /// `Job::Tp*` against this handle.
pub leader_handle: super::device_worker::TpHandle,
/// Candle device for rank 0. Mirrors what
/// `TpLeaderModel::device()` would return, kept on the struct so
/// the request path can name the device without an RPC.
pub leader_device: Device, pub leader_device: Device,
/// Same poisoning gate as [`LoadedModel::poisoned`]. A TP forward /// Same poisoning gate as [`LoadedModel::poisoned`]. A TP forward
/// failure (CUDA OOM on any rank, NCCL desync, illegal address) is /// failure (CUDA OOM on any rank, NCCL desync, illegal address) is
@@ -180,9 +186,8 @@ pub struct TpLoadedModel {
/// reliably reset without restarting the worker subprocesses. /// reliably reset without restarting the worker subprocesses.
pub poisoned: AtomicBool, pub poisoned: AtomicBool,
/// Worker thread for the leader's CUDA device. Owns the leader's /// Worker thread for the leader's CUDA device. Owns the leader's
/// `CudaContext` for the daemon's lifetime. VRAM queries route /// `CudaContext`, `NcclState`, and the boxed `TpLeaderModel`
/// through it; in later refactor phases the forward, kv-cache /// referenced by `leader_handle`.
/// clear, and shard unload route through it too.
pub worker: Arc<super::device_worker::DeviceWorkerHandle>, pub worker: Arc<super::device_worker::DeviceWorkerHandle>,
} }
@@ -1642,6 +1647,16 @@ impl Harness for CandleHarness {
anyhow::bail!("cannot unload '{model_id}': inference still in flight"); anyhow::bail!("cannot unload '{model_id}': inference still in flight");
} }
}; };
// Drop the leader's TpLeaderModel on the device worker
// thread (CUDA tensors and Arc<Comm> clones release on
// the same OS thread that allocated them).
if let Err(e) = tp.worker.drop_tp(tp.leader_handle).await {
tracing::warn!(
model = %model_id,
error = %e,
"TP unload: DropTp RPC failed (leader model may leak in worker slab)"
);
}
let mut pool = tp.pool.into_inner(); let mut pool = tp.pool.into_inner();
if let Err(e) = pool.unload_model(model_id).await { if let Err(e) = pool.unload_model(model_id).await {
tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed"); tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed");
@@ -1715,9 +1730,14 @@ impl CandleHarness {
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks // 2. Spawn the worker pool. Rank 0 stays in-process; ranks
// 1..tp_size are subprocesses, one per device after the // 1..tp_size are subprocesses, one per device after the
// leader's own. // leader's own. The leader's device worker thread is
// spawned (or reused) here and passed into the pool so
// `init_nccl`, the load, every TP forward, and KV-cache
// clears all dispatch from the same OS thread.
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?; let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
let mut pool = super::tp::WorkerPool::spawn(&exe, tp_size, &devices).await?; let leader_worker = self.ensure_device_worker(devices[0]).await?;
let mut pool =
super::tp::WorkerPool::spawn(&exe, tp_size, &devices, leader_worker.clone()).await?;
// 3. NCCL handshake across all ranks. // 3. NCCL handshake across all ranks.
let leader_device_idx = devices[0]; let leader_device_idx = devices[0];
@@ -1727,8 +1747,11 @@ impl CandleHarness {
let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize) let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize)
.context("Device::new_cuda for TP leader")?; .context("Device::new_cuda for TP leader")?;
// 5. Load this rank's shard on every rank. // 5. Load this rank's shard on every rank. After Phase 3
let leader_model = pool // `load_dense_shard` transfers the freshly-built
// `TpLeaderModel` into the device worker's TP slab and
// returns the resulting handle.
let leader_handle = pool
.load_dense_shard( .load_dense_shard(
&spec.model_id, &spec.model_id,
&config_json, &config_json,
@@ -1743,21 +1766,18 @@ impl CandleHarness {
let tokenizer = Tokenizer::from_file(&tokenizer_path) let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?; .map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
// 7. Worker thread for the leader's CUDA device. TP always
// runs on CUDA — the harness rejects TP without the cuda
// feature earlier in this function — so we always have a
// device to own.
let worker = self.ensure_device_worker(devices[0]).await?;
let tp_loaded = StdArc::new(TpLoadedModel { let tp_loaded = StdArc::new(TpLoadedModel {
model_id: spec.model_id.clone(), model_id: spec.model_id.clone(),
tokenizer, tokenizer,
devices: devices.clone(), devices: devices.clone(),
pool: TMutex::new(pool), pool: TMutex::new(pool),
leader_model, leader_handle,
leader_device: leader_device.clone(), leader_device: leader_device.clone(),
poisoned: AtomicBool::new(false), poisoned: AtomicBool::new(false),
worker, // Same `leader_worker` we passed into the pool above —
// single `Arc` shared between WorkerPool and
// TpLoadedModel so they reference the same thread.
worker: leader_worker,
}); });
let mut models = self.models.write().await; let mut models = self.models.write().await;
@@ -1932,14 +1952,14 @@ impl CandleHarness {
async move { async move {
let mut failure: Option<String> = None; let mut failure: Option<String> = None;
let mut pool = acquire_pool_lock(&tp_for_task.pool, &model_id).await; let mut pool = acquire_pool_lock(&tp_for_task.pool, &model_id).await;
let leader_arc = tp_for_task.leader_model.clone(); let leader_handle = tp_for_task.leader_handle;
let mut all_tokens: Vec<u32> = Vec::new(); let mut all_tokens: Vec<u32> = Vec::new();
let mut decoded_prefix = String::new(); let mut decoded_prefix = String::new();
let mut finish_reason = "length".to_string(); let mut finish_reason = "length".to_string();
'work: { 'work: {
if let Err(e) = pool.clear_kv_cache(&model_id, leader_arc.clone()).await { if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await {
failure = Some(format!("clear_kv_cache: {e:#}")); failure = Some(format!("clear_kv_cache: {e:#}"));
break 'work; break 'work;
} }
@@ -1957,8 +1977,8 @@ impl CandleHarness {
}; };
// Prefill — every rank embeds the prompt, offset = 0. // Prefill — every rank embeds the prompt, offset = 0.
let logits = match pool let logits_vec = match pool
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0) .generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0)
.await .await
{ {
Ok(l) => l, Ok(l) => l,
@@ -1974,11 +1994,18 @@ impl CandleHarness {
vram_free_mb = post_prefill_vram_free_mb, vram_free_mb = post_prefill_vram_free_mb,
"TP chat_completion (stream): prefill complete" "TP chat_completion (stream): prefill complete"
); );
let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) {
Ok(t) => t,
Err(e) => {
failure = Some(format!("prefill build cpu logits: {e:#}"));
break 'work;
}
};
let mut next_token = let mut next_token =
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) { match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t, Ok(t) => t,
Err(e) => { Err(e) => {
let health = logits_health(&logits); let health = logits_health_slice(&logits_vec);
tracing::warn!( tracing::warn!(
model = %model_id, model = %model_id,
?health, ?health,
@@ -2010,10 +2037,10 @@ impl CandleHarness {
} }
for index in 0..max_new.saturating_sub(1) { for index in 0..max_new.saturating_sub(1) {
let logits = match pool let logits_vec = match pool
.generate_step( .generate_step(
&model_id, &model_id,
leader_arc.clone(), leader_handle,
vec![next_token], vec![next_token],
prompt_len + index, prompt_len + index,
) )
@@ -2025,6 +2052,14 @@ impl CandleHarness {
break 'work; break 'work;
} }
}; };
let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) {
Ok(t) => t,
Err(e) => {
failure =
Some(format!("decode build cpu logits {index}: {e:#}"));
break 'work;
}
};
next_token = match sample_with_penalty( next_token = match sample_with_penalty(
&logits, &logits,
&all_tokens, &all_tokens,
@@ -2032,7 +2067,7 @@ impl CandleHarness {
) { ) {
Ok(t) => t, Ok(t) => t,
Err(e) => { Err(e) => {
let health = logits_health(&logits); let health = logits_health_slice(&logits_vec);
tracing::warn!( tracing::warn!(
model = %model_id, model = %model_id,
step = index, step = index,
@@ -2180,20 +2215,19 @@ async fn chat_completion_tp_inner(
"TP chat_completion: starting" "TP chat_completion: starting"
); );
// Acquire the pool lock for the duration of the request. The // Acquire the pool lock for the duration of the request. After
// leader_model's own Mutex is acquired step-by-step inside // Phase 3 the leader's TpLeaderModel lives in the device worker
// pool.generate_step (so spawn_blocking can grab it without // thread, so the pool lock now serialises only subprocess RPC
// holding the pool lock across the blocking_lock call). // traffic — but holding it for the whole request still keeps
// `acquire_pool_lock` warns periodically while we wait so a // concurrent chat_completions against the same TP model from
// stuck holder doesn't make the queueing requests look like // interleaving prefill/decode jobs.
// silence in the journal.
let mut pool = acquire_pool_lock(&tp.pool, &model_id).await; let mut pool = acquire_pool_lock(&tp.pool, &model_id).await;
let leader_arc = tp.leader_model.clone(); let leader_handle = tp.leader_handle;
// Reset every rank's KV cache so this request doesn't attend // Reset every rank's KV cache so this request doesn't attend
// over the previous request's tokens. // over the previous request's tokens.
let clear_start = std::time::Instant::now(); let clear_start = std::time::Instant::now();
pool.clear_kv_cache(&model_id, leader_arc.clone()) pool.clear_kv_cache(&model_id, leader_handle)
.await .await
.map_err(InferenceError::Other)?; .map_err(InferenceError::Other)?;
tracing::debug!( tracing::debug!(
@@ -2219,8 +2253,8 @@ async fn chat_completion_tp_inner(
// Prefill: every rank embeds the whole prompt, offset = 0. // Prefill: every rank embeds the whole prompt, offset = 0.
let prefill_start = std::time::Instant::now(); let prefill_start = std::time::Instant::now();
let logits = pool let logits_vec = pool
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0) .generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0)
.await .await
.map_err(InferenceError::Other)?; .map_err(InferenceError::Other)?;
let (post_prefill_vram_free_mb, _) = tp.query_vram().await; let (post_prefill_vram_free_mb, _) = tp.query_vram().await;
@@ -2231,6 +2265,11 @@ async fn chat_completion_tp_inner(
vram_free_mb = post_prefill_vram_free_mb, vram_free_mb = post_prefill_vram_free_mb,
"TP chat_completion: prefill complete" "TP chat_completion: prefill complete"
); );
// Wrap the CPU-side logits in a CPU candle Tensor for sampling.
// No device touch on the async caller's thread — sampling reads
// from CPU memory only.
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("build cpu logits: {e}")))?;
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) { let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t, Ok(t) => t,
Err(e) => { Err(e) => {
@@ -2239,7 +2278,7 @@ async fn chat_completion_tp_inner(
// this WARN sits just above that and carries the actual // this WARN sits just above that and carries the actual
// numerical state so an operator can tell at a glance // numerical state so an operator can tell at a glance
// whether it was a NaN cascade, an Inf, or something else. // whether it was a NaN cascade, an Inf, or something else.
let health = logits_health(&logits); let health = logits_health_slice(&logits_vec);
tracing::warn!( tracing::warn!(
model = %model_id, model = %model_id,
?health, ?health,
@@ -2256,19 +2295,22 @@ async fn chat_completion_tp_inner(
let decode_start = std::time::Instant::now(); let decode_start = std::time::Instant::now();
for index in 0..max_new.saturating_sub(1) { for index in 0..max_new.saturating_sub(1) {
let step_start = std::time::Instant::now(); let step_start = std::time::Instant::now();
let logits = pool let logits_vec = pool
.generate_step( .generate_step(
&model_id, &model_id,
leader_arc.clone(), leader_handle,
vec![next_token], vec![next_token],
prompt_len + index, prompt_len + index,
) )
.await .await
.map_err(InferenceError::Other)?; .map_err(InferenceError::Other)?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu).map_err(|e| {
InferenceError::Other(anyhow::anyhow!("build cpu logits step {index}: {e}"))
})?;
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) { next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t, Ok(t) => t,
Err(e) => { Err(e) => {
let health = logits_health(&logits); let health = logits_health_slice(&logits_vec);
tracing::warn!( tracing::warn!(
model = %model_id, model = %model_id,
step = index, step = index,

View File

@@ -14,7 +14,12 @@
//! `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>`. //! `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>`.
use crate::harness::candle::ModelArch; use crate::harness::candle::ModelArch;
#[cfg(feature = "cuda")]
use crate::harness::device_worker::jobs::TpHandle;
use crate::harness::device_worker::jobs::{ArchHandle, Job}; use crate::harness::device_worker::jobs::{ArchHandle, Job};
#[cfg(feature = "cuda")]
use crate::harness::tp::TpLeaderModel;
use crate::harness::tp::nccl_state::NcclState;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
@@ -40,6 +45,20 @@ struct DeviceWorkerState {
/// increments and returns the new value. Wraps at u64::MAX after /// increments and returns the new value. Wraps at u64::MAX after
/// ~10^19 model loads — not a practical concern. /// ~10^19 model loads — not a practical concern.
next_handle: u64, next_handle: u64,
/// Leader's NCCL state. Populated by `Job::NcclInit`; the
/// underlying `Comm`'s libnccl handle lives bound to this thread
/// for its entire lifetime. Subprocess workers maintain their own
/// `NcclState` in their own processes — that's not visible from
/// here.
#[allow(dead_code)] // Read only via methods on NcclState
nccl: NcclState,
/// TP leader model slab. Same lifecycle as `models`; separate
/// namespace so `ArchHandle` and `TpHandle` can't collide.
#[cfg(feature = "cuda")]
tp_models: HashMap<TpHandle, Box<TpLeaderModel>>,
/// Counter for minting fresh `TpHandle`s.
#[cfg(feature = "cuda")]
next_tp_handle: u64,
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
#[allow(dead_code)] #[allow(dead_code)]
/// `None` only if `CudaContext::new()` failed — in that case the /// `None` only if `CudaContext::new()` failed — in that case the
@@ -128,15 +147,93 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
let result = forward_logits(&mut state, handle, &tokens, offset); let result = forward_logits(&mut state, handle, &tokens, offset);
let _ = reply.send(result); let _ = reply.send(result);
} }
Job::NcclInit {
cfg,
comm_id_hex,
reply,
} => {
let resp = state.nccl.init(cfg, &comm_id_hex);
let _ = reply.send(resp);
}
Job::NcclSanity { reply } => {
let resp = state.nccl.sanity_check();
let _ = reply.send(resp);
}
#[cfg(feature = "cuda")]
Job::CloneLeaderComm { reply } => {
let result = match state.nccl.comm() {
Some(comm) => Ok(crate::harness::tp::nccl_state::SendComm(comm)),
None => Err(anyhow::anyhow!(
"CloneLeaderComm: NcclState has no Comm; call NcclInit first"
)),
};
let _ = reply.send(result);
}
#[cfg(feature = "cuda")]
Job::TransferInTp { model, reply } => {
let handle = TpHandle(state.next_tp_handle);
state.next_tp_handle = state.next_tp_handle.wrapping_add(1);
state.tp_models.insert(handle, model);
tracing::debug!(
device_index,
tp_handle = handle.0,
slab_size = state.tp_models.len(),
"device worker: TP model transferred in"
);
let _ = reply.send(Ok(handle));
}
#[cfg(feature = "cuda")]
Job::DropTp { handle, reply } => {
let removed = state.tp_models.remove(&handle);
let was_present = removed.is_some();
drop(removed);
tracing::debug!(
device_index,
tp_handle = handle.0,
was_present,
slab_size = state.tp_models.len(),
"device worker: TP model dropped"
);
let _ = reply.send(());
}
#[cfg(feature = "cuda")]
Job::TpClearKv { handle, reply } => {
let result = match state.tp_models.get_mut(&handle) {
Some(model) => {
model.clear_kv_cache();
Ok(())
}
None => Err(anyhow::anyhow!(
"TpClearKv: no TP model for handle {}",
handle.0
)),
};
let _ = reply.send(result);
}
#[cfg(feature = "cuda")]
Job::TpForwardLogits {
handle,
tokens,
offset,
reply,
} => {
let result = tp_forward_logits(&mut state, handle, &tokens, offset);
let _ = reply.send(result);
}
// Handled by the matches!() check above; reaching here // Handled by the matches!() check above; reaching here
// means a Shutdown slipped past which is a bug. // means a Shutdown slipped past which is a bug.
Job::Shutdown => unreachable!("Shutdown should break above"), Job::Shutdown => unreachable!("Shutdown should break above"),
} }
} }
#[cfg(feature = "cuda")]
let tp_slab_size = state.tp_models.len();
#[cfg(not(feature = "cuda"))]
let tp_slab_size = 0_usize;
tracing::info!( tracing::info!(
device_index, device_index,
slab_size = state.models.len(), slab_size = state.models.len(),
tp_slab_size,
"device worker exiting; dropping remaining models" "device worker exiting; dropping remaining models"
); );
// Drops every model in the slab on this thread before the function // Drops every model in the slab on this thread before the function
@@ -193,6 +290,9 @@ fn init_state(device_index: u32) -> DeviceWorkerState {
device, device,
models: HashMap::new(), models: HashMap::new(),
next_handle: 1, next_handle: 1,
nccl: NcclState::new(),
tp_models: HashMap::new(),
next_tp_handle: 1,
ctx, ctx,
} }
} }
@@ -203,6 +303,7 @@ fn init_state(device_index: u32) -> DeviceWorkerState {
device: candle_core::Device::Cpu, device: candle_core::Device::Cpu,
models: HashMap::new(), models: HashMap::new(),
next_handle: 1, next_handle: 1,
nccl: NcclState::new(),
} }
} }
} }
@@ -231,6 +332,38 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
Ok((0, 0)) Ok((0, 0))
} }
/// TP-equivalent of [`forward_logits`]: looks up the leader's
/// [`TpLeaderModel`] in the slab, runs its forward, copies the
/// `[vocab]` logits to a CPU `Vec<f32>`. The leader's `Arc<Comm>`
/// clones embedded in the TP layers' AllReduce ops fire from this
/// thread — same thread that bound the CUDA context and that holds
/// the `Comm` in `state.nccl`.
#[cfg(feature = "cuda")]
fn tp_forward_logits(
state: &mut DeviceWorkerState,
handle: TpHandle,
tokens: &[u32],
offset: usize,
) -> anyhow::Result<Vec<f32>> {
use candle_core::{DType, Tensor};
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
let model = state
.tp_models
.get_mut(&handle)
.ok_or_else(|| anyhow::anyhow!("TpForwardLogits: no model for handle {}", handle.0))?;
let logits = model.forward(&input, offset)?;
// ForCausalLM forward returns [B, 1, V] after the trailing
// .i((.., l - 1.., ..))?.apply(lm_head); squeeze both leading
// singleton dims to a rank-1 [V] tensor for sampling.
let logits = logits.squeeze(0)?.squeeze(0)?;
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
let values = logits.to_vec1::<f32>()?;
Ok(values)
}
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready /// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
/// for sampling on the async caller. The model's `device()` (CUDA or /// for sampling on the async caller. The model's `device()` (CUDA or
/// CPU) determines where the kernel runs; this fn doesn't care. /// CPU) determines where the kernel runs; this fn doesn't care.
@@ -297,6 +430,38 @@ fn drain_poisoned(job: Job, device_index: u32) {
Job::ForwardLogits { reply, .. } => { Job::ForwardLogits { reply, .. } => {
let _ = reply.send(Err(err())); let _ = reply.send(Err(err()));
} }
Job::NcclInit { reply, .. } => {
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
kind: "device_worker_poisoned".into(),
message: format!("device worker {device_index} poisoned"),
});
}
Job::NcclSanity { reply } => {
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
kind: "device_worker_poisoned".into(),
message: format!("device worker {device_index} poisoned"),
});
}
#[cfg(feature = "cuda")]
Job::CloneLeaderComm { reply } => {
let _ = reply.send(Err(err()));
}
#[cfg(feature = "cuda")]
Job::TransferInTp { reply, .. } => {
let _ = reply.send(Err(err()));
}
#[cfg(feature = "cuda")]
Job::DropTp { reply, .. } => {
let _ = reply.send(());
}
#[cfg(feature = "cuda")]
Job::TpClearKv { reply, .. } => {
let _ = reply.send(Err(err()));
}
#[cfg(feature = "cuda")]
Job::TpForwardLogits { reply, .. } => {
let _ = reply.send(Err(err()));
}
Job::Shutdown => { Job::Shutdown => {
// Filtered by the matches!() guard in run(); reaching // Filtered by the matches!() guard in run(); reaching
// here would be a logic error. // here would be a logic error.

View File

@@ -19,6 +19,15 @@ use tokio::sync::oneshot;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ArchHandle(pub u64); pub struct ArchHandle(pub u64);
/// Opaque handle to a `TpLeaderModel` stored in the worker thread's
/// state slab. Same shape as [`ArchHandle`] but in a separate
/// namespace so the two slabs can coexist without ambiguity. Phase 3
/// introduces it; Phase 4 may unify the two slabs after the TP forward
/// path proves out.
#[cfg(feature = "cuda")]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TpHandle(pub u64);
/// One unit of work for the device worker. /// One unit of work for the device worker.
/// ///
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the /// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
@@ -73,6 +82,74 @@ pub enum Job {
offset: usize, offset: usize,
reply: oneshot::Sender<Result<Vec<f32>>>, reply: oneshot::Sender<Result<Vec<f32>>>,
}, },
/// Initialize the leader's NCCL communicator. The worker's
/// `NcclState` mints the `Comm` here so its underlying
/// `ncclComm_t` and `CudaContext` live on the same thread as
/// every later `Comm::all_reduce` call. Reply is the worker
/// response shape used by the subprocess workers (`InitOk` on
/// success, `Error` on failure) so the calling
/// `WorkerPool::init_nccl` orchestration stays uniform.
///
/// Available on both cuda and no-cuda builds — the dispatch
/// handler calls `NcclState::init` which has a no-cuda stub that
/// replies with `cuda_feature_not_enabled`. Keeping the Job
/// variant ungated lets `WorkerPool::init_nccl` stay uniform.
NcclInit {
cfg: crate::harness::tp::worker::WorkerConfig,
comm_id_hex: String,
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
},
/// Run NCCL's all_reduce sanity check on the leader's rank 0.
/// Same response shape as `NcclInit`; also available on both
/// builds via the no-cuda `NcclState::sanity_check` stub.
NcclSanity {
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
},
/// Clone the leader's `Arc<Comm>` out of the worker's `NcclState`
/// so a spawn_blocking-based load (Phase 3 bridge) can hand it to
/// the row-parallel layers. Wrapped in `SendComm` because
/// `Arc<Comm>` is `!Send` at the type level (the NCCL contract
/// requires serialised access, which we provide structurally).
/// Phase 4 eliminates this when `TpLoadShard` becomes a Job and
/// the load runs entirely on the worker thread.
#[cfg(feature = "cuda")]
CloneLeaderComm {
reply: oneshot::Sender<Result<crate::harness::tp::nccl_state::SendComm>>,
},
/// Move a freshly-built `TpLeaderModel` into the worker's tp slab.
/// Returns a `TpHandle` the caller stores on `TpLoadedModel`.
#[cfg(feature = "cuda")]
TransferInTp {
model: Box<crate::harness::tp::TpLeaderModel>,
reply: oneshot::Sender<Result<TpHandle>>,
},
/// Drop the TP leader model on the worker thread. CUDA tensors
/// and `Arc<Comm>` clones held inside the model release on the
/// thread that allocated them.
#[cfg(feature = "cuda")]
DropTp {
handle: TpHandle,
reply: oneshot::Sender<()>,
},
/// Reset the leader's KV cache for a TP model. Mirrors `ClearKv`
/// for single-GPU.
#[cfg(feature = "cuda")]
TpClearKv {
handle: TpHandle,
reply: oneshot::Sender<Result<()>>,
},
/// Run one TP forward step on the leader's shard. Returns CPU-
/// side logits as a `Vec<f32>` so the async caller can sample
/// without holding a device tensor. The caller is also
/// responsible for fan-out to subprocess ranks and drain — only
/// the leader's forward moves into the worker thread.
#[cfg(feature = "cuda")]
TpForwardLogits {
handle: TpHandle,
tokens: Vec<u32>,
offset: usize,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Tell the worker to break its dispatch loop and exit. Any jobs /// Tell the worker to break its dispatch loop and exit. Any jobs
/// queued after this in the channel reply `Err` to their oneshot /// queued after this in the channel reply `Err` to their oneshot
/// senders (the senders are dropped on the worker's exit, which /// senders (the senders are dropped on the worker's exit, which

View File

@@ -49,6 +49,8 @@ use std::sync::mpsc::{self, Sender};
use std::thread::JoinHandle; use std::thread::JoinHandle;
use tokio::sync::oneshot; use tokio::sync::oneshot;
#[cfg(feature = "cuda")]
pub use jobs::TpHandle;
pub use jobs::{ArchHandle, Job}; pub use jobs::{ArchHandle, Job};
/// Errors returned by `DeviceWorkerHandle` submit methods. /// Errors returned by `DeviceWorkerHandle` submit methods.
@@ -277,6 +279,192 @@ impl DeviceWorkerHandle {
} }
} }
/// Initialise the leader's NCCL communicator. The reply uses
/// `WorkerResponse` (same shape subprocess workers use over stdio
/// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader +
/// subprocess responses uniformly. Available on no-cuda builds
/// too — the dispatch handler calls the no-cuda `NcclState::init`
/// stub which replies `cuda_feature_not_enabled`.
pub async fn nccl_init(
&self,
cfg: crate::harness::tp::worker::WorkerConfig,
comm_id_hex: String,
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::NcclInit {
cfg,
comm_id_hex,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
reply_rx.await.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})
}
/// Run an NCCL sanity all_reduce on the leader's rank 0.
/// Available on no-cuda builds; replies with an error response.
pub async fn nccl_sanity(
&self,
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::NcclSanity { reply: reply_tx })
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
reply_rx.await.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})
}
/// Clone the leader's `Arc<Comm>` so a spawn_blocking-based load
/// (Phase 3 bridge) can pass it to the row-parallel layers.
/// Phase 4 eliminates this once the TP load runs on this thread.
#[cfg(feature = "cuda")]
pub async fn clone_leader_comm(
&self,
) -> Result<crate::harness::tp::nccl_state::SendComm, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::CloneLeaderComm { reply: reply_tx })
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Move a freshly-built `TpLeaderModel` into the worker's TP slab.
#[cfg(feature = "cuda")]
pub async fn transfer_in_tp(
&self,
model: Box<crate::harness::tp::TpLeaderModel>,
) -> Result<TpHandle, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TransferInTp {
model,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Drop the TP model at `handle` on the worker thread.
#[cfg(feature = "cuda")]
pub async fn drop_tp(&self, handle: TpHandle) -> Result<(), WorkerError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::DropTp {
handle,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(()) => Ok(()),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Reset the leader's KV cache for a TP model.
#[cfg(feature = "cuda")]
pub async fn tp_clear_kv(&self, handle: TpHandle) -> Result<(), WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TpClearKv {
handle,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Run one TP forward step on the leader's shard. Returns CPU-side
/// logits as `Vec<f32>` ready for sampling. The caller is
/// responsible for fan-out / drain of the subprocess workers
/// concurrently with this call.
#[cfg(feature = "cuda")]
pub async fn tp_forward_logits(
&self,
handle: TpHandle,
tokens: Vec<u32>,
offset: usize,
) -> Result<Vec<f32>, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TpForwardLogits {
handle,
tokens,
offset,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Send `Job::Shutdown` and join the thread. Idempotent — calling /// Send `Job::Shutdown` and join the thread. Idempotent — calling
/// twice is a no-op the second time. /// twice is a no-op the second time.
pub fn shutdown(&self) -> anyhow::Result<()> { pub fn shutdown(&self) -> anyhow::Result<()> {

View File

@@ -212,10 +212,15 @@ pub struct WorkerPool {
/// Path to the neuron binary used to launch workers. /// Path to the neuron binary used to launch workers.
#[allow(dead_code)] #[allow(dead_code)]
exe: PathBuf, exe: PathBuf,
/// Leader's own NCCL rank-0 state. Defaults to empty; populated by /// The leader's per-device CUDA worker thread. Phase 3 moved the
/// `init_nccl()`. Held here so the leader can participate in /// leader's `NcclState` (rank-0 NCCL Comm) into this thread, so
/// collectives (rank 0) without spawning a fourth subprocess. /// every NCCL op (init, sanity, all_reduce inside forward) issues
leader_nccl: nccl_state::NcclState, /// from one OS thread for the daemon's lifetime. The handle is
/// also used by `load_dense_shard` to clone the leader's
/// `Arc<Comm>` for the row-parallel layers' AllReduce ops; in
/// Phase 4 the load itself moves onto the worker and that bridge
/// goes away.
pub(crate) leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
} }
impl WorkerPool { impl WorkerPool {
@@ -228,7 +233,12 @@ impl WorkerPool {
/// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`). /// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`).
/// `cuda_devices` is one entry per rank including rank 0. Worker /// `cuda_devices` is one entry per rank including rank 0. Worker
/// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`. /// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`.
pub async fn spawn(binary: &Path, world_size: u32, cuda_devices: &[u32]) -> Result<Self> { pub async fn spawn(
binary: &Path,
world_size: u32,
cuda_devices: &[u32],
leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
) -> Result<Self> {
if world_size < 2 { if world_size < 2 {
anyhow::bail!( anyhow::bail!(
"WorkerPool::spawn called with world_size={world_size}; \ "WorkerPool::spawn called with world_size={world_size}; \
@@ -289,7 +299,7 @@ impl WorkerPool {
world_size, world_size,
workers, workers,
exe, exe,
leader_nccl: nccl_state::NcclState::new(), leader_worker,
}) })
} }
@@ -321,27 +331,26 @@ impl WorkerPool {
} }
// 2. Leader rank 0 calls Comm::from_rank on its own device. // 2. Leader rank 0 calls Comm::from_rank on its own device.
// Runs on spawn_blocking because NCCL's init blocks until // Phase 3 moved this from spawn_blocking onto the leader's
// every rank has called in — that's exactly the workers // device worker thread (`Job::NcclInit`); the underlying
// above. The leader's NcclState is moved through the // `Comm` now lives on the same OS thread for its entire
// blocking task and returned to the pool. // lifetime, including every later `Comm::all_reduce` issued
// by the row-parallel layers during forward.
//
// NCCL's init blocks until every rank has called in — the
// subprocess workers above and the leader's device worker
// here. The Job's reply unblocks when the leader's
// Comm::from_rank returns.
let leader_cfg = worker::WorkerConfig { let leader_cfg = worker::WorkerConfig {
rank: 0, rank: 0,
world_size: self.world_size, world_size: self.world_size,
cuda_device: leader_cuda_device, cuda_device: leader_cuda_device,
}; };
let comm_id_for_leader = comm_id.clone(); let leader_resp = self
// Swap out the leader's NcclState into a fresh empty one so we .leader_worker
// can move it into spawn_blocking; restore after the task .nccl_init(leader_cfg, comm_id.clone())
// returns. (NcclState isn't Clone — it owns a real NCCL Comm.) .await
let mut leader_state = std::mem::take(&mut self.leader_nccl); .map_err(|e| anyhow::anyhow!("leader NCCL init via device worker: {e}"))?;
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
let resp = leader_state.init(leader_cfg, &comm_id_for_leader);
(leader_state, resp)
})
.await
.context("leader NCCL init task panicked")?;
self.leader_nccl = returned_state;
match leader_resp { match leader_resp {
rpc::WorkerResponse::InitOk => {} rpc::WorkerResponse::InitOk => {}
rpc::WorkerResponse::Error { kind, message } => { rpc::WorkerResponse::Error { kind, message } => {
@@ -387,16 +396,16 @@ impl WorkerPool {
w.send_only(&WorkerRequest::NcclSanityCheck).await?; w.send_only(&WorkerRequest::NcclSanityCheck).await?;
} }
// 2. Leader's own all_reduce, in spawn_blocking. NCCL operations // 2. Leader's own all_reduce, on its device worker thread.
// block until every rank participates. // NCCL operations block until every rank participates;
let mut leader_state = std::mem::take(&mut self.leader_nccl); // Job::NcclSanity returns once the leader's side completes
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || { // (which happens when every subprocess worker reaches its
let resp = leader_state.sanity_check(); // all_reduce call too).
(leader_state, resp) let leader_resp = self
}) .leader_worker
.await .nccl_sanity()
.context("leader NCCL sanity task panicked")?; .await
self.leader_nccl = returned_state; .map_err(|e| anyhow::anyhow!("leader NCCL sanity via device worker: {e}"))?;
let expected = self.world_size; let expected = self.world_size;
let leader_sum = match leader_resp { let leader_sum = match leader_resp {
@@ -483,21 +492,24 @@ impl WorkerPool {
leader_device: &candle_core::Device, leader_device: &candle_core::Device,
dtype: candle_core::DType, dtype: candle_core::DType,
quant: Option<String>, quant: Option<String>,
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> { ) -> Result<super::device_worker::TpHandle> {
use candle_nn::var_builder::ShardedSafeTensors; use candle_nn::var_builder::ShardedSafeTensors;
use std::sync::Arc;
use tokio::sync::Mutex;
// Wrap the comm in SendComm immediately so it stays Send across // Ask the leader's device worker for an `Arc<Comm>` clone.
// the await points in this method — bare Arc<Comm> would // Phase 3 moved `NcclState` ownership onto the worker thread,
// poison the async fn's Send bound (Comm's raw NCCL pointer is // so the spawn_blocking load below can no longer reach the
// !Send). The wrapper's safety contract is satisfied by the // Comm directly. The reply is wrapped in `SendComm` because
// pool's outer Mutex serialising callers + the spawn_blocking // the underlying `Arc<Comm>` is `!Send` at the type level;
// thread being the only place ops are issued. // the safety contract (only one thread issues NCCL ops at a
let leader_comm = // time) is preserved because the load runs on a single
nccl_state::SendComm(self.leader_nccl.comm().ok_or_else(|| { // spawn_blocking thread and AllReduce ops fire only from the
anyhow::anyhow!("leader NCCL not initialised; call init_nccl first") // device worker thread later. Phase 4 eliminates this bridge
})?); // when the load itself moves onto the worker.
let leader_comm = self
.leader_worker
.clone_leader_comm()
.await
.map_err(|e| anyhow::anyhow!("clone leader Comm via device worker: {e}"))?;
let world_size = self.world_size; let world_size = self.world_size;
let safetensors_str: Vec<String> = safetensors_paths let safetensors_str: Vec<String> = safetensors_paths
.iter() .iter()
@@ -601,15 +613,32 @@ impl WorkerPool {
} }
} }
Ok(Arc::new(Mutex::new(leader_model))) // Phase 3: move the leader's freshly-built `TpLeaderModel`
// into the device worker's TP slab. The model holds
// `Arc<Comm>` clones (in its AllReduce ops) plus CUDA
// tensors — both need to live on the device worker thread so
// every `Comm::all_reduce` and tensor op during forward
// dispatches from the same OS thread that bound the CUDA
// context.
let handle = self
.leader_worker
.transfer_in_tp(Box::new(leader_model))
.await
.map_err(|e| anyhow::anyhow!("transfer TP leader model into device worker: {e}"))?;
Ok(handle)
} }
/// Run one forward step across every rank. The leader's forward /// Run one forward step across every rank. The leader's forward
/// returns the last-position logits as a candle Tensor on the /// runs on the device worker thread via `Job::TpForwardLogits` and
/// leader's device; the caller does sampling out-of-band. Workers /// returns CPU-side `[vocab]` logits as `Vec<f32>`; the async
/// run their own forwards (the AllReduce inside row-parallel layers /// caller wraps them in a CPU tensor for `apply_repeat_penalty` +
/// is what lets the leader's collective complete) and reply with /// sampling without holding a device-resident tensor on a tokio
/// `GenerateStepOk` — they do not ship logits over the wire. /// thread.
///
/// Subprocess workers run their own forwards in parallel (the
/// AllReduce CustomOps inside row-parallel layers are what let
/// the leader's collective complete) and reply with
/// `GenerateStepOk` over the RPC stream — they do not ship logits.
/// ///
/// `tokens` is the input for this step (prompt for prefill, the /// `tokens` is the input for this step (prompt for prefill, the
/// previously-sampled token for decode). `offset` is the KV-cache /// previously-sampled token for decode). `offset` is the KV-cache
@@ -618,10 +647,10 @@ impl WorkerPool {
pub async fn generate_step( pub async fn generate_step(
&mut self, &mut self,
model_id: &str, model_id: &str,
leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>, leader_handle: super::device_worker::TpHandle,
tokens: Vec<u32>, tokens: Vec<u32>,
offset: usize, offset: usize,
) -> Result<candle_core::Tensor> { ) -> Result<Vec<f32>> {
let step_start = std::time::Instant::now(); let step_start = std::time::Instant::now();
let tokens_len = tokens.len(); let tokens_len = tokens.len();
tracing::debug!( tracing::debug!(
@@ -630,7 +659,7 @@ impl WorkerPool {
offset, offset,
"WorkerPool::generate_step: fan-out" "WorkerPool::generate_step: fan-out"
); );
// 1. Fan-out to workers. // 1. Fan-out to subprocess workers.
for w in &mut self.workers { for w in &mut self.workers {
w.send_only(&WorkerRequest::GenerateStep { w.send_only(&WorkerRequest::GenerateStep {
model_id: model_id.to_string(), model_id: model_id.to_string(),
@@ -640,35 +669,30 @@ impl WorkerPool {
.await?; .await?;
} }
// 2. Leader's forward in spawn_blocking. The AllReduce CustomOps // 2. Leader's forward on its device worker thread. The
// inside the row-parallel layers block until every worker's // AllReduce CustomOps inside the row-parallel layers block
// forward issues the matching collective. // until every subprocess worker's forward issues the
// matching collective. Returning CPU-side `Vec<f32>` keeps
// the device tensor from escaping the worker thread —
// that's the invariant the whole refactor exists to
// preserve.
let leader_start = std::time::Instant::now(); let leader_start = std::time::Instant::now();
let leader_result = tokio::task::spawn_blocking(move || -> Result<candle_core::Tensor> { let leader_result = self
let mut model = leader_model.blocking_lock(); .leader_worker
let device = model.device().clone(); .tp_forward_logits(leader_handle, tokens, offset)
let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; .await;
// ForCausalLM::forward returns [B, 1, V] — squeeze both let leader_ok = leader_result.is_ok();
// leading dims to the rank-1 vocab logits the sampler wants.
let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?;
Ok(logits)
})
.await
.context("leader forward task panicked");
let leader_ok = matches!(leader_result, Ok(Ok(_)));
let leader_ms = leader_start.elapsed().as_millis(); let leader_ms = leader_start.elapsed().as_millis();
// Surface the leader's own error at WARN. Previously this was // Surface the leader's own error at WARN before draining
// silently coerced to `leader_ok=false` while only worker // workers so the operator can correlate it with whatever the
// ranks' errors got logged — when both the leader and a worker // subprocess workers logged. Previously this was silently
// fail together (the typical "CUDA context is now poisoned" // coerced to a bool.
// pattern after an OOM), the operator could see only the
// worker side and had to guess what hit rank 0.
if !leader_ok { if !leader_ok {
let detail = match &leader_result { let detail = leader_result
Ok(Err(e)) => format!("{e:#}"), .as_ref()
Err(e) => format!("task: {e:#}"), .err()
Ok(Ok(_)) => unreachable!("leader_ok=false implies an error path"), .map(|e| format!("{e:#}"))
}; .unwrap_or_default();
tracing::warn!( tracing::warn!(
model = %model_id, model = %model_id,
tokens = tokens_len, tokens = tokens_len,
@@ -707,7 +731,33 @@ impl WorkerPool {
"WorkerPool::generate_step: workers drained" "WorkerPool::generate_step: workers drained"
); );
combine_leader_workers(leader_result, worker_errors, "GenerateStep") // Combine the leader's Result + the workers' string-error
// list. Phase 3 inlines this because the upstream
// `combine_leader_workers` expects the spawn_blocking-shaped
// `Result<Result<T>>`; the new device-worker path produces a
// single `Result<T, WorkerError>` instead.
match leader_result {
Ok(values) => {
if worker_errors.is_empty() {
Ok(values)
} else {
anyhow::bail!(
"GenerateStep: leader succeeded but workers failed: {}",
worker_errors.join("; ")
)
}
}
Err(e) => {
if worker_errors.is_empty() {
Err(anyhow::Error::new(e).context("GenerateStep: leader forward failed"))
} else {
Err(anyhow::Error::new(e).context(format!(
"GenerateStep: leader forward failed and workers also failed: {}",
worker_errors.join("; ")
)))
}
}
}
} }
/// Reset the KV cache for `model_id` on every rank. Called at the /// Reset the KV cache for `model_id` on every rank. Called at the
@@ -716,7 +766,7 @@ impl WorkerPool {
pub async fn clear_kv_cache( pub async fn clear_kv_cache(
&mut self, &mut self,
model_id: &str, model_id: &str,
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>, #[cfg(feature = "cuda")] leader_handle: super::device_worker::TpHandle,
) -> Result<()> { ) -> Result<()> {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out"); tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out");
@@ -728,13 +778,18 @@ impl WorkerPool {
} }
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
{ {
let mut m = leader_model.lock().await; // Leader-side clear on the device worker thread —
m.clear_kv_cache(); // `TpLeaderModel::clear_kv_cache` is infallible but still
// routes through Job::TpClearKv so the cache reset runs
// on the same thread that owns the model's CUDA tensors.
if let Err(e) = self.leader_worker.tp_clear_kv(leader_handle).await {
anyhow::bail!("leader TP clear_kv_cache via device worker: {e}");
}
} }
// Drain workers — same rationale as `generate_step`. The // Drain workers — same rationale as `generate_step`. The
// leader's clear_kv_cache is in-process and infallible, but we // leader's clear_kv_cache is now async-via-channel but still
// still always drain so an error on one worker doesn't leave // returns before the drain so the workers' KvCacheCleared
// pending responses for the others. // replies are processed in order.
let worker_errors = drain_workers(&mut self.workers, |r| match r { let worker_errors = drain_workers(&mut self.workers, |r| match r {
WorkerResponse::KvCacheCleared => Ok(()), WorkerResponse::KvCacheCleared => Ok(()),
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")), WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),

View File

@@ -118,7 +118,13 @@ async fn tp_smoke(tp_size: u32, cuda_devices: Vec<u32>) -> Result<()> {
binary = %exe.display(), binary = %exe.display(),
"tp-smoke: spawning worker pool" "tp-smoke: spawning worker pool"
); );
let mut pool = tp::WorkerPool::spawn(&exe, tp_size, &cuda_devices).await?; // tp_smoke is a diagnostic tool; spawn the leader's device worker
// directly. (In the daemon path, CandleHarness::ensure_device_worker
// caches one per device.)
let leader_worker = neuron::harness::device_worker::DeviceWorkerHandle::spawn(leader_device)
.context("spawn leader device worker for tp-smoke")?;
let mut pool =
tp::WorkerPool::spawn(&exe, tp_size, &cuda_devices, leader_worker.clone()).await?;
tracing::info!("tp-smoke: pinging every worker"); tracing::info!("tp-smoke: pinging every worker");
let pongs = pool.ping_all().await?; let pongs = pool.ping_all().await?;

View File

@@ -5,6 +5,7 @@
//! `Init` and `NcclSanityCheck` are stubbed in 7a-i, so this test //! `Init` and `NcclSanityCheck` are stubbed in 7a-i, so this test
//! runs on any host the workspace builds on. //! runs on any host the workspace builds on.
use neuron::harness::device_worker::DeviceWorkerHandle;
use neuron::harness::tp::{WorkerPool, rpc::WorkerResponse}; use neuron::harness::tp::{WorkerPool, rpc::WorkerResponse};
/// Path to the neuron binary built by cargo for this test process. /// Path to the neuron binary built by cargo for this test process.
@@ -19,7 +20,8 @@ const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
async fn test_spawn_ping_shutdown() { async fn test_spawn_ping_shutdown() {
// cuda_devices: rank 0 → device 0 (leader, unused here), // cuda_devices: rank 0 → device 0 (leader, unused here),
// rank 1 → device 1 (worker; not actually opened in 7a-i). // rank 1 → device 1 (worker; not actually opened in 7a-i).
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1]) let leader_worker = DeviceWorkerHandle::spawn(0).expect("spawn device worker");
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1], leader_worker)
.await .await
.expect("spawn worker pool"); .expect("spawn worker pool");
@@ -44,7 +46,8 @@ async fn test_spawn_ping_shutdown() {
/// Three workers — exercise the loop in `ping_all` / `shutdown`. /// Three workers — exercise the loop in `ping_all` / `shutdown`.
#[tokio::test] #[tokio::test]
async fn test_spawn_three_workers() { async fn test_spawn_three_workers() {
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 3, &[0, 1, 2]) let leader_worker = DeviceWorkerHandle::spawn(0).expect("spawn device worker");
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 3, &[0, 1, 2], leader_worker)
.await .await
.expect("spawn worker pool"); .expect("spawn worker pool");

View File

@@ -25,7 +25,9 @@ async fn test_init_and_sanity_check_two_ranks() {
.try_init(); .try_init();
// 2 ranks: leader = rank 0 on device 0, worker = rank 1 on device 1. // 2 ranks: leader = rank 0 on device 0, worker = rank 1 on device 1.
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1]) let leader_worker = neuron::harness::device_worker::DeviceWorkerHandle::spawn(0)
.expect("spawn leader device worker");
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1], leader_worker)
.await .await
.expect("spawn worker pool"); .expect("spawn worker pool");