refactor(neuron): phase 3 — TP forward + NCCL state move onto device worker
Some checks failed
CI / Format (push) Successful in 29s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Test (push) Failing after 58s
CI / Clippy (push) Successful in 2m31s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m13s
build-prerelease / Build neuron-blackwell (push) Successful in 4m1s
build-prerelease / Package cortex RPM (push) Successful in 1m30s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Some checks failed
CI / Format (push) Successful in 29s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Test (push) Failing after 58s
CI / Clippy (push) Successful in 2m31s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m13s
build-prerelease / Build neuron-blackwell (push) Successful in 4m1s
build-prerelease / Package cortex RPM (push) Successful in 1m30s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Third slice of the per-device CUDA context-ownership refactor planned at ~/.claude/plans/plan-the-per-device-worker-abstract-micali.md. The leader's `NcclState`, every `Comm::all_reduce` issued by the TP layers, the leader-side KV cache reset, and the TP forward step itself now all run on the per-device worker thread — the same OS thread that bound the leader's `CudaContext` at startup. What this phase changes: - `Job` gains `NcclInit`, `NcclSanity`, `CloneLeaderComm` (Phase 3 bridge — Phase 4 removes), `TransferInTp`, `DropTp`, `TpClearKv`, `TpForwardLogits`. Plus a new `TpHandle(u64)` opaque key. - `DeviceWorkerState` gains `nccl: NcclState` and `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>` (+ counter). - `WorkerPool` loses its `leader_nccl` field; gains a `leader_worker: Arc<DeviceWorkerHandle>` passed at construction. `init_nccl`, `nccl_sanity_check`, `load_dense_shard`, `generate_step`, `clear_kv_cache` all route their leader-side ops through `Job::Nccl*` / `Job::Tp*` instead of spawn_blocking against a Mutex-wrapped state. `generate_step` returns `Vec<f32>` instead of a device-resident `Tensor` — the worker copies logits to CPU before reply so the async caller can sample on a CPU candle tensor with zero device-context touch. - `TpLoadedModel.leader_model: Arc<Mutex<TpLeaderModel>>` → opaque `leader_handle: TpHandle`. The boxed `TpLeaderModel` lives in the worker thread's slab; both the model's CUDA tensors and the embedded `Arc<Comm>` clones release on the same thread that allocated them (the Drop semantics constraint cudarc forces). - `Job::CloneLeaderComm` is a Phase 3 bridge: the TP shard load still runs in spawn_blocking and needs the leader's `Arc<Comm>` to build the row-parallel layers' AllReduce ops. The Job clones the Comm out of the worker's NcclState and ships it back as `SendComm`. Phase 4 deletes this bridge when the load itself moves onto the worker. - `Job::NcclInit` and `Job::NcclSanity` are ungated by `cuda` so the no-cuda `NcclState` stubs (which reply with `cuda_feature_not_enabled`) still flow through the same channel uniformly; the cuda-only TP variants (CloneLeaderComm, Transfer/Drop/Clear/Forward Tp) remain gated. What this phase doesn't touch (yet): - TP shard load itself — still spawn_blocking, bridged via `CloneLeaderComm`. Phase 4 moves it to `Job::TpLoadShard` and reads `state.nccl.comm()` directly inside the worker. - Single-GPU model loads — still spawn_blocking, transferred via `Job::TransferIn`. Phase 4 moves them. - `device_vram_mb` / `cuda_mem_mb` / `log_construction_complete` helpers — still present, used inside spawn_blocking load closures. Phase 4 cleanup folds them into `dispatch.rs`. `tp/mod.rs::WorkerPool::spawn` gained a required `leader_worker: Arc<DeviceWorkerHandle>` argument. Three external callers were updated: `CandleHarness::load_tp` (passes the cached device worker), `main.rs::tp_smoke` (spawns a fresh worker), and the two `tp_worker_lifecycle*.rs` integration tests. Public API unchanged. fmt + clippy clean; 37 lib tests + all integration tests pass. CUDA-only TP integration smoke deferred to the next deploy on beast. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -163,16 +163,22 @@ pub struct TpLoadedModel {
|
||||
pub model_id: String,
|
||||
pub tokenizer: Tokenizer,
|
||||
pub devices: Vec<u32>,
|
||||
/// One end-to-end gate: the pool's RPC stream isn't safe to use
|
||||
/// concurrently and the leader shard's KV cache mutates with every
|
||||
/// step. The same Mutex covers both for the simplest correctness
|
||||
/// story.
|
||||
/// One end-to-end gate: the pool's RPC stream to the subprocess
|
||||
/// workers isn't safe to use concurrently. After Phase 3 the
|
||||
/// leader's `TpLeaderModel` lives in the worker thread's slab,
|
||||
/// so this Mutex no longer covers the leader's KV cache; it just
|
||||
/// serialises subprocess RPC traffic on the pool's
|
||||
/// `Vec<Worker>` channels.
|
||||
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
|
||||
pub leader_model: Arc<tokio::sync::Mutex<super::tp::TpLeaderModel>>,
|
||||
/// Candle device for rank 0. Mirrors what `leader_model.device()`
|
||||
/// would return, but stored separately so the request path can
|
||||
/// query VRAM without locking the leader (which would contend with
|
||||
/// the in-flight forward).
|
||||
/// Handle into the leader device worker's TP slab. The boxed
|
||||
/// `TpLeaderModel` (with its embedded `Arc<Comm>` clones and
|
||||
/// per-rank CUDA tensors) lives on the worker thread; we hold an
|
||||
/// opaque index. Forward / clear_kv / unload all route through
|
||||
/// `Job::Tp*` against this handle.
|
||||
pub leader_handle: super::device_worker::TpHandle,
|
||||
/// Candle device for rank 0. Mirrors what
|
||||
/// `TpLeaderModel::device()` would return, kept on the struct so
|
||||
/// the request path can name the device without an RPC.
|
||||
pub leader_device: Device,
|
||||
/// Same poisoning gate as [`LoadedModel::poisoned`]. A TP forward
|
||||
/// failure (CUDA OOM on any rank, NCCL desync, illegal address) is
|
||||
@@ -180,9 +186,8 @@ pub struct TpLoadedModel {
|
||||
/// reliably reset without restarting the worker subprocesses.
|
||||
pub poisoned: AtomicBool,
|
||||
/// Worker thread for the leader's CUDA device. Owns the leader's
|
||||
/// `CudaContext` for the daemon's lifetime. VRAM queries route
|
||||
/// through it; in later refactor phases the forward, kv-cache
|
||||
/// clear, and shard unload route through it too.
|
||||
/// `CudaContext`, `NcclState`, and the boxed `TpLeaderModel`
|
||||
/// referenced by `leader_handle`.
|
||||
pub worker: Arc<super::device_worker::DeviceWorkerHandle>,
|
||||
}
|
||||
|
||||
@@ -1642,6 +1647,16 @@ impl Harness for CandleHarness {
|
||||
anyhow::bail!("cannot unload '{model_id}': inference still in flight");
|
||||
}
|
||||
};
|
||||
// Drop the leader's TpLeaderModel on the device worker
|
||||
// thread (CUDA tensors and Arc<Comm> clones release on
|
||||
// the same OS thread that allocated them).
|
||||
if let Err(e) = tp.worker.drop_tp(tp.leader_handle).await {
|
||||
tracing::warn!(
|
||||
model = %model_id,
|
||||
error = %e,
|
||||
"TP unload: DropTp RPC failed (leader model may leak in worker slab)"
|
||||
);
|
||||
}
|
||||
let mut pool = tp.pool.into_inner();
|
||||
if let Err(e) = pool.unload_model(model_id).await {
|
||||
tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed");
|
||||
@@ -1715,9 +1730,14 @@ impl CandleHarness {
|
||||
|
||||
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
|
||||
// 1..tp_size are subprocesses, one per device after the
|
||||
// leader's own.
|
||||
// leader's own. The leader's device worker thread is
|
||||
// spawned (or reused) here and passed into the pool so
|
||||
// `init_nccl`, the load, every TP forward, and KV-cache
|
||||
// clears all dispatch from the same OS thread.
|
||||
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
|
||||
let mut pool = super::tp::WorkerPool::spawn(&exe, tp_size, &devices).await?;
|
||||
let leader_worker = self.ensure_device_worker(devices[0]).await?;
|
||||
let mut pool =
|
||||
super::tp::WorkerPool::spawn(&exe, tp_size, &devices, leader_worker.clone()).await?;
|
||||
|
||||
// 3. NCCL handshake across all ranks.
|
||||
let leader_device_idx = devices[0];
|
||||
@@ -1727,8 +1747,11 @@ impl CandleHarness {
|
||||
let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize)
|
||||
.context("Device::new_cuda for TP leader")?;
|
||||
|
||||
// 5. Load this rank's shard on every rank.
|
||||
let leader_model = pool
|
||||
// 5. Load this rank's shard on every rank. After Phase 3
|
||||
// `load_dense_shard` transfers the freshly-built
|
||||
// `TpLeaderModel` into the device worker's TP slab and
|
||||
// returns the resulting handle.
|
||||
let leader_handle = pool
|
||||
.load_dense_shard(
|
||||
&spec.model_id,
|
||||
&config_json,
|
||||
@@ -1743,21 +1766,18 @@ impl CandleHarness {
|
||||
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
||||
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
||||
|
||||
// 7. Worker thread for the leader's CUDA device. TP always
|
||||
// runs on CUDA — the harness rejects TP without the cuda
|
||||
// feature earlier in this function — so we always have a
|
||||
// device to own.
|
||||
let worker = self.ensure_device_worker(devices[0]).await?;
|
||||
|
||||
let tp_loaded = StdArc::new(TpLoadedModel {
|
||||
model_id: spec.model_id.clone(),
|
||||
tokenizer,
|
||||
devices: devices.clone(),
|
||||
pool: TMutex::new(pool),
|
||||
leader_model,
|
||||
leader_handle,
|
||||
leader_device: leader_device.clone(),
|
||||
poisoned: AtomicBool::new(false),
|
||||
worker,
|
||||
// Same `leader_worker` we passed into the pool above —
|
||||
// single `Arc` shared between WorkerPool and
|
||||
// TpLoadedModel so they reference the same thread.
|
||||
worker: leader_worker,
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
@@ -1932,14 +1952,14 @@ impl CandleHarness {
|
||||
async move {
|
||||
let mut failure: Option<String> = None;
|
||||
let mut pool = acquire_pool_lock(&tp_for_task.pool, &model_id).await;
|
||||
let leader_arc = tp_for_task.leader_model.clone();
|
||||
let leader_handle = tp_for_task.leader_handle;
|
||||
|
||||
let mut all_tokens: Vec<u32> = Vec::new();
|
||||
let mut decoded_prefix = String::new();
|
||||
let mut finish_reason = "length".to_string();
|
||||
|
||||
'work: {
|
||||
if let Err(e) = pool.clear_kv_cache(&model_id, leader_arc.clone()).await {
|
||||
if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await {
|
||||
failure = Some(format!("clear_kv_cache: {e:#}"));
|
||||
break 'work;
|
||||
}
|
||||
@@ -1957,8 +1977,8 @@ impl CandleHarness {
|
||||
};
|
||||
|
||||
// Prefill — every rank embeds the prompt, offset = 0.
|
||||
let logits = match pool
|
||||
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
||||
let logits_vec = match pool
|
||||
.generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0)
|
||||
.await
|
||||
{
|
||||
Ok(l) => l,
|
||||
@@ -1974,11 +1994,18 @@ impl CandleHarness {
|
||||
vram_free_mb = post_prefill_vram_free_mb,
|
||||
"TP chat_completion (stream): prefill complete"
|
||||
);
|
||||
let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
failure = Some(format!("prefill build cpu logits: {e:#}"));
|
||||
break 'work;
|
||||
}
|
||||
};
|
||||
let mut next_token =
|
||||
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health(&logits);
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
model = %model_id,
|
||||
?health,
|
||||
@@ -2010,10 +2037,10 @@ impl CandleHarness {
|
||||
}
|
||||
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
let logits = match pool
|
||||
let logits_vec = match pool
|
||||
.generate_step(
|
||||
&model_id,
|
||||
leader_arc.clone(),
|
||||
leader_handle,
|
||||
vec![next_token],
|
||||
prompt_len + index,
|
||||
)
|
||||
@@ -2025,6 +2052,14 @@ impl CandleHarness {
|
||||
break 'work;
|
||||
}
|
||||
};
|
||||
let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
failure =
|
||||
Some(format!("decode build cpu logits {index}: {e:#}"));
|
||||
break 'work;
|
||||
}
|
||||
};
|
||||
next_token = match sample_with_penalty(
|
||||
&logits,
|
||||
&all_tokens,
|
||||
@@ -2032,7 +2067,7 @@ impl CandleHarness {
|
||||
) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health(&logits);
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
model = %model_id,
|
||||
step = index,
|
||||
@@ -2180,20 +2215,19 @@ async fn chat_completion_tp_inner(
|
||||
"TP chat_completion: starting"
|
||||
);
|
||||
|
||||
// Acquire the pool lock for the duration of the request. The
|
||||
// leader_model's own Mutex is acquired step-by-step inside
|
||||
// pool.generate_step (so spawn_blocking can grab it without
|
||||
// holding the pool lock across the blocking_lock call).
|
||||
// `acquire_pool_lock` warns periodically while we wait so a
|
||||
// stuck holder doesn't make the queueing requests look like
|
||||
// silence in the journal.
|
||||
// Acquire the pool lock for the duration of the request. After
|
||||
// Phase 3 the leader's TpLeaderModel lives in the device worker
|
||||
// thread, so the pool lock now serialises only subprocess RPC
|
||||
// traffic — but holding it for the whole request still keeps
|
||||
// concurrent chat_completions against the same TP model from
|
||||
// interleaving prefill/decode jobs.
|
||||
let mut pool = acquire_pool_lock(&tp.pool, &model_id).await;
|
||||
let leader_arc = tp.leader_model.clone();
|
||||
let leader_handle = tp.leader_handle;
|
||||
|
||||
// Reset every rank's KV cache so this request doesn't attend
|
||||
// over the previous request's tokens.
|
||||
let clear_start = std::time::Instant::now();
|
||||
pool.clear_kv_cache(&model_id, leader_arc.clone())
|
||||
pool.clear_kv_cache(&model_id, leader_handle)
|
||||
.await
|
||||
.map_err(InferenceError::Other)?;
|
||||
tracing::debug!(
|
||||
@@ -2219,8 +2253,8 @@ async fn chat_completion_tp_inner(
|
||||
|
||||
// Prefill: every rank embeds the whole prompt, offset = 0.
|
||||
let prefill_start = std::time::Instant::now();
|
||||
let logits = pool
|
||||
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
||||
let logits_vec = pool
|
||||
.generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0)
|
||||
.await
|
||||
.map_err(InferenceError::Other)?;
|
||||
let (post_prefill_vram_free_mb, _) = tp.query_vram().await;
|
||||
@@ -2231,6 +2265,11 @@ async fn chat_completion_tp_inner(
|
||||
vram_free_mb = post_prefill_vram_free_mb,
|
||||
"TP chat_completion: prefill complete"
|
||||
);
|
||||
// Wrap the CPU-side logits in a CPU candle Tensor for sampling.
|
||||
// No device touch on the async caller's thread — sampling reads
|
||||
// from CPU memory only.
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)
|
||||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("build cpu logits: {e}")))?;
|
||||
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
@@ -2239,7 +2278,7 @@ async fn chat_completion_tp_inner(
|
||||
// this WARN sits just above that and carries the actual
|
||||
// numerical state so an operator can tell at a glance
|
||||
// whether it was a NaN cascade, an Inf, or something else.
|
||||
let health = logits_health(&logits);
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
model = %model_id,
|
||||
?health,
|
||||
@@ -2256,19 +2295,22 @@ async fn chat_completion_tp_inner(
|
||||
let decode_start = std::time::Instant::now();
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
let step_start = std::time::Instant::now();
|
||||
let logits = pool
|
||||
let logits_vec = pool
|
||||
.generate_step(
|
||||
&model_id,
|
||||
leader_arc.clone(),
|
||||
leader_handle,
|
||||
vec![next_token],
|
||||
prompt_len + index,
|
||||
)
|
||||
.await
|
||||
.map_err(InferenceError::Other)?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu).map_err(|e| {
|
||||
InferenceError::Other(anyhow::anyhow!("build cpu logits step {index}: {e}"))
|
||||
})?;
|
||||
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health(&logits);
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
model = %model_id,
|
||||
step = index,
|
||||
|
||||
Reference in New Issue
Block a user