From d46d8d4f6ca000abf1ee9b05672f3667589f4562 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Wed, 20 May 2026 06:38:33 +0300 Subject: [PATCH] =?UTF-8?q?feat(tp):=20Stage=207b-iv=20=E2=80=94=20RPC=20+?= =?UTF-8?q?=20orchestration=20for=20TP=20load/inference?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wires the in-flight TP machinery (Stage 7a workers, 7b-iii sharded Qwen3) end to end so a non-streaming chat completion can run across multiple GPUs via NCCL. RPC additions (tp/rpc.rs): - LoadDenseShard{model_id, config_json, safetensors_paths} - GenerateStep{model_id, tokens, offset} - ClearKvCache{model_id} - UnloadModel{model_id} - LoadDenseShardOk / GenerateStepOk / KvCacheCleared / Unloaded Worker side (tp/worker.rs): - WorkerState gains a `models: HashMap` keyed by model_id. LoadDenseShard mmaps safetensors via ShardedVarBuilder (only this rank's slice materialises), builds the TP model with the rank's NCCL Comm cloned from NcclState. - GenerateStep runs the rank-local forward; the resulting logits are dropped (only the leader's are used for sampling). The forward's value here is the NCCL collectives inside the row-parallel layers letting the leader's rank-0 forward make progress. Pool side (tp/mod.rs): - WorkerPool::load_dense_shard fans LoadDenseShard out to every worker, builds rank 0's shard on the leader via spawn_blocking with a fresh SendComm wrapper at the move boundary (Comm is !Send at the type level), collects per-rank LoadDenseShardOk. Returns the leader's Arc>. - WorkerPool::generate_step fans GenerateStep out, runs the leader's rank-0 forward in spawn_blocking (the AllReduce CustomOps inside row-parallel layers block until every worker issues the matching collective), returns the leader's last-position logits Tensor. - WorkerPool::clear_kv_cache + unload_model follow the same pattern. NcclState refactor (tp/nccl_state.rs): - comm field becomes Option> (was Option) so callers can share a clone with TpQwen3ForCausalLM::load. - new `comm()` accessor + `SendComm` wrapper for spawn_blocking moves. - single allow(clippy::arc_with_non_send_sync) at the canonical construction site (Comm is !Send by type but the runtime invariant is enforced by SendComm + the pool's Mutex). Harness side (candle.rs): - LoadedHandle enum (Single | Tp) replaces the bare Arc in the harness's registry. list_models / unload_model / inference_endpoint walk the enum uniformly. - TpLoadedModel holds the pool + leader_model + tokenizer + devices. - load_model dispatches on `spec.tensor_parallel > 1` to a new cuda-gated load_tp path: resolve dense files via hf-hub, spawn the pool, init_nccl, load_dense_shard. - chat_completion branches on the handle variant. The TP path mirrors run_inference: clear_kv_cache, prefill, sample, decode loop, detokenize. Acquires the pool Mutex for the whole request. - Streaming through TP is deferred to Stage 7c (returns Other(err)). Script (script/validate-neuron.sh): - 4th positional arg `tp_size` (default 1). When >1, switches to the dense path (tp + GGUF is mutually exclusive — bails) and adds `tensor_parallel` + `devices` to the load payload. NEURON_DEVICES env overrides the default 0..N-1 device list. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/neuron/src/harness/candle.rs | 372 +++++++++++++++++++-- crates/neuron/src/harness/tp/mod.rs | 235 +++++++++++++ crates/neuron/src/harness/tp/nccl_state.rs | 58 +++- crates/neuron/src/harness/tp/rpc.rs | 64 ++++ crates/neuron/src/harness/tp/worker.rs | 225 ++++++++++++- script/validate-neuron.sh | 46 ++- 6 files changed, 960 insertions(+), 40 deletions(-) diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 778dc86..cc1ffd1 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -31,11 +31,44 @@ use tokio::sync::{Mutex, RwLock, mpsc}; /// In-process candle harness. Owns the loaded model registry. pub struct CandleHarness { - models: Arc>>>, + models: Arc>>, hf_cache: Option, bind_url: String, } +/// One entry in the harness's loaded-model registry. Single-GPU loads +/// land in `Single`; loads with `tensor_parallel > 1` land in `Tp`. +/// The two variants share the same `model_id` key in the map, so +/// `list_models`, `unload_model`, and `inference_endpoint` can walk +/// them uniformly without branching the storage layout. +/// +/// `Clone` is cheap: both variants hold `Arc<_>` and cloning just bumps +/// the refcount. +#[derive(Clone)] +pub enum LoadedHandle { + Single(Arc), + #[cfg(feature = "cuda")] + Tp(Arc), +} + +impl LoadedHandle { + pub fn model_id(&self) -> &str { + match self { + LoadedHandle::Single(m) => &m.model_id, + #[cfg(feature = "cuda")] + LoadedHandle::Tp(m) => &m.model_id, + } + } + + pub fn devices(&self) -> Vec { + match self { + LoadedHandle::Single(m) => m.devices.clone(), + #[cfg(feature = "cuda")] + LoadedHandle::Tp(m) => m.devices.clone(), + } + } +} + /// A loaded model with its tokenizer, device placement, and architecture- /// specific weights. The `arch` is `Arc>` so the lock guard can be /// moved into `spawn_blocking` for synchronous candle forward passes. @@ -48,6 +81,25 @@ pub struct LoadedModel { pub devices: Vec, } +/// Tensor-parallel loaded model. Holds the leader's rank-0 shard +/// (which the inference loop drives via spawn_blocking) and the +/// `WorkerPool` (which drives every non-zero rank over the RPC +/// channel). Both are behind tokio Mutexes so concurrent inference +/// requests against the same model are serialised; concurrent loads +/// for *different* models would each have their own pool. +#[cfg(feature = "cuda")] +pub struct TpLoadedModel { + pub model_id: String, + pub tokenizer: Tokenizer, + pub devices: Vec, + /// One end-to-end gate: the pool's RPC stream isn't safe to use + /// concurrently and the leader shard's KV cache mutates with every + /// step. The same Mutex covers both for the simplest correctness + /// story. + pub pool: tokio::sync::Mutex, + pub leader_model: Arc>, +} + /// Architecture-specific weights. /// /// - `Qwen3Quantized` — GGUF source, pre-quantized. Single-GPU only; @@ -357,11 +409,22 @@ impl CandleHarness { &self, request: ChatCompletionRequest, ) -> Result { - let loaded = { + let handle = { let models = self.models.read().await; models.get(&request.model).cloned() }; - let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?; + let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?; + // The match is technically infallible without `cuda` (only Single + // exists), but the cfg-gated Tp arm makes this the right shape + // under both feature flags. + #[allow(clippy::infallible_destructuring_match)] + let loaded = match handle { + LoadedHandle::Single(m) => m, + #[cfg(feature = "cuda")] + LoadedHandle::Tp(m) => { + return self.chat_completion_tp(m, request).await; + } + }; let prompt = format_qwen3_prompt(&request.messages); @@ -451,11 +514,29 @@ impl CandleHarness { &self, request: ChatCompletionRequest, ) -> Result, InferenceError> { - let loaded = { + let handle = { let models = self.models.read().await; models.get(&request.model).cloned() }; - let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?; + let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?; + // The match is technically infallible without `cuda` (only Single + // exists), but the cfg-gated Tp arm makes this the right shape + // under both feature flags. + #[allow(clippy::infallible_destructuring_match)] + let loaded = match handle { + LoadedHandle::Single(m) => m, + #[cfg(feature = "cuda")] + LoadedHandle::Tp(_) => { + // Streaming through TP is Stage 7c work — the + // non-streaming path drives the same forwards through + // the pool but doesn't have to interleave SSE writes + // with spawn_blocking forwards. + return Err(InferenceError::Other(anyhow::anyhow!( + "streaming chat completions through TP are not yet supported; \ + retry with stream=false" + ))); + } + }; let prompt = format_qwen3_prompt(&request.messages); let encoding = loaded @@ -552,11 +633,11 @@ impl Harness for CandleHarness { let models = self.models.read().await; Ok(models .values() - .map(|m| ModelInfo { - id: m.model_id.clone(), + .map(|h| ModelInfo { + id: h.model_id().into(), harness: "candle".into(), status: "loaded".into(), - devices: m.devices.clone(), + devices: h.devices(), vram_used_mb: None, }) .collect()) @@ -574,19 +655,20 @@ impl Harness for CandleHarness { } } - // Stage 7a-i scaffolds tensor-parallel worker subprocesses but - // does not yet route inference through them. Refuse TP loads - // for now with a clear marker so the request surface is honest; - // Stage 7b-iv replaces this bail with the TP dispatch. let tp_size = spec.tensor_parallel.unwrap_or(1); if tp_size > 1 { - anyhow::bail!( - "tensor_parallel={tp_size} requested for '{}': TP worker \ - lifecycle + NCCL handshake are in place (Stage 7a) but \ - TP-aware Qwen3 inference orchestration lands in Stage \ - 7b-iv; single-GPU loads only for now", - spec.model_id - ); + #[cfg(feature = "cuda")] + { + return self.load_tp(spec, tp_size).await; + } + #[cfg(not(feature = "cuda"))] + { + anyhow::bail!( + "tensor_parallel={tp_size} requested for '{}': this neuron \ + binary was built without --features cuda; TP requires CUDA + NCCL", + spec.model_id + ); + } } let devices = spec.devices.clone().unwrap_or_else(|| vec![0]); @@ -615,15 +697,52 @@ impl Harness for CandleHarness { }); let mut models = self.models.write().await; - models.insert(spec.model_id.clone(), loaded); + models.insert(spec.model_id.clone(), LoadedHandle::Single(loaded)); tracing::info!(model = %spec.model_id, "model loaded"); Ok(()) } async fn unload_model(&self, model_id: &str) -> Result<()> { - let mut models = self.models.write().await; - if models.remove(model_id).is_none() { + let removed = { + let mut models = self.models.write().await; + models.remove(model_id) + }; + let Some(handle) = removed else { anyhow::bail!("model '{model_id}' not loaded"); + }; + // Single-GPU drops are immediate — the LoadedModel goes out of + // scope with the Arc and candle frees VRAM. TP unloads also + // need to tell every worker to drop its shard before the pool + // itself is dropped (otherwise the workers keep their shards + // around until Shutdown, which is wasteful and would surface + // as VRAM not freed promptly). + match handle { + LoadedHandle::Single(_) => {} + #[cfg(feature = "cuda")] + LoadedHandle::Tp(tp) => { + // Try to recover the inner TpLoadedModel so we can move + // the pool and shut it down. If anyone else still holds + // a clone of the Arc (shouldn't happen — the only owners + // are the registry and any in-flight chat_completion), + // bail with a clear marker rather than silently leaking. + let tp = match Arc::try_unwrap(tp) { + Ok(t) => t, + Err(arc) => { + // Reinsert so we don't leave the registry in an + // inconsistent state. + let mut models = self.models.write().await; + models.insert(model_id.into(), LoadedHandle::Tp(arc)); + anyhow::bail!("cannot unload '{model_id}': inference still in flight"); + } + }; + let mut pool = tp.pool.into_inner(); + if let Err(e) = pool.unload_model(model_id).await { + tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed"); + } + if let Err(e) = pool.shutdown().await { + tracing::warn!(model = %model_id, error = %e, "TP pool shutdown failed"); + } + } } tracing::info!(model = %model_id, "model unloaded"); Ok(()) @@ -635,6 +754,215 @@ impl Harness for CandleHarness { } } +impl CandleHarness { + /// Tensor-parallel load. Resolves dense safetensors via hf-hub the + /// same way the single-GPU dense path does, spins up a TP worker + /// pool sized to `tp_size`, runs the NCCL handshake, then has + /// every rank load its shard of the model. + /// + /// `spec.devices` carries the per-rank CUDA device indices (one + /// entry per rank, in rank order); defaults to `0..tp_size`. + #[cfg(feature = "cuda")] + async fn load_tp(&self, spec: &ModelSpec, tp_size: u32) -> Result<()> { + use std::sync::Arc as StdArc; + use tokio::sync::Mutex as TMutex; + + // Default per-rank device assignment: 0, 1, ..., tp_size - 1. + let devices = spec + .devices + .clone() + .unwrap_or_else(|| (0..tp_size).collect()); + if devices.len() as u32 != tp_size { + anyhow::bail!( + "tensor_parallel={tp_size} requires {tp_size} entries in devices, got {}", + devices.len() + ); + } + if spec.quant.is_some() { + anyhow::bail!( + "tensor_parallel={tp_size} with quant={:?}: GGUF quantized models \ + are not supported in the TP path; use a dense safetensors source", + spec.quant + ); + } + + // 1. Resolve config + tokenizer + safetensors via hf-hub. + let (config_path, tokenizer_path, safetensors_paths) = + self.resolve_dense_files(spec).await?; + let config_json = std::fs::read_to_string(&config_path).context("read config.json")?; + + // 2. Spawn the worker pool. Rank 0 stays in-process; ranks + // 1..tp_size are subprocesses, one per device after the + // leader's own. + let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?; + let mut pool = super::tp::WorkerPool::spawn(&exe, tp_size, &devices).await?; + + // 3. NCCL handshake across all ranks. + let leader_device_idx = devices[0]; + pool.init_nccl(leader_device_idx).await?; + + // 4. Pick the leader's candle Device (same index as init_nccl). + let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize) + .context("Device::new_cuda for TP leader")?; + + // 5. Load this rank's shard on every rank. + let leader_model = pool + .load_dense_shard( + &spec.model_id, + &config_json, + &safetensors_paths, + &leader_device, + candle_core::DType::BF16, + ) + .await?; + + // 6. Tokenizer (same as single-GPU path). + let tokenizer = Tokenizer::from_file(&tokenizer_path) + .map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?; + + let tp_loaded = StdArc::new(TpLoadedModel { + model_id: spec.model_id.clone(), + tokenizer, + devices: devices.clone(), + pool: TMutex::new(pool), + leader_model, + }); + + let mut models = self.models.write().await; + models.insert(spec.model_id.clone(), LoadedHandle::Tp(tp_loaded)); + tracing::info!( + model = %spec.model_id, + tp_size, + ?devices, + "TP model loaded" + ); + Ok(()) + } + + /// Non-streaming chat completion against a TP model. Pattern mirrors + /// the single-GPU `run_inference`: tokenize, prefill, sample, decode + /// loop, detokenize. Each forward step fans out to every rank via + /// the WorkerPool and uses the leader's last-position logits to + /// sample. + #[cfg(feature = "cuda")] + async fn chat_completion_tp( + &self, + tp: Arc, + request: ChatCompletionRequest, + ) -> Result { + let prompt = format_qwen3_prompt(&request.messages); + let encoding = tp + .tokenizer + .encode(prompt.as_str(), true) + .map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?; + let prompt_tokens: Vec = encoding.get_ids().to_vec(); + let prompt_len = prompt_tokens.len(); + + let temperature = request.temperature.unwrap_or(0.7); + let top_p = request.top_p; + let max_new = request.max_tokens.unwrap_or(512) as usize; + let seed = unix_subsec_nanos(); + + let eos_id = tp + .tokenizer + .token_to_id("<|im_end|>") + .or_else(|| tp.tokenizer.token_to_id("<|endoftext|>")); + + let model_id = request.model.clone(); + + // Acquire the pool lock for the duration of the request. The + // leader_model's own Mutex is acquired step-by-step inside + // pool.generate_step (so spawn_blocking can grab it without + // holding the pool lock across the blocking_lock call). + let mut pool = tp.pool.lock().await; + let leader_arc = tp.leader_model.clone(); + + // Reset every rank's KV cache so this request doesn't attend + // over the previous request's tokens. + pool.clear_kv_cache(&model_id, leader_arc.clone()) + .await + .map_err(InferenceError::Other)?; + + let mut logits_processor = { + let sampling = if temperature <= 0.0 { + Sampling::ArgMax + } else { + match top_p { + Some(p) => Sampling::TopP { p, temperature }, + None => Sampling::All { temperature }, + } + }; + LogitsProcessor::from_sampling(seed, sampling) + }; + + let mut generated: Vec = Vec::new(); + let mut finish_reason = "length".to_string(); + + // Prefill: every rank embeds the whole prompt, offset = 0. + let logits = pool + .generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0) + .await + .map_err(InferenceError::Other)?; + let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor) + .map_err(InferenceError::Other)?; + + if Some(next_token) == eos_id { + finish_reason = "stop".into(); + } else { + generated.push(next_token); + for index in 0..max_new.saturating_sub(1) { + let logits = pool + .generate_step( + &model_id, + leader_arc.clone(), + vec![next_token], + prompt_len + index, + ) + .await + .map_err(InferenceError::Other)?; + next_token = sample_with_penalty(&logits, &generated, &mut logits_processor) + .map_err(InferenceError::Other)?; + if Some(next_token) == eos_id { + finish_reason = "stop".into(); + break; + } + generated.push(next_token); + } + } + drop(pool); + + let completion_text = tp + .tokenizer + .decode(&generated, true) + .map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?; + + let usage = Usage { + prompt_tokens: prompt_len as u64, + completion_tokens: generated.len() as u64, + total_tokens: (prompt_len + generated.len()) as u64, + }; + + Ok(ChatCompletionResponse { + id: format!("chatcmpl-{:x}", unix_subsec_nanos()), + object: "chat.completion".into(), + created: unix_now_secs(), + model: model_id, + choices: vec![ChatCompletionChoice { + index: 0, + message: ChatMessage { + role: "assistant".into(), + content: MessageContent::Text(completion_text), + extra: serde_json::Value::Object(Default::default()), + }, + finish_reason: Some(finish_reason), + extra: serde_json::Value::Object(Default::default()), + }], + usage: Some(usage), + extra: serde_json::Value::Object(Default::default()), + }) + } +} + /// Errors returned by `CandleHarness::chat_completion`. The /// `ModelNotLoaded` variant lets the HTTP handler map cleanly to 404 /// without string-matching on anyhow messages. diff --git a/crates/neuron/src/harness/tp/mod.rs b/crates/neuron/src/harness/tp/mod.rs index e497dd7..451904c 100644 --- a/crates/neuron/src/harness/tp/mod.rs +++ b/crates/neuron/src/harness/tp/mod.rs @@ -338,6 +338,241 @@ impl WorkerPool { Ok(out) } + /// Load this rank's shard of a dense Qwen3 model on every rank. + /// + /// The leader builds rank 0's `TpQwen3ForCausalLM` directly into + /// the returned `Arc>` — workers build their rank-local + /// shards in their own address spaces and confirm via + /// `LoadDenseShardOk`. All ranks see the same `safetensors_paths`; + /// `ShardedVarBuilder` slices each tensor by rank at materialisation + /// time, so the per-rank VRAM footprint is roughly `1/world_size` + /// of the full model (plus the replicated embedding/norm/lm_head). + /// + /// `leader_device` is the candle `Device` the leader's shard lives + /// on — typically `Device::new_cuda(leader_cuda_device)` matching + /// the same index passed to `init_nccl`. `dtype` is the on-device + /// element type; bf16 is the canonical Qwen3 distribution dtype. + /// + /// `init_nccl` must have completed first. Bails if the leader's + /// NCCL comm isn't set up yet. + #[cfg(feature = "cuda")] + pub async fn load_dense_shard( + &mut self, + model_id: &str, + config_json: &str, + safetensors_paths: &[std::path::PathBuf], + leader_device: &candle_core::Device, + dtype: candle_core::DType, + ) -> Result>> { + use candle_nn::var_builder::ShardedSafeTensors; + use std::sync::Arc; + use tokio::sync::Mutex; + + // Wrap the comm in SendComm immediately so it stays Send across + // the await points in this method — bare Arc would + // poison the async fn's Send bound (Comm's raw NCCL pointer is + // !Send). The wrapper's safety contract is satisfied by the + // pool's outer Mutex serialising callers + the spawn_blocking + // thread being the only place ops are issued. + let leader_comm = + nccl_state::SendComm(self.leader_nccl.comm().ok_or_else(|| { + anyhow::anyhow!("leader NCCL not initialised; call init_nccl first") + })?); + let world_size = self.world_size; + let safetensors_str: Vec = safetensors_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + + // 1. Fan out the LoadDenseShard request to every worker without + // awaiting their replies — they'll build their shards in + // parallel with the leader below. + for w in &mut self.workers { + w.send_only(&WorkerRequest::LoadDenseShard { + model_id: model_id.to_string(), + config_json: config_json.to_string(), + safetensors_paths: safetensors_str.clone(), + }) + .await?; + } + + // 2. Build rank 0's shard on the leader. ShardedVarBuilder reads + // only the rank's slice from safetensors — no full-tensor + // materialisation. Runs in spawn_blocking because the + // file-mmap + slice + copy-to-device work is synchronous. + let cfg: super::tp::tp_qwen3::Config = + serde_json::from_str(config_json).context("parse Qwen3 Config JSON for leader load")?; + let paths_for_leader: Vec = safetensors_paths.to_vec(); + let device_for_leader = leader_device.clone(); + let comm_for_leader = leader_comm; + let model_id_for_log = model_id.to_string(); + let leader_model = tokio::task::spawn_blocking( + move || -> Result { + // SAFETY: same invariant as the single-GPU dense path — + // the HF cache files are treated as immutable while the + // mmap is held. + let vb = unsafe { + ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader) + .context("build ShardedVarBuilder over safetensors")? + }; + let model = super::tp::tp_qwen3::TpQwen3ForCausalLM::load( + &cfg, + &vb, + 0, + world_size, + comm_for_leader.into_inner(), + )?; + tracing::info!(rank = 0, model = %model_id_for_log, "loaded TP shard (leader)"); + Ok(model) + }, + ) + .await + .context("leader load task panicked")??; + + // 3. Collect worker confirmations. Anything other than + // LoadDenseShardOk aborts the whole load — the leader's + // already-loaded shard drops when this fn returns Err. + for w in &mut self.workers { + let resp = w.recv_only().await?; + match resp { + WorkerResponse::LoadDenseShardOk => {} + WorkerResponse::Error { kind, message } => { + anyhow::bail!("worker rank {} LoadDenseShard [{kind}]: {message}", w.rank) + } + other => anyhow::bail!( + "worker rank {} LoadDenseShard: expected LoadDenseShardOk, got {other:?}", + w.rank + ), + } + } + + Ok(Arc::new(Mutex::new(leader_model))) + } + + /// Run one forward step across every rank. The leader's forward + /// returns the last-position logits as a candle Tensor on the + /// leader's device; the caller does sampling out-of-band. Workers + /// run their own forwards (the AllReduce inside row-parallel layers + /// is what lets the leader's collective complete) and reply with + /// `GenerateStepOk` — they do not ship logits over the wire. + /// + /// `tokens` is the input for this step (prompt for prefill, the + /// previously-sampled token for decode). `offset` is the KV-cache + /// position before this step. + #[cfg(feature = "cuda")] + pub async fn generate_step( + &mut self, + model_id: &str, + leader_model: std::sync::Arc>, + tokens: Vec, + offset: usize, + ) -> Result { + // 1. Fan-out to workers. + for w in &mut self.workers { + w.send_only(&WorkerRequest::GenerateStep { + model_id: model_id.to_string(), + tokens: tokens.clone(), + offset, + }) + .await?; + } + + // 2. Leader's forward in spawn_blocking. The AllReduce CustomOps + // inside the row-parallel layers block until every worker's + // forward issues the matching collective. + let logits = tokio::task::spawn_blocking(move || -> Result { + let mut model = leader_model.blocking_lock(); + let device = model.device().clone(); + let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; + // TpQwen3ForCausalLM::forward returns [B, 1, V] (it slices + // to the last position internally). Squeeze both leading + // dims to get the rank-1 vocab logits LogitsProcessor wants. + let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?; + Ok(logits) + }) + .await + .context("leader forward task panicked")??; + + // 3. Collect worker confirmations. + for w in &mut self.workers { + let resp = w.recv_only().await?; + match resp { + WorkerResponse::GenerateStepOk => {} + WorkerResponse::Error { kind, message } => { + anyhow::bail!("worker rank {} GenerateStep [{kind}]: {message}", w.rank) + } + other => anyhow::bail!( + "worker rank {} GenerateStep: expected GenerateStepOk, got {other:?}", + w.rank + ), + } + } + Ok(logits) + } + + /// Reset the KV cache for `model_id` on every rank. Called at the + /// start of every inference so a fresh request doesn't attend over + /// the previous one's tokens. + pub async fn clear_kv_cache( + &mut self, + model_id: &str, + #[cfg(feature = "cuda")] leader_model: std::sync::Arc< + tokio::sync::Mutex, + >, + ) -> Result<()> { + for w in &mut self.workers { + w.send_only(&WorkerRequest::ClearKvCache { + model_id: model_id.to_string(), + }) + .await?; + } + #[cfg(feature = "cuda")] + { + let mut m = leader_model.lock().await; + m.clear_kv_cache(); + } + for w in &mut self.workers { + let resp = w.recv_only().await?; + match resp { + WorkerResponse::KvCacheCleared => {} + WorkerResponse::Error { kind, message } => { + anyhow::bail!("worker rank {} ClearKvCache [{kind}]: {message}", w.rank) + } + other => anyhow::bail!( + "worker rank {} ClearKvCache: expected KvCacheCleared, got {other:?}", + w.rank + ), + } + } + Ok(()) + } + + /// Drop this model's shards on every rank. The leader's shard is + /// expected to have been dropped by the caller (its `Arc` was held + /// in the TpLoadedModel and goes away when that's removed). + pub async fn unload_model(&mut self, model_id: &str) -> Result<()> { + for w in &mut self.workers { + w.send_only(&WorkerRequest::UnloadModel { + model_id: model_id.to_string(), + }) + .await?; + } + for w in &mut self.workers { + let resp = w.recv_only().await?; + match resp { + WorkerResponse::Unloaded => {} + WorkerResponse::Error { kind, message } => { + anyhow::bail!("worker rank {} UnloadModel [{kind}]: {message}", w.rank) + } + other => anyhow::bail!( + "worker rank {} UnloadModel: expected Unloaded, got {other:?}", + w.rank + ), + } + } + Ok(()) + } + /// Send `Shutdown` to every worker, await each `Bye`, and reap the /// children. Best-effort — individual worker failures are logged /// but don't abort the rest of the sweep. diff --git a/crates/neuron/src/harness/tp/nccl_state.rs b/crates/neuron/src/harness/tp/nccl_state.rs index 7638599..402e218 100644 --- a/crates/neuron/src/harness/tp/nccl_state.rs +++ b/crates/neuron/src/harness/tp/nccl_state.rs @@ -83,7 +83,13 @@ mod cuda_impl { const NCCL_ID_BYTES: usize = 128; pub struct NcclState { - comm: Option, + /// Wrapped in `Arc` so we can hand a clone to `TpQwen3ForCausalLM` + /// at load time (every row-parallel layer needs a reference to + /// run its trailing `AllReduce`). The `Arc` is the single source + /// of truth for the comm's lifetime — when the pool drops and + /// every layer that captured a clone drops, NCCL releases the + /// underlying `ncclComm_t`. + comm: Option>, /// Held alongside the Comm so the device isn't dropped /// underneath the NCCL handle. #[allow(dead_code)] @@ -103,6 +109,40 @@ mod cuda_impl { ctx: None, } } + + /// Clone the comm out as an `Arc` so callers (the leader-side + /// `TpQwen3ForCausalLM::load`, or the worker's own model load) + /// can hold a reference for the lifetime of the model. Returns + /// `None` before `init` has run. + pub fn comm(&self) -> Option> { + self.comm.clone() + } + } + + /// `Arc` doesn't impl `Send` because `Comm` wraps a raw + /// `ncclComm_t` pointer. The NCCL contract is "operations against a + /// given comm must be serialised", not "the handle must stay on the + /// thread that created it" — so it's safe to move an `Arc` + /// across threads as long as no concurrent ops are issued. The + /// pool's outer Mutex serialises us into `spawn_blocking`, so this + /// wrapper at the move boundary is the only thing missing. + /// + /// `Sync` is also marked safe because the `Arc` clones held + /// by the row-parallel layers are only used from the + /// `spawn_blocking` thread driving the forward pass; concurrent + /// access from another thread would still be a bug. + pub struct SendComm(pub Arc); + + // SAFETY: see the doc-comment above; the invariant is enforced at + // the call site (pool Mutex + single spawn_blocking thread), not at + // the type level. + unsafe impl Send for SendComm {} + unsafe impl Sync for SendComm {} + + impl SendComm { + pub fn into_inner(self) -> Arc { + self.0 + } } // SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer @@ -143,7 +183,7 @@ mod cuda_impl { message: "sanity_check requires Init to have completed first".into(), }; }; - match try_sanity_check(comm) { + match try_sanity_check(comm.as_ref()) { Ok(sum) => WorkerResponse::NcclSanityResult { observed_sum: sum }, Err(msg) => WorkerResponse::Error { kind: "nccl_sanity_failed".into(), @@ -177,7 +217,17 @@ mod cuda_impl { })?; state.ctx = Some(ctx); - state.comm = Some(comm); + // `Comm` is !Send + !Sync at the type level because it wraps a + // raw `ncclComm_t`. The `Arc` is fine in practice — we + // serialise operations through the pool's outer Mutex and the + // SendComm wrapper at thread-crossing boundaries enforces this + // at every move site. clippy's `arc_with_non_send_sync` lint + // can't see that invariant; allow once at the canonical + // construction site. + #[allow(clippy::arc_with_non_send_sync)] + { + state.comm = Some(Arc::new(comm)); + } Ok(()) } @@ -202,7 +252,7 @@ mod cuda_impl { } #[cfg(feature = "cuda")] -pub use cuda_impl::{NcclState, generate_comm_id_hex}; +pub use cuda_impl::{NcclState, SendComm, generate_comm_id_hex}; /// Non-cuda stub for the leader: returns a clear marker error rather /// than letting `init_nccl` succeed vacuously. diff --git a/crates/neuron/src/harness/tp/rpc.rs b/crates/neuron/src/harness/tp/rpc.rs index e9fd047..2a444a6 100644 --- a/crates/neuron/src/harness/tp/rpc.rs +++ b/crates/neuron/src/harness/tp/rpc.rs @@ -45,6 +45,52 @@ pub enum WorkerRequest { /// the NCCL handshake is genuinely live, not just configured. NcclSanityCheck, + /// Load this rank's shard of a dense Qwen3 model from mmaped + /// safetensors. The same `safetensors_paths` list is sent to every + /// rank — the ShardedVarBuilder reads only the rank-local slice of + /// each tensor at materialisation time, so the worker's VRAM + /// footprint is `1 / world_size` of the full model (plus replicated + /// embedding/norm/lm_head). + LoadDenseShard { + /// Caller-supplied id for later `GenerateStep` / `UnloadModel` + /// lookups. Typically the HF model id verbatim. + model_id: String, + /// JSON-serialised `candle_transformers::models::qwen3::Config` + /// — the same blob the leader parsed from the HF cache's + /// `config.json`. Threaded through verbatim so the worker uses + /// identical hyperparameters. + config_json: String, + /// Absolute paths the worker should mmap. The same set on every + /// rank; ShardedVarBuilder slices into them per rank. + safetensors_paths: Vec, + }, + + /// Run one forward step on this rank's loaded model. The worker + /// reaches into its NCCL Comm for the row-parallel `AllReduce`s + /// inside the model — and so blocks on every other rank issuing the + /// same op. The leader does *not* receive logits back over RPC; it + /// runs its own rank-0 forward in parallel and uses its own logits + /// for sampling. + GenerateStep { + model_id: String, + /// Input token ids for this step. For prefill, the whole prompt; + /// for decode, a single token. Identical on every rank. + tokens: Vec, + /// KV cache offset (count of tokens already in the cache before + /// this step). + offset: usize, + }, + + /// Reset the KV cache for this model on this rank. Sent at the + /// start of every inference so a fresh request doesn't accidentally + /// attend over the previous one's tokens. + ClearKvCache { model_id: String }, + + /// Drop this rank's shard for the given model. Releases the VRAM + /// the shard's weights occupied; subsequent `GenerateStep` calls + /// against the same `model_id` return an `Error`. + UnloadModel { model_id: String }, + /// Worker should release resources and exit. Worker replies `Bye` /// and then closes stdout / exits zero. The leader reaps the /// child via the `tokio::process::Child` it kept. @@ -74,6 +120,24 @@ pub enum WorkerResponse { /// this matches `world_size`. NcclSanityResult { observed_sum: u32 }, + /// Reply to `LoadDenseShard`. Empty payload — success is the + /// absence of `Error`. By the time this comes back, the rank's + /// `TpQwen3ForCausalLM` is constructed in memory and ready for + /// `GenerateStep`. + LoadDenseShardOk, + + /// Reply to `GenerateStep`. Empty payload — workers don't ship + /// logits over the wire. The leader uses its own rank-0 logits; + /// workers only need to confirm the collective completed. + GenerateStepOk, + + /// Reply to `ClearKvCache`. Empty payload. + KvCacheCleared, + + /// Reply to `UnloadModel`. Empty payload. The named model is no + /// longer present on this rank. + Unloaded, + /// Reply to `Shutdown`. Worker exits immediately after writing this. Bye, diff --git a/crates/neuron/src/harness/tp/worker.rs b/crates/neuron/src/harness/tp/worker.rs index f33995a..d7d0ca7 100644 --- a/crates/neuron/src/harness/tp/worker.rs +++ b/crates/neuron/src/harness/tp/worker.rs @@ -5,18 +5,23 @@ //! exactly one `WorkerResponse` JSON line to stdout. tracing goes to //! stderr so it doesn't collide with the RPC stream. //! -//! NCCL operations (`Init`, `NcclSanityCheck`) are real when built -//! with the `cuda` feature; without it they reply with -//! `Error{kind="cuda_feature_not_enabled"}` so the leader can tell -//! the difference between a misconfigured build and a genuine NCCL -//! failure. +//! NCCL operations (`Init`, `NcclSanityCheck`) and model lifecycle ops +//! (`LoadDenseShard`, `GenerateStep`, `ClearKvCache`, `UnloadModel`) +//! are real when built with the `cuda` feature; without it they reply +//! with `Error{kind="cuda_feature_not_enabled"}` so the leader can tell +//! the difference between a misconfigured build and a genuine NCCL or +//! model failure. use anyhow::Result; +use std::collections::HashMap; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use super::nccl_state::NcclState; use super::rpc::{WorkerRequest, WorkerResponse}; +#[cfg(feature = "cuda")] +use super::tp_qwen3::TpQwen3ForCausalLM; + #[derive(Debug, Clone, Copy)] pub struct WorkerConfig { pub rank: u32, @@ -74,9 +79,22 @@ async fn write_response(stdout: &mut tokio::io::Stdout, resp: &WorkerResponse) - Ok(()) } +/// One rank's local state. Owns the rank's NCCL communicator (via +/// `NcclState`) and the rank's shard of every loaded model. struct WorkerState { config: WorkerConfig, nccl: NcclState, + /// Loaded model shards keyed by `model_id`. Each entry holds this + /// rank's `TpQwen3ForCausalLM` — the column/row-parallel layers + /// hold an `Arc` cloned from `nccl`. Cuda-only: there is no + /// TpQwen3ForCausalLM type without the cuda feature in scope. + #[cfg(feature = "cuda")] + models: HashMap, + /// Placeholder so the non-cuda build keeps the same field name set + /// and `WorkerState::new` reads the same on both. + #[cfg(not(feature = "cuda"))] + #[allow(dead_code)] + models: HashMap, } impl WorkerState { @@ -84,6 +102,7 @@ impl WorkerState { Self { config, nccl: NcclState::new(), + models: HashMap::new(), } } @@ -96,7 +115,203 @@ impl WorkerState { }, WorkerRequest::Init { comm_id } => self.nccl.init(self.config, &comm_id), WorkerRequest::NcclSanityCheck => self.nccl.sanity_check(), + WorkerRequest::LoadDenseShard { + model_id, + config_json, + safetensors_paths, + } => self.handle_load_dense_shard(model_id, config_json, safetensors_paths), + WorkerRequest::GenerateStep { + model_id, + tokens, + offset, + } => self.handle_generate_step(&model_id, tokens, offset), + WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id), + WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id), WorkerRequest::Shutdown => WorkerResponse::Bye, } } + + #[cfg(feature = "cuda")] + fn handle_load_dense_shard( + &mut self, + model_id: String, + config_json: String, + safetensors_paths: Vec, + ) -> WorkerResponse { + use candle_core::{DType, Device}; + use candle_nn::var_builder::ShardedSafeTensors; + use candle_transformers::models::qwen3 as qwen3_dense; + use std::path::PathBuf; + + if self.models.contains_key(&model_id) { + return WorkerResponse::Error { + kind: "already_loaded".into(), + message: format!("model '{model_id}' already loaded on this rank"), + }; + } + let comm = match self.nccl.comm() { + Some(c) => c, + None => { + return WorkerResponse::Error { + kind: "nccl_not_initialised".into(), + message: "LoadDenseShard requires Init to have completed first".into(), + }; + } + }; + + let cfg: qwen3_dense::Config = match serde_json::from_str(&config_json) { + Ok(c) => c, + Err(e) => { + return WorkerResponse::Error { + kind: "bad_request".into(), + message: format!("parse Qwen3 Config JSON: {e}"), + }; + } + }; + + let device = match Device::new_cuda(self.config.cuda_device as usize) { + Ok(d) => d, + Err(e) => { + return WorkerResponse::Error { + kind: "cuda_unavailable".into(), + message: format!("Device::new_cuda({}) failed: {e}", self.config.cuda_device), + }; + } + }; + + let paths: Vec = safetensors_paths.into_iter().map(PathBuf::from).collect(); + // SAFETY: same invariant as the single-GPU dense path — the HF + // cache files are treated as immutable while the mmap is held. + let vb = match unsafe { ShardedSafeTensors::var_builder(&paths, DType::BF16, &device) } { + Ok(v) => v, + Err(e) => { + return WorkerResponse::Error { + kind: "load_failed".into(), + message: format!("ShardedSafeTensors::var_builder: {e}"), + }; + } + }; + let model = match TpQwen3ForCausalLM::load( + &cfg, + &vb, + self.config.rank, + self.config.world_size, + comm, + ) { + Ok(m) => m, + Err(e) => { + return WorkerResponse::Error { + kind: "load_failed".into(), + message: format!("TpQwen3ForCausalLM::load: {e:#}"), + }; + } + }; + + self.models.insert(model_id.clone(), model); + tracing::info!(rank = self.config.rank, model = %model_id, "loaded TP shard"); + WorkerResponse::LoadDenseShardOk + } + + #[cfg(not(feature = "cuda"))] + fn handle_load_dense_shard( + &mut self, + _model_id: String, + _config_json: String, + _safetensors_paths: Vec, + ) -> WorkerResponse { + WorkerResponse::Error { + kind: "cuda_feature_not_enabled".into(), + message: "LoadDenseShard requires --features cuda".into(), + } + } + + #[cfg(feature = "cuda")] + fn handle_generate_step( + &mut self, + model_id: &str, + tokens: Vec, + offset: usize, + ) -> WorkerResponse { + use candle_core::Tensor; + + let Some(model) = self.models.get_mut(model_id) else { + return WorkerResponse::Error { + kind: "model_not_loaded".into(), + message: format!("model '{model_id}' not loaded on rank {}", self.config.rank), + }; + }; + let device = model.device().clone(); + let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) { + Ok(t) => t, + Err(e) => { + return WorkerResponse::Error { + kind: "forward_failed".into(), + message: format!("build input tensor: {e}"), + }; + } + }; + // Drop the resulting logits — the leader uses its own copy from + // rank 0. The forward's value here is the NCCL collectives it + // issues, which let the leader's rank-0 forward make progress. + if let Err(e) = model.forward(&input, offset) { + return WorkerResponse::Error { + kind: "forward_failed".into(), + message: format!("TpQwen3ForCausalLM::forward: {e}"), + }; + } + WorkerResponse::GenerateStepOk + } + + #[cfg(not(feature = "cuda"))] + fn handle_generate_step( + &mut self, + _model_id: &str, + _tokens: Vec, + _offset: usize, + ) -> WorkerResponse { + WorkerResponse::Error { + kind: "cuda_feature_not_enabled".into(), + message: "GenerateStep requires --features cuda".into(), + } + } + + #[cfg(feature = "cuda")] + fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse { + let Some(model) = self.models.get_mut(model_id) else { + return WorkerResponse::Error { + kind: "model_not_loaded".into(), + message: format!("model '{model_id}' not loaded on rank {}", self.config.rank), + }; + }; + model.clear_kv_cache(); + WorkerResponse::KvCacheCleared + } + + #[cfg(not(feature = "cuda"))] + fn handle_clear_kv_cache(&mut self, _model_id: &str) -> WorkerResponse { + WorkerResponse::Error { + kind: "cuda_feature_not_enabled".into(), + message: "ClearKvCache requires --features cuda".into(), + } + } + + #[cfg(feature = "cuda")] + fn handle_unload_model(&mut self, model_id: &str) -> WorkerResponse { + if self.models.remove(model_id).is_none() { + return WorkerResponse::Error { + kind: "model_not_loaded".into(), + message: format!("model '{model_id}' not loaded on rank {}", self.config.rank), + }; + } + tracing::info!(rank = self.config.rank, model = %model_id, "unloaded TP shard"); + WorkerResponse::Unloaded + } + + #[cfg(not(feature = "cuda"))] + fn handle_unload_model(&mut self, _model_id: &str) -> WorkerResponse { + WorkerResponse::Error { + kind: "cuda_feature_not_enabled".into(), + message: "UnloadModel requires --features cuda".into(), + } + } } diff --git a/script/validate-neuron.sh b/script/validate-neuron.sh index fae0cf4..ceacd0c 100755 --- a/script/validate-neuron.sh +++ b/script/validate-neuron.sh @@ -9,14 +9,15 @@ # after pushing new neuron builds. # # Usage: -# script/validate-neuron.sh [host] [model_id] [quant] +# script/validate-neuron.sh [host] [model_id] [quant] [tp_size] # # Defaults: # host = beast.hanzalova.internal # model_id = unsloth/Qwen3-0.6B-GGUF (official Qwen3-*-GGUF repos # ship Q8_0 only; unsloth's mirror ships the full Q-spectrum # including Q4_K_M) -# quant = Q4_K_M +# quant = Q4_K_M (empty = dense safetensors path) +# tp_size = unset (= 1 = single-GPU; pass 2 to drive the TP path) set -euo pipefail @@ -25,6 +26,11 @@ MODEL_ID="${2:-unsloth/Qwen3-0.6B-GGUF}" # `${3-Q4_K_M}` (no colon) only uses the default when the arg is # UNSET — passing an explicit empty string drives the dense path. QUANT="${3-Q4_K_M}" +# tp_size > 1 forces the dense path (TP requires safetensors) and adds +# `tensor_parallel: N` to the load payload. The harness picks device +# indices 0..N-1 by default; override by passing NEURON_DEVICES="0,1,..." +# in the environment. +TP_SIZE="${4-1}" PORT="${NEURON_PORT:-13131}" BASE="http://${HOST}:${PORT}" @@ -69,21 +75,43 @@ is_loaded() { } trigger_load() { - say "POST /models/load ${MODEL_ID} (quant=${QUANT:-}, device=[0])" + # Build the per-rank CUDA device list as a JSON array. Either + # honour NEURON_DEVICES (`0,1,2`) verbatim or default to + # `[0, 1, ..., tp_size - 1]`. + local devices_json + if [[ -n "${NEURON_DEVICES:-}" ]]; then + devices_json=$(jq -n -c --arg s "${NEURON_DEVICES}" \ + '$s | split(",") | map(tonumber)') + else + devices_json=$(jq -n -c --argjson n "${TP_SIZE}" '[range(0; $n)]') + fi + say "POST /models/load ${MODEL_ID} (quant=${QUANT:-}, tp=${TP_SIZE}, devices=${devices_json})" say " (synchronous; may take a minute on first run while HF downloads)" - # Build the payload via jq so the optional `quant` field is - # omitted entirely when empty — that's the signal to the harness - # to take the dense safetensors load path rather than GGUF. + if (( TP_SIZE > 1 )) && [[ -n "${QUANT}" ]]; then + die "tp_size>1 requires dense safetensors — pass quant='' as the 3rd argument" + fi + # Build the payload via jq so the optional `quant` and + # `tensor_parallel` fields are omitted entirely when not in use — + # that's how the harness tells dense from quantized and single-GPU + # from TP. local payload - if [[ -z "${QUANT}" ]]; then + if [[ -z "${QUANT}" ]] && (( TP_SIZE > 1 )); then payload=$(jq -n -c \ --arg id "${MODEL_ID}" \ - '{model_id: $id, harness: "candle", devices: [0]}') + --argjson tp "${TP_SIZE}" \ + --argjson devices "${devices_json}" \ + '{model_id: $id, harness: "candle", tensor_parallel: $tp, devices: $devices}') + elif [[ -z "${QUANT}" ]]; then + payload=$(jq -n -c \ + --arg id "${MODEL_ID}" \ + --argjson devices "${devices_json}" \ + '{model_id: $id, harness: "candle", devices: $devices}') else payload=$(jq -n -c \ --arg id "${MODEL_ID}" \ --arg q "${QUANT}" \ - '{model_id: $id, harness: "candle", quant: $q, devices: [0]}') + --argjson devices "${devices_json}" \ + '{model_id: $id, harness: "candle", quant: $q, devices: $devices}') fi # --write-out captures the response code on a separate line so we # can surface a real diagnostic instead of relying on --fail.