feat(tp): Stage 7b-iv — RPC + orchestration for TP load/inference
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m20s
build-prerelease / Build cortex binary (push) Successful in 4m25s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
CI / Test (push) Successful in 4m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m57s
build-prerelease / Build neuron-ampere (push) Successful in 4m51s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m49s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m51s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s

Wires the in-flight TP machinery (Stage 7a workers, 7b-iii sharded
Qwen3) end to end so a non-streaming chat completion can run across
multiple GPUs via NCCL.

RPC additions (tp/rpc.rs):
- LoadDenseShard{model_id, config_json, safetensors_paths}
- GenerateStep{model_id, tokens, offset}
- ClearKvCache{model_id}
- UnloadModel{model_id}
- LoadDenseShardOk / GenerateStepOk / KvCacheCleared / Unloaded

Worker side (tp/worker.rs):
- WorkerState gains a `models: HashMap<String, TpQwen3ForCausalLM>`
  keyed by model_id. LoadDenseShard mmaps safetensors via
  ShardedVarBuilder (only this rank's slice materialises), builds the
  TP model with the rank's NCCL Comm cloned from NcclState.
- GenerateStep runs the rank-local forward; the resulting logits are
  dropped (only the leader's are used for sampling). The forward's
  value here is the NCCL collectives inside the row-parallel layers
  letting the leader's rank-0 forward make progress.

Pool side (tp/mod.rs):
- WorkerPool::load_dense_shard fans LoadDenseShard out to every worker,
  builds rank 0's shard on the leader via spawn_blocking with a fresh
  SendComm wrapper at the move boundary (Comm is !Send at the type
  level), collects per-rank LoadDenseShardOk. Returns the leader's
  Arc<Mutex<TpQwen3ForCausalLM>>.
- WorkerPool::generate_step fans GenerateStep out, runs the leader's
  rank-0 forward in spawn_blocking (the AllReduce CustomOps inside
  row-parallel layers block until every worker issues the matching
  collective), returns the leader's last-position logits Tensor.
- WorkerPool::clear_kv_cache + unload_model follow the same pattern.

NcclState refactor (tp/nccl_state.rs):
- comm field becomes Option<Arc<Comm>> (was Option<Comm>) so callers
  can share a clone with TpQwen3ForCausalLM::load.
- new `comm()` accessor + `SendComm` wrapper for spawn_blocking moves.
- single allow(clippy::arc_with_non_send_sync) at the canonical
  construction site (Comm is !Send by type but the runtime invariant
  is enforced by SendComm + the pool's Mutex).

Harness side (candle.rs):
- LoadedHandle enum (Single | Tp) replaces the bare Arc<LoadedModel>
  in the harness's registry. list_models / unload_model /
  inference_endpoint walk the enum uniformly.
- TpLoadedModel holds the pool + leader_model + tokenizer + devices.
- load_model dispatches on `spec.tensor_parallel > 1` to a new
  cuda-gated load_tp path: resolve dense files via hf-hub, spawn the
  pool, init_nccl, load_dense_shard.
- chat_completion branches on the handle variant. The TP path mirrors
  run_inference: clear_kv_cache, prefill, sample, decode loop,
  detokenize. Acquires the pool Mutex for the whole request.
- Streaming through TP is deferred to Stage 7c (returns Other(err)).

Script (script/validate-neuron.sh):
- 4th positional arg `tp_size` (default 1). When >1, switches to the
  dense path (tp + GGUF is mutually exclusive — bails) and adds
  `tensor_parallel` + `devices` to the load payload. NEURON_DEVICES
  env overrides the default 0..N-1 device list.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-20 06:38:33 +03:00
parent 9b8bd146f6
commit d46d8d4f6c
6 changed files with 960 additions and 40 deletions

View File

@@ -31,11 +31,44 @@ use tokio::sync::{Mutex, RwLock, mpsc};
/// In-process candle harness. Owns the loaded model registry.
pub struct CandleHarness {
models: Arc<RwLock<HashMap<String, Arc<LoadedModel>>>>,
models: Arc<RwLock<HashMap<String, LoadedHandle>>>,
hf_cache: Option<PathBuf>,
bind_url: String,
}
/// One entry in the harness's loaded-model registry. Single-GPU loads
/// land in `Single`; loads with `tensor_parallel > 1` land in `Tp`.
/// The two variants share the same `model_id` key in the map, so
/// `list_models`, `unload_model`, and `inference_endpoint` can walk
/// them uniformly without branching the storage layout.
///
/// `Clone` is cheap: both variants hold `Arc<_>` and cloning just bumps
/// the refcount.
#[derive(Clone)]
pub enum LoadedHandle {
Single(Arc<LoadedModel>),
#[cfg(feature = "cuda")]
Tp(Arc<TpLoadedModel>),
}
impl LoadedHandle {
pub fn model_id(&self) -> &str {
match self {
LoadedHandle::Single(m) => &m.model_id,
#[cfg(feature = "cuda")]
LoadedHandle::Tp(m) => &m.model_id,
}
}
pub fn devices(&self) -> Vec<u32> {
match self {
LoadedHandle::Single(m) => m.devices.clone(),
#[cfg(feature = "cuda")]
LoadedHandle::Tp(m) => m.devices.clone(),
}
}
}
/// A loaded model with its tokenizer, device placement, and architecture-
/// specific weights. The `arch` is `Arc<Mutex<>>` so the lock guard can be
/// moved into `spawn_blocking` for synchronous candle forward passes.
@@ -48,6 +81,25 @@ pub struct LoadedModel {
pub devices: Vec<u32>,
}
/// Tensor-parallel loaded model. Holds the leader's rank-0 shard
/// (which the inference loop drives via spawn_blocking) and the
/// `WorkerPool` (which drives every non-zero rank over the RPC
/// channel). Both are behind tokio Mutexes so concurrent inference
/// requests against the same model are serialised; concurrent loads
/// for *different* models would each have their own pool.
#[cfg(feature = "cuda")]
pub struct TpLoadedModel {
pub model_id: String,
pub tokenizer: Tokenizer,
pub devices: Vec<u32>,
/// One end-to-end gate: the pool's RPC stream isn't safe to use
/// concurrently and the leader shard's KV cache mutates with every
/// step. The same Mutex covers both for the simplest correctness
/// story.
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
pub leader_model: Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
}
/// Architecture-specific weights.
///
/// - `Qwen3Quantized` — GGUF source, pre-quantized. Single-GPU only;
@@ -357,11 +409,22 @@ impl CandleHarness {
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, InferenceError> {
let loaded = {
let handle = {
let models = self.models.read().await;
models.get(&request.model).cloned()
};
let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
// The match is technically infallible without `cuda` (only Single
// exists), but the cfg-gated Tp arm makes this the right shape
// under both feature flags.
#[allow(clippy::infallible_destructuring_match)]
let loaded = match handle {
LoadedHandle::Single(m) => m,
#[cfg(feature = "cuda")]
LoadedHandle::Tp(m) => {
return self.chat_completion_tp(m, request).await;
}
};
let prompt = format_qwen3_prompt(&request.messages);
@@ -451,11 +514,29 @@ impl CandleHarness {
&self,
request: ChatCompletionRequest,
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
let loaded = {
let handle = {
let models = self.models.read().await;
models.get(&request.model).cloned()
};
let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
// The match is technically infallible without `cuda` (only Single
// exists), but the cfg-gated Tp arm makes this the right shape
// under both feature flags.
#[allow(clippy::infallible_destructuring_match)]
let loaded = match handle {
LoadedHandle::Single(m) => m,
#[cfg(feature = "cuda")]
LoadedHandle::Tp(_) => {
// Streaming through TP is Stage 7c work — the
// non-streaming path drives the same forwards through
// the pool but doesn't have to interleave SSE writes
// with spawn_blocking forwards.
return Err(InferenceError::Other(anyhow::anyhow!(
"streaming chat completions through TP are not yet supported; \
retry with stream=false"
)));
}
};
let prompt = format_qwen3_prompt(&request.messages);
let encoding = loaded
@@ -552,11 +633,11 @@ impl Harness for CandleHarness {
let models = self.models.read().await;
Ok(models
.values()
.map(|m| ModelInfo {
id: m.model_id.clone(),
.map(|h| ModelInfo {
id: h.model_id().into(),
harness: "candle".into(),
status: "loaded".into(),
devices: m.devices.clone(),
devices: h.devices(),
vram_used_mb: None,
})
.collect())
@@ -574,19 +655,20 @@ impl Harness for CandleHarness {
}
}
// Stage 7a-i scaffolds tensor-parallel worker subprocesses but
// does not yet route inference through them. Refuse TP loads
// for now with a clear marker so the request surface is honest;
// Stage 7b-iv replaces this bail with the TP dispatch.
let tp_size = spec.tensor_parallel.unwrap_or(1);
if tp_size > 1 {
anyhow::bail!(
"tensor_parallel={tp_size} requested for '{}': TP worker \
lifecycle + NCCL handshake are in place (Stage 7a) but \
TP-aware Qwen3 inference orchestration lands in Stage \
7b-iv; single-GPU loads only for now",
spec.model_id
);
#[cfg(feature = "cuda")]
{
return self.load_tp(spec, tp_size).await;
}
#[cfg(not(feature = "cuda"))]
{
anyhow::bail!(
"tensor_parallel={tp_size} requested for '{}': this neuron \
binary was built without --features cuda; TP requires CUDA + NCCL",
spec.model_id
);
}
}
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
@@ -615,15 +697,52 @@ impl Harness for CandleHarness {
});
let mut models = self.models.write().await;
models.insert(spec.model_id.clone(), loaded);
models.insert(spec.model_id.clone(), LoadedHandle::Single(loaded));
tracing::info!(model = %spec.model_id, "model loaded");
Ok(())
}
async fn unload_model(&self, model_id: &str) -> Result<()> {
let mut models = self.models.write().await;
if models.remove(model_id).is_none() {
let removed = {
let mut models = self.models.write().await;
models.remove(model_id)
};
let Some(handle) = removed else {
anyhow::bail!("model '{model_id}' not loaded");
};
// Single-GPU drops are immediate — the LoadedModel goes out of
// scope with the Arc and candle frees VRAM. TP unloads also
// need to tell every worker to drop its shard before the pool
// itself is dropped (otherwise the workers keep their shards
// around until Shutdown, which is wasteful and would surface
// as VRAM not freed promptly).
match handle {
LoadedHandle::Single(_) => {}
#[cfg(feature = "cuda")]
LoadedHandle::Tp(tp) => {
// Try to recover the inner TpLoadedModel so we can move
// the pool and shut it down. If anyone else still holds
// a clone of the Arc (shouldn't happen — the only owners
// are the registry and any in-flight chat_completion),
// bail with a clear marker rather than silently leaking.
let tp = match Arc::try_unwrap(tp) {
Ok(t) => t,
Err(arc) => {
// Reinsert so we don't leave the registry in an
// inconsistent state.
let mut models = self.models.write().await;
models.insert(model_id.into(), LoadedHandle::Tp(arc));
anyhow::bail!("cannot unload '{model_id}': inference still in flight");
}
};
let mut pool = tp.pool.into_inner();
if let Err(e) = pool.unload_model(model_id).await {
tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed");
}
if let Err(e) = pool.shutdown().await {
tracing::warn!(model = %model_id, error = %e, "TP pool shutdown failed");
}
}
}
tracing::info!(model = %model_id, "model unloaded");
Ok(())
@@ -635,6 +754,215 @@ impl Harness for CandleHarness {
}
}
impl CandleHarness {
/// Tensor-parallel load. Resolves dense safetensors via hf-hub the
/// same way the single-GPU dense path does, spins up a TP worker
/// pool sized to `tp_size`, runs the NCCL handshake, then has
/// every rank load its shard of the model.
///
/// `spec.devices` carries the per-rank CUDA device indices (one
/// entry per rank, in rank order); defaults to `0..tp_size`.
#[cfg(feature = "cuda")]
async fn load_tp(&self, spec: &ModelSpec, tp_size: u32) -> Result<()> {
use std::sync::Arc as StdArc;
use tokio::sync::Mutex as TMutex;
// Default per-rank device assignment: 0, 1, ..., tp_size - 1.
let devices = spec
.devices
.clone()
.unwrap_or_else(|| (0..tp_size).collect());
if devices.len() as u32 != tp_size {
anyhow::bail!(
"tensor_parallel={tp_size} requires {tp_size} entries in devices, got {}",
devices.len()
);
}
if spec.quant.is_some() {
anyhow::bail!(
"tensor_parallel={tp_size} with quant={:?}: GGUF quantized models \
are not supported in the TP path; use a dense safetensors source",
spec.quant
);
}
// 1. Resolve config + tokenizer + safetensors via hf-hub.
let (config_path, tokenizer_path, safetensors_paths) =
self.resolve_dense_files(spec).await?;
let config_json = std::fs::read_to_string(&config_path).context("read config.json")?;
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
// 1..tp_size are subprocesses, one per device after the
// leader's own.
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
let mut pool = super::tp::WorkerPool::spawn(&exe, tp_size, &devices).await?;
// 3. NCCL handshake across all ranks.
let leader_device_idx = devices[0];
pool.init_nccl(leader_device_idx).await?;
// 4. Pick the leader's candle Device (same index as init_nccl).
let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize)
.context("Device::new_cuda for TP leader")?;
// 5. Load this rank's shard on every rank.
let leader_model = pool
.load_dense_shard(
&spec.model_id,
&config_json,
&safetensors_paths,
&leader_device,
candle_core::DType::BF16,
)
.await?;
// 6. Tokenizer (same as single-GPU path).
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
let tp_loaded = StdArc::new(TpLoadedModel {
model_id: spec.model_id.clone(),
tokenizer,
devices: devices.clone(),
pool: TMutex::new(pool),
leader_model,
});
let mut models = self.models.write().await;
models.insert(spec.model_id.clone(), LoadedHandle::Tp(tp_loaded));
tracing::info!(
model = %spec.model_id,
tp_size,
?devices,
"TP model loaded"
);
Ok(())
}
/// Non-streaming chat completion against a TP model. Pattern mirrors
/// the single-GPU `run_inference`: tokenize, prefill, sample, decode
/// loop, detokenize. Each forward step fans out to every rank via
/// the WorkerPool and uses the leader's last-position logits to
/// sample.
#[cfg(feature = "cuda")]
async fn chat_completion_tp(
&self,
tp: Arc<TpLoadedModel>,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, InferenceError> {
let prompt = format_qwen3_prompt(&request.messages);
let encoding = tp
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let prompt_len = prompt_tokens.len();
let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(512) as usize;
let seed = unix_subsec_nanos();
let eos_id = tp
.tokenizer
.token_to_id("<|im_end|>")
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
let model_id = request.model.clone();
// Acquire the pool lock for the duration of the request. The
// leader_model's own Mutex is acquired step-by-step inside
// pool.generate_step (so spawn_blocking can grab it without
// holding the pool lock across the blocking_lock call).
let mut pool = tp.pool.lock().await;
let leader_arc = tp.leader_model.clone();
// Reset every rank's KV cache so this request doesn't attend
// over the previous request's tokens.
pool.clear_kv_cache(&model_id, leader_arc.clone())
.await
.map_err(InferenceError::Other)?;
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut generated: Vec<u32> = Vec::new();
let mut finish_reason = "length".to_string();
// Prefill: every rank embeds the whole prompt, offset = 0.
let logits = pool
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
.await
.map_err(InferenceError::Other)?;
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
.map_err(InferenceError::Other)?;
if Some(next_token) == eos_id {
finish_reason = "stop".into();
} else {
generated.push(next_token);
for index in 0..max_new.saturating_sub(1) {
let logits = pool
.generate_step(
&model_id,
leader_arc.clone(),
vec![next_token],
prompt_len + index,
)
.await
.map_err(InferenceError::Other)?;
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
.map_err(InferenceError::Other)?;
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
}
generated.push(next_token);
}
}
drop(pool);
let completion_text = tp
.tokenizer
.decode(&generated, true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
let usage = Usage {
prompt_tokens: prompt_len as u64,
completion_tokens: generated.len() as u64,
total_tokens: (prompt_len + generated.len()) as u64,
};
Ok(ChatCompletionResponse {
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
object: "chat.completion".into(),
created: unix_now_secs(),
model: model_id,
choices: vec![ChatCompletionChoice {
index: 0,
message: ChatMessage {
role: "assistant".into(),
content: MessageContent::Text(completion_text),
extra: serde_json::Value::Object(Default::default()),
},
finish_reason: Some(finish_reason),
extra: serde_json::Value::Object(Default::default()),
}],
usage: Some(usage),
extra: serde_json::Value::Object(Default::default()),
})
}
}
/// Errors returned by `CandleHarness::chat_completion`. The
/// `ModelNotLoaded` variant lets the HTTP handler map cleanly to 404
/// without string-matching on anyhow messages.