feat(tp): Stage 7b-iv — RPC + orchestration for TP load/inference
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m20s
build-prerelease / Build cortex binary (push) Successful in 4m25s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
CI / Test (push) Successful in 4m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m57s
build-prerelease / Build neuron-ampere (push) Successful in 4m51s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m49s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m51s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m20s
build-prerelease / Build cortex binary (push) Successful in 4m25s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
CI / Test (push) Successful in 4m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m57s
build-prerelease / Build neuron-ampere (push) Successful in 4m51s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m49s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m51s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
Wires the in-flight TP machinery (Stage 7a workers, 7b-iii sharded
Qwen3) end to end so a non-streaming chat completion can run across
multiple GPUs via NCCL.
RPC additions (tp/rpc.rs):
- LoadDenseShard{model_id, config_json, safetensors_paths}
- GenerateStep{model_id, tokens, offset}
- ClearKvCache{model_id}
- UnloadModel{model_id}
- LoadDenseShardOk / GenerateStepOk / KvCacheCleared / Unloaded
Worker side (tp/worker.rs):
- WorkerState gains a `models: HashMap<String, TpQwen3ForCausalLM>`
keyed by model_id. LoadDenseShard mmaps safetensors via
ShardedVarBuilder (only this rank's slice materialises), builds the
TP model with the rank's NCCL Comm cloned from NcclState.
- GenerateStep runs the rank-local forward; the resulting logits are
dropped (only the leader's are used for sampling). The forward's
value here is the NCCL collectives inside the row-parallel layers
letting the leader's rank-0 forward make progress.
Pool side (tp/mod.rs):
- WorkerPool::load_dense_shard fans LoadDenseShard out to every worker,
builds rank 0's shard on the leader via spawn_blocking with a fresh
SendComm wrapper at the move boundary (Comm is !Send at the type
level), collects per-rank LoadDenseShardOk. Returns the leader's
Arc<Mutex<TpQwen3ForCausalLM>>.
- WorkerPool::generate_step fans GenerateStep out, runs the leader's
rank-0 forward in spawn_blocking (the AllReduce CustomOps inside
row-parallel layers block until every worker issues the matching
collective), returns the leader's last-position logits Tensor.
- WorkerPool::clear_kv_cache + unload_model follow the same pattern.
NcclState refactor (tp/nccl_state.rs):
- comm field becomes Option<Arc<Comm>> (was Option<Comm>) so callers
can share a clone with TpQwen3ForCausalLM::load.
- new `comm()` accessor + `SendComm` wrapper for spawn_blocking moves.
- single allow(clippy::arc_with_non_send_sync) at the canonical
construction site (Comm is !Send by type but the runtime invariant
is enforced by SendComm + the pool's Mutex).
Harness side (candle.rs):
- LoadedHandle enum (Single | Tp) replaces the bare Arc<LoadedModel>
in the harness's registry. list_models / unload_model /
inference_endpoint walk the enum uniformly.
- TpLoadedModel holds the pool + leader_model + tokenizer + devices.
- load_model dispatches on `spec.tensor_parallel > 1` to a new
cuda-gated load_tp path: resolve dense files via hf-hub, spawn the
pool, init_nccl, load_dense_shard.
- chat_completion branches on the handle variant. The TP path mirrors
run_inference: clear_kv_cache, prefill, sample, decode loop,
detokenize. Acquires the pool Mutex for the whole request.
- Streaming through TP is deferred to Stage 7c (returns Other(err)).
Script (script/validate-neuron.sh):
- 4th positional arg `tp_size` (default 1). When >1, switches to the
dense path (tp + GGUF is mutually exclusive — bails) and adds
`tensor_parallel` + `devices` to the load payload. NEURON_DEVICES
env overrides the default 0..N-1 device list.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -31,11 +31,44 @@ use tokio::sync::{Mutex, RwLock, mpsc};
|
||||
|
||||
/// In-process candle harness. Owns the loaded model registry.
|
||||
pub struct CandleHarness {
|
||||
models: Arc<RwLock<HashMap<String, Arc<LoadedModel>>>>,
|
||||
models: Arc<RwLock<HashMap<String, LoadedHandle>>>,
|
||||
hf_cache: Option<PathBuf>,
|
||||
bind_url: String,
|
||||
}
|
||||
|
||||
/// One entry in the harness's loaded-model registry. Single-GPU loads
|
||||
/// land in `Single`; loads with `tensor_parallel > 1` land in `Tp`.
|
||||
/// The two variants share the same `model_id` key in the map, so
|
||||
/// `list_models`, `unload_model`, and `inference_endpoint` can walk
|
||||
/// them uniformly without branching the storage layout.
|
||||
///
|
||||
/// `Clone` is cheap: both variants hold `Arc<_>` and cloning just bumps
|
||||
/// the refcount.
|
||||
#[derive(Clone)]
|
||||
pub enum LoadedHandle {
|
||||
Single(Arc<LoadedModel>),
|
||||
#[cfg(feature = "cuda")]
|
||||
Tp(Arc<TpLoadedModel>),
|
||||
}
|
||||
|
||||
impl LoadedHandle {
|
||||
pub fn model_id(&self) -> &str {
|
||||
match self {
|
||||
LoadedHandle::Single(m) => &m.model_id,
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(m) => &m.model_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn devices(&self) -> Vec<u32> {
|
||||
match self {
|
||||
LoadedHandle::Single(m) => m.devices.clone(),
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(m) => m.devices.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A loaded model with its tokenizer, device placement, and architecture-
|
||||
/// specific weights. The `arch` is `Arc<Mutex<>>` so the lock guard can be
|
||||
/// moved into `spawn_blocking` for synchronous candle forward passes.
|
||||
@@ -48,6 +81,25 @@ pub struct LoadedModel {
|
||||
pub devices: Vec<u32>,
|
||||
}
|
||||
|
||||
/// Tensor-parallel loaded model. Holds the leader's rank-0 shard
|
||||
/// (which the inference loop drives via spawn_blocking) and the
|
||||
/// `WorkerPool` (which drives every non-zero rank over the RPC
|
||||
/// channel). Both are behind tokio Mutexes so concurrent inference
|
||||
/// requests against the same model are serialised; concurrent loads
|
||||
/// for *different* models would each have their own pool.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub struct TpLoadedModel {
|
||||
pub model_id: String,
|
||||
pub tokenizer: Tokenizer,
|
||||
pub devices: Vec<u32>,
|
||||
/// One end-to-end gate: the pool's RPC stream isn't safe to use
|
||||
/// concurrently and the leader shard's KV cache mutates with every
|
||||
/// step. The same Mutex covers both for the simplest correctness
|
||||
/// story.
|
||||
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
|
||||
pub leader_model: Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
|
||||
}
|
||||
|
||||
/// Architecture-specific weights.
|
||||
///
|
||||
/// - `Qwen3Quantized` — GGUF source, pre-quantized. Single-GPU only;
|
||||
@@ -357,11 +409,22 @@ impl CandleHarness {
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, InferenceError> {
|
||||
let loaded = {
|
||||
let handle = {
|
||||
let models = self.models.read().await;
|
||||
models.get(&request.model).cloned()
|
||||
};
|
||||
let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
||||
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
||||
// The match is technically infallible without `cuda` (only Single
|
||||
// exists), but the cfg-gated Tp arm makes this the right shape
|
||||
// under both feature flags.
|
||||
#[allow(clippy::infallible_destructuring_match)]
|
||||
let loaded = match handle {
|
||||
LoadedHandle::Single(m) => m,
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(m) => {
|
||||
return self.chat_completion_tp(m, request).await;
|
||||
}
|
||||
};
|
||||
|
||||
let prompt = format_qwen3_prompt(&request.messages);
|
||||
|
||||
@@ -451,11 +514,29 @@ impl CandleHarness {
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
||||
let loaded = {
|
||||
let handle = {
|
||||
let models = self.models.read().await;
|
||||
models.get(&request.model).cloned()
|
||||
};
|
||||
let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
||||
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
||||
// The match is technically infallible without `cuda` (only Single
|
||||
// exists), but the cfg-gated Tp arm makes this the right shape
|
||||
// under both feature flags.
|
||||
#[allow(clippy::infallible_destructuring_match)]
|
||||
let loaded = match handle {
|
||||
LoadedHandle::Single(m) => m,
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(_) => {
|
||||
// Streaming through TP is Stage 7c work — the
|
||||
// non-streaming path drives the same forwards through
|
||||
// the pool but doesn't have to interleave SSE writes
|
||||
// with spawn_blocking forwards.
|
||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||
"streaming chat completions through TP are not yet supported; \
|
||||
retry with stream=false"
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let prompt = format_qwen3_prompt(&request.messages);
|
||||
let encoding = loaded
|
||||
@@ -552,11 +633,11 @@ impl Harness for CandleHarness {
|
||||
let models = self.models.read().await;
|
||||
Ok(models
|
||||
.values()
|
||||
.map(|m| ModelInfo {
|
||||
id: m.model_id.clone(),
|
||||
.map(|h| ModelInfo {
|
||||
id: h.model_id().into(),
|
||||
harness: "candle".into(),
|
||||
status: "loaded".into(),
|
||||
devices: m.devices.clone(),
|
||||
devices: h.devices(),
|
||||
vram_used_mb: None,
|
||||
})
|
||||
.collect())
|
||||
@@ -574,20 +655,21 @@ impl Harness for CandleHarness {
|
||||
}
|
||||
}
|
||||
|
||||
// Stage 7a-i scaffolds tensor-parallel worker subprocesses but
|
||||
// does not yet route inference through them. Refuse TP loads
|
||||
// for now with a clear marker so the request surface is honest;
|
||||
// Stage 7b-iv replaces this bail with the TP dispatch.
|
||||
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
||||
if tp_size > 1 {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
return self.load_tp(spec, tp_size).await;
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
anyhow::bail!(
|
||||
"tensor_parallel={tp_size} requested for '{}': TP worker \
|
||||
lifecycle + NCCL handshake are in place (Stage 7a) but \
|
||||
TP-aware Qwen3 inference orchestration lands in Stage \
|
||||
7b-iv; single-GPU loads only for now",
|
||||
"tensor_parallel={tp_size} requested for '{}': this neuron \
|
||||
binary was built without --features cuda; TP requires CUDA + NCCL",
|
||||
spec.model_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
|
||||
let device = Self::pick_device(&devices)?;
|
||||
@@ -615,15 +697,52 @@ impl Harness for CandleHarness {
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
models.insert(spec.model_id.clone(), loaded);
|
||||
models.insert(spec.model_id.clone(), LoadedHandle::Single(loaded));
|
||||
tracing::info!(model = %spec.model_id, "model loaded");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unload_model(&self, model_id: &str) -> Result<()> {
|
||||
let removed = {
|
||||
let mut models = self.models.write().await;
|
||||
if models.remove(model_id).is_none() {
|
||||
models.remove(model_id)
|
||||
};
|
||||
let Some(handle) = removed else {
|
||||
anyhow::bail!("model '{model_id}' not loaded");
|
||||
};
|
||||
// Single-GPU drops are immediate — the LoadedModel goes out of
|
||||
// scope with the Arc and candle frees VRAM. TP unloads also
|
||||
// need to tell every worker to drop its shard before the pool
|
||||
// itself is dropped (otherwise the workers keep their shards
|
||||
// around until Shutdown, which is wasteful and would surface
|
||||
// as VRAM not freed promptly).
|
||||
match handle {
|
||||
LoadedHandle::Single(_) => {}
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(tp) => {
|
||||
// Try to recover the inner TpLoadedModel so we can move
|
||||
// the pool and shut it down. If anyone else still holds
|
||||
// a clone of the Arc (shouldn't happen — the only owners
|
||||
// are the registry and any in-flight chat_completion),
|
||||
// bail with a clear marker rather than silently leaking.
|
||||
let tp = match Arc::try_unwrap(tp) {
|
||||
Ok(t) => t,
|
||||
Err(arc) => {
|
||||
// Reinsert so we don't leave the registry in an
|
||||
// inconsistent state.
|
||||
let mut models = self.models.write().await;
|
||||
models.insert(model_id.into(), LoadedHandle::Tp(arc));
|
||||
anyhow::bail!("cannot unload '{model_id}': inference still in flight");
|
||||
}
|
||||
};
|
||||
let mut pool = tp.pool.into_inner();
|
||||
if let Err(e) = pool.unload_model(model_id).await {
|
||||
tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed");
|
||||
}
|
||||
if let Err(e) = pool.shutdown().await {
|
||||
tracing::warn!(model = %model_id, error = %e, "TP pool shutdown failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::info!(model = %model_id, "model unloaded");
|
||||
Ok(())
|
||||
@@ -635,6 +754,215 @@ impl Harness for CandleHarness {
|
||||
}
|
||||
}
|
||||
|
||||
impl CandleHarness {
|
||||
/// Tensor-parallel load. Resolves dense safetensors via hf-hub the
|
||||
/// same way the single-GPU dense path does, spins up a TP worker
|
||||
/// pool sized to `tp_size`, runs the NCCL handshake, then has
|
||||
/// every rank load its shard of the model.
|
||||
///
|
||||
/// `spec.devices` carries the per-rank CUDA device indices (one
|
||||
/// entry per rank, in rank order); defaults to `0..tp_size`.
|
||||
#[cfg(feature = "cuda")]
|
||||
async fn load_tp(&self, spec: &ModelSpec, tp_size: u32) -> Result<()> {
|
||||
use std::sync::Arc as StdArc;
|
||||
use tokio::sync::Mutex as TMutex;
|
||||
|
||||
// Default per-rank device assignment: 0, 1, ..., tp_size - 1.
|
||||
let devices = spec
|
||||
.devices
|
||||
.clone()
|
||||
.unwrap_or_else(|| (0..tp_size).collect());
|
||||
if devices.len() as u32 != tp_size {
|
||||
anyhow::bail!(
|
||||
"tensor_parallel={tp_size} requires {tp_size} entries in devices, got {}",
|
||||
devices.len()
|
||||
);
|
||||
}
|
||||
if spec.quant.is_some() {
|
||||
anyhow::bail!(
|
||||
"tensor_parallel={tp_size} with quant={:?}: GGUF quantized models \
|
||||
are not supported in the TP path; use a dense safetensors source",
|
||||
spec.quant
|
||||
);
|
||||
}
|
||||
|
||||
// 1. Resolve config + tokenizer + safetensors via hf-hub.
|
||||
let (config_path, tokenizer_path, safetensors_paths) =
|
||||
self.resolve_dense_files(spec).await?;
|
||||
let config_json = std::fs::read_to_string(&config_path).context("read config.json")?;
|
||||
|
||||
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
|
||||
// 1..tp_size are subprocesses, one per device after the
|
||||
// leader's own.
|
||||
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
|
||||
let mut pool = super::tp::WorkerPool::spawn(&exe, tp_size, &devices).await?;
|
||||
|
||||
// 3. NCCL handshake across all ranks.
|
||||
let leader_device_idx = devices[0];
|
||||
pool.init_nccl(leader_device_idx).await?;
|
||||
|
||||
// 4. Pick the leader's candle Device (same index as init_nccl).
|
||||
let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize)
|
||||
.context("Device::new_cuda for TP leader")?;
|
||||
|
||||
// 5. Load this rank's shard on every rank.
|
||||
let leader_model = pool
|
||||
.load_dense_shard(
|
||||
&spec.model_id,
|
||||
&config_json,
|
||||
&safetensors_paths,
|
||||
&leader_device,
|
||||
candle_core::DType::BF16,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// 6. Tokenizer (same as single-GPU path).
|
||||
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
||||
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
||||
|
||||
let tp_loaded = StdArc::new(TpLoadedModel {
|
||||
model_id: spec.model_id.clone(),
|
||||
tokenizer,
|
||||
devices: devices.clone(),
|
||||
pool: TMutex::new(pool),
|
||||
leader_model,
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
models.insert(spec.model_id.clone(), LoadedHandle::Tp(tp_loaded));
|
||||
tracing::info!(
|
||||
model = %spec.model_id,
|
||||
tp_size,
|
||||
?devices,
|
||||
"TP model loaded"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Non-streaming chat completion against a TP model. Pattern mirrors
|
||||
/// the single-GPU `run_inference`: tokenize, prefill, sample, decode
|
||||
/// loop, detokenize. Each forward step fans out to every rank via
|
||||
/// the WorkerPool and uses the leader's last-position logits to
|
||||
/// sample.
|
||||
#[cfg(feature = "cuda")]
|
||||
async fn chat_completion_tp(
|
||||
&self,
|
||||
tp: Arc<TpLoadedModel>,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, InferenceError> {
|
||||
let prompt = format_qwen3_prompt(&request.messages);
|
||||
let encoding = tp
|
||||
.tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||||
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||||
let prompt_len = prompt_tokens.len();
|
||||
|
||||
let temperature = request.temperature.unwrap_or(0.7);
|
||||
let top_p = request.top_p;
|
||||
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
||||
let seed = unix_subsec_nanos();
|
||||
|
||||
let eos_id = tp
|
||||
.tokenizer
|
||||
.token_to_id("<|im_end|>")
|
||||
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
|
||||
|
||||
let model_id = request.model.clone();
|
||||
|
||||
// Acquire the pool lock for the duration of the request. The
|
||||
// leader_model's own Mutex is acquired step-by-step inside
|
||||
// pool.generate_step (so spawn_blocking can grab it without
|
||||
// holding the pool lock across the blocking_lock call).
|
||||
let mut pool = tp.pool.lock().await;
|
||||
let leader_arc = tp.leader_model.clone();
|
||||
|
||||
// Reset every rank's KV cache so this request doesn't attend
|
||||
// over the previous request's tokens.
|
||||
pool.clear_kv_cache(&model_id, leader_arc.clone())
|
||||
.await
|
||||
.map_err(InferenceError::Other)?;
|
||||
|
||||
let mut logits_processor = {
|
||||
let sampling = if temperature <= 0.0 {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match top_p {
|
||||
Some(p) => Sampling::TopP { p, temperature },
|
||||
None => Sampling::All { temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(seed, sampling)
|
||||
};
|
||||
|
||||
let mut generated: Vec<u32> = Vec::new();
|
||||
let mut finish_reason = "length".to_string();
|
||||
|
||||
// Prefill: every rank embeds the whole prompt, offset = 0.
|
||||
let logits = pool
|
||||
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
||||
.await
|
||||
.map_err(InferenceError::Other)?;
|
||||
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
||||
.map_err(InferenceError::Other)?;
|
||||
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = "stop".into();
|
||||
} else {
|
||||
generated.push(next_token);
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
let logits = pool
|
||||
.generate_step(
|
||||
&model_id,
|
||||
leader_arc.clone(),
|
||||
vec![next_token],
|
||||
prompt_len + index,
|
||||
)
|
||||
.await
|
||||
.map_err(InferenceError::Other)?;
|
||||
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
||||
.map_err(InferenceError::Other)?;
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = "stop".into();
|
||||
break;
|
||||
}
|
||||
generated.push(next_token);
|
||||
}
|
||||
}
|
||||
drop(pool);
|
||||
|
||||
let completion_text = tp
|
||||
.tokenizer
|
||||
.decode(&generated, true)
|
||||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
|
||||
|
||||
let usage = Usage {
|
||||
prompt_tokens: prompt_len as u64,
|
||||
completion_tokens: generated.len() as u64,
|
||||
total_tokens: (prompt_len + generated.len()) as u64,
|
||||
};
|
||||
|
||||
Ok(ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
|
||||
object: "chat.completion".into(),
|
||||
created: unix_now_secs(),
|
||||
model: model_id,
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index: 0,
|
||||
message: ChatMessage {
|
||||
role: "assistant".into(),
|
||||
content: MessageContent::Text(completion_text),
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
},
|
||||
finish_reason: Some(finish_reason),
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
}],
|
||||
usage: Some(usage),
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors returned by `CandleHarness::chat_completion`. The
|
||||
/// `ModelNotLoaded` variant lets the HTTP handler map cleanly to 404
|
||||
/// without string-matching on anyhow messages.
|
||||
|
||||
@@ -338,6 +338,241 @@ impl WorkerPool {
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Load this rank's shard of a dense Qwen3 model on every rank.
|
||||
///
|
||||
/// The leader builds rank 0's `TpQwen3ForCausalLM` directly into
|
||||
/// the returned `Arc<Mutex<_>>` — workers build their rank-local
|
||||
/// shards in their own address spaces and confirm via
|
||||
/// `LoadDenseShardOk`. All ranks see the same `safetensors_paths`;
|
||||
/// `ShardedVarBuilder` slices each tensor by rank at materialisation
|
||||
/// time, so the per-rank VRAM footprint is roughly `1/world_size`
|
||||
/// of the full model (plus the replicated embedding/norm/lm_head).
|
||||
///
|
||||
/// `leader_device` is the candle `Device` the leader's shard lives
|
||||
/// on — typically `Device::new_cuda(leader_cuda_device)` matching
|
||||
/// the same index passed to `init_nccl`. `dtype` is the on-device
|
||||
/// element type; bf16 is the canonical Qwen3 distribution dtype.
|
||||
///
|
||||
/// `init_nccl` must have completed first. Bails if the leader's
|
||||
/// NCCL comm isn't set up yet.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn load_dense_shard(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
config_json: &str,
|
||||
safetensors_paths: &[std::path::PathBuf],
|
||||
leader_device: &candle_core::Device,
|
||||
dtype: candle_core::DType,
|
||||
) -> Result<std::sync::Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>> {
|
||||
use candle_nn::var_builder::ShardedSafeTensors;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
// Wrap the comm in SendComm immediately so it stays Send across
|
||||
// the await points in this method — bare Arc<Comm> would
|
||||
// poison the async fn's Send bound (Comm's raw NCCL pointer is
|
||||
// !Send). The wrapper's safety contract is satisfied by the
|
||||
// pool's outer Mutex serialising callers + the spawn_blocking
|
||||
// thread being the only place ops are issued.
|
||||
let leader_comm =
|
||||
nccl_state::SendComm(self.leader_nccl.comm().ok_or_else(|| {
|
||||
anyhow::anyhow!("leader NCCL not initialised; call init_nccl first")
|
||||
})?);
|
||||
let world_size = self.world_size;
|
||||
let safetensors_str: Vec<String> = safetensors_paths
|
||||
.iter()
|
||||
.map(|p| p.to_string_lossy().into_owned())
|
||||
.collect();
|
||||
|
||||
// 1. Fan out the LoadDenseShard request to every worker without
|
||||
// awaiting their replies — they'll build their shards in
|
||||
// parallel with the leader below.
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::LoadDenseShard {
|
||||
model_id: model_id.to_string(),
|
||||
config_json: config_json.to_string(),
|
||||
safetensors_paths: safetensors_str.clone(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
// 2. Build rank 0's shard on the leader. ShardedVarBuilder reads
|
||||
// only the rank's slice from safetensors — no full-tensor
|
||||
// materialisation. Runs in spawn_blocking because the
|
||||
// file-mmap + slice + copy-to-device work is synchronous.
|
||||
let cfg: super::tp::tp_qwen3::Config =
|
||||
serde_json::from_str(config_json).context("parse Qwen3 Config JSON for leader load")?;
|
||||
let paths_for_leader: Vec<std::path::PathBuf> = safetensors_paths.to_vec();
|
||||
let device_for_leader = leader_device.clone();
|
||||
let comm_for_leader = leader_comm;
|
||||
let model_id_for_log = model_id.to_string();
|
||||
let leader_model = tokio::task::spawn_blocking(
|
||||
move || -> Result<super::tp::tp_qwen3::TpQwen3ForCausalLM> {
|
||||
// SAFETY: same invariant as the single-GPU dense path —
|
||||
// the HF cache files are treated as immutable while the
|
||||
// mmap is held.
|
||||
let vb = unsafe {
|
||||
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
|
||||
.context("build ShardedVarBuilder over safetensors")?
|
||||
};
|
||||
let model = super::tp::tp_qwen3::TpQwen3ForCausalLM::load(
|
||||
&cfg,
|
||||
&vb,
|
||||
0,
|
||||
world_size,
|
||||
comm_for_leader.into_inner(),
|
||||
)?;
|
||||
tracing::info!(rank = 0, model = %model_id_for_log, "loaded TP shard (leader)");
|
||||
Ok(model)
|
||||
},
|
||||
)
|
||||
.await
|
||||
.context("leader load task panicked")??;
|
||||
|
||||
// 3. Collect worker confirmations. Anything other than
|
||||
// LoadDenseShardOk aborts the whole load — the leader's
|
||||
// already-loaded shard drops when this fn returns Err.
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match resp {
|
||||
WorkerResponse::LoadDenseShardOk => {}
|
||||
WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} LoadDenseShard [{kind}]: {message}", w.rank)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"worker rank {} LoadDenseShard: expected LoadDenseShardOk, got {other:?}",
|
||||
w.rank
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Arc::new(Mutex::new(leader_model)))
|
||||
}
|
||||
|
||||
/// Run one forward step across every rank. The leader's forward
|
||||
/// returns the last-position logits as a candle Tensor on the
|
||||
/// leader's device; the caller does sampling out-of-band. Workers
|
||||
/// run their own forwards (the AllReduce inside row-parallel layers
|
||||
/// is what lets the leader's collective complete) and reply with
|
||||
/// `GenerateStepOk` — they do not ship logits over the wire.
|
||||
///
|
||||
/// `tokens` is the input for this step (prompt for prefill, the
|
||||
/// previously-sampled token for decode). `offset` is the KV-cache
|
||||
/// position before this step.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn generate_step(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
leader_model: std::sync::Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
) -> Result<candle_core::Tensor> {
|
||||
// 1. Fan-out to workers.
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::GenerateStep {
|
||||
model_id: model_id.to_string(),
|
||||
tokens: tokens.clone(),
|
||||
offset,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
// 2. Leader's forward in spawn_blocking. The AllReduce CustomOps
|
||||
// inside the row-parallel layers block until every worker's
|
||||
// forward issues the matching collective.
|
||||
let logits = tokio::task::spawn_blocking(move || -> Result<candle_core::Tensor> {
|
||||
let mut model = leader_model.blocking_lock();
|
||||
let device = model.device().clone();
|
||||
let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
// TpQwen3ForCausalLM::forward returns [B, 1, V] (it slices
|
||||
// to the last position internally). Squeeze both leading
|
||||
// dims to get the rank-1 vocab logits LogitsProcessor wants.
|
||||
let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?;
|
||||
Ok(logits)
|
||||
})
|
||||
.await
|
||||
.context("leader forward task panicked")??;
|
||||
|
||||
// 3. Collect worker confirmations.
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match resp {
|
||||
WorkerResponse::GenerateStepOk => {}
|
||||
WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} GenerateStep [{kind}]: {message}", w.rank)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"worker rank {} GenerateStep: expected GenerateStepOk, got {other:?}",
|
||||
w.rank
|
||||
),
|
||||
}
|
||||
}
|
||||
Ok(logits)
|
||||
}
|
||||
|
||||
/// Reset the KV cache for `model_id` on every rank. Called at the
|
||||
/// start of every inference so a fresh request doesn't attend over
|
||||
/// the previous one's tokens.
|
||||
pub async fn clear_kv_cache(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<
|
||||
tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>,
|
||||
>,
|
||||
) -> Result<()> {
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::ClearKvCache {
|
||||
model_id: model_id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
let mut m = leader_model.lock().await;
|
||||
m.clear_kv_cache();
|
||||
}
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match resp {
|
||||
WorkerResponse::KvCacheCleared => {}
|
||||
WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} ClearKvCache [{kind}]: {message}", w.rank)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"worker rank {} ClearKvCache: expected KvCacheCleared, got {other:?}",
|
||||
w.rank
|
||||
),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Drop this model's shards on every rank. The leader's shard is
|
||||
/// expected to have been dropped by the caller (its `Arc` was held
|
||||
/// in the TpLoadedModel and goes away when that's removed).
|
||||
pub async fn unload_model(&mut self, model_id: &str) -> Result<()> {
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::UnloadModel {
|
||||
model_id: model_id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match resp {
|
||||
WorkerResponse::Unloaded => {}
|
||||
WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} UnloadModel [{kind}]: {message}", w.rank)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"worker rank {} UnloadModel: expected Unloaded, got {other:?}",
|
||||
w.rank
|
||||
),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
|
||||
/// children. Best-effort — individual worker failures are logged
|
||||
/// but don't abort the rest of the sweep.
|
||||
|
||||
@@ -83,7 +83,13 @@ mod cuda_impl {
|
||||
const NCCL_ID_BYTES: usize = 128;
|
||||
|
||||
pub struct NcclState {
|
||||
comm: Option<Comm>,
|
||||
/// Wrapped in `Arc` so we can hand a clone to `TpQwen3ForCausalLM`
|
||||
/// at load time (every row-parallel layer needs a reference to
|
||||
/// run its trailing `AllReduce`). The `Arc` is the single source
|
||||
/// of truth for the comm's lifetime — when the pool drops and
|
||||
/// every layer that captured a clone drops, NCCL releases the
|
||||
/// underlying `ncclComm_t`.
|
||||
comm: Option<Arc<Comm>>,
|
||||
/// Held alongside the Comm so the device isn't dropped
|
||||
/// underneath the NCCL handle.
|
||||
#[allow(dead_code)]
|
||||
@@ -103,6 +109,40 @@ mod cuda_impl {
|
||||
ctx: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Clone the comm out as an `Arc` so callers (the leader-side
|
||||
/// `TpQwen3ForCausalLM::load`, or the worker's own model load)
|
||||
/// can hold a reference for the lifetime of the model. Returns
|
||||
/// `None` before `init` has run.
|
||||
pub fn comm(&self) -> Option<Arc<Comm>> {
|
||||
self.comm.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// `Arc<Comm>` doesn't impl `Send` because `Comm` wraps a raw
|
||||
/// `ncclComm_t` pointer. The NCCL contract is "operations against a
|
||||
/// given comm must be serialised", not "the handle must stay on the
|
||||
/// thread that created it" — so it's safe to move an `Arc<Comm>`
|
||||
/// across threads as long as no concurrent ops are issued. The
|
||||
/// pool's outer Mutex serialises us into `spawn_blocking`, so this
|
||||
/// wrapper at the move boundary is the only thing missing.
|
||||
///
|
||||
/// `Sync` is also marked safe because the `Arc<Comm>` clones held
|
||||
/// by the row-parallel layers are only used from the
|
||||
/// `spawn_blocking` thread driving the forward pass; concurrent
|
||||
/// access from another thread would still be a bug.
|
||||
pub struct SendComm(pub Arc<Comm>);
|
||||
|
||||
// SAFETY: see the doc-comment above; the invariant is enforced at
|
||||
// the call site (pool Mutex + single spawn_blocking thread), not at
|
||||
// the type level.
|
||||
unsafe impl Send for SendComm {}
|
||||
unsafe impl Sync for SendComm {}
|
||||
|
||||
impl SendComm {
|
||||
pub fn into_inner(self) -> Arc<Comm> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer
|
||||
@@ -143,7 +183,7 @@ mod cuda_impl {
|
||||
message: "sanity_check requires Init to have completed first".into(),
|
||||
};
|
||||
};
|
||||
match try_sanity_check(comm) {
|
||||
match try_sanity_check(comm.as_ref()) {
|
||||
Ok(sum) => WorkerResponse::NcclSanityResult { observed_sum: sum },
|
||||
Err(msg) => WorkerResponse::Error {
|
||||
kind: "nccl_sanity_failed".into(),
|
||||
@@ -177,7 +217,17 @@ mod cuda_impl {
|
||||
})?;
|
||||
|
||||
state.ctx = Some(ctx);
|
||||
state.comm = Some(comm);
|
||||
// `Comm` is !Send + !Sync at the type level because it wraps a
|
||||
// raw `ncclComm_t`. The `Arc` is fine in practice — we
|
||||
// serialise operations through the pool's outer Mutex and the
|
||||
// SendComm wrapper at thread-crossing boundaries enforces this
|
||||
// at every move site. clippy's `arc_with_non_send_sync` lint
|
||||
// can't see that invariant; allow once at the canonical
|
||||
// construction site.
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
{
|
||||
state.comm = Some(Arc::new(comm));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -202,7 +252,7 @@ mod cuda_impl {
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub use cuda_impl::{NcclState, generate_comm_id_hex};
|
||||
pub use cuda_impl::{NcclState, SendComm, generate_comm_id_hex};
|
||||
|
||||
/// Non-cuda stub for the leader: returns a clear marker error rather
|
||||
/// than letting `init_nccl` succeed vacuously.
|
||||
|
||||
@@ -45,6 +45,52 @@ pub enum WorkerRequest {
|
||||
/// the NCCL handshake is genuinely live, not just configured.
|
||||
NcclSanityCheck,
|
||||
|
||||
/// Load this rank's shard of a dense Qwen3 model from mmaped
|
||||
/// safetensors. The same `safetensors_paths` list is sent to every
|
||||
/// rank — the ShardedVarBuilder reads only the rank-local slice of
|
||||
/// each tensor at materialisation time, so the worker's VRAM
|
||||
/// footprint is `1 / world_size` of the full model (plus replicated
|
||||
/// embedding/norm/lm_head).
|
||||
LoadDenseShard {
|
||||
/// Caller-supplied id for later `GenerateStep` / `UnloadModel`
|
||||
/// lookups. Typically the HF model id verbatim.
|
||||
model_id: String,
|
||||
/// JSON-serialised `candle_transformers::models::qwen3::Config`
|
||||
/// — the same blob the leader parsed from the HF cache's
|
||||
/// `config.json`. Threaded through verbatim so the worker uses
|
||||
/// identical hyperparameters.
|
||||
config_json: String,
|
||||
/// Absolute paths the worker should mmap. The same set on every
|
||||
/// rank; ShardedVarBuilder slices into them per rank.
|
||||
safetensors_paths: Vec<String>,
|
||||
},
|
||||
|
||||
/// Run one forward step on this rank's loaded model. The worker
|
||||
/// reaches into its NCCL Comm for the row-parallel `AllReduce`s
|
||||
/// inside the model — and so blocks on every other rank issuing the
|
||||
/// same op. The leader does *not* receive logits back over RPC; it
|
||||
/// runs its own rank-0 forward in parallel and uses its own logits
|
||||
/// for sampling.
|
||||
GenerateStep {
|
||||
model_id: String,
|
||||
/// Input token ids for this step. For prefill, the whole prompt;
|
||||
/// for decode, a single token. Identical on every rank.
|
||||
tokens: Vec<u32>,
|
||||
/// KV cache offset (count of tokens already in the cache before
|
||||
/// this step).
|
||||
offset: usize,
|
||||
},
|
||||
|
||||
/// Reset the KV cache for this model on this rank. Sent at the
|
||||
/// start of every inference so a fresh request doesn't accidentally
|
||||
/// attend over the previous one's tokens.
|
||||
ClearKvCache { model_id: String },
|
||||
|
||||
/// Drop this rank's shard for the given model. Releases the VRAM
|
||||
/// the shard's weights occupied; subsequent `GenerateStep` calls
|
||||
/// against the same `model_id` return an `Error`.
|
||||
UnloadModel { model_id: String },
|
||||
|
||||
/// Worker should release resources and exit. Worker replies `Bye`
|
||||
/// and then closes stdout / exits zero. The leader reaps the
|
||||
/// child via the `tokio::process::Child` it kept.
|
||||
@@ -74,6 +120,24 @@ pub enum WorkerResponse {
|
||||
/// this matches `world_size`.
|
||||
NcclSanityResult { observed_sum: u32 },
|
||||
|
||||
/// Reply to `LoadDenseShard`. Empty payload — success is the
|
||||
/// absence of `Error`. By the time this comes back, the rank's
|
||||
/// `TpQwen3ForCausalLM` is constructed in memory and ready for
|
||||
/// `GenerateStep`.
|
||||
LoadDenseShardOk,
|
||||
|
||||
/// Reply to `GenerateStep`. Empty payload — workers don't ship
|
||||
/// logits over the wire. The leader uses its own rank-0 logits;
|
||||
/// workers only need to confirm the collective completed.
|
||||
GenerateStepOk,
|
||||
|
||||
/// Reply to `ClearKvCache`. Empty payload.
|
||||
KvCacheCleared,
|
||||
|
||||
/// Reply to `UnloadModel`. Empty payload. The named model is no
|
||||
/// longer present on this rank.
|
||||
Unloaded,
|
||||
|
||||
/// Reply to `Shutdown`. Worker exits immediately after writing this.
|
||||
Bye,
|
||||
|
||||
|
||||
@@ -5,18 +5,23 @@
|
||||
//! exactly one `WorkerResponse` JSON line to stdout. tracing goes to
|
||||
//! stderr so it doesn't collide with the RPC stream.
|
||||
//!
|
||||
//! NCCL operations (`Init`, `NcclSanityCheck`) are real when built
|
||||
//! with the `cuda` feature; without it they reply with
|
||||
//! `Error{kind="cuda_feature_not_enabled"}` so the leader can tell
|
||||
//! the difference between a misconfigured build and a genuine NCCL
|
||||
//! failure.
|
||||
//! NCCL operations (`Init`, `NcclSanityCheck`) and model lifecycle ops
|
||||
//! (`LoadDenseShard`, `GenerateStep`, `ClearKvCache`, `UnloadModel`)
|
||||
//! are real when built with the `cuda` feature; without it they reply
|
||||
//! with `Error{kind="cuda_feature_not_enabled"}` so the leader can tell
|
||||
//! the difference between a misconfigured build and a genuine NCCL or
|
||||
//! model failure.
|
||||
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
|
||||
use super::nccl_state::NcclState;
|
||||
use super::rpc::{WorkerRequest, WorkerResponse};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use super::tp_qwen3::TpQwen3ForCausalLM;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct WorkerConfig {
|
||||
pub rank: u32,
|
||||
@@ -74,9 +79,22 @@ async fn write_response(stdout: &mut tokio::io::Stdout, resp: &WorkerResponse) -
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// One rank's local state. Owns the rank's NCCL communicator (via
|
||||
/// `NcclState`) and the rank's shard of every loaded model.
|
||||
struct WorkerState {
|
||||
config: WorkerConfig,
|
||||
nccl: NcclState,
|
||||
/// Loaded model shards keyed by `model_id`. Each entry holds this
|
||||
/// rank's `TpQwen3ForCausalLM` — the column/row-parallel layers
|
||||
/// hold an `Arc<Comm>` cloned from `nccl`. Cuda-only: there is no
|
||||
/// TpQwen3ForCausalLM type without the cuda feature in scope.
|
||||
#[cfg(feature = "cuda")]
|
||||
models: HashMap<String, TpQwen3ForCausalLM>,
|
||||
/// Placeholder so the non-cuda build keeps the same field name set
|
||||
/// and `WorkerState::new` reads the same on both.
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
#[allow(dead_code)]
|
||||
models: HashMap<String, ()>,
|
||||
}
|
||||
|
||||
impl WorkerState {
|
||||
@@ -84,6 +102,7 @@ impl WorkerState {
|
||||
Self {
|
||||
config,
|
||||
nccl: NcclState::new(),
|
||||
models: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,7 +115,203 @@ impl WorkerState {
|
||||
},
|
||||
WorkerRequest::Init { comm_id } => self.nccl.init(self.config, &comm_id),
|
||||
WorkerRequest::NcclSanityCheck => self.nccl.sanity_check(),
|
||||
WorkerRequest::LoadDenseShard {
|
||||
model_id,
|
||||
config_json,
|
||||
safetensors_paths,
|
||||
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths),
|
||||
WorkerRequest::GenerateStep {
|
||||
model_id,
|
||||
tokens,
|
||||
offset,
|
||||
} => self.handle_generate_step(&model_id, tokens, offset),
|
||||
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
|
||||
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
|
||||
WorkerRequest::Shutdown => WorkerResponse::Bye,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn handle_load_dense_shard(
|
||||
&mut self,
|
||||
model_id: String,
|
||||
config_json: String,
|
||||
safetensors_paths: Vec<String>,
|
||||
) -> WorkerResponse {
|
||||
use candle_core::{DType, Device};
|
||||
use candle_nn::var_builder::ShardedSafeTensors;
|
||||
use candle_transformers::models::qwen3 as qwen3_dense;
|
||||
use std::path::PathBuf;
|
||||
|
||||
if self.models.contains_key(&model_id) {
|
||||
return WorkerResponse::Error {
|
||||
kind: "already_loaded".into(),
|
||||
message: format!("model '{model_id}' already loaded on this rank"),
|
||||
};
|
||||
}
|
||||
let comm = match self.nccl.comm() {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "nccl_not_initialised".into(),
|
||||
message: "LoadDenseShard requires Init to have completed first".into(),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let cfg: qwen3_dense::Config = match serde_json::from_str(&config_json) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "bad_request".into(),
|
||||
message: format!("parse Qwen3 Config JSON: {e}"),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let device = match Device::new_cuda(self.config.cuda_device as usize) {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "cuda_unavailable".into(),
|
||||
message: format!("Device::new_cuda({}) failed: {e}", self.config.cuda_device),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let paths: Vec<PathBuf> = safetensors_paths.into_iter().map(PathBuf::from).collect();
|
||||
// SAFETY: same invariant as the single-GPU dense path — the HF
|
||||
// cache files are treated as immutable while the mmap is held.
|
||||
let vb = match unsafe { ShardedSafeTensors::var_builder(&paths, DType::BF16, &device) } {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "load_failed".into(),
|
||||
message: format!("ShardedSafeTensors::var_builder: {e}"),
|
||||
};
|
||||
}
|
||||
};
|
||||
let model = match TpQwen3ForCausalLM::load(
|
||||
&cfg,
|
||||
&vb,
|
||||
self.config.rank,
|
||||
self.config.world_size,
|
||||
comm,
|
||||
) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "load_failed".into(),
|
||||
message: format!("TpQwen3ForCausalLM::load: {e:#}"),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
self.models.insert(model_id.clone(), model);
|
||||
tracing::info!(rank = self.config.rank, model = %model_id, "loaded TP shard");
|
||||
WorkerResponse::LoadDenseShardOk
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn handle_load_dense_shard(
|
||||
&mut self,
|
||||
_model_id: String,
|
||||
_config_json: String,
|
||||
_safetensors_paths: Vec<String>,
|
||||
) -> WorkerResponse {
|
||||
WorkerResponse::Error {
|
||||
kind: "cuda_feature_not_enabled".into(),
|
||||
message: "LoadDenseShard requires --features cuda".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn handle_generate_step(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
) -> WorkerResponse {
|
||||
use candle_core::Tensor;
|
||||
|
||||
let Some(model) = self.models.get_mut(model_id) else {
|
||||
return WorkerResponse::Error {
|
||||
kind: "model_not_loaded".into(),
|
||||
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||
};
|
||||
};
|
||||
let device = model.device().clone();
|
||||
let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "forward_failed".into(),
|
||||
message: format!("build input tensor: {e}"),
|
||||
};
|
||||
}
|
||||
};
|
||||
// Drop the resulting logits — the leader uses its own copy from
|
||||
// rank 0. The forward's value here is the NCCL collectives it
|
||||
// issues, which let the leader's rank-0 forward make progress.
|
||||
if let Err(e) = model.forward(&input, offset) {
|
||||
return WorkerResponse::Error {
|
||||
kind: "forward_failed".into(),
|
||||
message: format!("TpQwen3ForCausalLM::forward: {e}"),
|
||||
};
|
||||
}
|
||||
WorkerResponse::GenerateStepOk
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn handle_generate_step(
|
||||
&mut self,
|
||||
_model_id: &str,
|
||||
_tokens: Vec<u32>,
|
||||
_offset: usize,
|
||||
) -> WorkerResponse {
|
||||
WorkerResponse::Error {
|
||||
kind: "cuda_feature_not_enabled".into(),
|
||||
message: "GenerateStep requires --features cuda".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse {
|
||||
let Some(model) = self.models.get_mut(model_id) else {
|
||||
return WorkerResponse::Error {
|
||||
kind: "model_not_loaded".into(),
|
||||
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||
};
|
||||
};
|
||||
model.clear_kv_cache();
|
||||
WorkerResponse::KvCacheCleared
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn handle_clear_kv_cache(&mut self, _model_id: &str) -> WorkerResponse {
|
||||
WorkerResponse::Error {
|
||||
kind: "cuda_feature_not_enabled".into(),
|
||||
message: "ClearKvCache requires --features cuda".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn handle_unload_model(&mut self, model_id: &str) -> WorkerResponse {
|
||||
if self.models.remove(model_id).is_none() {
|
||||
return WorkerResponse::Error {
|
||||
kind: "model_not_loaded".into(),
|
||||
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||
};
|
||||
}
|
||||
tracing::info!(rank = self.config.rank, model = %model_id, "unloaded TP shard");
|
||||
WorkerResponse::Unloaded
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn handle_unload_model(&mut self, _model_id: &str) -> WorkerResponse {
|
||||
WorkerResponse::Error {
|
||||
kind: "cuda_feature_not_enabled".into(),
|
||||
message: "UnloadModel requires --features cuda".into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,14 +9,15 @@
|
||||
# after pushing new neuron builds.
|
||||
#
|
||||
# Usage:
|
||||
# script/validate-neuron.sh [host] [model_id] [quant]
|
||||
# script/validate-neuron.sh [host] [model_id] [quant] [tp_size]
|
||||
#
|
||||
# Defaults:
|
||||
# host = beast.hanzalova.internal
|
||||
# model_id = unsloth/Qwen3-0.6B-GGUF (official Qwen3-*-GGUF repos
|
||||
# ship Q8_0 only; unsloth's mirror ships the full Q-spectrum
|
||||
# including Q4_K_M)
|
||||
# quant = Q4_K_M
|
||||
# quant = Q4_K_M (empty = dense safetensors path)
|
||||
# tp_size = unset (= 1 = single-GPU; pass 2 to drive the TP path)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
@@ -25,6 +26,11 @@ MODEL_ID="${2:-unsloth/Qwen3-0.6B-GGUF}"
|
||||
# `${3-Q4_K_M}` (no colon) only uses the default when the arg is
|
||||
# UNSET — passing an explicit empty string drives the dense path.
|
||||
QUANT="${3-Q4_K_M}"
|
||||
# tp_size > 1 forces the dense path (TP requires safetensors) and adds
|
||||
# `tensor_parallel: N` to the load payload. The harness picks device
|
||||
# indices 0..N-1 by default; override by passing NEURON_DEVICES="0,1,..."
|
||||
# in the environment.
|
||||
TP_SIZE="${4-1}"
|
||||
PORT="${NEURON_PORT:-13131}"
|
||||
BASE="http://${HOST}:${PORT}"
|
||||
|
||||
@@ -69,21 +75,43 @@ is_loaded() {
|
||||
}
|
||||
|
||||
trigger_load() {
|
||||
say "POST /models/load ${MODEL_ID} (quant=${QUANT:-<dense>}, device=[0])"
|
||||
# Build the per-rank CUDA device list as a JSON array. Either
|
||||
# honour NEURON_DEVICES (`0,1,2`) verbatim or default to
|
||||
# `[0, 1, ..., tp_size - 1]`.
|
||||
local devices_json
|
||||
if [[ -n "${NEURON_DEVICES:-}" ]]; then
|
||||
devices_json=$(jq -n -c --arg s "${NEURON_DEVICES}" \
|
||||
'$s | split(",") | map(tonumber)')
|
||||
else
|
||||
devices_json=$(jq -n -c --argjson n "${TP_SIZE}" '[range(0; $n)]')
|
||||
fi
|
||||
say "POST /models/load ${MODEL_ID} (quant=${QUANT:-<dense>}, tp=${TP_SIZE}, devices=${devices_json})"
|
||||
say " (synchronous; may take a minute on first run while HF downloads)"
|
||||
# Build the payload via jq so the optional `quant` field is
|
||||
# omitted entirely when empty — that's the signal to the harness
|
||||
# to take the dense safetensors load path rather than GGUF.
|
||||
if (( TP_SIZE > 1 )) && [[ -n "${QUANT}" ]]; then
|
||||
die "tp_size>1 requires dense safetensors — pass quant='' as the 3rd argument"
|
||||
fi
|
||||
# Build the payload via jq so the optional `quant` and
|
||||
# `tensor_parallel` fields are omitted entirely when not in use —
|
||||
# that's how the harness tells dense from quantized and single-GPU
|
||||
# from TP.
|
||||
local payload
|
||||
if [[ -z "${QUANT}" ]]; then
|
||||
if [[ -z "${QUANT}" ]] && (( TP_SIZE > 1 )); then
|
||||
payload=$(jq -n -c \
|
||||
--arg id "${MODEL_ID}" \
|
||||
'{model_id: $id, harness: "candle", devices: [0]}')
|
||||
--argjson tp "${TP_SIZE}" \
|
||||
--argjson devices "${devices_json}" \
|
||||
'{model_id: $id, harness: "candle", tensor_parallel: $tp, devices: $devices}')
|
||||
elif [[ -z "${QUANT}" ]]; then
|
||||
payload=$(jq -n -c \
|
||||
--arg id "${MODEL_ID}" \
|
||||
--argjson devices "${devices_json}" \
|
||||
'{model_id: $id, harness: "candle", devices: $devices}')
|
||||
else
|
||||
payload=$(jq -n -c \
|
||||
--arg id "${MODEL_ID}" \
|
||||
--arg q "${QUANT}" \
|
||||
'{model_id: $id, harness: "candle", quant: $q, devices: [0]}')
|
||||
--argjson devices "${devices_json}" \
|
||||
'{model_id: $id, harness: "candle", quant: $q, devices: $devices}')
|
||||
fi
|
||||
# --write-out captures the response code on a separate line so we
|
||||
# can surface a real diagnostic instead of relying on --fail.
|
||||
|
||||
Reference in New Issue
Block a user