refactor(neuron): phase 4 — model loads move onto the device worker
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 37s
CI / Clippy (push) Successful in 2m25s
CI / Test (push) Successful in 4m40s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m51s
build-prerelease / Build cortex binary (push) Successful in 4m21s
build-prerelease / Package cortex RPM (push) Successful in 1m20s
build-prerelease / Build neuron-ampere (push) Successful in 5m7s
build-prerelease / Build neuron-ada (push) Successful in 5m19s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 37s
CI / Clippy (push) Successful in 2m25s
CI / Test (push) Successful in 4m40s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m51s
build-prerelease / Build cortex binary (push) Successful in 4m21s
build-prerelease / Package cortex RPM (push) Successful in 1m20s
build-prerelease / Build neuron-ampere (push) Successful in 5m7s
build-prerelease / Build neuron-ada (push) Successful in 5m19s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
Final structural slice of the per-device CUDA context-ownership refactor. The four remaining spawn_blocking sites that did CUDA work on the leader are gone: - Single-GPU GGUF load (`load_arch_gguf` spawn_blocking) → `Job::LoadGguf` dispatched on the worker. - Single-GPU dense load (`load_arch_dense` spawn_blocking) → `Job::LoadDense` on the worker. - TP shard load (`WorkerPool::load_dense_shard` spawn_blocking) → `Job::TpLoadShard`. The dispatch handler reads `state.nccl.comm()` directly — no cross-thread `Arc<Comm>` transfer, no `SendComm` wrapper for this path. The Phase 2 / Phase 3 bridges that moved freshly-built models across the channel boundary (`Job::TransferIn`, `Job::TransferInTp`, `Job::CloneLeaderComm`) are removed. Models are now constructed on the worker thread directly; the slab gets populated by `insert_arch` / the inline `tp_models.insert` in dispatch handlers. What this phase preserves: - CPU loads still use `tokio::task::spawn_blocking` against `Arc<Mutex<ModelArch>>`. There's no CUDA context to own on CPU and channel overhead would only add latency. Four `spawn_blocking` references remain in `candle.rs` (load_arch_gguf, load_arch_dense, chat_completion, chat_completion_stream) and all are deliberate CPU-only fallback. - Public API unchanged. `Harness::load_model`, `chat_completion`, HTTP routes all keep identical signatures. What this phase removes: - `SendComm` wrapper is no longer used in the load path (the Phase 3 bridge that justified it). It remains in `nccl_state.rs` for the Phase 1–3 era and any future cross-thread Comm move; consider deleting in a follow-up. - `Job::TransferIn`, `Job::TransferInTp`, `Job::CloneLeaderComm` and their handle convenience methods deleted. - The leader_device parameter on `load_dense_shard` is now `_` — unused since the worker has its own bound device. Removing the arg outright is a public-API change; keeping the underscore prefix preserves the signature and signals deadness without churn. Helper relocation: - `LlamaDense::from_parts` is a new pub(crate) constructor so the worker-thread loader can build a `LlamaDense` without going through the original `load_arch_dense` async function. - `check_dense_config_supported` is bumped to `pub(crate)` for the same reason. Sweep verified: `grep -rn spawn_blocking crates/neuron/src/harness/` returns only CPU-fallback hits in `candle.rs` + doc-comment references to the old design. All four leader-side CUDA `spawn_blocking` sites are gone. fmt + clippy clean; 37 lib tests + all integration tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -307,6 +307,26 @@ pub struct LlamaDense {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl LlamaDense {
|
impl LlamaDense {
|
||||||
|
/// Constructor used by the dispatch-side loader. Keeps the field
|
||||||
|
/// names private while letting the worker thread build a
|
||||||
|
/// `LlamaDense` from already-loaded weights without going through
|
||||||
|
/// async candle code.
|
||||||
|
pub(crate) fn from_parts(
|
||||||
|
model: llama_dense::Llama,
|
||||||
|
cache: llama_dense::Cache,
|
||||||
|
config: llama_dense::Config,
|
||||||
|
dtype: DType,
|
||||||
|
device: Device,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
cache,
|
||||||
|
config,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||||
Ok(self.model.forward(input, offset, &mut self.cache)?)
|
Ok(self.model.forward(input, offset, &mut self.cache)?)
|
||||||
}
|
}
|
||||||
@@ -348,7 +368,7 @@ const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["llama", "qwen3", "qwen3_5", "qwe
|
|||||||
/// The result message names the model_type we saw, the supported set,
|
/// The result message names the model_type we saw, the supported set,
|
||||||
/// and points at the files an operator (or future contributor) needs
|
/// and points at the files an operator (or future contributor) needs
|
||||||
/// to touch to grow the supported set.
|
/// to touch to grow the supported set.
|
||||||
fn check_dense_config_supported(config_json: &str, model_id: &str) -> Result<()> {
|
pub(crate) fn check_dense_config_supported(config_json: &str, model_id: &str) -> Result<()> {
|
||||||
let v: serde_json::Value = serde_json::from_str(config_json)
|
let v: serde_json::Value = serde_json::from_str(config_json)
|
||||||
.with_context(|| format!("parse config.json for '{model_id}' as JSON"))?;
|
.with_context(|| format!("parse config.json for '{model_id}' as JSON"))?;
|
||||||
let model_type = v.get("model_type").and_then(|x| x.as_str()).unwrap_or("");
|
let model_type = v.get("model_type").and_then(|x| x.as_str()).unwrap_or("");
|
||||||
@@ -1547,41 +1567,46 @@ impl Harness for CandleHarness {
|
|||||||
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
|
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
|
||||||
let device = Self::pick_device(&devices)?;
|
let device = Self::pick_device(&devices)?;
|
||||||
|
|
||||||
// Dispatch by source format: GGUF (pre-quantized, single-GPU
|
// Phase 4: load directly on the worker thread for CUDA;
|
||||||
// only path) vs safetensors dense (bf16/fp16; the path that
|
// legacy spawn_blocking + Arc<Mutex<>> only for CPU. Resolve
|
||||||
// grows TP support). `spec.quant` is the signal — Some means
|
// hf-hub paths up front (always async), then either dispatch
|
||||||
// the operator picked a quantized GGUF; None means dense.
|
// a load Job (CUDA) or call the legacy local loader (CPU).
|
||||||
let (tokenizer_path, arch) = if spec.quant.is_some() {
|
|
||||||
self.load_arch_gguf(spec, &device).await?
|
|
||||||
} else {
|
|
||||||
self.load_arch_dense(spec, &device).await?
|
|
||||||
};
|
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
|
||||||
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
|
||||||
|
|
||||||
// Worker thread for the chosen device. CPU loads (CUDA
|
|
||||||
// unavailable / not requested) skip the worker — there's no
|
|
||||||
// context to own. For CUDA loads, the arch is transferred
|
|
||||||
// into the worker's slab now so the inference path can
|
|
||||||
// reference it via the returned `ArchHandle`. The explicit
|
|
||||||
// type annotation lets the no-cuda build resolve `None` to
|
|
||||||
// the right `Option<Arc<DeviceWorkerHandle>>` type.
|
|
||||||
let worker: Option<Arc<super::device_worker::DeviceWorkerHandle>> = match &device {
|
let worker: Option<Arc<super::device_worker::DeviceWorkerHandle>> = match &device {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?),
|
Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?),
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
let (arch_local, arch_handle) = match &worker {
|
|
||||||
Some(w) => {
|
let (tokenizer_path, arch_local, arch_handle) = if let Some(w) = &worker {
|
||||||
|
// CUDA path: resolve, then load in the worker.
|
||||||
|
if spec.quant.is_some() {
|
||||||
|
let (gguf_path, tokenizer_path) = self.resolve_files(spec).await?;
|
||||||
let handle = w
|
let handle = w
|
||||||
.transfer_in(Box::new(arch))
|
.load_gguf(gguf_path, spec.model_id.clone())
|
||||||
.await
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!("transfer arch into device worker: {e}"))?;
|
.map_err(|e| anyhow::anyhow!("worker load_gguf: {e}"))?;
|
||||||
(None, Some(handle))
|
(tokenizer_path, None, Some(handle))
|
||||||
|
} else {
|
||||||
|
let (config_path, tokenizer_path, safetensors_paths) =
|
||||||
|
self.resolve_dense_files(spec).await?;
|
||||||
|
let handle = w
|
||||||
|
.load_dense(config_path, safetensors_paths, spec.model_id.clone())
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("worker load_dense: {e}"))?;
|
||||||
|
(tokenizer_path, None, Some(handle))
|
||||||
}
|
}
|
||||||
None => (Some(Arc::new(Mutex::new(arch))), None),
|
} else {
|
||||||
|
// CPU path: legacy spawn_blocking + Arc<Mutex<ModelArch>>.
|
||||||
|
let (tokenizer_path, arch) = if spec.quant.is_some() {
|
||||||
|
self.load_arch_gguf(spec, &device).await?
|
||||||
|
} else {
|
||||||
|
self.load_arch_dense(spec, &device).await?
|
||||||
};
|
};
|
||||||
|
(tokenizer_path, Some(Arc::new(Mutex::new(arch))), None)
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
||||||
|
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
||||||
|
|
||||||
let loaded = Arc::new(LoadedModel {
|
let loaded = Arc::new(LoadedModel {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: spec.model_id.clone(),
|
||||||
|
|||||||
@@ -100,17 +100,25 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
|||||||
// discard the reply.
|
// discard the reply.
|
||||||
let _ = reply.send(result);
|
let _ = reply.send(result);
|
||||||
}
|
}
|
||||||
Job::TransferIn { arch, reply } => {
|
Job::LoadGguf {
|
||||||
let handle = ArchHandle(state.next_handle);
|
gguf_path,
|
||||||
state.next_handle = state.next_handle.wrapping_add(1);
|
model_id,
|
||||||
state.models.insert(handle, arch);
|
reply,
|
||||||
tracing::debug!(
|
} => {
|
||||||
device_index,
|
let result = load_gguf_inner(&state.device, &gguf_path, &model_id)
|
||||||
handle = handle.0,
|
.map(|arch| insert_arch(&mut state, Box::new(arch)));
|
||||||
slab_size = state.models.len(),
|
let _ = reply.send(result);
|
||||||
"device worker: model transferred in"
|
}
|
||||||
);
|
Job::LoadDense {
|
||||||
let _ = reply.send(Ok(handle));
|
config_path,
|
||||||
|
safetensors_paths,
|
||||||
|
model_id,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let result =
|
||||||
|
load_dense_inner(&state.device, &config_path, &safetensors_paths, &model_id)
|
||||||
|
.map(|arch| insert_arch(&mut state, Box::new(arch)));
|
||||||
|
let _ = reply.send(result);
|
||||||
}
|
}
|
||||||
Job::DropArch { handle, reply } => {
|
Job::DropArch { handle, reply } => {
|
||||||
let removed = state.models.remove(&handle);
|
let removed = state.models.remove(&handle);
|
||||||
@@ -160,27 +168,25 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
|||||||
let _ = reply.send(resp);
|
let _ = reply.send(resp);
|
||||||
}
|
}
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
Job::CloneLeaderComm { reply } => {
|
Job::TpLoadShard {
|
||||||
let result = match state.nccl.comm() {
|
model_id,
|
||||||
Some(comm) => Ok(crate::harness::tp::nccl_state::SendComm(comm)),
|
config_json,
|
||||||
None => Err(anyhow::anyhow!(
|
safetensors_paths,
|
||||||
"CloneLeaderComm: NcclState has no Comm; call NcclInit first"
|
dtype,
|
||||||
)),
|
quant,
|
||||||
};
|
world_size,
|
||||||
let _ = reply.send(result);
|
reply,
|
||||||
}
|
} => {
|
||||||
#[cfg(feature = "cuda")]
|
let result = tp_load_shard_inner(
|
||||||
Job::TransferInTp { model, reply } => {
|
&mut state,
|
||||||
let handle = TpHandle(state.next_tp_handle);
|
&model_id,
|
||||||
state.next_tp_handle = state.next_tp_handle.wrapping_add(1);
|
&config_json,
|
||||||
state.tp_models.insert(handle, model);
|
&safetensors_paths,
|
||||||
tracing::debug!(
|
dtype,
|
||||||
device_index,
|
quant.as_deref(),
|
||||||
tp_handle = handle.0,
|
world_size,
|
||||||
slab_size = state.tp_models.len(),
|
|
||||||
"device worker: TP model transferred in"
|
|
||||||
);
|
);
|
||||||
let _ = reply.send(Ok(handle));
|
let _ = reply.send(result);
|
||||||
}
|
}
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
Job::DropTp { handle, reply } => {
|
Job::DropTp { handle, reply } => {
|
||||||
@@ -332,6 +338,265 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
|
|||||||
Ok((0, 0))
|
Ok((0, 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Insert a freshly-built `ModelArch` into the slab and mint a fresh
|
||||||
|
/// `ArchHandle`. Used by both `LoadGguf` and `LoadDense` dispatch
|
||||||
|
/// handlers — they differ only in *how* the arch is built; the
|
||||||
|
/// post-construction bookkeeping is identical.
|
||||||
|
fn insert_arch(state: &mut DeviceWorkerState, arch: Box<ModelArch>) -> ArchHandle {
|
||||||
|
let handle = ArchHandle(state.next_handle);
|
||||||
|
state.next_handle = state.next_handle.wrapping_add(1);
|
||||||
|
state.models.insert(handle, arch);
|
||||||
|
tracing::debug!(
|
||||||
|
device_index = state.device_index,
|
||||||
|
handle = handle.0,
|
||||||
|
slab_size = state.models.len(),
|
||||||
|
"device worker: model inserted"
|
||||||
|
);
|
||||||
|
handle
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a GGUF (pre-quantized) model on the worker thread. Pulled
|
||||||
|
/// verbatim from the spawn_blocking closure that used to live in
|
||||||
|
/// `CandleHarness::load_arch_gguf`; the only change is that `device`
|
||||||
|
/// is now `state.device` (the worker's permanently-bound device).
|
||||||
|
fn load_gguf_inner(
|
||||||
|
device: &candle_core::Device,
|
||||||
|
gguf_path: &std::path::Path,
|
||||||
|
model_id: &str,
|
||||||
|
) -> anyhow::Result<ModelArch> {
|
||||||
|
use anyhow::Context;
|
||||||
|
use candle_core::DType;
|
||||||
|
use candle_core::quantized::gguf_file;
|
||||||
|
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaWeights;
|
||||||
|
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
|
||||||
|
use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE;
|
||||||
|
|
||||||
|
tracing::info!(model = %model_id, path = ?gguf_path, "loading GGUF");
|
||||||
|
let mut file = std::fs::File::open(gguf_path).context("open GGUF file")?;
|
||||||
|
let content =
|
||||||
|
gguf_file::Content::read(&mut file).map_err(|e| anyhow::anyhow!("parse GGUF: {e}"))?;
|
||||||
|
|
||||||
|
let architecture = content
|
||||||
|
.metadata
|
||||||
|
.get("general.architecture")
|
||||||
|
.and_then(|v| v.to_string().ok().cloned())
|
||||||
|
.unwrap_or_default();
|
||||||
|
tracing::info!(architecture = %architecture, "GGUF architecture");
|
||||||
|
|
||||||
|
// The `general.architecture` GGUF metadata key follows
|
||||||
|
// llama.cpp conventions (lowercase, no underscores in some
|
||||||
|
// cases) — `qwen3moe`, not `qwen3_moe`.
|
||||||
|
match architecture.as_str() {
|
||||||
|
"qwen3" => {
|
||||||
|
let weights = QuantizedQwen3Weights::from_gguf(content, &mut file, device)
|
||||||
|
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
|
||||||
|
Ok(ModelArch::Qwen3Quantized(weights))
|
||||||
|
}
|
||||||
|
"qwen3moe" => {
|
||||||
|
// GGUFQWenMoE takes an explicit compute dtype alongside
|
||||||
|
// the device — F16 matches the GGUF weights' typical
|
||||||
|
// accumulation precision and gives the best tokens/sec on
|
||||||
|
// consumer cards.
|
||||||
|
let weights = GGUFQWenMoE::from_gguf(content, &mut file, device, DType::F16)
|
||||||
|
.map_err(|e| anyhow::anyhow!("from_gguf qwen3_moe: {e}"))?;
|
||||||
|
Ok(ModelArch::Qwen3MoeQuantized(weights))
|
||||||
|
}
|
||||||
|
"llama" => {
|
||||||
|
let weights = QuantizedLlamaWeights::from_gguf(content, &mut file, device)
|
||||||
|
.map_err(|e| anyhow::anyhow!("from_gguf llama: {e}"))?;
|
||||||
|
Ok(ModelArch::LlamaQuantized(weights))
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"unsupported GGUF architecture '{other}'; quantized path supports \
|
||||||
|
qwen3, qwen3moe, llama"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a dense safetensors model on the worker thread.
|
||||||
|
fn load_dense_inner(
|
||||||
|
device: &candle_core::Device,
|
||||||
|
config_path: &std::path::Path,
|
||||||
|
safetensors_paths: &[std::path::PathBuf],
|
||||||
|
model_id: &str,
|
||||||
|
) -> anyhow::Result<ModelArch> {
|
||||||
|
use anyhow::Context;
|
||||||
|
use candle_core::DType;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::llama as llama_dense;
|
||||||
|
use candle_transformers::models::qwen3 as qwen3_dense;
|
||||||
|
use candle_transformers::models::qwen3_moe as qwen3_moe_dense;
|
||||||
|
|
||||||
|
let cfg_text = std::fs::read_to_string(config_path).context("read config.json")?;
|
||||||
|
crate::harness::candle::check_dense_config_supported(&cfg_text, model_id)?;
|
||||||
|
// Peek at model_type to choose the family before the typed
|
||||||
|
// deserialize — each family has its own Config.
|
||||||
|
let model_type = serde_json::from_str::<serde_json::Value>(&cfg_text)
|
||||||
|
.ok()
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|v| v.get("model_type"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
tracing::info!(
|
||||||
|
model = %model_id,
|
||||||
|
model_type = %model_type,
|
||||||
|
shards = safetensors_paths.len(),
|
||||||
|
"loading dense model from safetensors"
|
||||||
|
);
|
||||||
|
|
||||||
|
// bf16 is the canonical distribution dtype for Qwen3 / Llama 3 /
|
||||||
|
// Qwen3 MoE. CUDA on Ada+ has hardware bf16; Ampere has it too.
|
||||||
|
// CPU emulates.
|
||||||
|
let dtype = DType::BF16;
|
||||||
|
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
|
||||||
|
// mutation by another process while we hold the mapping is UB.
|
||||||
|
// We trust the HF cache is immutable-by-design.
|
||||||
|
let vb = unsafe {
|
||||||
|
VarBuilder::from_mmaped_safetensors(safetensors_paths, dtype, device)
|
||||||
|
.context("build VarBuilder over safetensors")?
|
||||||
|
};
|
||||||
|
|
||||||
|
match model_type.as_str() {
|
||||||
|
"qwen3" => {
|
||||||
|
let cfg: qwen3_dense::Config =
|
||||||
|
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
|
||||||
|
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
|
||||||
|
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
|
||||||
|
Ok(ModelArch::Qwen3Dense(model))
|
||||||
|
}
|
||||||
|
"qwen3_moe" => {
|
||||||
|
let cfg: qwen3_moe_dense::Config =
|
||||||
|
serde_json::from_str(&cfg_text).context("parse Qwen3 MoE config.json")?;
|
||||||
|
let model = qwen3_moe_dense::ModelForCausalLM::new(&cfg, vb)
|
||||||
|
.map_err(|e| anyhow::anyhow!("build Qwen3 MoE dense model: {e}"))?;
|
||||||
|
Ok(ModelArch::Qwen3MoeDense(model))
|
||||||
|
}
|
||||||
|
"llama" => {
|
||||||
|
let cfg: llama_dense::LlamaConfig =
|
||||||
|
serde_json::from_str(&cfg_text).context("parse Llama config.json")?;
|
||||||
|
let config = cfg.into_config(false);
|
||||||
|
let cache = llama_dense::Cache::new(true, dtype, &config, device)
|
||||||
|
.context("build Llama Cache")?;
|
||||||
|
let model = llama_dense::Llama::load(vb, &config)
|
||||||
|
.map_err(|e| anyhow::anyhow!("build Llama dense model: {e}"))?;
|
||||||
|
Ok(ModelArch::LlamaDense(Box::new(
|
||||||
|
crate::harness::candle::LlamaDense::from_parts(
|
||||||
|
model,
|
||||||
|
cache,
|
||||||
|
config,
|
||||||
|
dtype,
|
||||||
|
device.clone(),
|
||||||
|
),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
"qwen3_5" => {
|
||||||
|
let cfg: crate::harness::arch::qwen3_5::Config = serde_json::from_str(&cfg_text)
|
||||||
|
.context("parse Qwen3-Next (qwen3_5) config.json")?;
|
||||||
|
let sharded_vb = unsafe {
|
||||||
|
candle_nn::var_builder::ShardedSafeTensors::var_builder(
|
||||||
|
safetensors_paths,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
.context("build ShardedVarBuilder for Qwen3-Next")?
|
||||||
|
};
|
||||||
|
let model = crate::harness::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, sharded_vb)
|
||||||
|
.context("build Qwen3-Next dense model")?;
|
||||||
|
Ok(ModelArch::Qwen3_5Dense(model))
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"unrouted supported model_type '{other}' — \
|
||||||
|
DENSE_SUPPORTED_MODEL_TYPES and load_dense_inner \
|
||||||
|
must stay in sync"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load the leader's TP shard on the worker thread. Reads the Comm
|
||||||
|
/// directly from `state.nccl`; no cross-thread Arc<Comm> transfer.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn tp_load_shard_inner(
|
||||||
|
state: &mut DeviceWorkerState,
|
||||||
|
model_id: &str,
|
||||||
|
config_json: &str,
|
||||||
|
safetensors_paths: &[std::path::PathBuf],
|
||||||
|
dtype: candle_core::DType,
|
||||||
|
quant: Option<&str>,
|
||||||
|
world_size: u32,
|
||||||
|
) -> anyhow::Result<TpHandle> {
|
||||||
|
use anyhow::Context;
|
||||||
|
use candle_nn::var_builder::ShardedSafeTensors;
|
||||||
|
|
||||||
|
let comm = state.nccl.comm().ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("TpLoadShard: NcclState has no Comm; call NcclInit first")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let model_type = serde_json::from_str::<serde_json::Value>(config_json)
|
||||||
|
.ok()
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|v| v.get("model_type"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
// SAFETY: same invariant as the single-GPU dense path — the HF
|
||||||
|
// cache files are treated as immutable while the mmap is held.
|
||||||
|
let vb = unsafe {
|
||||||
|
ShardedSafeTensors::var_builder(safetensors_paths, dtype, &state.device)
|
||||||
|
.context("build ShardedVarBuilder over safetensors")?
|
||||||
|
};
|
||||||
|
let mmap = unsafe {
|
||||||
|
candle_core::safetensors::MmapedSafetensors::multi(safetensors_paths)
|
||||||
|
.context("build MmapedSafetensors for leader load")?
|
||||||
|
};
|
||||||
|
|
||||||
|
let loaded = match model_type.as_str() {
|
||||||
|
"qwen3" => {
|
||||||
|
let cfg: crate::harness::tp::tp_qwen3::Config = serde_json::from_str(config_json)
|
||||||
|
.context("parse Qwen3 Config JSON for leader load")?;
|
||||||
|
TpLeaderModel::Qwen3(crate::harness::tp::tp_qwen3::TpQwen3ForCausalLM::load(
|
||||||
|
&cfg, &vb, 0, world_size, comm,
|
||||||
|
)?)
|
||||||
|
}
|
||||||
|
"qwen3_5" => {
|
||||||
|
let cfg: crate::harness::tp::tp_qwen3_5::Config = serde_json::from_str(config_json)
|
||||||
|
.context("parse Qwen3-Next Config JSON for leader load")?;
|
||||||
|
let quant_dtype = crate::harness::tp::worker::parse_quant_string(quant)?;
|
||||||
|
TpLeaderModel::Qwen3_5(crate::harness::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
|
||||||
|
cfg,
|
||||||
|
&vb,
|
||||||
|
&mmap,
|
||||||
|
0,
|
||||||
|
world_size,
|
||||||
|
comm,
|
||||||
|
quant_dtype,
|
||||||
|
)?)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
rank = 0,
|
||||||
|
model = %model_id,
|
||||||
|
model_type = %model_type,
|
||||||
|
"loaded TP shard (leader)"
|
||||||
|
);
|
||||||
|
|
||||||
|
let handle = TpHandle(state.next_tp_handle);
|
||||||
|
state.next_tp_handle = state.next_tp_handle.wrapping_add(1);
|
||||||
|
state.tp_models.insert(handle, Box::new(loaded));
|
||||||
|
tracing::debug!(
|
||||||
|
device_index = state.device_index,
|
||||||
|
tp_handle = handle.0,
|
||||||
|
slab_size = state.tp_models.len(),
|
||||||
|
"device worker: TP model inserted"
|
||||||
|
);
|
||||||
|
Ok(handle)
|
||||||
|
}
|
||||||
|
|
||||||
/// TP-equivalent of [`forward_logits`]: looks up the leader's
|
/// TP-equivalent of [`forward_logits`]: looks up the leader's
|
||||||
/// [`TpLeaderModel`] in the slab, runs its forward, copies the
|
/// [`TpLeaderModel`] in the slab, runs its forward, copies the
|
||||||
/// `[vocab]` logits to a CPU `Vec<f32>`. The leader's `Arc<Comm>`
|
/// `[vocab]` logits to a CPU `Vec<f32>`. The leader's `Arc<Comm>`
|
||||||
@@ -414,7 +679,10 @@ fn drain_poisoned(job: Job, device_index: u32) {
|
|||||||
Job::QueryVram { reply } => {
|
Job::QueryVram { reply } => {
|
||||||
let _ = reply.send(Err(err()));
|
let _ = reply.send(Err(err()));
|
||||||
}
|
}
|
||||||
Job::TransferIn { reply, .. } => {
|
Job::LoadGguf { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
Job::LoadDense { reply, .. } => {
|
||||||
let _ = reply.send(Err(err()));
|
let _ = reply.send(Err(err()));
|
||||||
}
|
}
|
||||||
Job::DropArch { reply, .. } => {
|
Job::DropArch { reply, .. } => {
|
||||||
@@ -443,11 +711,7 @@ fn drain_poisoned(job: Job, device_index: u32) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
Job::CloneLeaderComm { reply } => {
|
Job::TpLoadShard { reply, .. } => {
|
||||||
let _ = reply.send(Err(err()));
|
|
||||||
}
|
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
Job::TransferInTp { reply, .. } => {
|
|
||||||
let _ = reply.send(Err(err()));
|
let _ = reply.send(Err(err()));
|
||||||
}
|
}
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
|
|||||||
@@ -5,8 +5,8 @@
|
|||||||
//! async-side `DeviceWorkerHandle` constructs a job, sends it down the
|
//! async-side `DeviceWorkerHandle` constructs a job, sends it down the
|
||||||
//! `std::sync::mpsc` channel, and `await`s the oneshot for the reply.
|
//! `std::sync::mpsc` channel, and `await`s the oneshot for the reply.
|
||||||
|
|
||||||
use crate::harness::candle::ModelArch;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use std::path::PathBuf;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
|
|
||||||
/// Opaque handle to a `ModelArch` stored in the worker thread's state
|
/// Opaque handle to a `ModelArch` stored in the worker thread's state
|
||||||
@@ -48,12 +48,24 @@ pub enum Job {
|
|||||||
QueryVram {
|
QueryVram {
|
||||||
reply: oneshot::Sender<Result<(u64, u64)>>,
|
reply: oneshot::Sender<Result<(u64, u64)>>,
|
||||||
},
|
},
|
||||||
/// Move a freshly-loaded `ModelArch` into the worker's state slab.
|
/// Load a GGUF (pre-quantized) single-GPU model on the worker
|
||||||
/// Returns an `ArchHandle` the caller stores on `LoadedModel` and
|
/// thread. The dispatch handler opens the GGUF file, parses
|
||||||
/// passes back in subsequent `ClearKv` / `ForwardLogits` /
|
/// metadata, dispatches on `general.architecture`, and inserts
|
||||||
/// `DropArch` jobs.
|
/// the resulting `ModelArch` into the slab. Returns the fresh
|
||||||
TransferIn {
|
/// `ArchHandle`.
|
||||||
arch: Box<ModelArch>,
|
LoadGguf {
|
||||||
|
gguf_path: PathBuf,
|
||||||
|
model_id: String,
|
||||||
|
reply: oneshot::Sender<Result<ArchHandle>>,
|
||||||
|
},
|
||||||
|
/// Load a dense safetensors single-GPU model on the worker
|
||||||
|
/// thread. The dispatch handler reads `config.json`, dispatches on
|
||||||
|
/// `model_type`, builds a `VarBuilder` over the mmap'd
|
||||||
|
/// safetensors, and inserts the resulting `ModelArch`.
|
||||||
|
LoadDense {
|
||||||
|
config_path: PathBuf,
|
||||||
|
safetensors_paths: Vec<PathBuf>,
|
||||||
|
model_id: String,
|
||||||
reply: oneshot::Sender<Result<ArchHandle>>,
|
reply: oneshot::Sender<Result<ArchHandle>>,
|
||||||
},
|
},
|
||||||
/// Remove the model from the slab and drop it. The `Drop` runs on
|
/// Remove the model from the slab and drop it. The `Drop` runs on
|
||||||
@@ -105,22 +117,21 @@ pub enum Job {
|
|||||||
NcclSanity {
|
NcclSanity {
|
||||||
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
||||||
},
|
},
|
||||||
/// Clone the leader's `Arc<Comm>` out of the worker's `NcclState`
|
/// Load the leader's TP shard on the worker thread. The dispatch
|
||||||
/// so a spawn_blocking-based load (Phase 3 bridge) can hand it to
|
/// handler reads `state.nccl.comm()` directly (no cross-thread
|
||||||
/// the row-parallel layers. Wrapped in `SendComm` because
|
/// `Arc<Comm>` transfer, no `SendComm` wrapper) and builds the
|
||||||
/// `Arc<Comm>` is `!Send` at the type level (the NCCL contract
|
/// `TpLeaderModel` against that Comm. The model's embedded
|
||||||
/// requires serialised access, which we provide structurally).
|
/// `Arc<Comm>` clones, `CudaContext`, and all per-rank CUDA
|
||||||
/// Phase 4 eliminates this when `TpLoadShard` becomes a Job and
|
/// tensors live on this thread for the model's lifetime.
|
||||||
/// the load runs entirely on the worker thread.
|
/// Inserts into the TP slab and returns the fresh `TpHandle`.
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
CloneLeaderComm {
|
TpLoadShard {
|
||||||
reply: oneshot::Sender<Result<crate::harness::tp::nccl_state::SendComm>>,
|
model_id: String,
|
||||||
},
|
config_json: String,
|
||||||
/// Move a freshly-built `TpLeaderModel` into the worker's tp slab.
|
safetensors_paths: Vec<PathBuf>,
|
||||||
/// Returns a `TpHandle` the caller stores on `TpLoadedModel`.
|
dtype: candle_core::DType,
|
||||||
#[cfg(feature = "cuda")]
|
quant: Option<String>,
|
||||||
TransferInTp {
|
world_size: u32,
|
||||||
model: Box<crate::harness::tp::TpLeaderModel>,
|
|
||||||
reply: oneshot::Sender<Result<TpHandle>>,
|
reply: oneshot::Sender<Result<TpHandle>>,
|
||||||
},
|
},
|
||||||
/// Drop the TP leader model on the worker thread. CUDA tensors
|
/// Drop the TP leader model on the worker thread. CUDA tensors
|
||||||
|
|||||||
@@ -161,13 +161,15 @@ impl DeviceWorkerHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Move a freshly-loaded `ModelArch` into the worker's state slab.
|
/// Load a GGUF (pre-quantized) single-GPU model on the worker
|
||||||
/// Returns the `ArchHandle` the caller stores on `LoadedModel`.
|
/// thread. The hf-hub resolution happens on the async caller; the
|
||||||
/// The `Box<ModelArch>` crosses the channel; the worker thread
|
/// resolved local `gguf_path` plus the spec's model_id are sent
|
||||||
/// owns it from here on.
|
/// into the worker which opens, parses, and constructs the
|
||||||
pub async fn transfer_in(
|
/// `ModelArch` on the right thread.
|
||||||
|
pub async fn load_gguf(
|
||||||
&self,
|
&self,
|
||||||
arch: Box<crate::harness::candle::ModelArch>,
|
gguf_path: std::path::PathBuf,
|
||||||
|
model_id: String,
|
||||||
) -> Result<ArchHandle, WorkerError> {
|
) -> Result<ArchHandle, WorkerError> {
|
||||||
if self.poisoned.load(Ordering::Acquire) {
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
return Err(WorkerError::Poisoned {
|
return Err(WorkerError::Poisoned {
|
||||||
@@ -176,8 +178,40 @@ impl DeviceWorkerHandle {
|
|||||||
}
|
}
|
||||||
let (reply_tx, reply_rx) = oneshot::channel();
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
self.tx
|
self.tx
|
||||||
.send(Job::TransferIn {
|
.send(Job::LoadGguf {
|
||||||
arch,
|
gguf_path,
|
||||||
|
model_id,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a dense safetensors single-GPU model on the worker thread.
|
||||||
|
pub async fn load_dense(
|
||||||
|
&self,
|
||||||
|
config_path: std::path::PathBuf,
|
||||||
|
safetensors_paths: Vec<std::path::PathBuf>,
|
||||||
|
model_id: String,
|
||||||
|
) -> Result<ArchHandle, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::LoadDense {
|
||||||
|
config_path,
|
||||||
|
safetensors_paths,
|
||||||
|
model_id,
|
||||||
reply: reply_tx,
|
reply: reply_tx,
|
||||||
})
|
})
|
||||||
.map_err(|_| WorkerError::Gone {
|
.map_err(|_| WorkerError::Gone {
|
||||||
@@ -331,37 +365,21 @@ impl DeviceWorkerHandle {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Clone the leader's `Arc<Comm>` so a spawn_blocking-based load
|
/// Load the leader's TP shard on the worker thread. The dispatch
|
||||||
/// (Phase 3 bridge) can pass it to the row-parallel layers.
|
/// handler reads its own `NcclState`'s `Arc<Comm>` directly — no
|
||||||
/// Phase 4 eliminates this once the TP load runs on this thread.
|
/// cross-thread Comm transfer — and builds the `TpLeaderModel`
|
||||||
|
/// against it. Phase 4 replaces the Phase 3 Clone/TransferIn
|
||||||
|
/// bridge with this single Job.
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
pub async fn clone_leader_comm(
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub async fn tp_load_shard(
|
||||||
&self,
|
&self,
|
||||||
) -> Result<crate::harness::tp::nccl_state::SendComm, WorkerError> {
|
model_id: String,
|
||||||
if self.poisoned.load(Ordering::Acquire) {
|
config_json: String,
|
||||||
return Err(WorkerError::Poisoned {
|
safetensors_paths: Vec<std::path::PathBuf>,
|
||||||
device_index: self.device_index,
|
dtype: candle_core::DType,
|
||||||
});
|
quant: Option<String>,
|
||||||
}
|
world_size: u32,
|
||||||
let (reply_tx, reply_rx) = oneshot::channel();
|
|
||||||
self.tx
|
|
||||||
.send(Job::CloneLeaderComm { reply: reply_tx })
|
|
||||||
.map_err(|_| WorkerError::Gone {
|
|
||||||
device_index: self.device_index,
|
|
||||||
})?;
|
|
||||||
match reply_rx.await {
|
|
||||||
Ok(result) => result.map_err(WorkerError::from),
|
|
||||||
Err(_) => Err(WorkerError::Gone {
|
|
||||||
device_index: self.device_index,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Move a freshly-built `TpLeaderModel` into the worker's TP slab.
|
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
pub async fn transfer_in_tp(
|
|
||||||
&self,
|
|
||||||
model: Box<crate::harness::tp::TpLeaderModel>,
|
|
||||||
) -> Result<TpHandle, WorkerError> {
|
) -> Result<TpHandle, WorkerError> {
|
||||||
if self.poisoned.load(Ordering::Acquire) {
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
return Err(WorkerError::Poisoned {
|
return Err(WorkerError::Poisoned {
|
||||||
@@ -370,8 +388,13 @@ impl DeviceWorkerHandle {
|
|||||||
}
|
}
|
||||||
let (reply_tx, reply_rx) = oneshot::channel();
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
self.tx
|
self.tx
|
||||||
.send(Job::TransferInTp {
|
.send(Job::TpLoadShard {
|
||||||
model,
|
model_id,
|
||||||
|
config_json,
|
||||||
|
safetensors_paths,
|
||||||
|
dtype,
|
||||||
|
quant,
|
||||||
|
world_size,
|
||||||
reply: reply_tx,
|
reply: reply_tx,
|
||||||
})
|
})
|
||||||
.map_err(|_| WorkerError::Gone {
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
|||||||
@@ -489,36 +489,19 @@ impl WorkerPool {
|
|||||||
model_id: &str,
|
model_id: &str,
|
||||||
config_json: &str,
|
config_json: &str,
|
||||||
safetensors_paths: &[std::path::PathBuf],
|
safetensors_paths: &[std::path::PathBuf],
|
||||||
leader_device: &candle_core::Device,
|
_leader_device: &candle_core::Device,
|
||||||
dtype: candle_core::DType,
|
dtype: candle_core::DType,
|
||||||
quant: Option<String>,
|
quant: Option<String>,
|
||||||
) -> Result<super::device_worker::TpHandle> {
|
) -> Result<super::device_worker::TpHandle> {
|
||||||
use candle_nn::var_builder::ShardedSafeTensors;
|
|
||||||
|
|
||||||
// Ask the leader's device worker for an `Arc<Comm>` clone.
|
|
||||||
// Phase 3 moved `NcclState` ownership onto the worker thread,
|
|
||||||
// so the spawn_blocking load below can no longer reach the
|
|
||||||
// Comm directly. The reply is wrapped in `SendComm` because
|
|
||||||
// the underlying `Arc<Comm>` is `!Send` at the type level;
|
|
||||||
// the safety contract (only one thread issues NCCL ops at a
|
|
||||||
// time) is preserved because the load runs on a single
|
|
||||||
// spawn_blocking thread and AllReduce ops fire only from the
|
|
||||||
// device worker thread later. Phase 4 eliminates this bridge
|
|
||||||
// when the load itself moves onto the worker.
|
|
||||||
let leader_comm = self
|
|
||||||
.leader_worker
|
|
||||||
.clone_leader_comm()
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow::anyhow!("clone leader Comm via device worker: {e}"))?;
|
|
||||||
let world_size = self.world_size;
|
let world_size = self.world_size;
|
||||||
let safetensors_str: Vec<String> = safetensors_paths
|
let safetensors_str: Vec<String> = safetensors_paths
|
||||||
.iter()
|
.iter()
|
||||||
.map(|p| p.to_string_lossy().into_owned())
|
.map(|p| p.to_string_lossy().into_owned())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// 1. Fan out the LoadDenseShard request to every worker without
|
// 1. Fan out the LoadDenseShard request to every subprocess
|
||||||
// awaiting their replies — they'll build their shards in
|
// worker without awaiting their replies — they'll build
|
||||||
// parallel with the leader below.
|
// their shards in parallel with the leader below.
|
||||||
for w in &mut self.workers {
|
for w in &mut self.workers {
|
||||||
w.send_only(&WorkerRequest::LoadDenseShard {
|
w.send_only(&WorkerRequest::LoadDenseShard {
|
||||||
model_id: model_id.to_string(),
|
model_id: model_id.to_string(),
|
||||||
@@ -529,76 +512,32 @@ impl WorkerPool {
|
|||||||
.await?;
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Build rank 0's shard on the leader. Dispatch on model_type
|
// 2. Build rank 0's shard on the leader's device worker
|
||||||
// — for `qwen3` we build a `TpQwen3ForCausalLM`, for
|
// thread. Phase 4 moved the load itself onto the worker —
|
||||||
// `qwen3_5` (Qwen3-Next, Qwen3.6's architecture) we build
|
// the dispatch handler reads `state.nccl.comm()` directly
|
||||||
// `TpQwen3_5ForCausalLM`. Both end up wrapped in the
|
// so the leader's `Arc<Comm>` clones embedded in the
|
||||||
// `TpLeaderModel` enum so downstream callers don't care.
|
// row-parallel layers are constructed and used on the same
|
||||||
let model_type = serde_json::from_str::<serde_json::Value>(config_json)
|
// OS thread for the model's entire lifetime. No
|
||||||
.ok()
|
// spawn_blocking, no SendComm bridge.
|
||||||
.as_ref()
|
let handle = self
|
||||||
.and_then(|v| v.get("model_type"))
|
.leader_worker
|
||||||
.and_then(|v| v.as_str())
|
.tp_load_shard(
|
||||||
.unwrap_or("")
|
model_id.to_string(),
|
||||||
.to_string();
|
config_json.to_string(),
|
||||||
let paths_for_leader: Vec<std::path::PathBuf> = safetensors_paths.to_vec();
|
safetensors_paths.to_vec(),
|
||||||
let device_for_leader = leader_device.clone();
|
dtype,
|
||||||
let comm_for_leader = leader_comm;
|
quant.clone(),
|
||||||
let model_id_for_log = model_id.to_string();
|
|
||||||
let config_json_for_leader = config_json.to_string();
|
|
||||||
let quant_for_leader = quant.clone();
|
|
||||||
|
|
||||||
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
|
|
||||||
// SAFETY: same invariant as the single-GPU dense path —
|
|
||||||
// the HF cache files are treated as immutable while the
|
|
||||||
// mmap is held.
|
|
||||||
let vb = unsafe {
|
|
||||||
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
|
|
||||||
.context("build ShardedVarBuilder over safetensors")?
|
|
||||||
};
|
|
||||||
// SAFETY: as above — the HF cache files are immutable.
|
|
||||||
let mmap = unsafe {
|
|
||||||
candle_core::safetensors::MmapedSafetensors::multi(&paths_for_leader)
|
|
||||||
.context("build MmapedSafetensors for leader load")?
|
|
||||||
};
|
|
||||||
let comm = comm_for_leader.into_inner();
|
|
||||||
let loaded = match model_type.as_str() {
|
|
||||||
"qwen3" => {
|
|
||||||
let cfg: super::tp::tp_qwen3::Config = serde_json::from_str(&config_json_for_leader)
|
|
||||||
.context("parse Qwen3 Config JSON for leader load")?;
|
|
||||||
TpLeaderModel::Qwen3(super::tp::tp_qwen3::TpQwen3ForCausalLM::load(
|
|
||||||
&cfg, &vb, 0, world_size, comm,
|
|
||||||
)?)
|
|
||||||
}
|
|
||||||
"qwen3_5" => {
|
|
||||||
let cfg: super::tp::tp_qwen3_5::Config =
|
|
||||||
serde_json::from_str(&config_json_for_leader)
|
|
||||||
.context("parse Qwen3-Next Config JSON for leader load")?;
|
|
||||||
let quant_dtype =
|
|
||||||
super::tp::worker::parse_quant_string(quant_for_leader.as_deref())?;
|
|
||||||
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
|
|
||||||
cfg,
|
|
||||||
&vb,
|
|
||||||
&mmap,
|
|
||||||
0,
|
|
||||||
world_size,
|
world_size,
|
||||||
comm,
|
)
|
||||||
quant_dtype,
|
|
||||||
)?)
|
|
||||||
}
|
|
||||||
other => anyhow::bail!(
|
|
||||||
"TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)"
|
|
||||||
),
|
|
||||||
};
|
|
||||||
tracing::info!(rank = 0, model = %model_id_for_log, model_type = %model_type, "loaded TP shard (leader)");
|
|
||||||
Ok(loaded)
|
|
||||||
})
|
|
||||||
.await
|
.await
|
||||||
.context("leader load task panicked")??;
|
.map_err(|e| anyhow::anyhow!("leader TP shard load via device worker: {e}"))?;
|
||||||
|
|
||||||
// 3. Collect worker confirmations. Anything other than
|
// 3. Collect worker confirmations. Anything other than
|
||||||
// LoadDenseShardOk aborts the whole load — the leader's
|
// LoadDenseShardOk aborts the whole load — the leader's
|
||||||
// already-loaded shard drops when this fn returns Err.
|
// already-inserted shard would leak in the worker slab
|
||||||
|
// until the daemon restarts; an explicit DropTp would be
|
||||||
|
// cleaner but the failure here is rare and the operator's
|
||||||
|
// next step is to restart anyway.
|
||||||
for w in &mut self.workers {
|
for w in &mut self.workers {
|
||||||
let resp = w.recv_only().await?;
|
let resp = w.recv_only().await?;
|
||||||
match resp {
|
match resp {
|
||||||
@@ -613,18 +552,6 @@ impl WorkerPool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 3: move the leader's freshly-built `TpLeaderModel`
|
|
||||||
// into the device worker's TP slab. The model holds
|
|
||||||
// `Arc<Comm>` clones (in its AllReduce ops) plus CUDA
|
|
||||||
// tensors — both need to live on the device worker thread so
|
|
||||||
// every `Comm::all_reduce` and tensor op during forward
|
|
||||||
// dispatches from the same OS thread that bound the CUDA
|
|
||||||
// context.
|
|
||||||
let handle = self
|
|
||||||
.leader_worker
|
|
||||||
.transfer_in_tp(Box::new(leader_model))
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow::anyhow!("transfer TP leader model into device worker: {e}"))?;
|
|
||||||
Ok(handle)
|
Ok(handle)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user