feat(tp): Stage 7b-iv — RPC + orchestration for TP load/inference
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m20s
build-prerelease / Build cortex binary (push) Successful in 4m25s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
CI / Test (push) Successful in 4m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m57s
build-prerelease / Build neuron-ampere (push) Successful in 4m51s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m49s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m51s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m20s
build-prerelease / Build cortex binary (push) Successful in 4m25s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
CI / Test (push) Successful in 4m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m57s
build-prerelease / Build neuron-ampere (push) Successful in 4m51s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m49s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m51s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
Wires the in-flight TP machinery (Stage 7a workers, 7b-iii sharded
Qwen3) end to end so a non-streaming chat completion can run across
multiple GPUs via NCCL.
RPC additions (tp/rpc.rs):
- LoadDenseShard{model_id, config_json, safetensors_paths}
- GenerateStep{model_id, tokens, offset}
- ClearKvCache{model_id}
- UnloadModel{model_id}
- LoadDenseShardOk / GenerateStepOk / KvCacheCleared / Unloaded
Worker side (tp/worker.rs):
- WorkerState gains a `models: HashMap<String, TpQwen3ForCausalLM>`
keyed by model_id. LoadDenseShard mmaps safetensors via
ShardedVarBuilder (only this rank's slice materialises), builds the
TP model with the rank's NCCL Comm cloned from NcclState.
- GenerateStep runs the rank-local forward; the resulting logits are
dropped (only the leader's are used for sampling). The forward's
value here is the NCCL collectives inside the row-parallel layers
letting the leader's rank-0 forward make progress.
Pool side (tp/mod.rs):
- WorkerPool::load_dense_shard fans LoadDenseShard out to every worker,
builds rank 0's shard on the leader via spawn_blocking with a fresh
SendComm wrapper at the move boundary (Comm is !Send at the type
level), collects per-rank LoadDenseShardOk. Returns the leader's
Arc<Mutex<TpQwen3ForCausalLM>>.
- WorkerPool::generate_step fans GenerateStep out, runs the leader's
rank-0 forward in spawn_blocking (the AllReduce CustomOps inside
row-parallel layers block until every worker issues the matching
collective), returns the leader's last-position logits Tensor.
- WorkerPool::clear_kv_cache + unload_model follow the same pattern.
NcclState refactor (tp/nccl_state.rs):
- comm field becomes Option<Arc<Comm>> (was Option<Comm>) so callers
can share a clone with TpQwen3ForCausalLM::load.
- new `comm()` accessor + `SendComm` wrapper for spawn_blocking moves.
- single allow(clippy::arc_with_non_send_sync) at the canonical
construction site (Comm is !Send by type but the runtime invariant
is enforced by SendComm + the pool's Mutex).
Harness side (candle.rs):
- LoadedHandle enum (Single | Tp) replaces the bare Arc<LoadedModel>
in the harness's registry. list_models / unload_model /
inference_endpoint walk the enum uniformly.
- TpLoadedModel holds the pool + leader_model + tokenizer + devices.
- load_model dispatches on `spec.tensor_parallel > 1` to a new
cuda-gated load_tp path: resolve dense files via hf-hub, spawn the
pool, init_nccl, load_dense_shard.
- chat_completion branches on the handle variant. The TP path mirrors
run_inference: clear_kv_cache, prefill, sample, decode loop,
detokenize. Acquires the pool Mutex for the whole request.
- Streaming through TP is deferred to Stage 7c (returns Other(err)).
Script (script/validate-neuron.sh):
- 4th positional arg `tp_size` (default 1). When >1, switches to the
dense path (tp + GGUF is mutually exclusive — bails) and adds
`tensor_parallel` + `devices` to the load payload. NEURON_DEVICES
env overrides the default 0..N-1 device list.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -31,11 +31,44 @@ use tokio::sync::{Mutex, RwLock, mpsc};
|
|||||||
|
|
||||||
/// In-process candle harness. Owns the loaded model registry.
|
/// In-process candle harness. Owns the loaded model registry.
|
||||||
pub struct CandleHarness {
|
pub struct CandleHarness {
|
||||||
models: Arc<RwLock<HashMap<String, Arc<LoadedModel>>>>,
|
models: Arc<RwLock<HashMap<String, LoadedHandle>>>,
|
||||||
hf_cache: Option<PathBuf>,
|
hf_cache: Option<PathBuf>,
|
||||||
bind_url: String,
|
bind_url: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// One entry in the harness's loaded-model registry. Single-GPU loads
|
||||||
|
/// land in `Single`; loads with `tensor_parallel > 1` land in `Tp`.
|
||||||
|
/// The two variants share the same `model_id` key in the map, so
|
||||||
|
/// `list_models`, `unload_model`, and `inference_endpoint` can walk
|
||||||
|
/// them uniformly without branching the storage layout.
|
||||||
|
///
|
||||||
|
/// `Clone` is cheap: both variants hold `Arc<_>` and cloning just bumps
|
||||||
|
/// the refcount.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum LoadedHandle {
|
||||||
|
Single(Arc<LoadedModel>),
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Tp(Arc<TpLoadedModel>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LoadedHandle {
|
||||||
|
pub fn model_id(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
LoadedHandle::Single(m) => &m.model_id,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
LoadedHandle::Tp(m) => &m.model_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn devices(&self) -> Vec<u32> {
|
||||||
|
match self {
|
||||||
|
LoadedHandle::Single(m) => m.devices.clone(),
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
LoadedHandle::Tp(m) => m.devices.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// A loaded model with its tokenizer, device placement, and architecture-
|
/// A loaded model with its tokenizer, device placement, and architecture-
|
||||||
/// specific weights. The `arch` is `Arc<Mutex<>>` so the lock guard can be
|
/// specific weights. The `arch` is `Arc<Mutex<>>` so the lock guard can be
|
||||||
/// moved into `spawn_blocking` for synchronous candle forward passes.
|
/// moved into `spawn_blocking` for synchronous candle forward passes.
|
||||||
@@ -48,6 +81,25 @@ pub struct LoadedModel {
|
|||||||
pub devices: Vec<u32>,
|
pub devices: Vec<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Tensor-parallel loaded model. Holds the leader's rank-0 shard
|
||||||
|
/// (which the inference loop drives via spawn_blocking) and the
|
||||||
|
/// `WorkerPool` (which drives every non-zero rank over the RPC
|
||||||
|
/// channel). Both are behind tokio Mutexes so concurrent inference
|
||||||
|
/// requests against the same model are serialised; concurrent loads
|
||||||
|
/// for *different* models would each have their own pool.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub struct TpLoadedModel {
|
||||||
|
pub model_id: String,
|
||||||
|
pub tokenizer: Tokenizer,
|
||||||
|
pub devices: Vec<u32>,
|
||||||
|
/// One end-to-end gate: the pool's RPC stream isn't safe to use
|
||||||
|
/// concurrently and the leader shard's KV cache mutates with every
|
||||||
|
/// step. The same Mutex covers both for the simplest correctness
|
||||||
|
/// story.
|
||||||
|
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
|
||||||
|
pub leader_model: Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Architecture-specific weights.
|
/// Architecture-specific weights.
|
||||||
///
|
///
|
||||||
/// - `Qwen3Quantized` — GGUF source, pre-quantized. Single-GPU only;
|
/// - `Qwen3Quantized` — GGUF source, pre-quantized. Single-GPU only;
|
||||||
@@ -357,11 +409,22 @@ impl CandleHarness {
|
|||||||
&self,
|
&self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<ChatCompletionResponse, InferenceError> {
|
) -> Result<ChatCompletionResponse, InferenceError> {
|
||||||
let loaded = {
|
let handle = {
|
||||||
let models = self.models.read().await;
|
let models = self.models.read().await;
|
||||||
models.get(&request.model).cloned()
|
models.get(&request.model).cloned()
|
||||||
};
|
};
|
||||||
let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
||||||
|
// The match is technically infallible without `cuda` (only Single
|
||||||
|
// exists), but the cfg-gated Tp arm makes this the right shape
|
||||||
|
// under both feature flags.
|
||||||
|
#[allow(clippy::infallible_destructuring_match)]
|
||||||
|
let loaded = match handle {
|
||||||
|
LoadedHandle::Single(m) => m,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
LoadedHandle::Tp(m) => {
|
||||||
|
return self.chat_completion_tp(m, request).await;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let prompt = format_qwen3_prompt(&request.messages);
|
let prompt = format_qwen3_prompt(&request.messages);
|
||||||
|
|
||||||
@@ -451,11 +514,29 @@ impl CandleHarness {
|
|||||||
&self,
|
&self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
||||||
let loaded = {
|
let handle = {
|
||||||
let models = self.models.read().await;
|
let models = self.models.read().await;
|
||||||
models.get(&request.model).cloned()
|
models.get(&request.model).cloned()
|
||||||
};
|
};
|
||||||
let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
||||||
|
// The match is technically infallible without `cuda` (only Single
|
||||||
|
// exists), but the cfg-gated Tp arm makes this the right shape
|
||||||
|
// under both feature flags.
|
||||||
|
#[allow(clippy::infallible_destructuring_match)]
|
||||||
|
let loaded = match handle {
|
||||||
|
LoadedHandle::Single(m) => m,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
LoadedHandle::Tp(_) => {
|
||||||
|
// Streaming through TP is Stage 7c work — the
|
||||||
|
// non-streaming path drives the same forwards through
|
||||||
|
// the pool but doesn't have to interleave SSE writes
|
||||||
|
// with spawn_blocking forwards.
|
||||||
|
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||||
|
"streaming chat completions through TP are not yet supported; \
|
||||||
|
retry with stream=false"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let prompt = format_qwen3_prompt(&request.messages);
|
let prompt = format_qwen3_prompt(&request.messages);
|
||||||
let encoding = loaded
|
let encoding = loaded
|
||||||
@@ -552,11 +633,11 @@ impl Harness for CandleHarness {
|
|||||||
let models = self.models.read().await;
|
let models = self.models.read().await;
|
||||||
Ok(models
|
Ok(models
|
||||||
.values()
|
.values()
|
||||||
.map(|m| ModelInfo {
|
.map(|h| ModelInfo {
|
||||||
id: m.model_id.clone(),
|
id: h.model_id().into(),
|
||||||
harness: "candle".into(),
|
harness: "candle".into(),
|
||||||
status: "loaded".into(),
|
status: "loaded".into(),
|
||||||
devices: m.devices.clone(),
|
devices: h.devices(),
|
||||||
vram_used_mb: None,
|
vram_used_mb: None,
|
||||||
})
|
})
|
||||||
.collect())
|
.collect())
|
||||||
@@ -574,20 +655,21 @@ impl Harness for CandleHarness {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stage 7a-i scaffolds tensor-parallel worker subprocesses but
|
|
||||||
// does not yet route inference through them. Refuse TP loads
|
|
||||||
// for now with a clear marker so the request surface is honest;
|
|
||||||
// Stage 7b-iv replaces this bail with the TP dispatch.
|
|
||||||
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
||||||
if tp_size > 1 {
|
if tp_size > 1 {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
return self.load_tp(spec, tp_size).await;
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
{
|
||||||
anyhow::bail!(
|
anyhow::bail!(
|
||||||
"tensor_parallel={tp_size} requested for '{}': TP worker \
|
"tensor_parallel={tp_size} requested for '{}': this neuron \
|
||||||
lifecycle + NCCL handshake are in place (Stage 7a) but \
|
binary was built without --features cuda; TP requires CUDA + NCCL",
|
||||||
TP-aware Qwen3 inference orchestration lands in Stage \
|
|
||||||
7b-iv; single-GPU loads only for now",
|
|
||||||
spec.model_id
|
spec.model_id
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
|
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
|
||||||
let device = Self::pick_device(&devices)?;
|
let device = Self::pick_device(&devices)?;
|
||||||
@@ -615,15 +697,52 @@ impl Harness for CandleHarness {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let mut models = self.models.write().await;
|
let mut models = self.models.write().await;
|
||||||
models.insert(spec.model_id.clone(), loaded);
|
models.insert(spec.model_id.clone(), LoadedHandle::Single(loaded));
|
||||||
tracing::info!(model = %spec.model_id, "model loaded");
|
tracing::info!(model = %spec.model_id, "model loaded");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn unload_model(&self, model_id: &str) -> Result<()> {
|
async fn unload_model(&self, model_id: &str) -> Result<()> {
|
||||||
|
let removed = {
|
||||||
let mut models = self.models.write().await;
|
let mut models = self.models.write().await;
|
||||||
if models.remove(model_id).is_none() {
|
models.remove(model_id)
|
||||||
|
};
|
||||||
|
let Some(handle) = removed else {
|
||||||
anyhow::bail!("model '{model_id}' not loaded");
|
anyhow::bail!("model '{model_id}' not loaded");
|
||||||
|
};
|
||||||
|
// Single-GPU drops are immediate — the LoadedModel goes out of
|
||||||
|
// scope with the Arc and candle frees VRAM. TP unloads also
|
||||||
|
// need to tell every worker to drop its shard before the pool
|
||||||
|
// itself is dropped (otherwise the workers keep their shards
|
||||||
|
// around until Shutdown, which is wasteful and would surface
|
||||||
|
// as VRAM not freed promptly).
|
||||||
|
match handle {
|
||||||
|
LoadedHandle::Single(_) => {}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
LoadedHandle::Tp(tp) => {
|
||||||
|
// Try to recover the inner TpLoadedModel so we can move
|
||||||
|
// the pool and shut it down. If anyone else still holds
|
||||||
|
// a clone of the Arc (shouldn't happen — the only owners
|
||||||
|
// are the registry and any in-flight chat_completion),
|
||||||
|
// bail with a clear marker rather than silently leaking.
|
||||||
|
let tp = match Arc::try_unwrap(tp) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(arc) => {
|
||||||
|
// Reinsert so we don't leave the registry in an
|
||||||
|
// inconsistent state.
|
||||||
|
let mut models = self.models.write().await;
|
||||||
|
models.insert(model_id.into(), LoadedHandle::Tp(arc));
|
||||||
|
anyhow::bail!("cannot unload '{model_id}': inference still in flight");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut pool = tp.pool.into_inner();
|
||||||
|
if let Err(e) = pool.unload_model(model_id).await {
|
||||||
|
tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed");
|
||||||
|
}
|
||||||
|
if let Err(e) = pool.shutdown().await {
|
||||||
|
tracing::warn!(model = %model_id, error = %e, "TP pool shutdown failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
tracing::info!(model = %model_id, "model unloaded");
|
tracing::info!(model = %model_id, "model unloaded");
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -635,6 +754,215 @@ impl Harness for CandleHarness {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl CandleHarness {
|
||||||
|
/// Tensor-parallel load. Resolves dense safetensors via hf-hub the
|
||||||
|
/// same way the single-GPU dense path does, spins up a TP worker
|
||||||
|
/// pool sized to `tp_size`, runs the NCCL handshake, then has
|
||||||
|
/// every rank load its shard of the model.
|
||||||
|
///
|
||||||
|
/// `spec.devices` carries the per-rank CUDA device indices (one
|
||||||
|
/// entry per rank, in rank order); defaults to `0..tp_size`.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
async fn load_tp(&self, spec: &ModelSpec, tp_size: u32) -> Result<()> {
|
||||||
|
use std::sync::Arc as StdArc;
|
||||||
|
use tokio::sync::Mutex as TMutex;
|
||||||
|
|
||||||
|
// Default per-rank device assignment: 0, 1, ..., tp_size - 1.
|
||||||
|
let devices = spec
|
||||||
|
.devices
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| (0..tp_size).collect());
|
||||||
|
if devices.len() as u32 != tp_size {
|
||||||
|
anyhow::bail!(
|
||||||
|
"tensor_parallel={tp_size} requires {tp_size} entries in devices, got {}",
|
||||||
|
devices.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if spec.quant.is_some() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"tensor_parallel={tp_size} with quant={:?}: GGUF quantized models \
|
||||||
|
are not supported in the TP path; use a dense safetensors source",
|
||||||
|
spec.quant
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Resolve config + tokenizer + safetensors via hf-hub.
|
||||||
|
let (config_path, tokenizer_path, safetensors_paths) =
|
||||||
|
self.resolve_dense_files(spec).await?;
|
||||||
|
let config_json = std::fs::read_to_string(&config_path).context("read config.json")?;
|
||||||
|
|
||||||
|
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
|
||||||
|
// 1..tp_size are subprocesses, one per device after the
|
||||||
|
// leader's own.
|
||||||
|
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
|
||||||
|
let mut pool = super::tp::WorkerPool::spawn(&exe, tp_size, &devices).await?;
|
||||||
|
|
||||||
|
// 3. NCCL handshake across all ranks.
|
||||||
|
let leader_device_idx = devices[0];
|
||||||
|
pool.init_nccl(leader_device_idx).await?;
|
||||||
|
|
||||||
|
// 4. Pick the leader's candle Device (same index as init_nccl).
|
||||||
|
let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize)
|
||||||
|
.context("Device::new_cuda for TP leader")?;
|
||||||
|
|
||||||
|
// 5. Load this rank's shard on every rank.
|
||||||
|
let leader_model = pool
|
||||||
|
.load_dense_shard(
|
||||||
|
&spec.model_id,
|
||||||
|
&config_json,
|
||||||
|
&safetensors_paths,
|
||||||
|
&leader_device,
|
||||||
|
candle_core::DType::BF16,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// 6. Tokenizer (same as single-GPU path).
|
||||||
|
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
||||||
|
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
||||||
|
|
||||||
|
let tp_loaded = StdArc::new(TpLoadedModel {
|
||||||
|
model_id: spec.model_id.clone(),
|
||||||
|
tokenizer,
|
||||||
|
devices: devices.clone(),
|
||||||
|
pool: TMutex::new(pool),
|
||||||
|
leader_model,
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut models = self.models.write().await;
|
||||||
|
models.insert(spec.model_id.clone(), LoadedHandle::Tp(tp_loaded));
|
||||||
|
tracing::info!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
tp_size,
|
||||||
|
?devices,
|
||||||
|
"TP model loaded"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Non-streaming chat completion against a TP model. Pattern mirrors
|
||||||
|
/// the single-GPU `run_inference`: tokenize, prefill, sample, decode
|
||||||
|
/// loop, detokenize. Each forward step fans out to every rank via
|
||||||
|
/// the WorkerPool and uses the leader's last-position logits to
|
||||||
|
/// sample.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
async fn chat_completion_tp(
|
||||||
|
&self,
|
||||||
|
tp: Arc<TpLoadedModel>,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Result<ChatCompletionResponse, InferenceError> {
|
||||||
|
let prompt = format_qwen3_prompt(&request.messages);
|
||||||
|
let encoding = tp
|
||||||
|
.tokenizer
|
||||||
|
.encode(prompt.as_str(), true)
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||||||
|
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||||||
|
let prompt_len = prompt_tokens.len();
|
||||||
|
|
||||||
|
let temperature = request.temperature.unwrap_or(0.7);
|
||||||
|
let top_p = request.top_p;
|
||||||
|
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
||||||
|
let seed = unix_subsec_nanos();
|
||||||
|
|
||||||
|
let eos_id = tp
|
||||||
|
.tokenizer
|
||||||
|
.token_to_id("<|im_end|>")
|
||||||
|
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
|
||||||
|
|
||||||
|
let model_id = request.model.clone();
|
||||||
|
|
||||||
|
// Acquire the pool lock for the duration of the request. The
|
||||||
|
// leader_model's own Mutex is acquired step-by-step inside
|
||||||
|
// pool.generate_step (so spawn_blocking can grab it without
|
||||||
|
// holding the pool lock across the blocking_lock call).
|
||||||
|
let mut pool = tp.pool.lock().await;
|
||||||
|
let leader_arc = tp.leader_model.clone();
|
||||||
|
|
||||||
|
// Reset every rank's KV cache so this request doesn't attend
|
||||||
|
// over the previous request's tokens.
|
||||||
|
pool.clear_kv_cache(&model_id, leader_arc.clone())
|
||||||
|
.await
|
||||||
|
.map_err(InferenceError::Other)?;
|
||||||
|
|
||||||
|
let mut logits_processor = {
|
||||||
|
let sampling = if temperature <= 0.0 {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match top_p {
|
||||||
|
Some(p) => Sampling::TopP { p, temperature },
|
||||||
|
None => Sampling::All { temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut generated: Vec<u32> = Vec::new();
|
||||||
|
let mut finish_reason = "length".to_string();
|
||||||
|
|
||||||
|
// Prefill: every rank embeds the whole prompt, offset = 0.
|
||||||
|
let logits = pool
|
||||||
|
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
||||||
|
.await
|
||||||
|
.map_err(InferenceError::Other)?;
|
||||||
|
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
||||||
|
.map_err(InferenceError::Other)?;
|
||||||
|
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
finish_reason = "stop".into();
|
||||||
|
} else {
|
||||||
|
generated.push(next_token);
|
||||||
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
|
let logits = pool
|
||||||
|
.generate_step(
|
||||||
|
&model_id,
|
||||||
|
leader_arc.clone(),
|
||||||
|
vec![next_token],
|
||||||
|
prompt_len + index,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(InferenceError::Other)?;
|
||||||
|
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
||||||
|
.map_err(InferenceError::Other)?;
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
finish_reason = "stop".into();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
generated.push(next_token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
drop(pool);
|
||||||
|
|
||||||
|
let completion_text = tp
|
||||||
|
.tokenizer
|
||||||
|
.decode(&generated, true)
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
|
||||||
|
|
||||||
|
let usage = Usage {
|
||||||
|
prompt_tokens: prompt_len as u64,
|
||||||
|
completion_tokens: generated.len() as u64,
|
||||||
|
total_tokens: (prompt_len + generated.len()) as u64,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(ChatCompletionResponse {
|
||||||
|
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
|
||||||
|
object: "chat.completion".into(),
|
||||||
|
created: unix_now_secs(),
|
||||||
|
model: model_id,
|
||||||
|
choices: vec![ChatCompletionChoice {
|
||||||
|
index: 0,
|
||||||
|
message: ChatMessage {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: MessageContent::Text(completion_text),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
},
|
||||||
|
finish_reason: Some(finish_reason),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}],
|
||||||
|
usage: Some(usage),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Errors returned by `CandleHarness::chat_completion`. The
|
/// Errors returned by `CandleHarness::chat_completion`. The
|
||||||
/// `ModelNotLoaded` variant lets the HTTP handler map cleanly to 404
|
/// `ModelNotLoaded` variant lets the HTTP handler map cleanly to 404
|
||||||
/// without string-matching on anyhow messages.
|
/// without string-matching on anyhow messages.
|
||||||
|
|||||||
@@ -338,6 +338,241 @@ impl WorkerPool {
|
|||||||
Ok(out)
|
Ok(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Load this rank's shard of a dense Qwen3 model on every rank.
|
||||||
|
///
|
||||||
|
/// The leader builds rank 0's `TpQwen3ForCausalLM` directly into
|
||||||
|
/// the returned `Arc<Mutex<_>>` — workers build their rank-local
|
||||||
|
/// shards in their own address spaces and confirm via
|
||||||
|
/// `LoadDenseShardOk`. All ranks see the same `safetensors_paths`;
|
||||||
|
/// `ShardedVarBuilder` slices each tensor by rank at materialisation
|
||||||
|
/// time, so the per-rank VRAM footprint is roughly `1/world_size`
|
||||||
|
/// of the full model (plus the replicated embedding/norm/lm_head).
|
||||||
|
///
|
||||||
|
/// `leader_device` is the candle `Device` the leader's shard lives
|
||||||
|
/// on — typically `Device::new_cuda(leader_cuda_device)` matching
|
||||||
|
/// the same index passed to `init_nccl`. `dtype` is the on-device
|
||||||
|
/// element type; bf16 is the canonical Qwen3 distribution dtype.
|
||||||
|
///
|
||||||
|
/// `init_nccl` must have completed first. Bails if the leader's
|
||||||
|
/// NCCL comm isn't set up yet.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn load_dense_shard(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
config_json: &str,
|
||||||
|
safetensors_paths: &[std::path::PathBuf],
|
||||||
|
leader_device: &candle_core::Device,
|
||||||
|
dtype: candle_core::DType,
|
||||||
|
) -> Result<std::sync::Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>> {
|
||||||
|
use candle_nn::var_builder::ShardedSafeTensors;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
// Wrap the comm in SendComm immediately so it stays Send across
|
||||||
|
// the await points in this method — bare Arc<Comm> would
|
||||||
|
// poison the async fn's Send bound (Comm's raw NCCL pointer is
|
||||||
|
// !Send). The wrapper's safety contract is satisfied by the
|
||||||
|
// pool's outer Mutex serialising callers + the spawn_blocking
|
||||||
|
// thread being the only place ops are issued.
|
||||||
|
let leader_comm =
|
||||||
|
nccl_state::SendComm(self.leader_nccl.comm().ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("leader NCCL not initialised; call init_nccl first")
|
||||||
|
})?);
|
||||||
|
let world_size = self.world_size;
|
||||||
|
let safetensors_str: Vec<String> = safetensors_paths
|
||||||
|
.iter()
|
||||||
|
.map(|p| p.to_string_lossy().into_owned())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// 1. Fan out the LoadDenseShard request to every worker without
|
||||||
|
// awaiting their replies — they'll build their shards in
|
||||||
|
// parallel with the leader below.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::LoadDenseShard {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
config_json: config_json.to_string(),
|
||||||
|
safetensors_paths: safetensors_str.clone(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Build rank 0's shard on the leader. ShardedVarBuilder reads
|
||||||
|
// only the rank's slice from safetensors — no full-tensor
|
||||||
|
// materialisation. Runs in spawn_blocking because the
|
||||||
|
// file-mmap + slice + copy-to-device work is synchronous.
|
||||||
|
let cfg: super::tp::tp_qwen3::Config =
|
||||||
|
serde_json::from_str(config_json).context("parse Qwen3 Config JSON for leader load")?;
|
||||||
|
let paths_for_leader: Vec<std::path::PathBuf> = safetensors_paths.to_vec();
|
||||||
|
let device_for_leader = leader_device.clone();
|
||||||
|
let comm_for_leader = leader_comm;
|
||||||
|
let model_id_for_log = model_id.to_string();
|
||||||
|
let leader_model = tokio::task::spawn_blocking(
|
||||||
|
move || -> Result<super::tp::tp_qwen3::TpQwen3ForCausalLM> {
|
||||||
|
// SAFETY: same invariant as the single-GPU dense path —
|
||||||
|
// the HF cache files are treated as immutable while the
|
||||||
|
// mmap is held.
|
||||||
|
let vb = unsafe {
|
||||||
|
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
|
||||||
|
.context("build ShardedVarBuilder over safetensors")?
|
||||||
|
};
|
||||||
|
let model = super::tp::tp_qwen3::TpQwen3ForCausalLM::load(
|
||||||
|
&cfg,
|
||||||
|
&vb,
|
||||||
|
0,
|
||||||
|
world_size,
|
||||||
|
comm_for_leader.into_inner(),
|
||||||
|
)?;
|
||||||
|
tracing::info!(rank = 0, model = %model_id_for_log, "loaded TP shard (leader)");
|
||||||
|
Ok(model)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.context("leader load task panicked")??;
|
||||||
|
|
||||||
|
// 3. Collect worker confirmations. Anything other than
|
||||||
|
// LoadDenseShardOk aborts the whole load — the leader's
|
||||||
|
// already-loaded shard drops when this fn returns Err.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::LoadDenseShardOk => {}
|
||||||
|
WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} LoadDenseShard [{kind}]: {message}", w.rank)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"worker rank {} LoadDenseShard: expected LoadDenseShardOk, got {other:?}",
|
||||||
|
w.rank
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Arc::new(Mutex::new(leader_model)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run one forward step across every rank. The leader's forward
|
||||||
|
/// returns the last-position logits as a candle Tensor on the
|
||||||
|
/// leader's device; the caller does sampling out-of-band. Workers
|
||||||
|
/// run their own forwards (the AllReduce inside row-parallel layers
|
||||||
|
/// is what lets the leader's collective complete) and reply with
|
||||||
|
/// `GenerateStepOk` — they do not ship logits over the wire.
|
||||||
|
///
|
||||||
|
/// `tokens` is the input for this step (prompt for prefill, the
|
||||||
|
/// previously-sampled token for decode). `offset` is the KV-cache
|
||||||
|
/// position before this step.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn generate_step(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
leader_model: std::sync::Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<candle_core::Tensor> {
|
||||||
|
// 1. Fan-out to workers.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::GenerateStep {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
tokens: tokens.clone(),
|
||||||
|
offset,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Leader's forward in spawn_blocking. The AllReduce CustomOps
|
||||||
|
// inside the row-parallel layers block until every worker's
|
||||||
|
// forward issues the matching collective.
|
||||||
|
let logits = tokio::task::spawn_blocking(move || -> Result<candle_core::Tensor> {
|
||||||
|
let mut model = leader_model.blocking_lock();
|
||||||
|
let device = model.device().clone();
|
||||||
|
let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
|
// TpQwen3ForCausalLM::forward returns [B, 1, V] (it slices
|
||||||
|
// to the last position internally). Squeeze both leading
|
||||||
|
// dims to get the rank-1 vocab logits LogitsProcessor wants.
|
||||||
|
let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?;
|
||||||
|
Ok(logits)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.context("leader forward task panicked")??;
|
||||||
|
|
||||||
|
// 3. Collect worker confirmations.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::GenerateStepOk => {}
|
||||||
|
WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} GenerateStep [{kind}]: {message}", w.rank)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"worker rank {} GenerateStep: expected GenerateStepOk, got {other:?}",
|
||||||
|
w.rank
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(logits)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset the KV cache for `model_id` on every rank. Called at the
|
||||||
|
/// start of every inference so a fresh request doesn't attend over
|
||||||
|
/// the previous one's tokens.
|
||||||
|
pub async fn clear_kv_cache(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<
|
||||||
|
tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>,
|
||||||
|
>,
|
||||||
|
) -> Result<()> {
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::ClearKvCache {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
let mut m = leader_model.lock().await;
|
||||||
|
m.clear_kv_cache();
|
||||||
|
}
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::KvCacheCleared => {}
|
||||||
|
WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} ClearKvCache [{kind}]: {message}", w.rank)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"worker rank {} ClearKvCache: expected KvCacheCleared, got {other:?}",
|
||||||
|
w.rank
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drop this model's shards on every rank. The leader's shard is
|
||||||
|
/// expected to have been dropped by the caller (its `Arc` was held
|
||||||
|
/// in the TpLoadedModel and goes away when that's removed).
|
||||||
|
pub async fn unload_model(&mut self, model_id: &str) -> Result<()> {
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::UnloadModel {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::Unloaded => {}
|
||||||
|
WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} UnloadModel [{kind}]: {message}", w.rank)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"worker rank {} UnloadModel: expected Unloaded, got {other:?}",
|
||||||
|
w.rank
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
|
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
|
||||||
/// children. Best-effort — individual worker failures are logged
|
/// children. Best-effort — individual worker failures are logged
|
||||||
/// but don't abort the rest of the sweep.
|
/// but don't abort the rest of the sweep.
|
||||||
|
|||||||
@@ -83,7 +83,13 @@ mod cuda_impl {
|
|||||||
const NCCL_ID_BYTES: usize = 128;
|
const NCCL_ID_BYTES: usize = 128;
|
||||||
|
|
||||||
pub struct NcclState {
|
pub struct NcclState {
|
||||||
comm: Option<Comm>,
|
/// Wrapped in `Arc` so we can hand a clone to `TpQwen3ForCausalLM`
|
||||||
|
/// at load time (every row-parallel layer needs a reference to
|
||||||
|
/// run its trailing `AllReduce`). The `Arc` is the single source
|
||||||
|
/// of truth for the comm's lifetime — when the pool drops and
|
||||||
|
/// every layer that captured a clone drops, NCCL releases the
|
||||||
|
/// underlying `ncclComm_t`.
|
||||||
|
comm: Option<Arc<Comm>>,
|
||||||
/// Held alongside the Comm so the device isn't dropped
|
/// Held alongside the Comm so the device isn't dropped
|
||||||
/// underneath the NCCL handle.
|
/// underneath the NCCL handle.
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
@@ -103,6 +109,40 @@ mod cuda_impl {
|
|||||||
ctx: None,
|
ctx: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Clone the comm out as an `Arc` so callers (the leader-side
|
||||||
|
/// `TpQwen3ForCausalLM::load`, or the worker's own model load)
|
||||||
|
/// can hold a reference for the lifetime of the model. Returns
|
||||||
|
/// `None` before `init` has run.
|
||||||
|
pub fn comm(&self) -> Option<Arc<Comm>> {
|
||||||
|
self.comm.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `Arc<Comm>` doesn't impl `Send` because `Comm` wraps a raw
|
||||||
|
/// `ncclComm_t` pointer. The NCCL contract is "operations against a
|
||||||
|
/// given comm must be serialised", not "the handle must stay on the
|
||||||
|
/// thread that created it" — so it's safe to move an `Arc<Comm>`
|
||||||
|
/// across threads as long as no concurrent ops are issued. The
|
||||||
|
/// pool's outer Mutex serialises us into `spawn_blocking`, so this
|
||||||
|
/// wrapper at the move boundary is the only thing missing.
|
||||||
|
///
|
||||||
|
/// `Sync` is also marked safe because the `Arc<Comm>` clones held
|
||||||
|
/// by the row-parallel layers are only used from the
|
||||||
|
/// `spawn_blocking` thread driving the forward pass; concurrent
|
||||||
|
/// access from another thread would still be a bug.
|
||||||
|
pub struct SendComm(pub Arc<Comm>);
|
||||||
|
|
||||||
|
// SAFETY: see the doc-comment above; the invariant is enforced at
|
||||||
|
// the call site (pool Mutex + single spawn_blocking thread), not at
|
||||||
|
// the type level.
|
||||||
|
unsafe impl Send for SendComm {}
|
||||||
|
unsafe impl Sync for SendComm {}
|
||||||
|
|
||||||
|
impl SendComm {
|
||||||
|
pub fn into_inner(self) -> Arc<Comm> {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer
|
// SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer
|
||||||
@@ -143,7 +183,7 @@ mod cuda_impl {
|
|||||||
message: "sanity_check requires Init to have completed first".into(),
|
message: "sanity_check requires Init to have completed first".into(),
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
match try_sanity_check(comm) {
|
match try_sanity_check(comm.as_ref()) {
|
||||||
Ok(sum) => WorkerResponse::NcclSanityResult { observed_sum: sum },
|
Ok(sum) => WorkerResponse::NcclSanityResult { observed_sum: sum },
|
||||||
Err(msg) => WorkerResponse::Error {
|
Err(msg) => WorkerResponse::Error {
|
||||||
kind: "nccl_sanity_failed".into(),
|
kind: "nccl_sanity_failed".into(),
|
||||||
@@ -177,7 +217,17 @@ mod cuda_impl {
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
state.ctx = Some(ctx);
|
state.ctx = Some(ctx);
|
||||||
state.comm = Some(comm);
|
// `Comm` is !Send + !Sync at the type level because it wraps a
|
||||||
|
// raw `ncclComm_t`. The `Arc` is fine in practice — we
|
||||||
|
// serialise operations through the pool's outer Mutex and the
|
||||||
|
// SendComm wrapper at thread-crossing boundaries enforces this
|
||||||
|
// at every move site. clippy's `arc_with_non_send_sync` lint
|
||||||
|
// can't see that invariant; allow once at the canonical
|
||||||
|
// construction site.
|
||||||
|
#[allow(clippy::arc_with_non_send_sync)]
|
||||||
|
{
|
||||||
|
state.comm = Some(Arc::new(comm));
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -202,7 +252,7 @@ mod cuda_impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
pub use cuda_impl::{NcclState, generate_comm_id_hex};
|
pub use cuda_impl::{NcclState, SendComm, generate_comm_id_hex};
|
||||||
|
|
||||||
/// Non-cuda stub for the leader: returns a clear marker error rather
|
/// Non-cuda stub for the leader: returns a clear marker error rather
|
||||||
/// than letting `init_nccl` succeed vacuously.
|
/// than letting `init_nccl` succeed vacuously.
|
||||||
|
|||||||
@@ -45,6 +45,52 @@ pub enum WorkerRequest {
|
|||||||
/// the NCCL handshake is genuinely live, not just configured.
|
/// the NCCL handshake is genuinely live, not just configured.
|
||||||
NcclSanityCheck,
|
NcclSanityCheck,
|
||||||
|
|
||||||
|
/// Load this rank's shard of a dense Qwen3 model from mmaped
|
||||||
|
/// safetensors. The same `safetensors_paths` list is sent to every
|
||||||
|
/// rank — the ShardedVarBuilder reads only the rank-local slice of
|
||||||
|
/// each tensor at materialisation time, so the worker's VRAM
|
||||||
|
/// footprint is `1 / world_size` of the full model (plus replicated
|
||||||
|
/// embedding/norm/lm_head).
|
||||||
|
LoadDenseShard {
|
||||||
|
/// Caller-supplied id for later `GenerateStep` / `UnloadModel`
|
||||||
|
/// lookups. Typically the HF model id verbatim.
|
||||||
|
model_id: String,
|
||||||
|
/// JSON-serialised `candle_transformers::models::qwen3::Config`
|
||||||
|
/// — the same blob the leader parsed from the HF cache's
|
||||||
|
/// `config.json`. Threaded through verbatim so the worker uses
|
||||||
|
/// identical hyperparameters.
|
||||||
|
config_json: String,
|
||||||
|
/// Absolute paths the worker should mmap. The same set on every
|
||||||
|
/// rank; ShardedVarBuilder slices into them per rank.
|
||||||
|
safetensors_paths: Vec<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Run one forward step on this rank's loaded model. The worker
|
||||||
|
/// reaches into its NCCL Comm for the row-parallel `AllReduce`s
|
||||||
|
/// inside the model — and so blocks on every other rank issuing the
|
||||||
|
/// same op. The leader does *not* receive logits back over RPC; it
|
||||||
|
/// runs its own rank-0 forward in parallel and uses its own logits
|
||||||
|
/// for sampling.
|
||||||
|
GenerateStep {
|
||||||
|
model_id: String,
|
||||||
|
/// Input token ids for this step. For prefill, the whole prompt;
|
||||||
|
/// for decode, a single token. Identical on every rank.
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
/// KV cache offset (count of tokens already in the cache before
|
||||||
|
/// this step).
|
||||||
|
offset: usize,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Reset the KV cache for this model on this rank. Sent at the
|
||||||
|
/// start of every inference so a fresh request doesn't accidentally
|
||||||
|
/// attend over the previous one's tokens.
|
||||||
|
ClearKvCache { model_id: String },
|
||||||
|
|
||||||
|
/// Drop this rank's shard for the given model. Releases the VRAM
|
||||||
|
/// the shard's weights occupied; subsequent `GenerateStep` calls
|
||||||
|
/// against the same `model_id` return an `Error`.
|
||||||
|
UnloadModel { model_id: String },
|
||||||
|
|
||||||
/// Worker should release resources and exit. Worker replies `Bye`
|
/// Worker should release resources and exit. Worker replies `Bye`
|
||||||
/// and then closes stdout / exits zero. The leader reaps the
|
/// and then closes stdout / exits zero. The leader reaps the
|
||||||
/// child via the `tokio::process::Child` it kept.
|
/// child via the `tokio::process::Child` it kept.
|
||||||
@@ -74,6 +120,24 @@ pub enum WorkerResponse {
|
|||||||
/// this matches `world_size`.
|
/// this matches `world_size`.
|
||||||
NcclSanityResult { observed_sum: u32 },
|
NcclSanityResult { observed_sum: u32 },
|
||||||
|
|
||||||
|
/// Reply to `LoadDenseShard`. Empty payload — success is the
|
||||||
|
/// absence of `Error`. By the time this comes back, the rank's
|
||||||
|
/// `TpQwen3ForCausalLM` is constructed in memory and ready for
|
||||||
|
/// `GenerateStep`.
|
||||||
|
LoadDenseShardOk,
|
||||||
|
|
||||||
|
/// Reply to `GenerateStep`. Empty payload — workers don't ship
|
||||||
|
/// logits over the wire. The leader uses its own rank-0 logits;
|
||||||
|
/// workers only need to confirm the collective completed.
|
||||||
|
GenerateStepOk,
|
||||||
|
|
||||||
|
/// Reply to `ClearKvCache`. Empty payload.
|
||||||
|
KvCacheCleared,
|
||||||
|
|
||||||
|
/// Reply to `UnloadModel`. Empty payload. The named model is no
|
||||||
|
/// longer present on this rank.
|
||||||
|
Unloaded,
|
||||||
|
|
||||||
/// Reply to `Shutdown`. Worker exits immediately after writing this.
|
/// Reply to `Shutdown`. Worker exits immediately after writing this.
|
||||||
Bye,
|
Bye,
|
||||||
|
|
||||||
|
|||||||
@@ -5,18 +5,23 @@
|
|||||||
//! exactly one `WorkerResponse` JSON line to stdout. tracing goes to
|
//! exactly one `WorkerResponse` JSON line to stdout. tracing goes to
|
||||||
//! stderr so it doesn't collide with the RPC stream.
|
//! stderr so it doesn't collide with the RPC stream.
|
||||||
//!
|
//!
|
||||||
//! NCCL operations (`Init`, `NcclSanityCheck`) are real when built
|
//! NCCL operations (`Init`, `NcclSanityCheck`) and model lifecycle ops
|
||||||
//! with the `cuda` feature; without it they reply with
|
//! (`LoadDenseShard`, `GenerateStep`, `ClearKvCache`, `UnloadModel`)
|
||||||
//! `Error{kind="cuda_feature_not_enabled"}` so the leader can tell
|
//! are real when built with the `cuda` feature; without it they reply
|
||||||
//! the difference between a misconfigured build and a genuine NCCL
|
//! with `Error{kind="cuda_feature_not_enabled"}` so the leader can tell
|
||||||
//! failure.
|
//! the difference between a misconfigured build and a genuine NCCL or
|
||||||
|
//! model failure.
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use std::collections::HashMap;
|
||||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
|
|
||||||
use super::nccl_state::NcclState;
|
use super::nccl_state::NcclState;
|
||||||
use super::rpc::{WorkerRequest, WorkerResponse};
|
use super::rpc::{WorkerRequest, WorkerResponse};
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use super::tp_qwen3::TpQwen3ForCausalLM;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct WorkerConfig {
|
pub struct WorkerConfig {
|
||||||
pub rank: u32,
|
pub rank: u32,
|
||||||
@@ -74,9 +79,22 @@ async fn write_response(stdout: &mut tokio::io::Stdout, resp: &WorkerResponse) -
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// One rank's local state. Owns the rank's NCCL communicator (via
|
||||||
|
/// `NcclState`) and the rank's shard of every loaded model.
|
||||||
struct WorkerState {
|
struct WorkerState {
|
||||||
config: WorkerConfig,
|
config: WorkerConfig,
|
||||||
nccl: NcclState,
|
nccl: NcclState,
|
||||||
|
/// Loaded model shards keyed by `model_id`. Each entry holds this
|
||||||
|
/// rank's `TpQwen3ForCausalLM` — the column/row-parallel layers
|
||||||
|
/// hold an `Arc<Comm>` cloned from `nccl`. Cuda-only: there is no
|
||||||
|
/// TpQwen3ForCausalLM type without the cuda feature in scope.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
models: HashMap<String, TpQwen3ForCausalLM>,
|
||||||
|
/// Placeholder so the non-cuda build keeps the same field name set
|
||||||
|
/// and `WorkerState::new` reads the same on both.
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
models: HashMap<String, ()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WorkerState {
|
impl WorkerState {
|
||||||
@@ -84,6 +102,7 @@ impl WorkerState {
|
|||||||
Self {
|
Self {
|
||||||
config,
|
config,
|
||||||
nccl: NcclState::new(),
|
nccl: NcclState::new(),
|
||||||
|
models: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,7 +115,203 @@ impl WorkerState {
|
|||||||
},
|
},
|
||||||
WorkerRequest::Init { comm_id } => self.nccl.init(self.config, &comm_id),
|
WorkerRequest::Init { comm_id } => self.nccl.init(self.config, &comm_id),
|
||||||
WorkerRequest::NcclSanityCheck => self.nccl.sanity_check(),
|
WorkerRequest::NcclSanityCheck => self.nccl.sanity_check(),
|
||||||
|
WorkerRequest::LoadDenseShard {
|
||||||
|
model_id,
|
||||||
|
config_json,
|
||||||
|
safetensors_paths,
|
||||||
|
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths),
|
||||||
|
WorkerRequest::GenerateStep {
|
||||||
|
model_id,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
} => self.handle_generate_step(&model_id, tokens, offset),
|
||||||
|
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
|
||||||
|
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
|
||||||
WorkerRequest::Shutdown => WorkerResponse::Bye,
|
WorkerRequest::Shutdown => WorkerResponse::Bye,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_load_dense_shard(
|
||||||
|
&mut self,
|
||||||
|
model_id: String,
|
||||||
|
config_json: String,
|
||||||
|
safetensors_paths: Vec<String>,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
use candle_core::{DType, Device};
|
||||||
|
use candle_nn::var_builder::ShardedSafeTensors;
|
||||||
|
use candle_transformers::models::qwen3 as qwen3_dense;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
if self.models.contains_key(&model_id) {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "already_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' already loaded on this rank"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
let comm = match self.nccl.comm() {
|
||||||
|
Some(c) => c,
|
||||||
|
None => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "nccl_not_initialised".into(),
|
||||||
|
message: "LoadDenseShard requires Init to have completed first".into(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let cfg: qwen3_dense::Config = match serde_json::from_str(&config_json) {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("parse Qwen3 Config JSON: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let device = match Device::new_cuda(self.config.cuda_device as usize) {
|
||||||
|
Ok(d) => d,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "cuda_unavailable".into(),
|
||||||
|
message: format!("Device::new_cuda({}) failed: {e}", self.config.cuda_device),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let paths: Vec<PathBuf> = safetensors_paths.into_iter().map(PathBuf::from).collect();
|
||||||
|
// SAFETY: same invariant as the single-GPU dense path — the HF
|
||||||
|
// cache files are treated as immutable while the mmap is held.
|
||||||
|
let vb = match unsafe { ShardedSafeTensors::var_builder(&paths, DType::BF16, &device) } {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "load_failed".into(),
|
||||||
|
message: format!("ShardedSafeTensors::var_builder: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let model = match TpQwen3ForCausalLM::load(
|
||||||
|
&cfg,
|
||||||
|
&vb,
|
||||||
|
self.config.rank,
|
||||||
|
self.config.world_size,
|
||||||
|
comm,
|
||||||
|
) {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "load_failed".into(),
|
||||||
|
message: format!("TpQwen3ForCausalLM::load: {e:#}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
self.models.insert(model_id.clone(), model);
|
||||||
|
tracing::info!(rank = self.config.rank, model = %model_id, "loaded TP shard");
|
||||||
|
WorkerResponse::LoadDenseShardOk
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_load_dense_shard(
|
||||||
|
&mut self,
|
||||||
|
_model_id: String,
|
||||||
|
_config_json: String,
|
||||||
|
_safetensors_paths: Vec<String>,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "LoadDenseShard requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_generate_step(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
use candle_core::Tensor;
|
||||||
|
|
||||||
|
let Some(model) = self.models.get_mut(model_id) else {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "model_not_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
let device = model.device().clone();
|
||||||
|
let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "forward_failed".into(),
|
||||||
|
message: format!("build input tensor: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Drop the resulting logits — the leader uses its own copy from
|
||||||
|
// rank 0. The forward's value here is the NCCL collectives it
|
||||||
|
// issues, which let the leader's rank-0 forward make progress.
|
||||||
|
if let Err(e) = model.forward(&input, offset) {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "forward_failed".into(),
|
||||||
|
message: format!("TpQwen3ForCausalLM::forward: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
WorkerResponse::GenerateStepOk
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_generate_step(
|
||||||
|
&mut self,
|
||||||
|
_model_id: &str,
|
||||||
|
_tokens: Vec<u32>,
|
||||||
|
_offset: usize,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "GenerateStep requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse {
|
||||||
|
let Some(model) = self.models.get_mut(model_id) else {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "model_not_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
model.clear_kv_cache();
|
||||||
|
WorkerResponse::KvCacheCleared
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_clear_kv_cache(&mut self, _model_id: &str) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "ClearKvCache requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_unload_model(&mut self, model_id: &str) -> WorkerResponse {
|
||||||
|
if self.models.remove(model_id).is_none() {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "model_not_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
tracing::info!(rank = self.config.rank, model = %model_id, "unloaded TP shard");
|
||||||
|
WorkerResponse::Unloaded
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_unload_model(&mut self, _model_id: &str) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "UnloadModel requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,14 +9,15 @@
|
|||||||
# after pushing new neuron builds.
|
# after pushing new neuron builds.
|
||||||
#
|
#
|
||||||
# Usage:
|
# Usage:
|
||||||
# script/validate-neuron.sh [host] [model_id] [quant]
|
# script/validate-neuron.sh [host] [model_id] [quant] [tp_size]
|
||||||
#
|
#
|
||||||
# Defaults:
|
# Defaults:
|
||||||
# host = beast.hanzalova.internal
|
# host = beast.hanzalova.internal
|
||||||
# model_id = unsloth/Qwen3-0.6B-GGUF (official Qwen3-*-GGUF repos
|
# model_id = unsloth/Qwen3-0.6B-GGUF (official Qwen3-*-GGUF repos
|
||||||
# ship Q8_0 only; unsloth's mirror ships the full Q-spectrum
|
# ship Q8_0 only; unsloth's mirror ships the full Q-spectrum
|
||||||
# including Q4_K_M)
|
# including Q4_K_M)
|
||||||
# quant = Q4_K_M
|
# quant = Q4_K_M (empty = dense safetensors path)
|
||||||
|
# tp_size = unset (= 1 = single-GPU; pass 2 to drive the TP path)
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
@@ -25,6 +26,11 @@ MODEL_ID="${2:-unsloth/Qwen3-0.6B-GGUF}"
|
|||||||
# `${3-Q4_K_M}` (no colon) only uses the default when the arg is
|
# `${3-Q4_K_M}` (no colon) only uses the default when the arg is
|
||||||
# UNSET — passing an explicit empty string drives the dense path.
|
# UNSET — passing an explicit empty string drives the dense path.
|
||||||
QUANT="${3-Q4_K_M}"
|
QUANT="${3-Q4_K_M}"
|
||||||
|
# tp_size > 1 forces the dense path (TP requires safetensors) and adds
|
||||||
|
# `tensor_parallel: N` to the load payload. The harness picks device
|
||||||
|
# indices 0..N-1 by default; override by passing NEURON_DEVICES="0,1,..."
|
||||||
|
# in the environment.
|
||||||
|
TP_SIZE="${4-1}"
|
||||||
PORT="${NEURON_PORT:-13131}"
|
PORT="${NEURON_PORT:-13131}"
|
||||||
BASE="http://${HOST}:${PORT}"
|
BASE="http://${HOST}:${PORT}"
|
||||||
|
|
||||||
@@ -69,21 +75,43 @@ is_loaded() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
trigger_load() {
|
trigger_load() {
|
||||||
say "POST /models/load ${MODEL_ID} (quant=${QUANT:-<dense>}, device=[0])"
|
# Build the per-rank CUDA device list as a JSON array. Either
|
||||||
|
# honour NEURON_DEVICES (`0,1,2`) verbatim or default to
|
||||||
|
# `[0, 1, ..., tp_size - 1]`.
|
||||||
|
local devices_json
|
||||||
|
if [[ -n "${NEURON_DEVICES:-}" ]]; then
|
||||||
|
devices_json=$(jq -n -c --arg s "${NEURON_DEVICES}" \
|
||||||
|
'$s | split(",") | map(tonumber)')
|
||||||
|
else
|
||||||
|
devices_json=$(jq -n -c --argjson n "${TP_SIZE}" '[range(0; $n)]')
|
||||||
|
fi
|
||||||
|
say "POST /models/load ${MODEL_ID} (quant=${QUANT:-<dense>}, tp=${TP_SIZE}, devices=${devices_json})"
|
||||||
say " (synchronous; may take a minute on first run while HF downloads)"
|
say " (synchronous; may take a minute on first run while HF downloads)"
|
||||||
# Build the payload via jq so the optional `quant` field is
|
if (( TP_SIZE > 1 )) && [[ -n "${QUANT}" ]]; then
|
||||||
# omitted entirely when empty — that's the signal to the harness
|
die "tp_size>1 requires dense safetensors — pass quant='' as the 3rd argument"
|
||||||
# to take the dense safetensors load path rather than GGUF.
|
fi
|
||||||
|
# Build the payload via jq so the optional `quant` and
|
||||||
|
# `tensor_parallel` fields are omitted entirely when not in use —
|
||||||
|
# that's how the harness tells dense from quantized and single-GPU
|
||||||
|
# from TP.
|
||||||
local payload
|
local payload
|
||||||
if [[ -z "${QUANT}" ]]; then
|
if [[ -z "${QUANT}" ]] && (( TP_SIZE > 1 )); then
|
||||||
payload=$(jq -n -c \
|
payload=$(jq -n -c \
|
||||||
--arg id "${MODEL_ID}" \
|
--arg id "${MODEL_ID}" \
|
||||||
'{model_id: $id, harness: "candle", devices: [0]}')
|
--argjson tp "${TP_SIZE}" \
|
||||||
|
--argjson devices "${devices_json}" \
|
||||||
|
'{model_id: $id, harness: "candle", tensor_parallel: $tp, devices: $devices}')
|
||||||
|
elif [[ -z "${QUANT}" ]]; then
|
||||||
|
payload=$(jq -n -c \
|
||||||
|
--arg id "${MODEL_ID}" \
|
||||||
|
--argjson devices "${devices_json}" \
|
||||||
|
'{model_id: $id, harness: "candle", devices: $devices}')
|
||||||
else
|
else
|
||||||
payload=$(jq -n -c \
|
payload=$(jq -n -c \
|
||||||
--arg id "${MODEL_ID}" \
|
--arg id "${MODEL_ID}" \
|
||||||
--arg q "${QUANT}" \
|
--arg q "${QUANT}" \
|
||||||
'{model_id: $id, harness: "candle", quant: $q, devices: [0]}')
|
--argjson devices "${devices_json}" \
|
||||||
|
'{model_id: $id, harness: "candle", quant: $q, devices: $devices}')
|
||||||
fi
|
fi
|
||||||
# --write-out captures the response code on a separate line so we
|
# --write-out captures the response code on a separate line so we
|
||||||
# can surface a real diagnostic instead of relying on --fail.
|
# can surface a real diagnostic instead of relying on --fail.
|
||||||
|
|||||||
Reference in New Issue
Block a user