refactor(neuron): phase 4 — model loads move onto the device worker
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 37s
CI / Clippy (push) Successful in 2m25s
CI / Test (push) Successful in 4m40s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m51s
build-prerelease / Build cortex binary (push) Successful in 4m21s
build-prerelease / Package cortex RPM (push) Successful in 1m20s
build-prerelease / Build neuron-ampere (push) Successful in 5m7s
build-prerelease / Build neuron-ada (push) Successful in 5m19s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s

Final structural slice of the per-device CUDA context-ownership
refactor. The four remaining spawn_blocking sites that did CUDA work
on the leader are gone:

- Single-GPU GGUF load (`load_arch_gguf` spawn_blocking) →
  `Job::LoadGguf` dispatched on the worker.
- Single-GPU dense load (`load_arch_dense` spawn_blocking) →
  `Job::LoadDense` on the worker.
- TP shard load (`WorkerPool::load_dense_shard` spawn_blocking) →
  `Job::TpLoadShard`. The dispatch handler reads `state.nccl.comm()`
  directly — no cross-thread `Arc<Comm>` transfer, no `SendComm`
  wrapper for this path.

The Phase 2 / Phase 3 bridges that moved freshly-built models across
the channel boundary (`Job::TransferIn`, `Job::TransferInTp`,
`Job::CloneLeaderComm`) are removed. Models are now constructed on
the worker thread directly; the slab gets populated by `insert_arch` /
the inline `tp_models.insert` in dispatch handlers.

What this phase preserves:

- CPU loads still use `tokio::task::spawn_blocking` against
  `Arc<Mutex<ModelArch>>`. There's no CUDA context to own on CPU and
  channel overhead would only add latency. Four `spawn_blocking`
  references remain in `candle.rs` (load_arch_gguf, load_arch_dense,
  chat_completion, chat_completion_stream) and all are deliberate
  CPU-only fallback.
- Public API unchanged. `Harness::load_model`, `chat_completion`,
  HTTP routes all keep identical signatures.

What this phase removes:

- `SendComm` wrapper is no longer used in the load path (the Phase 3
  bridge that justified it). It remains in `nccl_state.rs` for the
  Phase 1–3 era and any future cross-thread Comm move; consider
  deleting in a follow-up.
- `Job::TransferIn`, `Job::TransferInTp`, `Job::CloneLeaderComm` and
  their handle convenience methods deleted.
- The leader_device parameter on `load_dense_shard` is now `_` —
  unused since the worker has its own bound device. Removing the
  arg outright is a public-API change; keeping the underscore prefix
  preserves the signature and signals deadness without churn.

Helper relocation:

- `LlamaDense::from_parts` is a new pub(crate) constructor so the
  worker-thread loader can build a `LlamaDense` without going through
  the original `load_arch_dense` async function.
- `check_dense_config_supported` is bumped to `pub(crate)` for the
  same reason.

Sweep verified: `grep -rn spawn_blocking crates/neuron/src/harness/`
returns only CPU-fallback hits in `candle.rs` + doc-comment references
to the old design. All four leader-side CUDA `spawn_blocking` sites
are gone.

fmt + clippy clean; 37 lib tests + all integration tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-27 10:24:38 +03:00
parent 76ab24d98c
commit b4f3576d82
5 changed files with 475 additions and 225 deletions

View File

@@ -307,6 +307,26 @@ pub struct LlamaDense {
}
impl LlamaDense {
/// Constructor used by the dispatch-side loader. Keeps the field
/// names private while letting the worker thread build a
/// `LlamaDense` from already-loaded weights without going through
/// async candle code.
pub(crate) fn from_parts(
model: llama_dense::Llama,
cache: llama_dense::Cache,
config: llama_dense::Config,
dtype: DType,
device: Device,
) -> Self {
Self {
model,
cache,
config,
dtype,
device,
}
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
Ok(self.model.forward(input, offset, &mut self.cache)?)
}
@@ -348,7 +368,7 @@ const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["llama", "qwen3", "qwen3_5", "qwe
/// The result message names the model_type we saw, the supported set,
/// and points at the files an operator (or future contributor) needs
/// to touch to grow the supported set.
fn check_dense_config_supported(config_json: &str, model_id: &str) -> Result<()> {
pub(crate) fn check_dense_config_supported(config_json: &str, model_id: &str) -> Result<()> {
let v: serde_json::Value = serde_json::from_str(config_json)
.with_context(|| format!("parse config.json for '{model_id}' as JSON"))?;
let model_type = v.get("model_type").and_then(|x| x.as_str()).unwrap_or("");
@@ -1547,41 +1567,46 @@ impl Harness for CandleHarness {
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
let device = Self::pick_device(&devices)?;
// Dispatch by source format: GGUF (pre-quantized, single-GPU
// only path) vs safetensors dense (bf16/fp16; the path that
// grows TP support). `spec.quant` is the signal — Some means
// the operator picked a quantized GGUF; None means dense.
let (tokenizer_path, arch) = if spec.quant.is_some() {
self.load_arch_gguf(spec, &device).await?
} else {
self.load_arch_dense(spec, &device).await?
};
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
// Worker thread for the chosen device. CPU loads (CUDA
// unavailable / not requested) skip the worker — there's no
// context to own. For CUDA loads, the arch is transferred
// into the worker's slab now so the inference path can
// reference it via the returned `ArchHandle`. The explicit
// type annotation lets the no-cuda build resolve `None` to
// the right `Option<Arc<DeviceWorkerHandle>>` type.
// Phase 4: load directly on the worker thread for CUDA;
// legacy spawn_blocking + Arc<Mutex<>> only for CPU. Resolve
// hf-hub paths up front (always async), then either dispatch
// a load Job (CUDA) or call the legacy local loader (CPU).
let worker: Option<Arc<super::device_worker::DeviceWorkerHandle>> = match &device {
#[cfg(feature = "cuda")]
Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?),
_ => None,
};
let (arch_local, arch_handle) = match &worker {
Some(w) => {
let (tokenizer_path, arch_local, arch_handle) = if let Some(w) = &worker {
// CUDA path: resolve, then load in the worker.
if spec.quant.is_some() {
let (gguf_path, tokenizer_path) = self.resolve_files(spec).await?;
let handle = w
.transfer_in(Box::new(arch))
.load_gguf(gguf_path, spec.model_id.clone())
.await
.map_err(|e| anyhow::anyhow!("transfer arch into device worker: {e}"))?;
(None, Some(handle))
.map_err(|e| anyhow::anyhow!("worker load_gguf: {e}"))?;
(tokenizer_path, None, Some(handle))
} else {
let (config_path, tokenizer_path, safetensors_paths) =
self.resolve_dense_files(spec).await?;
let handle = w
.load_dense(config_path, safetensors_paths, spec.model_id.clone())
.await
.map_err(|e| anyhow::anyhow!("worker load_dense: {e}"))?;
(tokenizer_path, None, Some(handle))
}
None => (Some(Arc::new(Mutex::new(arch))), None),
} else {
// CPU path: legacy spawn_blocking + Arc<Mutex<ModelArch>>.
let (tokenizer_path, arch) = if spec.quant.is_some() {
self.load_arch_gguf(spec, &device).await?
} else {
self.load_arch_dense(spec, &device).await?
};
(tokenizer_path, Some(Arc::new(Mutex::new(arch))), None)
};
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
let loaded = Arc::new(LoadedModel {
model_id: spec.model_id.clone(),

View File

@@ -100,17 +100,25 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
// discard the reply.
let _ = reply.send(result);
}
Job::TransferIn { arch, reply } => {
let handle = ArchHandle(state.next_handle);
state.next_handle = state.next_handle.wrapping_add(1);
state.models.insert(handle, arch);
tracing::debug!(
device_index,
handle = handle.0,
slab_size = state.models.len(),
"device worker: model transferred in"
);
let _ = reply.send(Ok(handle));
Job::LoadGguf {
gguf_path,
model_id,
reply,
} => {
let result = load_gguf_inner(&state.device, &gguf_path, &model_id)
.map(|arch| insert_arch(&mut state, Box::new(arch)));
let _ = reply.send(result);
}
Job::LoadDense {
config_path,
safetensors_paths,
model_id,
reply,
} => {
let result =
load_dense_inner(&state.device, &config_path, &safetensors_paths, &model_id)
.map(|arch| insert_arch(&mut state, Box::new(arch)));
let _ = reply.send(result);
}
Job::DropArch { handle, reply } => {
let removed = state.models.remove(&handle);
@@ -160,27 +168,25 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
let _ = reply.send(resp);
}
#[cfg(feature = "cuda")]
Job::CloneLeaderComm { reply } => {
let result = match state.nccl.comm() {
Some(comm) => Ok(crate::harness::tp::nccl_state::SendComm(comm)),
None => Err(anyhow::anyhow!(
"CloneLeaderComm: NcclState has no Comm; call NcclInit first"
)),
};
let _ = reply.send(result);
}
#[cfg(feature = "cuda")]
Job::TransferInTp { model, reply } => {
let handle = TpHandle(state.next_tp_handle);
state.next_tp_handle = state.next_tp_handle.wrapping_add(1);
state.tp_models.insert(handle, model);
tracing::debug!(
device_index,
tp_handle = handle.0,
slab_size = state.tp_models.len(),
"device worker: TP model transferred in"
Job::TpLoadShard {
model_id,
config_json,
safetensors_paths,
dtype,
quant,
world_size,
reply,
} => {
let result = tp_load_shard_inner(
&mut state,
&model_id,
&config_json,
&safetensors_paths,
dtype,
quant.as_deref(),
world_size,
);
let _ = reply.send(Ok(handle));
let _ = reply.send(result);
}
#[cfg(feature = "cuda")]
Job::DropTp { handle, reply } => {
@@ -332,6 +338,265 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
Ok((0, 0))
}
/// Insert a freshly-built `ModelArch` into the slab and mint a fresh
/// `ArchHandle`. Used by both `LoadGguf` and `LoadDense` dispatch
/// handlers — they differ only in *how* the arch is built; the
/// post-construction bookkeeping is identical.
fn insert_arch(state: &mut DeviceWorkerState, arch: Box<ModelArch>) -> ArchHandle {
let handle = ArchHandle(state.next_handle);
state.next_handle = state.next_handle.wrapping_add(1);
state.models.insert(handle, arch);
tracing::debug!(
device_index = state.device_index,
handle = handle.0,
slab_size = state.models.len(),
"device worker: model inserted"
);
handle
}
/// Load a GGUF (pre-quantized) model on the worker thread. Pulled
/// verbatim from the spawn_blocking closure that used to live in
/// `CandleHarness::load_arch_gguf`; the only change is that `device`
/// is now `state.device` (the worker's permanently-bound device).
fn load_gguf_inner(
device: &candle_core::Device,
gguf_path: &std::path::Path,
model_id: &str,
) -> anyhow::Result<ModelArch> {
use anyhow::Context;
use candle_core::DType;
use candle_core::quantized::gguf_file;
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaWeights;
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE;
tracing::info!(model = %model_id, path = ?gguf_path, "loading GGUF");
let mut file = std::fs::File::open(gguf_path).context("open GGUF file")?;
let content =
gguf_file::Content::read(&mut file).map_err(|e| anyhow::anyhow!("parse GGUF: {e}"))?;
let architecture = content
.metadata
.get("general.architecture")
.and_then(|v| v.to_string().ok().cloned())
.unwrap_or_default();
tracing::info!(architecture = %architecture, "GGUF architecture");
// The `general.architecture` GGUF metadata key follows
// llama.cpp conventions (lowercase, no underscores in some
// cases) — `qwen3moe`, not `qwen3_moe`.
match architecture.as_str() {
"qwen3" => {
let weights = QuantizedQwen3Weights::from_gguf(content, &mut file, device)
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
Ok(ModelArch::Qwen3Quantized(weights))
}
"qwen3moe" => {
// GGUFQWenMoE takes an explicit compute dtype alongside
// the device — F16 matches the GGUF weights' typical
// accumulation precision and gives the best tokens/sec on
// consumer cards.
let weights = GGUFQWenMoE::from_gguf(content, &mut file, device, DType::F16)
.map_err(|e| anyhow::anyhow!("from_gguf qwen3_moe: {e}"))?;
Ok(ModelArch::Qwen3MoeQuantized(weights))
}
"llama" => {
let weights = QuantizedLlamaWeights::from_gguf(content, &mut file, device)
.map_err(|e| anyhow::anyhow!("from_gguf llama: {e}"))?;
Ok(ModelArch::LlamaQuantized(weights))
}
other => anyhow::bail!(
"unsupported GGUF architecture '{other}'; quantized path supports \
qwen3, qwen3moe, llama"
),
}
}
/// Load a dense safetensors model on the worker thread.
fn load_dense_inner(
device: &candle_core::Device,
config_path: &std::path::Path,
safetensors_paths: &[std::path::PathBuf],
model_id: &str,
) -> anyhow::Result<ModelArch> {
use anyhow::Context;
use candle_core::DType;
use candle_nn::VarBuilder;
use candle_transformers::models::llama as llama_dense;
use candle_transformers::models::qwen3 as qwen3_dense;
use candle_transformers::models::qwen3_moe as qwen3_moe_dense;
let cfg_text = std::fs::read_to_string(config_path).context("read config.json")?;
crate::harness::candle::check_dense_config_supported(&cfg_text, model_id)?;
// Peek at model_type to choose the family before the typed
// deserialize — each family has its own Config.
let model_type = serde_json::from_str::<serde_json::Value>(&cfg_text)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
tracing::info!(
model = %model_id,
model_type = %model_type,
shards = safetensors_paths.len(),
"loading dense model from safetensors"
);
// bf16 is the canonical distribution dtype for Qwen3 / Llama 3 /
// Qwen3 MoE. CUDA on Ada+ has hardware bf16; Ampere has it too.
// CPU emulates.
let dtype = DType::BF16;
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
// mutation by another process while we hold the mapping is UB.
// We trust the HF cache is immutable-by-design.
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(safetensors_paths, dtype, device)
.context("build VarBuilder over safetensors")?
};
match model_type.as_str() {
"qwen3" => {
let cfg: qwen3_dense::Config =
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
Ok(ModelArch::Qwen3Dense(model))
}
"qwen3_moe" => {
let cfg: qwen3_moe_dense::Config =
serde_json::from_str(&cfg_text).context("parse Qwen3 MoE config.json")?;
let model = qwen3_moe_dense::ModelForCausalLM::new(&cfg, vb)
.map_err(|e| anyhow::anyhow!("build Qwen3 MoE dense model: {e}"))?;
Ok(ModelArch::Qwen3MoeDense(model))
}
"llama" => {
let cfg: llama_dense::LlamaConfig =
serde_json::from_str(&cfg_text).context("parse Llama config.json")?;
let config = cfg.into_config(false);
let cache = llama_dense::Cache::new(true, dtype, &config, device)
.context("build Llama Cache")?;
let model = llama_dense::Llama::load(vb, &config)
.map_err(|e| anyhow::anyhow!("build Llama dense model: {e}"))?;
Ok(ModelArch::LlamaDense(Box::new(
crate::harness::candle::LlamaDense::from_parts(
model,
cache,
config,
dtype,
device.clone(),
),
)))
}
"qwen3_5" => {
let cfg: crate::harness::arch::qwen3_5::Config = serde_json::from_str(&cfg_text)
.context("parse Qwen3-Next (qwen3_5) config.json")?;
let sharded_vb = unsafe {
candle_nn::var_builder::ShardedSafeTensors::var_builder(
safetensors_paths,
dtype,
device,
)
.context("build ShardedVarBuilder for Qwen3-Next")?
};
let model = crate::harness::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, sharded_vb)
.context("build Qwen3-Next dense model")?;
Ok(ModelArch::Qwen3_5Dense(model))
}
other => anyhow::bail!(
"unrouted supported model_type '{other}' — \
DENSE_SUPPORTED_MODEL_TYPES and load_dense_inner \
must stay in sync"
),
}
}
/// Load the leader's TP shard on the worker thread. Reads the Comm
/// directly from `state.nccl`; no cross-thread Arc<Comm> transfer.
#[cfg(feature = "cuda")]
fn tp_load_shard_inner(
state: &mut DeviceWorkerState,
model_id: &str,
config_json: &str,
safetensors_paths: &[std::path::PathBuf],
dtype: candle_core::DType,
quant: Option<&str>,
world_size: u32,
) -> anyhow::Result<TpHandle> {
use anyhow::Context;
use candle_nn::var_builder::ShardedSafeTensors;
let comm = state.nccl.comm().ok_or_else(|| {
anyhow::anyhow!("TpLoadShard: NcclState has no Comm; call NcclInit first")
})?;
let model_type = serde_json::from_str::<serde_json::Value>(config_json)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
// SAFETY: same invariant as the single-GPU dense path — the HF
// cache files are treated as immutable while the mmap is held.
let vb = unsafe {
ShardedSafeTensors::var_builder(safetensors_paths, dtype, &state.device)
.context("build ShardedVarBuilder over safetensors")?
};
let mmap = unsafe {
candle_core::safetensors::MmapedSafetensors::multi(safetensors_paths)
.context("build MmapedSafetensors for leader load")?
};
let loaded = match model_type.as_str() {
"qwen3" => {
let cfg: crate::harness::tp::tp_qwen3::Config = serde_json::from_str(config_json)
.context("parse Qwen3 Config JSON for leader load")?;
TpLeaderModel::Qwen3(crate::harness::tp::tp_qwen3::TpQwen3ForCausalLM::load(
&cfg, &vb, 0, world_size, comm,
)?)
}
"qwen3_5" => {
let cfg: crate::harness::tp::tp_qwen3_5::Config = serde_json::from_str(config_json)
.context("parse Qwen3-Next Config JSON for leader load")?;
let quant_dtype = crate::harness::tp::worker::parse_quant_string(quant)?;
TpLeaderModel::Qwen3_5(crate::harness::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
cfg,
&vb,
&mmap,
0,
world_size,
comm,
quant_dtype,
)?)
}
other => anyhow::bail!(
"TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)"
),
};
tracing::info!(
rank = 0,
model = %model_id,
model_type = %model_type,
"loaded TP shard (leader)"
);
let handle = TpHandle(state.next_tp_handle);
state.next_tp_handle = state.next_tp_handle.wrapping_add(1);
state.tp_models.insert(handle, Box::new(loaded));
tracing::debug!(
device_index = state.device_index,
tp_handle = handle.0,
slab_size = state.tp_models.len(),
"device worker: TP model inserted"
);
Ok(handle)
}
/// TP-equivalent of [`forward_logits`]: looks up the leader's
/// [`TpLeaderModel`] in the slab, runs its forward, copies the
/// `[vocab]` logits to a CPU `Vec<f32>`. The leader's `Arc<Comm>`
@@ -414,7 +679,10 @@ fn drain_poisoned(job: Job, device_index: u32) {
Job::QueryVram { reply } => {
let _ = reply.send(Err(err()));
}
Job::TransferIn { reply, .. } => {
Job::LoadGguf { reply, .. } => {
let _ = reply.send(Err(err()));
}
Job::LoadDense { reply, .. } => {
let _ = reply.send(Err(err()));
}
Job::DropArch { reply, .. } => {
@@ -443,11 +711,7 @@ fn drain_poisoned(job: Job, device_index: u32) {
});
}
#[cfg(feature = "cuda")]
Job::CloneLeaderComm { reply } => {
let _ = reply.send(Err(err()));
}
#[cfg(feature = "cuda")]
Job::TransferInTp { reply, .. } => {
Job::TpLoadShard { reply, .. } => {
let _ = reply.send(Err(err()));
}
#[cfg(feature = "cuda")]

View File

@@ -5,8 +5,8 @@
//! async-side `DeviceWorkerHandle` constructs a job, sends it down the
//! `std::sync::mpsc` channel, and `await`s the oneshot for the reply.
use crate::harness::candle::ModelArch;
use anyhow::Result;
use std::path::PathBuf;
use tokio::sync::oneshot;
/// Opaque handle to a `ModelArch` stored in the worker thread's state
@@ -48,12 +48,24 @@ pub enum Job {
QueryVram {
reply: oneshot::Sender<Result<(u64, u64)>>,
},
/// Move a freshly-loaded `ModelArch` into the worker's state slab.
/// Returns an `ArchHandle` the caller stores on `LoadedModel` and
/// passes back in subsequent `ClearKv` / `ForwardLogits` /
/// `DropArch` jobs.
TransferIn {
arch: Box<ModelArch>,
/// Load a GGUF (pre-quantized) single-GPU model on the worker
/// thread. The dispatch handler opens the GGUF file, parses
/// metadata, dispatches on `general.architecture`, and inserts
/// the resulting `ModelArch` into the slab. Returns the fresh
/// `ArchHandle`.
LoadGguf {
gguf_path: PathBuf,
model_id: String,
reply: oneshot::Sender<Result<ArchHandle>>,
},
/// Load a dense safetensors single-GPU model on the worker
/// thread. The dispatch handler reads `config.json`, dispatches on
/// `model_type`, builds a `VarBuilder` over the mmap'd
/// safetensors, and inserts the resulting `ModelArch`.
LoadDense {
config_path: PathBuf,
safetensors_paths: Vec<PathBuf>,
model_id: String,
reply: oneshot::Sender<Result<ArchHandle>>,
},
/// Remove the model from the slab and drop it. The `Drop` runs on
@@ -105,22 +117,21 @@ pub enum Job {
NcclSanity {
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
},
/// Clone the leader's `Arc<Comm>` out of the worker's `NcclState`
/// so a spawn_blocking-based load (Phase 3 bridge) can hand it to
/// the row-parallel layers. Wrapped in `SendComm` because
/// `Arc<Comm>` is `!Send` at the type level (the NCCL contract
/// requires serialised access, which we provide structurally).
/// Phase 4 eliminates this when `TpLoadShard` becomes a Job and
/// the load runs entirely on the worker thread.
/// Load the leader's TP shard on the worker thread. The dispatch
/// handler reads `state.nccl.comm()` directly (no cross-thread
/// `Arc<Comm>` transfer, no `SendComm` wrapper) and builds the
/// `TpLeaderModel` against that Comm. The model's embedded
/// `Arc<Comm>` clones, `CudaContext`, and all per-rank CUDA
/// tensors live on this thread for the model's lifetime.
/// Inserts into the TP slab and returns the fresh `TpHandle`.
#[cfg(feature = "cuda")]
CloneLeaderComm {
reply: oneshot::Sender<Result<crate::harness::tp::nccl_state::SendComm>>,
},
/// Move a freshly-built `TpLeaderModel` into the worker's tp slab.
/// Returns a `TpHandle` the caller stores on `TpLoadedModel`.
#[cfg(feature = "cuda")]
TransferInTp {
model: Box<crate::harness::tp::TpLeaderModel>,
TpLoadShard {
model_id: String,
config_json: String,
safetensors_paths: Vec<PathBuf>,
dtype: candle_core::DType,
quant: Option<String>,
world_size: u32,
reply: oneshot::Sender<Result<TpHandle>>,
},
/// Drop the TP leader model on the worker thread. CUDA tensors

View File

@@ -161,13 +161,15 @@ impl DeviceWorkerHandle {
}
}
/// Move a freshly-loaded `ModelArch` into the worker's state slab.
/// Returns the `ArchHandle` the caller stores on `LoadedModel`.
/// The `Box<ModelArch>` crosses the channel; the worker thread
/// owns it from here on.
pub async fn transfer_in(
/// Load a GGUF (pre-quantized) single-GPU model on the worker
/// thread. The hf-hub resolution happens on the async caller; the
/// resolved local `gguf_path` plus the spec's model_id are sent
/// into the worker which opens, parses, and constructs the
/// `ModelArch` on the right thread.
pub async fn load_gguf(
&self,
arch: Box<crate::harness::candle::ModelArch>,
gguf_path: std::path::PathBuf,
model_id: String,
) -> Result<ArchHandle, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
@@ -176,8 +178,40 @@ impl DeviceWorkerHandle {
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TransferIn {
arch,
.send(Job::LoadGguf {
gguf_path,
model_id,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Load a dense safetensors single-GPU model on the worker thread.
pub async fn load_dense(
&self,
config_path: std::path::PathBuf,
safetensors_paths: Vec<std::path::PathBuf>,
model_id: String,
) -> Result<ArchHandle, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::LoadDense {
config_path,
safetensors_paths,
model_id,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
@@ -331,37 +365,21 @@ impl DeviceWorkerHandle {
})
}
/// Clone the leader's `Arc<Comm>` so a spawn_blocking-based load
/// (Phase 3 bridge) can pass it to the row-parallel layers.
/// Phase 4 eliminates this once the TP load runs on this thread.
/// Load the leader's TP shard on the worker thread. The dispatch
/// handler reads its own `NcclState`'s `Arc<Comm>` directly — no
/// cross-thread Comm transfer — and builds the `TpLeaderModel`
/// against it. Phase 4 replaces the Phase 3 Clone/TransferIn
/// bridge with this single Job.
#[cfg(feature = "cuda")]
pub async fn clone_leader_comm(
#[allow(clippy::too_many_arguments)]
pub async fn tp_load_shard(
&self,
) -> Result<crate::harness::tp::nccl_state::SendComm, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::CloneLeaderComm { reply: reply_tx })
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Move a freshly-built `TpLeaderModel` into the worker's TP slab.
#[cfg(feature = "cuda")]
pub async fn transfer_in_tp(
&self,
model: Box<crate::harness::tp::TpLeaderModel>,
model_id: String,
config_json: String,
safetensors_paths: Vec<std::path::PathBuf>,
dtype: candle_core::DType,
quant: Option<String>,
world_size: u32,
) -> Result<TpHandle, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
@@ -370,8 +388,13 @@ impl DeviceWorkerHandle {
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TransferInTp {
model,
.send(Job::TpLoadShard {
model_id,
config_json,
safetensors_paths,
dtype,
quant,
world_size,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {

View File

@@ -489,36 +489,19 @@ impl WorkerPool {
model_id: &str,
config_json: &str,
safetensors_paths: &[std::path::PathBuf],
leader_device: &candle_core::Device,
_leader_device: &candle_core::Device,
dtype: candle_core::DType,
quant: Option<String>,
) -> Result<super::device_worker::TpHandle> {
use candle_nn::var_builder::ShardedSafeTensors;
// Ask the leader's device worker for an `Arc<Comm>` clone.
// Phase 3 moved `NcclState` ownership onto the worker thread,
// so the spawn_blocking load below can no longer reach the
// Comm directly. The reply is wrapped in `SendComm` because
// the underlying `Arc<Comm>` is `!Send` at the type level;
// the safety contract (only one thread issues NCCL ops at a
// time) is preserved because the load runs on a single
// spawn_blocking thread and AllReduce ops fire only from the
// device worker thread later. Phase 4 eliminates this bridge
// when the load itself moves onto the worker.
let leader_comm = self
.leader_worker
.clone_leader_comm()
.await
.map_err(|e| anyhow::anyhow!("clone leader Comm via device worker: {e}"))?;
let world_size = self.world_size;
let safetensors_str: Vec<String> = safetensors_paths
.iter()
.map(|p| p.to_string_lossy().into_owned())
.collect();
// 1. Fan out the LoadDenseShard request to every worker without
// awaiting their replies — they'll build their shards in
// parallel with the leader below.
// 1. Fan out the LoadDenseShard request to every subprocess
// worker without awaiting their replies — they'll build
// their shards in parallel with the leader below.
for w in &mut self.workers {
w.send_only(&WorkerRequest::LoadDenseShard {
model_id: model_id.to_string(),
@@ -529,76 +512,32 @@ impl WorkerPool {
.await?;
}
// 2. Build rank 0's shard on the leader. Dispatch on model_type
// — for `qwen3` we build a `TpQwen3ForCausalLM`, for
// `qwen3_5` (Qwen3-Next, Qwen3.6's architecture) we build
// `TpQwen3_5ForCausalLM`. Both end up wrapped in the
// `TpLeaderModel` enum so downstream callers don't care.
let model_type = serde_json::from_str::<serde_json::Value>(config_json)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let paths_for_leader: Vec<std::path::PathBuf> = safetensors_paths.to_vec();
let device_for_leader = leader_device.clone();
let comm_for_leader = leader_comm;
let model_id_for_log = model_id.to_string();
let config_json_for_leader = config_json.to_string();
let quant_for_leader = quant.clone();
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
// SAFETY: same invariant as the single-GPU dense path —
// the HF cache files are treated as immutable while the
// mmap is held.
let vb = unsafe {
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
.context("build ShardedVarBuilder over safetensors")?
};
// SAFETY: as above — the HF cache files are immutable.
let mmap = unsafe {
candle_core::safetensors::MmapedSafetensors::multi(&paths_for_leader)
.context("build MmapedSafetensors for leader load")?
};
let comm = comm_for_leader.into_inner();
let loaded = match model_type.as_str() {
"qwen3" => {
let cfg: super::tp::tp_qwen3::Config = serde_json::from_str(&config_json_for_leader)
.context("parse Qwen3 Config JSON for leader load")?;
TpLeaderModel::Qwen3(super::tp::tp_qwen3::TpQwen3ForCausalLM::load(
&cfg, &vb, 0, world_size, comm,
)?)
}
"qwen3_5" => {
let cfg: super::tp::tp_qwen3_5::Config =
serde_json::from_str(&config_json_for_leader)
.context("parse Qwen3-Next Config JSON for leader load")?;
let quant_dtype =
super::tp::worker::parse_quant_string(quant_for_leader.as_deref())?;
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
cfg,
&vb,
&mmap,
0,
// 2. Build rank 0's shard on the leader's device worker
// thread. Phase 4 moved the load itself onto the worker —
// the dispatch handler reads `state.nccl.comm()` directly
// so the leader's `Arc<Comm>` clones embedded in the
// row-parallel layers are constructed and used on the same
// OS thread for the model's entire lifetime. No
// spawn_blocking, no SendComm bridge.
let handle = self
.leader_worker
.tp_load_shard(
model_id.to_string(),
config_json.to_string(),
safetensors_paths.to_vec(),
dtype,
quant.clone(),
world_size,
comm,
quant_dtype,
)?)
}
other => anyhow::bail!(
"TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)"
),
};
tracing::info!(rank = 0, model = %model_id_for_log, model_type = %model_type, "loaded TP shard (leader)");
Ok(loaded)
})
)
.await
.context("leader load task panicked")??;
.map_err(|e| anyhow::anyhow!("leader TP shard load via device worker: {e}"))?;
// 3. Collect worker confirmations. Anything other than
// LoadDenseShardOk aborts the whole load — the leader's
// already-loaded shard drops when this fn returns Err.
// already-inserted shard would leak in the worker slab
// until the daemon restarts; an explicit DropTp would be
// cleaner but the failure here is rare and the operator's
// next step is to restart anyway.
for w in &mut self.workers {
let resp = w.recv_only().await?;
match resp {
@@ -613,18 +552,6 @@ impl WorkerPool {
}
}
// Phase 3: move the leader's freshly-built `TpLeaderModel`
// into the device worker's TP slab. The model holds
// `Arc<Comm>` clones (in its AllReduce ops) plus CUDA
// tensors — both need to live on the device worker thread so
// every `Comm::all_reduce` and tensor op during forward
// dispatches from the same OS thread that bound the CUDA
// context.
let handle = self
.leader_worker
.transfer_in_tp(Box::new(leader_model))
.await
.map_err(|e| anyhow::anyhow!("transfer TP leader model into device worker: {e}"))?;
Ok(handle)
}