refactor(neuron): phase 3 — TP forward + NCCL state move onto device worker
Some checks failed
CI / Format (push) Successful in 29s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Test (push) Failing after 58s
CI / Clippy (push) Successful in 2m31s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m13s
build-prerelease / Build neuron-blackwell (push) Successful in 4m1s
build-prerelease / Package cortex RPM (push) Successful in 1m30s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled

Third slice of the per-device CUDA context-ownership refactor planned at
~/.claude/plans/plan-the-per-device-worker-abstract-micali.md. The
leader's `NcclState`, every `Comm::all_reduce` issued by the TP layers,
the leader-side KV cache reset, and the TP forward step itself now all
run on the per-device worker thread — the same OS thread that bound
the leader's `CudaContext` at startup.

What this phase changes:

- `Job` gains `NcclInit`, `NcclSanity`, `CloneLeaderComm` (Phase 3
  bridge — Phase 4 removes), `TransferInTp`, `DropTp`, `TpClearKv`,
  `TpForwardLogits`. Plus a new `TpHandle(u64)` opaque key.
- `DeviceWorkerState` gains `nccl: NcclState` and
  `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>` (+ counter).
- `WorkerPool` loses its `leader_nccl` field; gains a
  `leader_worker: Arc<DeviceWorkerHandle>` passed at construction.
  `init_nccl`, `nccl_sanity_check`, `load_dense_shard`,
  `generate_step`, `clear_kv_cache` all route their leader-side ops
  through `Job::Nccl*` / `Job::Tp*` instead of spawn_blocking against
  a Mutex-wrapped state. `generate_step` returns `Vec<f32>` instead
  of a device-resident `Tensor` — the worker copies logits to CPU
  before reply so the async caller can sample on a CPU candle
  tensor with zero device-context touch.
- `TpLoadedModel.leader_model: Arc<Mutex<TpLeaderModel>>` → opaque
  `leader_handle: TpHandle`. The boxed `TpLeaderModel` lives in the
  worker thread's slab; both the model's CUDA tensors and the
  embedded `Arc<Comm>` clones release on the same thread that
  allocated them (the Drop semantics constraint cudarc forces).
- `Job::CloneLeaderComm` is a Phase 3 bridge: the TP shard load still
  runs in spawn_blocking and needs the leader's `Arc<Comm>` to build
  the row-parallel layers' AllReduce ops. The Job clones the Comm
  out of the worker's NcclState and ships it back as `SendComm`.
  Phase 4 deletes this bridge when the load itself moves onto the
  worker.
- `Job::NcclInit` and `Job::NcclSanity` are ungated by `cuda` so the
  no-cuda `NcclState` stubs (which reply with `cuda_feature_not_enabled`)
  still flow through the same channel uniformly; the cuda-only
  TP variants (CloneLeaderComm, Transfer/Drop/Clear/Forward Tp)
  remain gated.

What this phase doesn't touch (yet):

- TP shard load itself — still spawn_blocking, bridged via
  `CloneLeaderComm`. Phase 4 moves it to `Job::TpLoadShard` and
  reads `state.nccl.comm()` directly inside the worker.
- Single-GPU model loads — still spawn_blocking, transferred via
  `Job::TransferIn`. Phase 4 moves them.
- `device_vram_mb` / `cuda_mem_mb` / `log_construction_complete`
  helpers — still present, used inside spawn_blocking load closures.
  Phase 4 cleanup folds them into `dispatch.rs`.

`tp/mod.rs::WorkerPool::spawn` gained a required
`leader_worker: Arc<DeviceWorkerHandle>` argument. Three external
callers were updated: `CandleHarness::load_tp` (passes the cached
device worker), `main.rs::tp_smoke` (spawns a fresh worker), and
the two `tp_worker_lifecycle*.rs` integration tests.

Public API unchanged. fmt + clippy clean; 37 lib tests + all
integration tests pass. CUDA-only TP integration smoke deferred to
the next deploy on beast.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-27 10:16:02 +03:00
parent b179204fd3
commit 76ab24d98c
8 changed files with 676 additions and 138 deletions

View File

@@ -163,16 +163,22 @@ pub struct TpLoadedModel {
pub model_id: String,
pub tokenizer: Tokenizer,
pub devices: Vec<u32>,
/// One end-to-end gate: the pool's RPC stream isn't safe to use
/// concurrently and the leader shard's KV cache mutates with every
/// step. The same Mutex covers both for the simplest correctness
/// story.
/// One end-to-end gate: the pool's RPC stream to the subprocess
/// workers isn't safe to use concurrently. After Phase 3 the
/// leader's `TpLeaderModel` lives in the worker thread's slab,
/// so this Mutex no longer covers the leader's KV cache; it just
/// serialises subprocess RPC traffic on the pool's
/// `Vec<Worker>` channels.
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
pub leader_model: Arc<tokio::sync::Mutex<super::tp::TpLeaderModel>>,
/// Candle device for rank 0. Mirrors what `leader_model.device()`
/// would return, but stored separately so the request path can
/// query VRAM without locking the leader (which would contend with
/// the in-flight forward).
/// Handle into the leader device worker's TP slab. The boxed
/// `TpLeaderModel` (with its embedded `Arc<Comm>` clones and
/// per-rank CUDA tensors) lives on the worker thread; we hold an
/// opaque index. Forward / clear_kv / unload all route through
/// `Job::Tp*` against this handle.
pub leader_handle: super::device_worker::TpHandle,
/// Candle device for rank 0. Mirrors what
/// `TpLeaderModel::device()` would return, kept on the struct so
/// the request path can name the device without an RPC.
pub leader_device: Device,
/// Same poisoning gate as [`LoadedModel::poisoned`]. A TP forward
/// failure (CUDA OOM on any rank, NCCL desync, illegal address) is
@@ -180,9 +186,8 @@ pub struct TpLoadedModel {
/// reliably reset without restarting the worker subprocesses.
pub poisoned: AtomicBool,
/// Worker thread for the leader's CUDA device. Owns the leader's
/// `CudaContext` for the daemon's lifetime. VRAM queries route
/// through it; in later refactor phases the forward, kv-cache
/// clear, and shard unload route through it too.
/// `CudaContext`, `NcclState`, and the boxed `TpLeaderModel`
/// referenced by `leader_handle`.
pub worker: Arc<super::device_worker::DeviceWorkerHandle>,
}
@@ -1642,6 +1647,16 @@ impl Harness for CandleHarness {
anyhow::bail!("cannot unload '{model_id}': inference still in flight");
}
};
// Drop the leader's TpLeaderModel on the device worker
// thread (CUDA tensors and Arc<Comm> clones release on
// the same OS thread that allocated them).
if let Err(e) = tp.worker.drop_tp(tp.leader_handle).await {
tracing::warn!(
model = %model_id,
error = %e,
"TP unload: DropTp RPC failed (leader model may leak in worker slab)"
);
}
let mut pool = tp.pool.into_inner();
if let Err(e) = pool.unload_model(model_id).await {
tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed");
@@ -1715,9 +1730,14 @@ impl CandleHarness {
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
// 1..tp_size are subprocesses, one per device after the
// leader's own.
// leader's own. The leader's device worker thread is
// spawned (or reused) here and passed into the pool so
// `init_nccl`, the load, every TP forward, and KV-cache
// clears all dispatch from the same OS thread.
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
let mut pool = super::tp::WorkerPool::spawn(&exe, tp_size, &devices).await?;
let leader_worker = self.ensure_device_worker(devices[0]).await?;
let mut pool =
super::tp::WorkerPool::spawn(&exe, tp_size, &devices, leader_worker.clone()).await?;
// 3. NCCL handshake across all ranks.
let leader_device_idx = devices[0];
@@ -1727,8 +1747,11 @@ impl CandleHarness {
let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize)
.context("Device::new_cuda for TP leader")?;
// 5. Load this rank's shard on every rank.
let leader_model = pool
// 5. Load this rank's shard on every rank. After Phase 3
// `load_dense_shard` transfers the freshly-built
// `TpLeaderModel` into the device worker's TP slab and
// returns the resulting handle.
let leader_handle = pool
.load_dense_shard(
&spec.model_id,
&config_json,
@@ -1743,21 +1766,18 @@ impl CandleHarness {
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
// 7. Worker thread for the leader's CUDA device. TP always
// runs on CUDA — the harness rejects TP without the cuda
// feature earlier in this function — so we always have a
// device to own.
let worker = self.ensure_device_worker(devices[0]).await?;
let tp_loaded = StdArc::new(TpLoadedModel {
model_id: spec.model_id.clone(),
tokenizer,
devices: devices.clone(),
pool: TMutex::new(pool),
leader_model,
leader_handle,
leader_device: leader_device.clone(),
poisoned: AtomicBool::new(false),
worker,
// Same `leader_worker` we passed into the pool above —
// single `Arc` shared between WorkerPool and
// TpLoadedModel so they reference the same thread.
worker: leader_worker,
});
let mut models = self.models.write().await;
@@ -1932,14 +1952,14 @@ impl CandleHarness {
async move {
let mut failure: Option<String> = None;
let mut pool = acquire_pool_lock(&tp_for_task.pool, &model_id).await;
let leader_arc = tp_for_task.leader_model.clone();
let leader_handle = tp_for_task.leader_handle;
let mut all_tokens: Vec<u32> = Vec::new();
let mut decoded_prefix = String::new();
let mut finish_reason = "length".to_string();
'work: {
if let Err(e) = pool.clear_kv_cache(&model_id, leader_arc.clone()).await {
if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await {
failure = Some(format!("clear_kv_cache: {e:#}"));
break 'work;
}
@@ -1957,8 +1977,8 @@ impl CandleHarness {
};
// Prefill — every rank embeds the prompt, offset = 0.
let logits = match pool
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
let logits_vec = match pool
.generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0)
.await
{
Ok(l) => l,
@@ -1974,11 +1994,18 @@ impl CandleHarness {
vram_free_mb = post_prefill_vram_free_mb,
"TP chat_completion (stream): prefill complete"
);
let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) {
Ok(t) => t,
Err(e) => {
failure = Some(format!("prefill build cpu logits: {e:#}"));
break 'work;
}
};
let mut next_token =
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health(&logits);
let health = logits_health_slice(&logits_vec);
tracing::warn!(
model = %model_id,
?health,
@@ -2010,10 +2037,10 @@ impl CandleHarness {
}
for index in 0..max_new.saturating_sub(1) {
let logits = match pool
let logits_vec = match pool
.generate_step(
&model_id,
leader_arc.clone(),
leader_handle,
vec![next_token],
prompt_len + index,
)
@@ -2025,6 +2052,14 @@ impl CandleHarness {
break 'work;
}
};
let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) {
Ok(t) => t,
Err(e) => {
failure =
Some(format!("decode build cpu logits {index}: {e:#}"));
break 'work;
}
};
next_token = match sample_with_penalty(
&logits,
&all_tokens,
@@ -2032,7 +2067,7 @@ impl CandleHarness {
) {
Ok(t) => t,
Err(e) => {
let health = logits_health(&logits);
let health = logits_health_slice(&logits_vec);
tracing::warn!(
model = %model_id,
step = index,
@@ -2180,20 +2215,19 @@ async fn chat_completion_tp_inner(
"TP chat_completion: starting"
);
// Acquire the pool lock for the duration of the request. The
// leader_model's own Mutex is acquired step-by-step inside
// pool.generate_step (so spawn_blocking can grab it without
// holding the pool lock across the blocking_lock call).
// `acquire_pool_lock` warns periodically while we wait so a
// stuck holder doesn't make the queueing requests look like
// silence in the journal.
// Acquire the pool lock for the duration of the request. After
// Phase 3 the leader's TpLeaderModel lives in the device worker
// thread, so the pool lock now serialises only subprocess RPC
// traffic — but holding it for the whole request still keeps
// concurrent chat_completions against the same TP model from
// interleaving prefill/decode jobs.
let mut pool = acquire_pool_lock(&tp.pool, &model_id).await;
let leader_arc = tp.leader_model.clone();
let leader_handle = tp.leader_handle;
// Reset every rank's KV cache so this request doesn't attend
// over the previous request's tokens.
let clear_start = std::time::Instant::now();
pool.clear_kv_cache(&model_id, leader_arc.clone())
pool.clear_kv_cache(&model_id, leader_handle)
.await
.map_err(InferenceError::Other)?;
tracing::debug!(
@@ -2219,8 +2253,8 @@ async fn chat_completion_tp_inner(
// Prefill: every rank embeds the whole prompt, offset = 0.
let prefill_start = std::time::Instant::now();
let logits = pool
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
let logits_vec = pool
.generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0)
.await
.map_err(InferenceError::Other)?;
let (post_prefill_vram_free_mb, _) = tp.query_vram().await;
@@ -2231,6 +2265,11 @@ async fn chat_completion_tp_inner(
vram_free_mb = post_prefill_vram_free_mb,
"TP chat_completion: prefill complete"
);
// Wrap the CPU-side logits in a CPU candle Tensor for sampling.
// No device touch on the async caller's thread — sampling reads
// from CPU memory only.
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("build cpu logits: {e}")))?;
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
@@ -2239,7 +2278,7 @@ async fn chat_completion_tp_inner(
// this WARN sits just above that and carries the actual
// numerical state so an operator can tell at a glance
// whether it was a NaN cascade, an Inf, or something else.
let health = logits_health(&logits);
let health = logits_health_slice(&logits_vec);
tracing::warn!(
model = %model_id,
?health,
@@ -2256,19 +2295,22 @@ async fn chat_completion_tp_inner(
let decode_start = std::time::Instant::now();
for index in 0..max_new.saturating_sub(1) {
let step_start = std::time::Instant::now();
let logits = pool
let logits_vec = pool
.generate_step(
&model_id,
leader_arc.clone(),
leader_handle,
vec![next_token],
prompt_len + index,
)
.await
.map_err(InferenceError::Other)?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu).map_err(|e| {
InferenceError::Other(anyhow::anyhow!("build cpu logits step {index}: {e}"))
})?;
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health(&logits);
let health = logits_health_slice(&logits_vec);
tracing::warn!(
model = %model_id,
step = index,