feat(tp): Stage 7b-iv — RPC + orchestration for TP load/inference
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m20s
build-prerelease / Build cortex binary (push) Successful in 4m25s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
CI / Test (push) Successful in 4m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m57s
build-prerelease / Build neuron-ampere (push) Successful in 4m51s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m49s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m51s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m20s
build-prerelease / Build cortex binary (push) Successful in 4m25s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
CI / Test (push) Successful in 4m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m57s
build-prerelease / Build neuron-ampere (push) Successful in 4m51s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m49s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m51s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
Wires the in-flight TP machinery (Stage 7a workers, 7b-iii sharded
Qwen3) end to end so a non-streaming chat completion can run across
multiple GPUs via NCCL.
RPC additions (tp/rpc.rs):
- LoadDenseShard{model_id, config_json, safetensors_paths}
- GenerateStep{model_id, tokens, offset}
- ClearKvCache{model_id}
- UnloadModel{model_id}
- LoadDenseShardOk / GenerateStepOk / KvCacheCleared / Unloaded
Worker side (tp/worker.rs):
- WorkerState gains a `models: HashMap<String, TpQwen3ForCausalLM>`
keyed by model_id. LoadDenseShard mmaps safetensors via
ShardedVarBuilder (only this rank's slice materialises), builds the
TP model with the rank's NCCL Comm cloned from NcclState.
- GenerateStep runs the rank-local forward; the resulting logits are
dropped (only the leader's are used for sampling). The forward's
value here is the NCCL collectives inside the row-parallel layers
letting the leader's rank-0 forward make progress.
Pool side (tp/mod.rs):
- WorkerPool::load_dense_shard fans LoadDenseShard out to every worker,
builds rank 0's shard on the leader via spawn_blocking with a fresh
SendComm wrapper at the move boundary (Comm is !Send at the type
level), collects per-rank LoadDenseShardOk. Returns the leader's
Arc<Mutex<TpQwen3ForCausalLM>>.
- WorkerPool::generate_step fans GenerateStep out, runs the leader's
rank-0 forward in spawn_blocking (the AllReduce CustomOps inside
row-parallel layers block until every worker issues the matching
collective), returns the leader's last-position logits Tensor.
- WorkerPool::clear_kv_cache + unload_model follow the same pattern.
NcclState refactor (tp/nccl_state.rs):
- comm field becomes Option<Arc<Comm>> (was Option<Comm>) so callers
can share a clone with TpQwen3ForCausalLM::load.
- new `comm()` accessor + `SendComm` wrapper for spawn_blocking moves.
- single allow(clippy::arc_with_non_send_sync) at the canonical
construction site (Comm is !Send by type but the runtime invariant
is enforced by SendComm + the pool's Mutex).
Harness side (candle.rs):
- LoadedHandle enum (Single | Tp) replaces the bare Arc<LoadedModel>
in the harness's registry. list_models / unload_model /
inference_endpoint walk the enum uniformly.
- TpLoadedModel holds the pool + leader_model + tokenizer + devices.
- load_model dispatches on `spec.tensor_parallel > 1` to a new
cuda-gated load_tp path: resolve dense files via hf-hub, spawn the
pool, init_nccl, load_dense_shard.
- chat_completion branches on the handle variant. The TP path mirrors
run_inference: clear_kv_cache, prefill, sample, decode loop,
detokenize. Acquires the pool Mutex for the whole request.
- Streaming through TP is deferred to Stage 7c (returns Other(err)).
Script (script/validate-neuron.sh):
- 4th positional arg `tp_size` (default 1). When >1, switches to the
dense path (tp + GGUF is mutually exclusive — bails) and adds
`tensor_parallel` + `devices` to the load payload. NEURON_DEVICES
env overrides the default 0..N-1 device list.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -338,6 +338,241 @@ impl WorkerPool {
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Load this rank's shard of a dense Qwen3 model on every rank.
|
||||
///
|
||||
/// The leader builds rank 0's `TpQwen3ForCausalLM` directly into
|
||||
/// the returned `Arc<Mutex<_>>` — workers build their rank-local
|
||||
/// shards in their own address spaces and confirm via
|
||||
/// `LoadDenseShardOk`. All ranks see the same `safetensors_paths`;
|
||||
/// `ShardedVarBuilder` slices each tensor by rank at materialisation
|
||||
/// time, so the per-rank VRAM footprint is roughly `1/world_size`
|
||||
/// of the full model (plus the replicated embedding/norm/lm_head).
|
||||
///
|
||||
/// `leader_device` is the candle `Device` the leader's shard lives
|
||||
/// on — typically `Device::new_cuda(leader_cuda_device)` matching
|
||||
/// the same index passed to `init_nccl`. `dtype` is the on-device
|
||||
/// element type; bf16 is the canonical Qwen3 distribution dtype.
|
||||
///
|
||||
/// `init_nccl` must have completed first. Bails if the leader's
|
||||
/// NCCL comm isn't set up yet.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn load_dense_shard(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
config_json: &str,
|
||||
safetensors_paths: &[std::path::PathBuf],
|
||||
leader_device: &candle_core::Device,
|
||||
dtype: candle_core::DType,
|
||||
) -> Result<std::sync::Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>> {
|
||||
use candle_nn::var_builder::ShardedSafeTensors;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
// Wrap the comm in SendComm immediately so it stays Send across
|
||||
// the await points in this method — bare Arc<Comm> would
|
||||
// poison the async fn's Send bound (Comm's raw NCCL pointer is
|
||||
// !Send). The wrapper's safety contract is satisfied by the
|
||||
// pool's outer Mutex serialising callers + the spawn_blocking
|
||||
// thread being the only place ops are issued.
|
||||
let leader_comm =
|
||||
nccl_state::SendComm(self.leader_nccl.comm().ok_or_else(|| {
|
||||
anyhow::anyhow!("leader NCCL not initialised; call init_nccl first")
|
||||
})?);
|
||||
let world_size = self.world_size;
|
||||
let safetensors_str: Vec<String> = safetensors_paths
|
||||
.iter()
|
||||
.map(|p| p.to_string_lossy().into_owned())
|
||||
.collect();
|
||||
|
||||
// 1. Fan out the LoadDenseShard request to every worker without
|
||||
// awaiting their replies — they'll build their shards in
|
||||
// parallel with the leader below.
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::LoadDenseShard {
|
||||
model_id: model_id.to_string(),
|
||||
config_json: config_json.to_string(),
|
||||
safetensors_paths: safetensors_str.clone(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
// 2. Build rank 0's shard on the leader. ShardedVarBuilder reads
|
||||
// only the rank's slice from safetensors — no full-tensor
|
||||
// materialisation. Runs in spawn_blocking because the
|
||||
// file-mmap + slice + copy-to-device work is synchronous.
|
||||
let cfg: super::tp::tp_qwen3::Config =
|
||||
serde_json::from_str(config_json).context("parse Qwen3 Config JSON for leader load")?;
|
||||
let paths_for_leader: Vec<std::path::PathBuf> = safetensors_paths.to_vec();
|
||||
let device_for_leader = leader_device.clone();
|
||||
let comm_for_leader = leader_comm;
|
||||
let model_id_for_log = model_id.to_string();
|
||||
let leader_model = tokio::task::spawn_blocking(
|
||||
move || -> Result<super::tp::tp_qwen3::TpQwen3ForCausalLM> {
|
||||
// SAFETY: same invariant as the single-GPU dense path —
|
||||
// the HF cache files are treated as immutable while the
|
||||
// mmap is held.
|
||||
let vb = unsafe {
|
||||
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
|
||||
.context("build ShardedVarBuilder over safetensors")?
|
||||
};
|
||||
let model = super::tp::tp_qwen3::TpQwen3ForCausalLM::load(
|
||||
&cfg,
|
||||
&vb,
|
||||
0,
|
||||
world_size,
|
||||
comm_for_leader.into_inner(),
|
||||
)?;
|
||||
tracing::info!(rank = 0, model = %model_id_for_log, "loaded TP shard (leader)");
|
||||
Ok(model)
|
||||
},
|
||||
)
|
||||
.await
|
||||
.context("leader load task panicked")??;
|
||||
|
||||
// 3. Collect worker confirmations. Anything other than
|
||||
// LoadDenseShardOk aborts the whole load — the leader's
|
||||
// already-loaded shard drops when this fn returns Err.
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match resp {
|
||||
WorkerResponse::LoadDenseShardOk => {}
|
||||
WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} LoadDenseShard [{kind}]: {message}", w.rank)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"worker rank {} LoadDenseShard: expected LoadDenseShardOk, got {other:?}",
|
||||
w.rank
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Arc::new(Mutex::new(leader_model)))
|
||||
}
|
||||
|
||||
/// Run one forward step across every rank. The leader's forward
|
||||
/// returns the last-position logits as a candle Tensor on the
|
||||
/// leader's device; the caller does sampling out-of-band. Workers
|
||||
/// run their own forwards (the AllReduce inside row-parallel layers
|
||||
/// is what lets the leader's collective complete) and reply with
|
||||
/// `GenerateStepOk` — they do not ship logits over the wire.
|
||||
///
|
||||
/// `tokens` is the input for this step (prompt for prefill, the
|
||||
/// previously-sampled token for decode). `offset` is the KV-cache
|
||||
/// position before this step.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn generate_step(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
leader_model: std::sync::Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
) -> Result<candle_core::Tensor> {
|
||||
// 1. Fan-out to workers.
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::GenerateStep {
|
||||
model_id: model_id.to_string(),
|
||||
tokens: tokens.clone(),
|
||||
offset,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
// 2. Leader's forward in spawn_blocking. The AllReduce CustomOps
|
||||
// inside the row-parallel layers block until every worker's
|
||||
// forward issues the matching collective.
|
||||
let logits = tokio::task::spawn_blocking(move || -> Result<candle_core::Tensor> {
|
||||
let mut model = leader_model.blocking_lock();
|
||||
let device = model.device().clone();
|
||||
let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
// TpQwen3ForCausalLM::forward returns [B, 1, V] (it slices
|
||||
// to the last position internally). Squeeze both leading
|
||||
// dims to get the rank-1 vocab logits LogitsProcessor wants.
|
||||
let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?;
|
||||
Ok(logits)
|
||||
})
|
||||
.await
|
||||
.context("leader forward task panicked")??;
|
||||
|
||||
// 3. Collect worker confirmations.
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match resp {
|
||||
WorkerResponse::GenerateStepOk => {}
|
||||
WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} GenerateStep [{kind}]: {message}", w.rank)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"worker rank {} GenerateStep: expected GenerateStepOk, got {other:?}",
|
||||
w.rank
|
||||
),
|
||||
}
|
||||
}
|
||||
Ok(logits)
|
||||
}
|
||||
|
||||
/// Reset the KV cache for `model_id` on every rank. Called at the
|
||||
/// start of every inference so a fresh request doesn't attend over
|
||||
/// the previous one's tokens.
|
||||
pub async fn clear_kv_cache(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<
|
||||
tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>,
|
||||
>,
|
||||
) -> Result<()> {
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::ClearKvCache {
|
||||
model_id: model_id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
let mut m = leader_model.lock().await;
|
||||
m.clear_kv_cache();
|
||||
}
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match resp {
|
||||
WorkerResponse::KvCacheCleared => {}
|
||||
WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} ClearKvCache [{kind}]: {message}", w.rank)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"worker rank {} ClearKvCache: expected KvCacheCleared, got {other:?}",
|
||||
w.rank
|
||||
),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Drop this model's shards on every rank. The leader's shard is
|
||||
/// expected to have been dropped by the caller (its `Arc` was held
|
||||
/// in the TpLoadedModel and goes away when that's removed).
|
||||
pub async fn unload_model(&mut self, model_id: &str) -> Result<()> {
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::UnloadModel {
|
||||
model_id: model_id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match resp {
|
||||
WorkerResponse::Unloaded => {}
|
||||
WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} UnloadModel [{kind}]: {message}", w.rank)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"worker rank {} UnloadModel: expected Unloaded, got {other:?}",
|
||||
w.rank
|
||||
),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
|
||||
/// children. Best-effort — individual worker failures are logged
|
||||
/// but don't abort the rest of the sweep.
|
||||
|
||||
Reference in New Issue
Block a user