refactor(neuron): phase 3 — TP forward + NCCL state move onto device worker
Some checks failed
CI / Format (push) Successful in 29s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Test (push) Failing after 58s
CI / Clippy (push) Successful in 2m31s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m13s
build-prerelease / Build neuron-blackwell (push) Successful in 4m1s
build-prerelease / Package cortex RPM (push) Successful in 1m30s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Some checks failed
CI / Format (push) Successful in 29s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Test (push) Failing after 58s
CI / Clippy (push) Successful in 2m31s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m13s
build-prerelease / Build neuron-blackwell (push) Successful in 4m1s
build-prerelease / Package cortex RPM (push) Successful in 1m30s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Third slice of the per-device CUDA context-ownership refactor planned at ~/.claude/plans/plan-the-per-device-worker-abstract-micali.md. The leader's `NcclState`, every `Comm::all_reduce` issued by the TP layers, the leader-side KV cache reset, and the TP forward step itself now all run on the per-device worker thread — the same OS thread that bound the leader's `CudaContext` at startup. What this phase changes: - `Job` gains `NcclInit`, `NcclSanity`, `CloneLeaderComm` (Phase 3 bridge — Phase 4 removes), `TransferInTp`, `DropTp`, `TpClearKv`, `TpForwardLogits`. Plus a new `TpHandle(u64)` opaque key. - `DeviceWorkerState` gains `nccl: NcclState` and `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>` (+ counter). - `WorkerPool` loses its `leader_nccl` field; gains a `leader_worker: Arc<DeviceWorkerHandle>` passed at construction. `init_nccl`, `nccl_sanity_check`, `load_dense_shard`, `generate_step`, `clear_kv_cache` all route their leader-side ops through `Job::Nccl*` / `Job::Tp*` instead of spawn_blocking against a Mutex-wrapped state. `generate_step` returns `Vec<f32>` instead of a device-resident `Tensor` — the worker copies logits to CPU before reply so the async caller can sample on a CPU candle tensor with zero device-context touch. - `TpLoadedModel.leader_model: Arc<Mutex<TpLeaderModel>>` → opaque `leader_handle: TpHandle`. The boxed `TpLeaderModel` lives in the worker thread's slab; both the model's CUDA tensors and the embedded `Arc<Comm>` clones release on the same thread that allocated them (the Drop semantics constraint cudarc forces). - `Job::CloneLeaderComm` is a Phase 3 bridge: the TP shard load still runs in spawn_blocking and needs the leader's `Arc<Comm>` to build the row-parallel layers' AllReduce ops. The Job clones the Comm out of the worker's NcclState and ships it back as `SendComm`. Phase 4 deletes this bridge when the load itself moves onto the worker. - `Job::NcclInit` and `Job::NcclSanity` are ungated by `cuda` so the no-cuda `NcclState` stubs (which reply with `cuda_feature_not_enabled`) still flow through the same channel uniformly; the cuda-only TP variants (CloneLeaderComm, Transfer/Drop/Clear/Forward Tp) remain gated. What this phase doesn't touch (yet): - TP shard load itself — still spawn_blocking, bridged via `CloneLeaderComm`. Phase 4 moves it to `Job::TpLoadShard` and reads `state.nccl.comm()` directly inside the worker. - Single-GPU model loads — still spawn_blocking, transferred via `Job::TransferIn`. Phase 4 moves them. - `device_vram_mb` / `cuda_mem_mb` / `log_construction_complete` helpers — still present, used inside spawn_blocking load closures. Phase 4 cleanup folds them into `dispatch.rs`. `tp/mod.rs::WorkerPool::spawn` gained a required `leader_worker: Arc<DeviceWorkerHandle>` argument. Three external callers were updated: `CandleHarness::load_tp` (passes the cached device worker), `main.rs::tp_smoke` (spawns a fresh worker), and the two `tp_worker_lifecycle*.rs` integration tests. Public API unchanged. fmt + clippy clean; 37 lib tests + all integration tests pass. CUDA-only TP integration smoke deferred to the next deploy on beast. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -163,16 +163,22 @@ pub struct TpLoadedModel {
|
|||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
pub tokenizer: Tokenizer,
|
pub tokenizer: Tokenizer,
|
||||||
pub devices: Vec<u32>,
|
pub devices: Vec<u32>,
|
||||||
/// One end-to-end gate: the pool's RPC stream isn't safe to use
|
/// One end-to-end gate: the pool's RPC stream to the subprocess
|
||||||
/// concurrently and the leader shard's KV cache mutates with every
|
/// workers isn't safe to use concurrently. After Phase 3 the
|
||||||
/// step. The same Mutex covers both for the simplest correctness
|
/// leader's `TpLeaderModel` lives in the worker thread's slab,
|
||||||
/// story.
|
/// so this Mutex no longer covers the leader's KV cache; it just
|
||||||
|
/// serialises subprocess RPC traffic on the pool's
|
||||||
|
/// `Vec<Worker>` channels.
|
||||||
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
|
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
|
||||||
pub leader_model: Arc<tokio::sync::Mutex<super::tp::TpLeaderModel>>,
|
/// Handle into the leader device worker's TP slab. The boxed
|
||||||
/// Candle device for rank 0. Mirrors what `leader_model.device()`
|
/// `TpLeaderModel` (with its embedded `Arc<Comm>` clones and
|
||||||
/// would return, but stored separately so the request path can
|
/// per-rank CUDA tensors) lives on the worker thread; we hold an
|
||||||
/// query VRAM without locking the leader (which would contend with
|
/// opaque index. Forward / clear_kv / unload all route through
|
||||||
/// the in-flight forward).
|
/// `Job::Tp*` against this handle.
|
||||||
|
pub leader_handle: super::device_worker::TpHandle,
|
||||||
|
/// Candle device for rank 0. Mirrors what
|
||||||
|
/// `TpLeaderModel::device()` would return, kept on the struct so
|
||||||
|
/// the request path can name the device without an RPC.
|
||||||
pub leader_device: Device,
|
pub leader_device: Device,
|
||||||
/// Same poisoning gate as [`LoadedModel::poisoned`]. A TP forward
|
/// Same poisoning gate as [`LoadedModel::poisoned`]. A TP forward
|
||||||
/// failure (CUDA OOM on any rank, NCCL desync, illegal address) is
|
/// failure (CUDA OOM on any rank, NCCL desync, illegal address) is
|
||||||
@@ -180,9 +186,8 @@ pub struct TpLoadedModel {
|
|||||||
/// reliably reset without restarting the worker subprocesses.
|
/// reliably reset without restarting the worker subprocesses.
|
||||||
pub poisoned: AtomicBool,
|
pub poisoned: AtomicBool,
|
||||||
/// Worker thread for the leader's CUDA device. Owns the leader's
|
/// Worker thread for the leader's CUDA device. Owns the leader's
|
||||||
/// `CudaContext` for the daemon's lifetime. VRAM queries route
|
/// `CudaContext`, `NcclState`, and the boxed `TpLeaderModel`
|
||||||
/// through it; in later refactor phases the forward, kv-cache
|
/// referenced by `leader_handle`.
|
||||||
/// clear, and shard unload route through it too.
|
|
||||||
pub worker: Arc<super::device_worker::DeviceWorkerHandle>,
|
pub worker: Arc<super::device_worker::DeviceWorkerHandle>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1642,6 +1647,16 @@ impl Harness for CandleHarness {
|
|||||||
anyhow::bail!("cannot unload '{model_id}': inference still in flight");
|
anyhow::bail!("cannot unload '{model_id}': inference still in flight");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
// Drop the leader's TpLeaderModel on the device worker
|
||||||
|
// thread (CUDA tensors and Arc<Comm> clones release on
|
||||||
|
// the same OS thread that allocated them).
|
||||||
|
if let Err(e) = tp.worker.drop_tp(tp.leader_handle).await {
|
||||||
|
tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"TP unload: DropTp RPC failed (leader model may leak in worker slab)"
|
||||||
|
);
|
||||||
|
}
|
||||||
let mut pool = tp.pool.into_inner();
|
let mut pool = tp.pool.into_inner();
|
||||||
if let Err(e) = pool.unload_model(model_id).await {
|
if let Err(e) = pool.unload_model(model_id).await {
|
||||||
tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed");
|
tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed");
|
||||||
@@ -1715,9 +1730,14 @@ impl CandleHarness {
|
|||||||
|
|
||||||
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
|
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
|
||||||
// 1..tp_size are subprocesses, one per device after the
|
// 1..tp_size are subprocesses, one per device after the
|
||||||
// leader's own.
|
// leader's own. The leader's device worker thread is
|
||||||
|
// spawned (or reused) here and passed into the pool so
|
||||||
|
// `init_nccl`, the load, every TP forward, and KV-cache
|
||||||
|
// clears all dispatch from the same OS thread.
|
||||||
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
|
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
|
||||||
let mut pool = super::tp::WorkerPool::spawn(&exe, tp_size, &devices).await?;
|
let leader_worker = self.ensure_device_worker(devices[0]).await?;
|
||||||
|
let mut pool =
|
||||||
|
super::tp::WorkerPool::spawn(&exe, tp_size, &devices, leader_worker.clone()).await?;
|
||||||
|
|
||||||
// 3. NCCL handshake across all ranks.
|
// 3. NCCL handshake across all ranks.
|
||||||
let leader_device_idx = devices[0];
|
let leader_device_idx = devices[0];
|
||||||
@@ -1727,8 +1747,11 @@ impl CandleHarness {
|
|||||||
let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize)
|
let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize)
|
||||||
.context("Device::new_cuda for TP leader")?;
|
.context("Device::new_cuda for TP leader")?;
|
||||||
|
|
||||||
// 5. Load this rank's shard on every rank.
|
// 5. Load this rank's shard on every rank. After Phase 3
|
||||||
let leader_model = pool
|
// `load_dense_shard` transfers the freshly-built
|
||||||
|
// `TpLeaderModel` into the device worker's TP slab and
|
||||||
|
// returns the resulting handle.
|
||||||
|
let leader_handle = pool
|
||||||
.load_dense_shard(
|
.load_dense_shard(
|
||||||
&spec.model_id,
|
&spec.model_id,
|
||||||
&config_json,
|
&config_json,
|
||||||
@@ -1743,21 +1766,18 @@ impl CandleHarness {
|
|||||||
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
||||||
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
||||||
|
|
||||||
// 7. Worker thread for the leader's CUDA device. TP always
|
|
||||||
// runs on CUDA — the harness rejects TP without the cuda
|
|
||||||
// feature earlier in this function — so we always have a
|
|
||||||
// device to own.
|
|
||||||
let worker = self.ensure_device_worker(devices[0]).await?;
|
|
||||||
|
|
||||||
let tp_loaded = StdArc::new(TpLoadedModel {
|
let tp_loaded = StdArc::new(TpLoadedModel {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: spec.model_id.clone(),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
devices: devices.clone(),
|
devices: devices.clone(),
|
||||||
pool: TMutex::new(pool),
|
pool: TMutex::new(pool),
|
||||||
leader_model,
|
leader_handle,
|
||||||
leader_device: leader_device.clone(),
|
leader_device: leader_device.clone(),
|
||||||
poisoned: AtomicBool::new(false),
|
poisoned: AtomicBool::new(false),
|
||||||
worker,
|
// Same `leader_worker` we passed into the pool above —
|
||||||
|
// single `Arc` shared between WorkerPool and
|
||||||
|
// TpLoadedModel so they reference the same thread.
|
||||||
|
worker: leader_worker,
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut models = self.models.write().await;
|
let mut models = self.models.write().await;
|
||||||
@@ -1932,14 +1952,14 @@ impl CandleHarness {
|
|||||||
async move {
|
async move {
|
||||||
let mut failure: Option<String> = None;
|
let mut failure: Option<String> = None;
|
||||||
let mut pool = acquire_pool_lock(&tp_for_task.pool, &model_id).await;
|
let mut pool = acquire_pool_lock(&tp_for_task.pool, &model_id).await;
|
||||||
let leader_arc = tp_for_task.leader_model.clone();
|
let leader_handle = tp_for_task.leader_handle;
|
||||||
|
|
||||||
let mut all_tokens: Vec<u32> = Vec::new();
|
let mut all_tokens: Vec<u32> = Vec::new();
|
||||||
let mut decoded_prefix = String::new();
|
let mut decoded_prefix = String::new();
|
||||||
let mut finish_reason = "length".to_string();
|
let mut finish_reason = "length".to_string();
|
||||||
|
|
||||||
'work: {
|
'work: {
|
||||||
if let Err(e) = pool.clear_kv_cache(&model_id, leader_arc.clone()).await {
|
if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await {
|
||||||
failure = Some(format!("clear_kv_cache: {e:#}"));
|
failure = Some(format!("clear_kv_cache: {e:#}"));
|
||||||
break 'work;
|
break 'work;
|
||||||
}
|
}
|
||||||
@@ -1957,8 +1977,8 @@ impl CandleHarness {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Prefill — every rank embeds the prompt, offset = 0.
|
// Prefill — every rank embeds the prompt, offset = 0.
|
||||||
let logits = match pool
|
let logits_vec = match pool
|
||||||
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
.generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(l) => l,
|
Ok(l) => l,
|
||||||
@@ -1974,11 +1994,18 @@ impl CandleHarness {
|
|||||||
vram_free_mb = post_prefill_vram_free_mb,
|
vram_free_mb = post_prefill_vram_free_mb,
|
||||||
"TP chat_completion (stream): prefill complete"
|
"TP chat_completion (stream): prefill complete"
|
||||||
);
|
);
|
||||||
|
let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
failure = Some(format!("prefill build cpu logits: {e:#}"));
|
||||||
|
break 'work;
|
||||||
|
}
|
||||||
|
};
|
||||||
let mut next_token =
|
let mut next_token =
|
||||||
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||||
Ok(t) => t,
|
Ok(t) => t,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let health = logits_health(&logits);
|
let health = logits_health_slice(&logits_vec);
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
model = %model_id,
|
model = %model_id,
|
||||||
?health,
|
?health,
|
||||||
@@ -2010,10 +2037,10 @@ impl CandleHarness {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for index in 0..max_new.saturating_sub(1) {
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
let logits = match pool
|
let logits_vec = match pool
|
||||||
.generate_step(
|
.generate_step(
|
||||||
&model_id,
|
&model_id,
|
||||||
leader_arc.clone(),
|
leader_handle,
|
||||||
vec![next_token],
|
vec![next_token],
|
||||||
prompt_len + index,
|
prompt_len + index,
|
||||||
)
|
)
|
||||||
@@ -2025,6 +2052,14 @@ impl CandleHarness {
|
|||||||
break 'work;
|
break 'work;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
failure =
|
||||||
|
Some(format!("decode build cpu logits {index}: {e:#}"));
|
||||||
|
break 'work;
|
||||||
|
}
|
||||||
|
};
|
||||||
next_token = match sample_with_penalty(
|
next_token = match sample_with_penalty(
|
||||||
&logits,
|
&logits,
|
||||||
&all_tokens,
|
&all_tokens,
|
||||||
@@ -2032,7 +2067,7 @@ impl CandleHarness {
|
|||||||
) {
|
) {
|
||||||
Ok(t) => t,
|
Ok(t) => t,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let health = logits_health(&logits);
|
let health = logits_health_slice(&logits_vec);
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
model = %model_id,
|
model = %model_id,
|
||||||
step = index,
|
step = index,
|
||||||
@@ -2180,20 +2215,19 @@ async fn chat_completion_tp_inner(
|
|||||||
"TP chat_completion: starting"
|
"TP chat_completion: starting"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Acquire the pool lock for the duration of the request. The
|
// Acquire the pool lock for the duration of the request. After
|
||||||
// leader_model's own Mutex is acquired step-by-step inside
|
// Phase 3 the leader's TpLeaderModel lives in the device worker
|
||||||
// pool.generate_step (so spawn_blocking can grab it without
|
// thread, so the pool lock now serialises only subprocess RPC
|
||||||
// holding the pool lock across the blocking_lock call).
|
// traffic — but holding it for the whole request still keeps
|
||||||
// `acquire_pool_lock` warns periodically while we wait so a
|
// concurrent chat_completions against the same TP model from
|
||||||
// stuck holder doesn't make the queueing requests look like
|
// interleaving prefill/decode jobs.
|
||||||
// silence in the journal.
|
|
||||||
let mut pool = acquire_pool_lock(&tp.pool, &model_id).await;
|
let mut pool = acquire_pool_lock(&tp.pool, &model_id).await;
|
||||||
let leader_arc = tp.leader_model.clone();
|
let leader_handle = tp.leader_handle;
|
||||||
|
|
||||||
// Reset every rank's KV cache so this request doesn't attend
|
// Reset every rank's KV cache so this request doesn't attend
|
||||||
// over the previous request's tokens.
|
// over the previous request's tokens.
|
||||||
let clear_start = std::time::Instant::now();
|
let clear_start = std::time::Instant::now();
|
||||||
pool.clear_kv_cache(&model_id, leader_arc.clone())
|
pool.clear_kv_cache(&model_id, leader_handle)
|
||||||
.await
|
.await
|
||||||
.map_err(InferenceError::Other)?;
|
.map_err(InferenceError::Other)?;
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
@@ -2219,8 +2253,8 @@ async fn chat_completion_tp_inner(
|
|||||||
|
|
||||||
// Prefill: every rank embeds the whole prompt, offset = 0.
|
// Prefill: every rank embeds the whole prompt, offset = 0.
|
||||||
let prefill_start = std::time::Instant::now();
|
let prefill_start = std::time::Instant::now();
|
||||||
let logits = pool
|
let logits_vec = pool
|
||||||
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
.generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0)
|
||||||
.await
|
.await
|
||||||
.map_err(InferenceError::Other)?;
|
.map_err(InferenceError::Other)?;
|
||||||
let (post_prefill_vram_free_mb, _) = tp.query_vram().await;
|
let (post_prefill_vram_free_mb, _) = tp.query_vram().await;
|
||||||
@@ -2231,6 +2265,11 @@ async fn chat_completion_tp_inner(
|
|||||||
vram_free_mb = post_prefill_vram_free_mb,
|
vram_free_mb = post_prefill_vram_free_mb,
|
||||||
"TP chat_completion: prefill complete"
|
"TP chat_completion: prefill complete"
|
||||||
);
|
);
|
||||||
|
// Wrap the CPU-side logits in a CPU candle Tensor for sampling.
|
||||||
|
// No device touch on the async caller's thread — sampling reads
|
||||||
|
// from CPU memory only.
|
||||||
|
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("build cpu logits: {e}")))?;
|
||||||
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||||
Ok(t) => t,
|
Ok(t) => t,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -2239,7 +2278,7 @@ async fn chat_completion_tp_inner(
|
|||||||
// this WARN sits just above that and carries the actual
|
// this WARN sits just above that and carries the actual
|
||||||
// numerical state so an operator can tell at a glance
|
// numerical state so an operator can tell at a glance
|
||||||
// whether it was a NaN cascade, an Inf, or something else.
|
// whether it was a NaN cascade, an Inf, or something else.
|
||||||
let health = logits_health(&logits);
|
let health = logits_health_slice(&logits_vec);
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
model = %model_id,
|
model = %model_id,
|
||||||
?health,
|
?health,
|
||||||
@@ -2256,19 +2295,22 @@ async fn chat_completion_tp_inner(
|
|||||||
let decode_start = std::time::Instant::now();
|
let decode_start = std::time::Instant::now();
|
||||||
for index in 0..max_new.saturating_sub(1) {
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
let step_start = std::time::Instant::now();
|
let step_start = std::time::Instant::now();
|
||||||
let logits = pool
|
let logits_vec = pool
|
||||||
.generate_step(
|
.generate_step(
|
||||||
&model_id,
|
&model_id,
|
||||||
leader_arc.clone(),
|
leader_handle,
|
||||||
vec![next_token],
|
vec![next_token],
|
||||||
prompt_len + index,
|
prompt_len + index,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(InferenceError::Other)?;
|
.map_err(InferenceError::Other)?;
|
||||||
|
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu).map_err(|e| {
|
||||||
|
InferenceError::Other(anyhow::anyhow!("build cpu logits step {index}: {e}"))
|
||||||
|
})?;
|
||||||
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||||
Ok(t) => t,
|
Ok(t) => t,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let health = logits_health(&logits);
|
let health = logits_health_slice(&logits_vec);
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
model = %model_id,
|
model = %model_id,
|
||||||
step = index,
|
step = index,
|
||||||
|
|||||||
@@ -14,7 +14,12 @@
|
|||||||
//! `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>`.
|
//! `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>`.
|
||||||
|
|
||||||
use crate::harness::candle::ModelArch;
|
use crate::harness::candle::ModelArch;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use crate::harness::device_worker::jobs::TpHandle;
|
||||||
use crate::harness::device_worker::jobs::{ArchHandle, Job};
|
use crate::harness::device_worker::jobs::{ArchHandle, Job};
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use crate::harness::tp::TpLeaderModel;
|
||||||
|
use crate::harness::tp::nccl_state::NcclState;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
@@ -40,6 +45,20 @@ struct DeviceWorkerState {
|
|||||||
/// increments and returns the new value. Wraps at u64::MAX after
|
/// increments and returns the new value. Wraps at u64::MAX after
|
||||||
/// ~10^19 model loads — not a practical concern.
|
/// ~10^19 model loads — not a practical concern.
|
||||||
next_handle: u64,
|
next_handle: u64,
|
||||||
|
/// Leader's NCCL state. Populated by `Job::NcclInit`; the
|
||||||
|
/// underlying `Comm`'s libnccl handle lives bound to this thread
|
||||||
|
/// for its entire lifetime. Subprocess workers maintain their own
|
||||||
|
/// `NcclState` in their own processes — that's not visible from
|
||||||
|
/// here.
|
||||||
|
#[allow(dead_code)] // Read only via methods on NcclState
|
||||||
|
nccl: NcclState,
|
||||||
|
/// TP leader model slab. Same lifecycle as `models`; separate
|
||||||
|
/// namespace so `ArchHandle` and `TpHandle` can't collide.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
tp_models: HashMap<TpHandle, Box<TpLeaderModel>>,
|
||||||
|
/// Counter for minting fresh `TpHandle`s.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
next_tp_handle: u64,
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
/// `None` only if `CudaContext::new()` failed — in that case the
|
/// `None` only if `CudaContext::new()` failed — in that case the
|
||||||
@@ -128,15 +147,93 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
|||||||
let result = forward_logits(&mut state, handle, &tokens, offset);
|
let result = forward_logits(&mut state, handle, &tokens, offset);
|
||||||
let _ = reply.send(result);
|
let _ = reply.send(result);
|
||||||
}
|
}
|
||||||
|
Job::NcclInit {
|
||||||
|
cfg,
|
||||||
|
comm_id_hex,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let resp = state.nccl.init(cfg, &comm_id_hex);
|
||||||
|
let _ = reply.send(resp);
|
||||||
|
}
|
||||||
|
Job::NcclSanity { reply } => {
|
||||||
|
let resp = state.nccl.sanity_check();
|
||||||
|
let _ = reply.send(resp);
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::CloneLeaderComm { reply } => {
|
||||||
|
let result = match state.nccl.comm() {
|
||||||
|
Some(comm) => Ok(crate::harness::tp::nccl_state::SendComm(comm)),
|
||||||
|
None => Err(anyhow::anyhow!(
|
||||||
|
"CloneLeaderComm: NcclState has no Comm; call NcclInit first"
|
||||||
|
)),
|
||||||
|
};
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TransferInTp { model, reply } => {
|
||||||
|
let handle = TpHandle(state.next_tp_handle);
|
||||||
|
state.next_tp_handle = state.next_tp_handle.wrapping_add(1);
|
||||||
|
state.tp_models.insert(handle, model);
|
||||||
|
tracing::debug!(
|
||||||
|
device_index,
|
||||||
|
tp_handle = handle.0,
|
||||||
|
slab_size = state.tp_models.len(),
|
||||||
|
"device worker: TP model transferred in"
|
||||||
|
);
|
||||||
|
let _ = reply.send(Ok(handle));
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::DropTp { handle, reply } => {
|
||||||
|
let removed = state.tp_models.remove(&handle);
|
||||||
|
let was_present = removed.is_some();
|
||||||
|
drop(removed);
|
||||||
|
tracing::debug!(
|
||||||
|
device_index,
|
||||||
|
tp_handle = handle.0,
|
||||||
|
was_present,
|
||||||
|
slab_size = state.tp_models.len(),
|
||||||
|
"device worker: TP model dropped"
|
||||||
|
);
|
||||||
|
let _ = reply.send(());
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TpClearKv { handle, reply } => {
|
||||||
|
let result = match state.tp_models.get_mut(&handle) {
|
||||||
|
Some(model) => {
|
||||||
|
model.clear_kv_cache();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
None => Err(anyhow::anyhow!(
|
||||||
|
"TpClearKv: no TP model for handle {}",
|
||||||
|
handle.0
|
||||||
|
)),
|
||||||
|
};
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TpForwardLogits {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let result = tp_forward_logits(&mut state, handle, &tokens, offset);
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
// Handled by the matches!() check above; reaching here
|
// Handled by the matches!() check above; reaching here
|
||||||
// means a Shutdown slipped past which is a bug.
|
// means a Shutdown slipped past which is a bug.
|
||||||
Job::Shutdown => unreachable!("Shutdown should break above"),
|
Job::Shutdown => unreachable!("Shutdown should break above"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
let tp_slab_size = state.tp_models.len();
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
let tp_slab_size = 0_usize;
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
device_index,
|
device_index,
|
||||||
slab_size = state.models.len(),
|
slab_size = state.models.len(),
|
||||||
|
tp_slab_size,
|
||||||
"device worker exiting; dropping remaining models"
|
"device worker exiting; dropping remaining models"
|
||||||
);
|
);
|
||||||
// Drops every model in the slab on this thread before the function
|
// Drops every model in the slab on this thread before the function
|
||||||
@@ -193,6 +290,9 @@ fn init_state(device_index: u32) -> DeviceWorkerState {
|
|||||||
device,
|
device,
|
||||||
models: HashMap::new(),
|
models: HashMap::new(),
|
||||||
next_handle: 1,
|
next_handle: 1,
|
||||||
|
nccl: NcclState::new(),
|
||||||
|
tp_models: HashMap::new(),
|
||||||
|
next_tp_handle: 1,
|
||||||
ctx,
|
ctx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -203,6 +303,7 @@ fn init_state(device_index: u32) -> DeviceWorkerState {
|
|||||||
device: candle_core::Device::Cpu,
|
device: candle_core::Device::Cpu,
|
||||||
models: HashMap::new(),
|
models: HashMap::new(),
|
||||||
next_handle: 1,
|
next_handle: 1,
|
||||||
|
nccl: NcclState::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -231,6 +332,38 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
|
|||||||
Ok((0, 0))
|
Ok((0, 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// TP-equivalent of [`forward_logits`]: looks up the leader's
|
||||||
|
/// [`TpLeaderModel`] in the slab, runs its forward, copies the
|
||||||
|
/// `[vocab]` logits to a CPU `Vec<f32>`. The leader's `Arc<Comm>`
|
||||||
|
/// clones embedded in the TP layers' AllReduce ops fire from this
|
||||||
|
/// thread — same thread that bound the CUDA context and that holds
|
||||||
|
/// the `Comm` in `state.nccl`.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn tp_forward_logits(
|
||||||
|
state: &mut DeviceWorkerState,
|
||||||
|
handle: TpHandle,
|
||||||
|
tokens: &[u32],
|
||||||
|
offset: usize,
|
||||||
|
) -> anyhow::Result<Vec<f32>> {
|
||||||
|
use candle_core::{DType, Tensor};
|
||||||
|
|
||||||
|
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
|
let model = state
|
||||||
|
.tp_models
|
||||||
|
.get_mut(&handle)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("TpForwardLogits: no model for handle {}", handle.0))?;
|
||||||
|
|
||||||
|
let logits = model.forward(&input, offset)?;
|
||||||
|
// ForCausalLM forward returns [B, 1, V] after the trailing
|
||||||
|
// .i((.., l - 1.., ..))?.apply(lm_head); squeeze both leading
|
||||||
|
// singleton dims to a rank-1 [V] tensor for sampling.
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?;
|
||||||
|
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
|
||||||
|
let values = logits.to_vec1::<f32>()?;
|
||||||
|
Ok(values)
|
||||||
|
}
|
||||||
|
|
||||||
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
|
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
|
||||||
/// for sampling on the async caller. The model's `device()` (CUDA or
|
/// for sampling on the async caller. The model's `device()` (CUDA or
|
||||||
/// CPU) determines where the kernel runs; this fn doesn't care.
|
/// CPU) determines where the kernel runs; this fn doesn't care.
|
||||||
@@ -297,6 +430,38 @@ fn drain_poisoned(job: Job, device_index: u32) {
|
|||||||
Job::ForwardLogits { reply, .. } => {
|
Job::ForwardLogits { reply, .. } => {
|
||||||
let _ = reply.send(Err(err()));
|
let _ = reply.send(Err(err()));
|
||||||
}
|
}
|
||||||
|
Job::NcclInit { reply, .. } => {
|
||||||
|
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
||||||
|
kind: "device_worker_poisoned".into(),
|
||||||
|
message: format!("device worker {device_index} poisoned"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Job::NcclSanity { reply } => {
|
||||||
|
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
||||||
|
kind: "device_worker_poisoned".into(),
|
||||||
|
message: format!("device worker {device_index} poisoned"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::CloneLeaderComm { reply } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TransferInTp { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::DropTp { reply, .. } => {
|
||||||
|
let _ = reply.send(());
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TpClearKv { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TpForwardLogits { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
Job::Shutdown => {
|
Job::Shutdown => {
|
||||||
// Filtered by the matches!() guard in run(); reaching
|
// Filtered by the matches!() guard in run(); reaching
|
||||||
// here would be a logic error.
|
// here would be a logic error.
|
||||||
|
|||||||
@@ -19,6 +19,15 @@ use tokio::sync::oneshot;
|
|||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub struct ArchHandle(pub u64);
|
pub struct ArchHandle(pub u64);
|
||||||
|
|
||||||
|
/// Opaque handle to a `TpLeaderModel` stored in the worker thread's
|
||||||
|
/// state slab. Same shape as [`ArchHandle`] but in a separate
|
||||||
|
/// namespace so the two slabs can coexist without ambiguity. Phase 3
|
||||||
|
/// introduces it; Phase 4 may unify the two slabs after the TP forward
|
||||||
|
/// path proves out.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub struct TpHandle(pub u64);
|
||||||
|
|
||||||
/// One unit of work for the device worker.
|
/// One unit of work for the device worker.
|
||||||
///
|
///
|
||||||
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
|
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
|
||||||
@@ -73,6 +82,74 @@ pub enum Job {
|
|||||||
offset: usize,
|
offset: usize,
|
||||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
},
|
},
|
||||||
|
/// Initialize the leader's NCCL communicator. The worker's
|
||||||
|
/// `NcclState` mints the `Comm` here so its underlying
|
||||||
|
/// `ncclComm_t` and `CudaContext` live on the same thread as
|
||||||
|
/// every later `Comm::all_reduce` call. Reply is the worker
|
||||||
|
/// response shape used by the subprocess workers (`InitOk` on
|
||||||
|
/// success, `Error` on failure) so the calling
|
||||||
|
/// `WorkerPool::init_nccl` orchestration stays uniform.
|
||||||
|
///
|
||||||
|
/// Available on both cuda and no-cuda builds — the dispatch
|
||||||
|
/// handler calls `NcclState::init` which has a no-cuda stub that
|
||||||
|
/// replies with `cuda_feature_not_enabled`. Keeping the Job
|
||||||
|
/// variant ungated lets `WorkerPool::init_nccl` stay uniform.
|
||||||
|
NcclInit {
|
||||||
|
cfg: crate::harness::tp::worker::WorkerConfig,
|
||||||
|
comm_id_hex: String,
|
||||||
|
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
||||||
|
},
|
||||||
|
/// Run NCCL's all_reduce sanity check on the leader's rank 0.
|
||||||
|
/// Same response shape as `NcclInit`; also available on both
|
||||||
|
/// builds via the no-cuda `NcclState::sanity_check` stub.
|
||||||
|
NcclSanity {
|
||||||
|
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
||||||
|
},
|
||||||
|
/// Clone the leader's `Arc<Comm>` out of the worker's `NcclState`
|
||||||
|
/// so a spawn_blocking-based load (Phase 3 bridge) can hand it to
|
||||||
|
/// the row-parallel layers. Wrapped in `SendComm` because
|
||||||
|
/// `Arc<Comm>` is `!Send` at the type level (the NCCL contract
|
||||||
|
/// requires serialised access, which we provide structurally).
|
||||||
|
/// Phase 4 eliminates this when `TpLoadShard` becomes a Job and
|
||||||
|
/// the load runs entirely on the worker thread.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
CloneLeaderComm {
|
||||||
|
reply: oneshot::Sender<Result<crate::harness::tp::nccl_state::SendComm>>,
|
||||||
|
},
|
||||||
|
/// Move a freshly-built `TpLeaderModel` into the worker's tp slab.
|
||||||
|
/// Returns a `TpHandle` the caller stores on `TpLoadedModel`.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
TransferInTp {
|
||||||
|
model: Box<crate::harness::tp::TpLeaderModel>,
|
||||||
|
reply: oneshot::Sender<Result<TpHandle>>,
|
||||||
|
},
|
||||||
|
/// Drop the TP leader model on the worker thread. CUDA tensors
|
||||||
|
/// and `Arc<Comm>` clones held inside the model release on the
|
||||||
|
/// thread that allocated them.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
DropTp {
|
||||||
|
handle: TpHandle,
|
||||||
|
reply: oneshot::Sender<()>,
|
||||||
|
},
|
||||||
|
/// Reset the leader's KV cache for a TP model. Mirrors `ClearKv`
|
||||||
|
/// for single-GPU.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
TpClearKv {
|
||||||
|
handle: TpHandle,
|
||||||
|
reply: oneshot::Sender<Result<()>>,
|
||||||
|
},
|
||||||
|
/// Run one TP forward step on the leader's shard. Returns CPU-
|
||||||
|
/// side logits as a `Vec<f32>` so the async caller can sample
|
||||||
|
/// without holding a device tensor. The caller is also
|
||||||
|
/// responsible for fan-out to subprocess ranks and drain — only
|
||||||
|
/// the leader's forward moves into the worker thread.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
TpForwardLogits {
|
||||||
|
handle: TpHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
|
},
|
||||||
/// Tell the worker to break its dispatch loop and exit. Any jobs
|
/// Tell the worker to break its dispatch loop and exit. Any jobs
|
||||||
/// queued after this in the channel reply `Err` to their oneshot
|
/// queued after this in the channel reply `Err` to their oneshot
|
||||||
/// senders (the senders are dropped on the worker's exit, which
|
/// senders (the senders are dropped on the worker's exit, which
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ use std::sync::mpsc::{self, Sender};
|
|||||||
use std::thread::JoinHandle;
|
use std::thread::JoinHandle;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub use jobs::TpHandle;
|
||||||
pub use jobs::{ArchHandle, Job};
|
pub use jobs::{ArchHandle, Job};
|
||||||
|
|
||||||
/// Errors returned by `DeviceWorkerHandle` submit methods.
|
/// Errors returned by `DeviceWorkerHandle` submit methods.
|
||||||
@@ -277,6 +279,192 @@ impl DeviceWorkerHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initialise the leader's NCCL communicator. The reply uses
|
||||||
|
/// `WorkerResponse` (same shape subprocess workers use over stdio
|
||||||
|
/// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader +
|
||||||
|
/// subprocess responses uniformly. Available on no-cuda builds
|
||||||
|
/// too — the dispatch handler calls the no-cuda `NcclState::init`
|
||||||
|
/// stub which replies `cuda_feature_not_enabled`.
|
||||||
|
pub async fn nccl_init(
|
||||||
|
&self,
|
||||||
|
cfg: crate::harness::tp::worker::WorkerConfig,
|
||||||
|
comm_id_hex: String,
|
||||||
|
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::NcclInit {
|
||||||
|
cfg,
|
||||||
|
comm_id_hex,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
reply_rx.await.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run an NCCL sanity all_reduce on the leader's rank 0.
|
||||||
|
/// Available on no-cuda builds; replies with an error response.
|
||||||
|
pub async fn nccl_sanity(
|
||||||
|
&self,
|
||||||
|
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::NcclSanity { reply: reply_tx })
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
reply_rx.await.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clone the leader's `Arc<Comm>` so a spawn_blocking-based load
|
||||||
|
/// (Phase 3 bridge) can pass it to the row-parallel layers.
|
||||||
|
/// Phase 4 eliminates this once the TP load runs on this thread.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn clone_leader_comm(
|
||||||
|
&self,
|
||||||
|
) -> Result<crate::harness::tp::nccl_state::SendComm, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::CloneLeaderComm { reply: reply_tx })
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Move a freshly-built `TpLeaderModel` into the worker's TP slab.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn transfer_in_tp(
|
||||||
|
&self,
|
||||||
|
model: Box<crate::harness::tp::TpLeaderModel>,
|
||||||
|
) -> Result<TpHandle, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::TransferInTp {
|
||||||
|
model,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drop the TP model at `handle` on the worker thread.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn drop_tp(&self, handle: TpHandle) -> Result<(), WorkerError> {
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::DropTp {
|
||||||
|
handle,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(()) => Ok(()),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset the leader's KV cache for a TP model.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn tp_clear_kv(&self, handle: TpHandle) -> Result<(), WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::TpClearKv {
|
||||||
|
handle,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run one TP forward step on the leader's shard. Returns CPU-side
|
||||||
|
/// logits as `Vec<f32>` ready for sampling. The caller is
|
||||||
|
/// responsible for fan-out / drain of the subprocess workers
|
||||||
|
/// concurrently with this call.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn tp_forward_logits(
|
||||||
|
&self,
|
||||||
|
handle: TpHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<Vec<f32>, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::TpForwardLogits {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
|
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
|
||||||
/// twice is a no-op the second time.
|
/// twice is a no-op the second time.
|
||||||
pub fn shutdown(&self) -> anyhow::Result<()> {
|
pub fn shutdown(&self) -> anyhow::Result<()> {
|
||||||
|
|||||||
@@ -212,10 +212,15 @@ pub struct WorkerPool {
|
|||||||
/// Path to the neuron binary used to launch workers.
|
/// Path to the neuron binary used to launch workers.
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
exe: PathBuf,
|
exe: PathBuf,
|
||||||
/// Leader's own NCCL rank-0 state. Defaults to empty; populated by
|
/// The leader's per-device CUDA worker thread. Phase 3 moved the
|
||||||
/// `init_nccl()`. Held here so the leader can participate in
|
/// leader's `NcclState` (rank-0 NCCL Comm) into this thread, so
|
||||||
/// collectives (rank 0) without spawning a fourth subprocess.
|
/// every NCCL op (init, sanity, all_reduce inside forward) issues
|
||||||
leader_nccl: nccl_state::NcclState,
|
/// from one OS thread for the daemon's lifetime. The handle is
|
||||||
|
/// also used by `load_dense_shard` to clone the leader's
|
||||||
|
/// `Arc<Comm>` for the row-parallel layers' AllReduce ops; in
|
||||||
|
/// Phase 4 the load itself moves onto the worker and that bridge
|
||||||
|
/// goes away.
|
||||||
|
pub(crate) leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WorkerPool {
|
impl WorkerPool {
|
||||||
@@ -228,7 +233,12 @@ impl WorkerPool {
|
|||||||
/// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`).
|
/// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`).
|
||||||
/// `cuda_devices` is one entry per rank including rank 0. Worker
|
/// `cuda_devices` is one entry per rank including rank 0. Worker
|
||||||
/// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`.
|
/// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`.
|
||||||
pub async fn spawn(binary: &Path, world_size: u32, cuda_devices: &[u32]) -> Result<Self> {
|
pub async fn spawn(
|
||||||
|
binary: &Path,
|
||||||
|
world_size: u32,
|
||||||
|
cuda_devices: &[u32],
|
||||||
|
leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
|
||||||
|
) -> Result<Self> {
|
||||||
if world_size < 2 {
|
if world_size < 2 {
|
||||||
anyhow::bail!(
|
anyhow::bail!(
|
||||||
"WorkerPool::spawn called with world_size={world_size}; \
|
"WorkerPool::spawn called with world_size={world_size}; \
|
||||||
@@ -289,7 +299,7 @@ impl WorkerPool {
|
|||||||
world_size,
|
world_size,
|
||||||
workers,
|
workers,
|
||||||
exe,
|
exe,
|
||||||
leader_nccl: nccl_state::NcclState::new(),
|
leader_worker,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -321,27 +331,26 @@ impl WorkerPool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. Leader rank 0 calls Comm::from_rank on its own device.
|
// 2. Leader rank 0 calls Comm::from_rank on its own device.
|
||||||
// Runs on spawn_blocking because NCCL's init blocks until
|
// Phase 3 moved this from spawn_blocking onto the leader's
|
||||||
// every rank has called in — that's exactly the workers
|
// device worker thread (`Job::NcclInit`); the underlying
|
||||||
// above. The leader's NcclState is moved through the
|
// `Comm` now lives on the same OS thread for its entire
|
||||||
// blocking task and returned to the pool.
|
// lifetime, including every later `Comm::all_reduce` issued
|
||||||
|
// by the row-parallel layers during forward.
|
||||||
|
//
|
||||||
|
// NCCL's init blocks until every rank has called in — the
|
||||||
|
// subprocess workers above and the leader's device worker
|
||||||
|
// here. The Job's reply unblocks when the leader's
|
||||||
|
// Comm::from_rank returns.
|
||||||
let leader_cfg = worker::WorkerConfig {
|
let leader_cfg = worker::WorkerConfig {
|
||||||
rank: 0,
|
rank: 0,
|
||||||
world_size: self.world_size,
|
world_size: self.world_size,
|
||||||
cuda_device: leader_cuda_device,
|
cuda_device: leader_cuda_device,
|
||||||
};
|
};
|
||||||
let comm_id_for_leader = comm_id.clone();
|
let leader_resp = self
|
||||||
// Swap out the leader's NcclState into a fresh empty one so we
|
.leader_worker
|
||||||
// can move it into spawn_blocking; restore after the task
|
.nccl_init(leader_cfg, comm_id.clone())
|
||||||
// returns. (NcclState isn't Clone — it owns a real NCCL Comm.)
|
|
||||||
let mut leader_state = std::mem::take(&mut self.leader_nccl);
|
|
||||||
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
|
|
||||||
let resp = leader_state.init(leader_cfg, &comm_id_for_leader);
|
|
||||||
(leader_state, resp)
|
|
||||||
})
|
|
||||||
.await
|
.await
|
||||||
.context("leader NCCL init task panicked")?;
|
.map_err(|e| anyhow::anyhow!("leader NCCL init via device worker: {e}"))?;
|
||||||
self.leader_nccl = returned_state;
|
|
||||||
match leader_resp {
|
match leader_resp {
|
||||||
rpc::WorkerResponse::InitOk => {}
|
rpc::WorkerResponse::InitOk => {}
|
||||||
rpc::WorkerResponse::Error { kind, message } => {
|
rpc::WorkerResponse::Error { kind, message } => {
|
||||||
@@ -387,16 +396,16 @@ impl WorkerPool {
|
|||||||
w.send_only(&WorkerRequest::NcclSanityCheck).await?;
|
w.send_only(&WorkerRequest::NcclSanityCheck).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Leader's own all_reduce, in spawn_blocking. NCCL operations
|
// 2. Leader's own all_reduce, on its device worker thread.
|
||||||
// block until every rank participates.
|
// NCCL operations block until every rank participates;
|
||||||
let mut leader_state = std::mem::take(&mut self.leader_nccl);
|
// Job::NcclSanity returns once the leader's side completes
|
||||||
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
|
// (which happens when every subprocess worker reaches its
|
||||||
let resp = leader_state.sanity_check();
|
// all_reduce call too).
|
||||||
(leader_state, resp)
|
let leader_resp = self
|
||||||
})
|
.leader_worker
|
||||||
|
.nccl_sanity()
|
||||||
.await
|
.await
|
||||||
.context("leader NCCL sanity task panicked")?;
|
.map_err(|e| anyhow::anyhow!("leader NCCL sanity via device worker: {e}"))?;
|
||||||
self.leader_nccl = returned_state;
|
|
||||||
|
|
||||||
let expected = self.world_size;
|
let expected = self.world_size;
|
||||||
let leader_sum = match leader_resp {
|
let leader_sum = match leader_resp {
|
||||||
@@ -483,21 +492,24 @@ impl WorkerPool {
|
|||||||
leader_device: &candle_core::Device,
|
leader_device: &candle_core::Device,
|
||||||
dtype: candle_core::DType,
|
dtype: candle_core::DType,
|
||||||
quant: Option<String>,
|
quant: Option<String>,
|
||||||
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> {
|
) -> Result<super::device_worker::TpHandle> {
|
||||||
use candle_nn::var_builder::ShardedSafeTensors;
|
use candle_nn::var_builder::ShardedSafeTensors;
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::Mutex;
|
|
||||||
|
|
||||||
// Wrap the comm in SendComm immediately so it stays Send across
|
// Ask the leader's device worker for an `Arc<Comm>` clone.
|
||||||
// the await points in this method — bare Arc<Comm> would
|
// Phase 3 moved `NcclState` ownership onto the worker thread,
|
||||||
// poison the async fn's Send bound (Comm's raw NCCL pointer is
|
// so the spawn_blocking load below can no longer reach the
|
||||||
// !Send). The wrapper's safety contract is satisfied by the
|
// Comm directly. The reply is wrapped in `SendComm` because
|
||||||
// pool's outer Mutex serialising callers + the spawn_blocking
|
// the underlying `Arc<Comm>` is `!Send` at the type level;
|
||||||
// thread being the only place ops are issued.
|
// the safety contract (only one thread issues NCCL ops at a
|
||||||
let leader_comm =
|
// time) is preserved because the load runs on a single
|
||||||
nccl_state::SendComm(self.leader_nccl.comm().ok_or_else(|| {
|
// spawn_blocking thread and AllReduce ops fire only from the
|
||||||
anyhow::anyhow!("leader NCCL not initialised; call init_nccl first")
|
// device worker thread later. Phase 4 eliminates this bridge
|
||||||
})?);
|
// when the load itself moves onto the worker.
|
||||||
|
let leader_comm = self
|
||||||
|
.leader_worker
|
||||||
|
.clone_leader_comm()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("clone leader Comm via device worker: {e}"))?;
|
||||||
let world_size = self.world_size;
|
let world_size = self.world_size;
|
||||||
let safetensors_str: Vec<String> = safetensors_paths
|
let safetensors_str: Vec<String> = safetensors_paths
|
||||||
.iter()
|
.iter()
|
||||||
@@ -601,15 +613,32 @@ impl WorkerPool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Arc::new(Mutex::new(leader_model)))
|
// Phase 3: move the leader's freshly-built `TpLeaderModel`
|
||||||
|
// into the device worker's TP slab. The model holds
|
||||||
|
// `Arc<Comm>` clones (in its AllReduce ops) plus CUDA
|
||||||
|
// tensors — both need to live on the device worker thread so
|
||||||
|
// every `Comm::all_reduce` and tensor op during forward
|
||||||
|
// dispatches from the same OS thread that bound the CUDA
|
||||||
|
// context.
|
||||||
|
let handle = self
|
||||||
|
.leader_worker
|
||||||
|
.transfer_in_tp(Box::new(leader_model))
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("transfer TP leader model into device worker: {e}"))?;
|
||||||
|
Ok(handle)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run one forward step across every rank. The leader's forward
|
/// Run one forward step across every rank. The leader's forward
|
||||||
/// returns the last-position logits as a candle Tensor on the
|
/// runs on the device worker thread via `Job::TpForwardLogits` and
|
||||||
/// leader's device; the caller does sampling out-of-band. Workers
|
/// returns CPU-side `[vocab]` logits as `Vec<f32>`; the async
|
||||||
/// run their own forwards (the AllReduce inside row-parallel layers
|
/// caller wraps them in a CPU tensor for `apply_repeat_penalty` +
|
||||||
/// is what lets the leader's collective complete) and reply with
|
/// sampling without holding a device-resident tensor on a tokio
|
||||||
/// `GenerateStepOk` — they do not ship logits over the wire.
|
/// thread.
|
||||||
|
///
|
||||||
|
/// Subprocess workers run their own forwards in parallel (the
|
||||||
|
/// AllReduce CustomOps inside row-parallel layers are what let
|
||||||
|
/// the leader's collective complete) and reply with
|
||||||
|
/// `GenerateStepOk` over the RPC stream — they do not ship logits.
|
||||||
///
|
///
|
||||||
/// `tokens` is the input for this step (prompt for prefill, the
|
/// `tokens` is the input for this step (prompt for prefill, the
|
||||||
/// previously-sampled token for decode). `offset` is the KV-cache
|
/// previously-sampled token for decode). `offset` is the KV-cache
|
||||||
@@ -618,10 +647,10 @@ impl WorkerPool {
|
|||||||
pub async fn generate_step(
|
pub async fn generate_step(
|
||||||
&mut self,
|
&mut self,
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
|
leader_handle: super::device_worker::TpHandle,
|
||||||
tokens: Vec<u32>,
|
tokens: Vec<u32>,
|
||||||
offset: usize,
|
offset: usize,
|
||||||
) -> Result<candle_core::Tensor> {
|
) -> Result<Vec<f32>> {
|
||||||
let step_start = std::time::Instant::now();
|
let step_start = std::time::Instant::now();
|
||||||
let tokens_len = tokens.len();
|
let tokens_len = tokens.len();
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
@@ -630,7 +659,7 @@ impl WorkerPool {
|
|||||||
offset,
|
offset,
|
||||||
"WorkerPool::generate_step: fan-out"
|
"WorkerPool::generate_step: fan-out"
|
||||||
);
|
);
|
||||||
// 1. Fan-out to workers.
|
// 1. Fan-out to subprocess workers.
|
||||||
for w in &mut self.workers {
|
for w in &mut self.workers {
|
||||||
w.send_only(&WorkerRequest::GenerateStep {
|
w.send_only(&WorkerRequest::GenerateStep {
|
||||||
model_id: model_id.to_string(),
|
model_id: model_id.to_string(),
|
||||||
@@ -640,35 +669,30 @@ impl WorkerPool {
|
|||||||
.await?;
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Leader's forward in spawn_blocking. The AllReduce CustomOps
|
// 2. Leader's forward on its device worker thread. The
|
||||||
// inside the row-parallel layers block until every worker's
|
// AllReduce CustomOps inside the row-parallel layers block
|
||||||
// forward issues the matching collective.
|
// until every subprocess worker's forward issues the
|
||||||
|
// matching collective. Returning CPU-side `Vec<f32>` keeps
|
||||||
|
// the device tensor from escaping the worker thread —
|
||||||
|
// that's the invariant the whole refactor exists to
|
||||||
|
// preserve.
|
||||||
let leader_start = std::time::Instant::now();
|
let leader_start = std::time::Instant::now();
|
||||||
let leader_result = tokio::task::spawn_blocking(move || -> Result<candle_core::Tensor> {
|
let leader_result = self
|
||||||
let mut model = leader_model.blocking_lock();
|
.leader_worker
|
||||||
let device = model.device().clone();
|
.tp_forward_logits(leader_handle, tokens, offset)
|
||||||
let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
.await;
|
||||||
// ForCausalLM::forward returns [B, 1, V] — squeeze both
|
let leader_ok = leader_result.is_ok();
|
||||||
// leading dims to the rank-1 vocab logits the sampler wants.
|
|
||||||
let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?;
|
|
||||||
Ok(logits)
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.context("leader forward task panicked");
|
|
||||||
let leader_ok = matches!(leader_result, Ok(Ok(_)));
|
|
||||||
let leader_ms = leader_start.elapsed().as_millis();
|
let leader_ms = leader_start.elapsed().as_millis();
|
||||||
// Surface the leader's own error at WARN. Previously this was
|
// Surface the leader's own error at WARN before draining
|
||||||
// silently coerced to `leader_ok=false` while only worker
|
// workers so the operator can correlate it with whatever the
|
||||||
// ranks' errors got logged — when both the leader and a worker
|
// subprocess workers logged. Previously this was silently
|
||||||
// fail together (the typical "CUDA context is now poisoned"
|
// coerced to a bool.
|
||||||
// pattern after an OOM), the operator could see only the
|
|
||||||
// worker side and had to guess what hit rank 0.
|
|
||||||
if !leader_ok {
|
if !leader_ok {
|
||||||
let detail = match &leader_result {
|
let detail = leader_result
|
||||||
Ok(Err(e)) => format!("{e:#}"),
|
.as_ref()
|
||||||
Err(e) => format!("task: {e:#}"),
|
.err()
|
||||||
Ok(Ok(_)) => unreachable!("leader_ok=false implies an error path"),
|
.map(|e| format!("{e:#}"))
|
||||||
};
|
.unwrap_or_default();
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
model = %model_id,
|
model = %model_id,
|
||||||
tokens = tokens_len,
|
tokens = tokens_len,
|
||||||
@@ -707,7 +731,33 @@ impl WorkerPool {
|
|||||||
"WorkerPool::generate_step: workers drained"
|
"WorkerPool::generate_step: workers drained"
|
||||||
);
|
);
|
||||||
|
|
||||||
combine_leader_workers(leader_result, worker_errors, "GenerateStep")
|
// Combine the leader's Result + the workers' string-error
|
||||||
|
// list. Phase 3 inlines this because the upstream
|
||||||
|
// `combine_leader_workers` expects the spawn_blocking-shaped
|
||||||
|
// `Result<Result<T>>`; the new device-worker path produces a
|
||||||
|
// single `Result<T, WorkerError>` instead.
|
||||||
|
match leader_result {
|
||||||
|
Ok(values) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Ok(values)
|
||||||
|
} else {
|
||||||
|
anyhow::bail!(
|
||||||
|
"GenerateStep: leader succeeded but workers failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Err(anyhow::Error::new(e).context("GenerateStep: leader forward failed"))
|
||||||
|
} else {
|
||||||
|
Err(anyhow::Error::new(e).context(format!(
|
||||||
|
"GenerateStep: leader forward failed and workers also failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Reset the KV cache for `model_id` on every rank. Called at the
|
/// Reset the KV cache for `model_id` on every rank. Called at the
|
||||||
@@ -716,7 +766,7 @@ impl WorkerPool {
|
|||||||
pub async fn clear_kv_cache(
|
pub async fn clear_kv_cache(
|
||||||
&mut self,
|
&mut self,
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
|
#[cfg(feature = "cuda")] leader_handle: super::device_worker::TpHandle,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out");
|
tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out");
|
||||||
@@ -728,13 +778,18 @@ impl WorkerPool {
|
|||||||
}
|
}
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
{
|
{
|
||||||
let mut m = leader_model.lock().await;
|
// Leader-side clear on the device worker thread —
|
||||||
m.clear_kv_cache();
|
// `TpLeaderModel::clear_kv_cache` is infallible but still
|
||||||
|
// routes through Job::TpClearKv so the cache reset runs
|
||||||
|
// on the same thread that owns the model's CUDA tensors.
|
||||||
|
if let Err(e) = self.leader_worker.tp_clear_kv(leader_handle).await {
|
||||||
|
anyhow::bail!("leader TP clear_kv_cache via device worker: {e}");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Drain workers — same rationale as `generate_step`. The
|
// Drain workers — same rationale as `generate_step`. The
|
||||||
// leader's clear_kv_cache is in-process and infallible, but we
|
// leader's clear_kv_cache is now async-via-channel but still
|
||||||
// still always drain so an error on one worker doesn't leave
|
// returns before the drain so the workers' KvCacheCleared
|
||||||
// pending responses for the others.
|
// replies are processed in order.
|
||||||
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
||||||
WorkerResponse::KvCacheCleared => Ok(()),
|
WorkerResponse::KvCacheCleared => Ok(()),
|
||||||
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
||||||
|
|||||||
@@ -118,7 +118,13 @@ async fn tp_smoke(tp_size: u32, cuda_devices: Vec<u32>) -> Result<()> {
|
|||||||
binary = %exe.display(),
|
binary = %exe.display(),
|
||||||
"tp-smoke: spawning worker pool"
|
"tp-smoke: spawning worker pool"
|
||||||
);
|
);
|
||||||
let mut pool = tp::WorkerPool::spawn(&exe, tp_size, &cuda_devices).await?;
|
// tp_smoke is a diagnostic tool; spawn the leader's device worker
|
||||||
|
// directly. (In the daemon path, CandleHarness::ensure_device_worker
|
||||||
|
// caches one per device.)
|
||||||
|
let leader_worker = neuron::harness::device_worker::DeviceWorkerHandle::spawn(leader_device)
|
||||||
|
.context("spawn leader device worker for tp-smoke")?;
|
||||||
|
let mut pool =
|
||||||
|
tp::WorkerPool::spawn(&exe, tp_size, &cuda_devices, leader_worker.clone()).await?;
|
||||||
|
|
||||||
tracing::info!("tp-smoke: pinging every worker");
|
tracing::info!("tp-smoke: pinging every worker");
|
||||||
let pongs = pool.ping_all().await?;
|
let pongs = pool.ping_all().await?;
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
//! `Init` and `NcclSanityCheck` are stubbed in 7a-i, so this test
|
//! `Init` and `NcclSanityCheck` are stubbed in 7a-i, so this test
|
||||||
//! runs on any host the workspace builds on.
|
//! runs on any host the workspace builds on.
|
||||||
|
|
||||||
|
use neuron::harness::device_worker::DeviceWorkerHandle;
|
||||||
use neuron::harness::tp::{WorkerPool, rpc::WorkerResponse};
|
use neuron::harness::tp::{WorkerPool, rpc::WorkerResponse};
|
||||||
|
|
||||||
/// Path to the neuron binary built by cargo for this test process.
|
/// Path to the neuron binary built by cargo for this test process.
|
||||||
@@ -19,7 +20,8 @@ const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
|
|||||||
async fn test_spawn_ping_shutdown() {
|
async fn test_spawn_ping_shutdown() {
|
||||||
// cuda_devices: rank 0 → device 0 (leader, unused here),
|
// cuda_devices: rank 0 → device 0 (leader, unused here),
|
||||||
// rank 1 → device 1 (worker; not actually opened in 7a-i).
|
// rank 1 → device 1 (worker; not actually opened in 7a-i).
|
||||||
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1])
|
let leader_worker = DeviceWorkerHandle::spawn(0).expect("spawn device worker");
|
||||||
|
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1], leader_worker)
|
||||||
.await
|
.await
|
||||||
.expect("spawn worker pool");
|
.expect("spawn worker pool");
|
||||||
|
|
||||||
@@ -44,7 +46,8 @@ async fn test_spawn_ping_shutdown() {
|
|||||||
/// Three workers — exercise the loop in `ping_all` / `shutdown`.
|
/// Three workers — exercise the loop in `ping_all` / `shutdown`.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_spawn_three_workers() {
|
async fn test_spawn_three_workers() {
|
||||||
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 3, &[0, 1, 2])
|
let leader_worker = DeviceWorkerHandle::spawn(0).expect("spawn device worker");
|
||||||
|
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 3, &[0, 1, 2], leader_worker)
|
||||||
.await
|
.await
|
||||||
.expect("spawn worker pool");
|
.expect("spawn worker pool");
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,9 @@ async fn test_init_and_sanity_check_two_ranks() {
|
|||||||
.try_init();
|
.try_init();
|
||||||
|
|
||||||
// 2 ranks: leader = rank 0 on device 0, worker = rank 1 on device 1.
|
// 2 ranks: leader = rank 0 on device 0, worker = rank 1 on device 1.
|
||||||
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1])
|
let leader_worker = neuron::harness::device_worker::DeviceWorkerHandle::spawn(0)
|
||||||
|
.expect("spawn leader device worker");
|
||||||
|
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1], leader_worker)
|
||||||
.await
|
.await
|
||||||
.expect("spawn worker pool");
|
.expect("spawn worker pool");
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user