feat(stage-8d-7): direct safetensors fused-region loader
Some checks failed
build-prerelease / Package cortex RPM (push) Blocked by required conditions
CI / Format (push) Successful in 35s
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Clippy (push) Successful in 2m18s
CI / Test (push) Successful in 4m28s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m51s
build-prerelease / Build cortex binary (push) Successful in 4m13s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Some checks failed
build-prerelease / Package cortex RPM (push) Blocked by required conditions
CI / Format (push) Successful in 35s
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Clippy (push) Successful in 2m18s
CI / Test (push) Successful in 4m28s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m51s
build-prerelease / Build cortex binary (push) Successful in 4m13s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Replaces load_fused_qkv_slice_2d/_3d with reads from a separate MmapedSafetensors handle. Each per-rank fused tensor is built by reading the three region byte-slices directly from the mmap, concatenating them host-side, and uploading as one device allocation — no full-fused-tensor device materialisation. The prior approach allocated a ~100 MB transient device tensor per linear-attention layer; on Qwen3.6-27B with 48 linear-attn layers that's ~4.8 GB of allocator churn during load — enough to fragment the cuda caching allocator on a tight-VRAM 32 GB consumer GPU, which is what triggered the layer-22 up_proj OOM seen on beast. Threading: MmapedSafetensors flows worker → ForCausalLM → Model → DecoderLayer → GatedDeltaNet::load. Both leader (mod.rs) and worker (worker.rs) construct their own mmap; Linux's page cache shares the underlying pages. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2120,6 +2120,7 @@ dependencies = [
|
||||
"half",
|
||||
"hf-hub",
|
||||
"reqwest",
|
||||
"safetensors 0.7.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.18",
|
||||
|
||||
@@ -76,6 +76,11 @@ cudarc = { version = "0.19", optional = true, default-features = false, features
|
||||
half = { version = "2.5", optional = true }
|
||||
tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
|
||||
hf-hub = { version = "0.4", features = ["tokio"] }
|
||||
# Direct dep on `safetensors` (re-exported by candle but its `TensorView`
|
||||
# / `slice::IndexOp` types are public-but-not-re-exported). Used by the
|
||||
# tp `fused_load` module to read per-rank slices of fused QKV tensors
|
||||
# without materialising the full tensor on device.
|
||||
safetensors = "0.7"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["test-util"] }
|
||||
|
||||
213
crates/neuron/src/harness/tp/fused_load.rs
Normal file
213
crates/neuron/src/harness/tp/fused_load.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
//! Direct safetensors readers for fused-region weight tensors.
|
||||
//!
|
||||
//! Qwen3-Next's `in_proj_qkv` and `conv1d` weights are *fused* —
|
||||
//! three regions stored sequentially along dim 0 (`[key_q, key_k,
|
||||
//! value]`). The per-rank shard for each region has unequal size
|
||||
//! (`key_dim/ws` vs `value_dim/ws`), so candle's `ShardedSafeTensors`
|
||||
//! built-in `Shard { dim, rank, world_size }` (uniform split) doesn't
|
||||
//! map to the right slices.
|
||||
//!
|
||||
//! The previous approach loaded the full fused tensor onto the device,
|
||||
//! `narrow`ed the three regions, and `Tensor::cat(...).contiguous()`'d
|
||||
//! the per-rank slice. That left ~100 MB of transient device memory
|
||||
//! per linear-attention layer — 48 layers × 100 MB = ~4.8 GB of
|
||||
//! allocator pressure during load, enough to trigger fragmentation
|
||||
//! OOM on tight-VRAM consumer GPUs.
|
||||
//!
|
||||
//! This module reads the three per-rank byte ranges *directly from
|
||||
//! the safetensors mmap* (host-side), concatenates them into a single
|
||||
//! contiguous byte buffer, and uploads as one device allocation. No
|
||||
//! full-tensor device materialisation.
|
||||
|
||||
use anyhow::{Context, Result, bail};
|
||||
use candle_core::safetensors::MmapedSafetensors;
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
|
||||
/// Read a 2D fused-QKV tensor `[conv_dim, hidden_size]` and return
|
||||
/// this rank's per-region slice as a `[per_rank_conv_dim, hidden_size]`
|
||||
/// device tensor.
|
||||
///
|
||||
/// `tensor_name` must be the fully-qualified safetensors key (e.g.
|
||||
/// `"model.language_model.layers.5.linear_attn.in_proj_qkv.weight"`).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn load_fused_qkv_2d(
|
||||
mmap: &MmapedSafetensors,
|
||||
tensor_name: &str,
|
||||
hidden_size: usize,
|
||||
key_dim: usize,
|
||||
value_dim: usize,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
target_dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let ws = world_size as usize;
|
||||
let r = rank as usize;
|
||||
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
|
||||
bail!(
|
||||
"fused qkv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
|
||||
must each be divisible by world_size ({ws})"
|
||||
);
|
||||
}
|
||||
let per_rank_key = key_dim / ws;
|
||||
let per_rank_value = value_dim / ws;
|
||||
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
|
||||
|
||||
let view = mmap
|
||||
.get(tensor_name)
|
||||
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 2D"))?;
|
||||
let view_dtype: DType = view
|
||||
.dtype()
|
||||
.try_into()
|
||||
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
|
||||
|
||||
let shape = view.shape();
|
||||
if shape.len() != 2 {
|
||||
bail!(
|
||||
"fused qkv tensor '{tensor_name}' has shape {shape:?}, expected 2D \
|
||||
[conv_dim, hidden_size]"
|
||||
);
|
||||
}
|
||||
let conv_dim = key_dim * 2 + value_dim;
|
||||
if shape[0] != conv_dim || shape[1] != hidden_size {
|
||||
bail!(
|
||||
"fused qkv tensor '{tensor_name}' shape {shape:?} \
|
||||
doesn't match expected [{conv_dim}, {hidden_size}]"
|
||||
);
|
||||
}
|
||||
|
||||
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
|
||||
let k_bytes = slice_dim0_bytes(
|
||||
&view,
|
||||
key_dim + r * per_rank_key,
|
||||
per_rank_key,
|
||||
tensor_name,
|
||||
"k",
|
||||
)?;
|
||||
let v_bytes = slice_dim0_bytes(
|
||||
&view,
|
||||
2 * key_dim + r * per_rank_value,
|
||||
per_rank_value,
|
||||
tensor_name,
|
||||
"v",
|
||||
)?;
|
||||
|
||||
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
|
||||
bytes.extend_from_slice(&q_bytes);
|
||||
bytes.extend_from_slice(&k_bytes);
|
||||
bytes.extend_from_slice(&v_bytes);
|
||||
|
||||
let tensor = Tensor::from_raw_buffer(
|
||||
&bytes,
|
||||
view_dtype,
|
||||
&[per_rank_conv_dim, hidden_size],
|
||||
device,
|
||||
)
|
||||
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused qkv '{tensor_name}'"))?;
|
||||
tensor
|
||||
.to_dtype(target_dtype)
|
||||
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
|
||||
}
|
||||
|
||||
/// Read a 3D fused-QKV tensor `[conv_dim, 1, kernel_size]` (the
|
||||
/// depthwise conv1d weight) and return this rank's per-region slice
|
||||
/// as a `[per_rank_conv_dim, 1, kernel_size]` device tensor.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn load_fused_qkv_3d(
|
||||
mmap: &MmapedSafetensors,
|
||||
tensor_name: &str,
|
||||
mid: usize,
|
||||
kernel_size: usize,
|
||||
key_dim: usize,
|
||||
value_dim: usize,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
target_dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let ws = world_size as usize;
|
||||
let r = rank as usize;
|
||||
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
|
||||
bail!(
|
||||
"fused conv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
|
||||
must each be divisible by world_size ({ws})"
|
||||
);
|
||||
}
|
||||
let per_rank_key = key_dim / ws;
|
||||
let per_rank_value = value_dim / ws;
|
||||
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
|
||||
|
||||
let view = mmap
|
||||
.get(tensor_name)
|
||||
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 3D"))?;
|
||||
let view_dtype: DType = view
|
||||
.dtype()
|
||||
.try_into()
|
||||
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
|
||||
|
||||
let shape = view.shape();
|
||||
if shape.len() != 3 {
|
||||
bail!(
|
||||
"fused conv tensor '{tensor_name}' has shape {shape:?}, expected 3D \
|
||||
[conv_dim, mid, kernel_size]"
|
||||
);
|
||||
}
|
||||
let conv_dim = key_dim * 2 + value_dim;
|
||||
if shape[0] != conv_dim || shape[1] != mid || shape[2] != kernel_size {
|
||||
bail!(
|
||||
"fused conv tensor '{tensor_name}' shape {shape:?} \
|
||||
doesn't match expected [{conv_dim}, {mid}, {kernel_size}]"
|
||||
);
|
||||
}
|
||||
|
||||
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
|
||||
let k_bytes = slice_dim0_bytes(
|
||||
&view,
|
||||
key_dim + r * per_rank_key,
|
||||
per_rank_key,
|
||||
tensor_name,
|
||||
"k",
|
||||
)?;
|
||||
let v_bytes = slice_dim0_bytes(
|
||||
&view,
|
||||
2 * key_dim + r * per_rank_value,
|
||||
per_rank_value,
|
||||
tensor_name,
|
||||
"v",
|
||||
)?;
|
||||
|
||||
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
|
||||
bytes.extend_from_slice(&q_bytes);
|
||||
bytes.extend_from_slice(&k_bytes);
|
||||
bytes.extend_from_slice(&v_bytes);
|
||||
|
||||
let tensor = Tensor::from_raw_buffer(
|
||||
&bytes,
|
||||
view_dtype,
|
||||
&[per_rank_conv_dim, mid, kernel_size],
|
||||
device,
|
||||
)
|
||||
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused conv '{tensor_name}'"))?;
|
||||
tensor
|
||||
.to_dtype(target_dtype)
|
||||
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
|
||||
}
|
||||
|
||||
/// Read `len` consecutive rows along dim 0 starting at `start` from
|
||||
/// the safetensors view, returning the raw bytes. Wraps the same
|
||||
/// `view.slice(start..stop)` machinery that candle's
|
||||
/// `ShardedSafeTensors::get` uses internally.
|
||||
fn slice_dim0_bytes(
|
||||
view: &safetensors::tensor::TensorView<'_>,
|
||||
start: usize,
|
||||
len: usize,
|
||||
tensor_name: &str,
|
||||
region: &str,
|
||||
) -> Result<Vec<u8>> {
|
||||
use safetensors::slice::IndexOp;
|
||||
let stop = start + len;
|
||||
let iter = view.slice(start..stop).map_err(|e| {
|
||||
anyhow::anyhow!("slice '{tensor_name}' region {region} ({start}..{stop}): {e:?}")
|
||||
})?;
|
||||
Ok(iter.into_iter().flatten().copied().collect())
|
||||
}
|
||||
@@ -18,6 +18,7 @@
|
||||
//! - **7c:** crash detection, streaming SSE, graceful unload.
|
||||
|
||||
pub mod all_reduce;
|
||||
pub mod fused_load;
|
||||
pub mod nccl_state;
|
||||
pub mod rpc;
|
||||
pub mod tp_linear;
|
||||
@@ -539,6 +540,11 @@ impl WorkerPool {
|
||||
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
|
||||
.context("build ShardedVarBuilder over safetensors")?
|
||||
};
|
||||
// SAFETY: as above — the HF cache files are immutable.
|
||||
let mmap = unsafe {
|
||||
candle_core::safetensors::MmapedSafetensors::multi(&paths_for_leader)
|
||||
.context("build MmapedSafetensors for leader load")?
|
||||
};
|
||||
let comm = comm_for_leader.into_inner();
|
||||
let loaded = match model_type.as_str() {
|
||||
"qwen3" => {
|
||||
@@ -553,7 +559,7 @@ impl WorkerPool {
|
||||
serde_json::from_str(&config_json_for_leader)
|
||||
.context("parse Qwen3-Next Config JSON for leader load")?;
|
||||
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
|
||||
cfg, &vb, 0, world_size, comm,
|
||||
cfg, &vb, &mmap, 0, world_size, comm,
|
||||
)?)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
//! linear-attention block, lm_head, the rotary table.
|
||||
|
||||
use anyhow::{Context, Result, bail};
|
||||
use candle_core::safetensors::MmapedSafetensors;
|
||||
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
use candle_nn::{Embedding, Linear, kv_cache::ConcatKvCache};
|
||||
@@ -94,26 +95,29 @@ impl TpQwen3_5GatedDeltaNet {
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
vb: &ShardedVarBuilder,
|
||||
mmap: &MmapedSafetensors,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
) -> Result<Self> {
|
||||
Self::load_inner(cfg, vb, rank, world_size, comm)
|
||||
Self::load_inner(cfg, vb, mmap, rank, world_size, comm)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
vb: &ShardedVarBuilder,
|
||||
mmap: &MmapedSafetensors,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
) -> Result<Self> {
|
||||
Self::load_inner(cfg, vb, rank, world_size)
|
||||
Self::load_inner(cfg, vb, mmap, rank, world_size)
|
||||
}
|
||||
|
||||
fn load_inner(
|
||||
cfg: &TextConfig,
|
||||
vb: &ShardedVarBuilder,
|
||||
mmap: &MmapedSafetensors,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
#[cfg(feature = "cuda")] comm: Arc<Comm>,
|
||||
@@ -150,29 +154,43 @@ impl TpQwen3_5GatedDeltaNet {
|
||||
|
||||
let key_dim = head_k_dim * num_k_heads;
|
||||
let value_dim = head_v_dim * num_v_heads;
|
||||
let conv_dim = key_dim * 2 + value_dim;
|
||||
let _conv_dim = key_dim * 2 + value_dim;
|
||||
let hidden_size = cfg.hidden_size;
|
||||
|
||||
// ----- Fused `in_proj_qkv` and `conv1d` (per-region slicing). -----
|
||||
let in_proj_qkv_weight = load_fused_qkv_slice_2d(
|
||||
vb,
|
||||
"in_proj_qkv",
|
||||
conv_dim,
|
||||
// ----- Fused `in_proj_qkv` and `conv1d` (direct safetensors slicing).
|
||||
// Reads only this rank's per-region byte slices from the mmap
|
||||
// and uploads as one device allocation per fused tensor — no
|
||||
// full-fused-tensor device materialisation, which on the prior
|
||||
// narrow+cat approach was the main allocator-fragmentation
|
||||
// source on consumer GPUs near their VRAM ceiling.
|
||||
let dtype = vb.dtype();
|
||||
let device = vb.device().clone();
|
||||
let in_proj_qkv_name = format!("{}.in_proj_qkv.weight", vb.prefix());
|
||||
let in_proj_qkv_weight = super::fused_load::load_fused_qkv_2d(
|
||||
mmap,
|
||||
&in_proj_qkv_name,
|
||||
hidden_size,
|
||||
key_dim,
|
||||
value_dim,
|
||||
rank,
|
||||
world_size,
|
||||
dtype,
|
||||
&device,
|
||||
)?;
|
||||
let in_proj_qkv = Linear::new(in_proj_qkv_weight, None);
|
||||
|
||||
let conv1d_weight = load_fused_qkv_slice_3d(
|
||||
&vb.pp("conv1d"),
|
||||
(conv_dim, 1, conv_kernel_size),
|
||||
let conv1d_name = format!("{}.conv1d.weight", vb.prefix());
|
||||
let conv1d_weight = super::fused_load::load_fused_qkv_3d(
|
||||
mmap,
|
||||
&conv1d_name,
|
||||
1,
|
||||
conv_kernel_size,
|
||||
key_dim,
|
||||
value_dim,
|
||||
rank,
|
||||
world_size,
|
||||
dtype,
|
||||
&device,
|
||||
)?;
|
||||
|
||||
// ----- Uniformly-sharded projections (along output dim 0). -----
|
||||
@@ -621,11 +639,13 @@ pub struct TpQwen3_5DecoderLayer {
|
||||
|
||||
impl TpQwen3_5DecoderLayer {
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
layer_idx: usize,
|
||||
vb: &ShardedVarBuilder,
|
||||
mmap: &MmapedSafetensors,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
@@ -647,6 +667,7 @@ impl TpQwen3_5DecoderLayer {
|
||||
"linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load(
|
||||
cfg,
|
||||
&vb.pp("linear_attn"),
|
||||
mmap,
|
||||
rank,
|
||||
world_size,
|
||||
comm.clone(),
|
||||
@@ -675,6 +696,7 @@ impl TpQwen3_5DecoderLayer {
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
layer_idx: usize,
|
||||
vb: &ShardedVarBuilder,
|
||||
mmap: &MmapedSafetensors,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
) -> Result<Self> {
|
||||
@@ -694,6 +716,7 @@ impl TpQwen3_5DecoderLayer {
|
||||
"linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load(
|
||||
cfg,
|
||||
&vb.pp("linear_attn"),
|
||||
mmap,
|
||||
rank,
|
||||
world_size,
|
||||
)?),
|
||||
@@ -755,6 +778,7 @@ impl TpQwen3_5Model {
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
vb: &ShardedVarBuilder,
|
||||
mmap: &MmapedSafetensors,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
@@ -789,6 +813,7 @@ impl TpQwen3_5Model {
|
||||
rotary.clone(),
|
||||
i,
|
||||
&vb_l.pp(i),
|
||||
mmap,
|
||||
rank,
|
||||
world_size,
|
||||
comm.clone(),
|
||||
@@ -816,6 +841,7 @@ impl TpQwen3_5Model {
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
vb: &ShardedVarBuilder,
|
||||
mmap: &MmapedSafetensors,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
) -> Result<Self> {
|
||||
@@ -848,6 +874,7 @@ impl TpQwen3_5Model {
|
||||
rotary.clone(),
|
||||
i,
|
||||
&vb_l.pp(i),
|
||||
mmap,
|
||||
rank,
|
||||
world_size,
|
||||
)?);
|
||||
@@ -907,12 +934,13 @@ impl TpQwen3_5ForCausalLM {
|
||||
pub fn load(
|
||||
config: Config,
|
||||
vb: &ShardedVarBuilder,
|
||||
mmap: &MmapedSafetensors,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
) -> Result<Self> {
|
||||
let cfg = &config.text_config;
|
||||
let base = TpQwen3_5Model::load(cfg, vb, rank, world_size, comm)?;
|
||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm)?;
|
||||
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||
Ok(Self { base, lm_head })
|
||||
}
|
||||
@@ -921,11 +949,12 @@ impl TpQwen3_5ForCausalLM {
|
||||
pub fn load(
|
||||
config: Config,
|
||||
vb: &ShardedVarBuilder,
|
||||
mmap: &MmapedSafetensors,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
) -> Result<Self> {
|
||||
let cfg = &config.text_config;
|
||||
let base = TpQwen3_5Model::load(cfg, vb, rank, world_size)?;
|
||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size)?;
|
||||
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||
Ok(Self { base, lm_head })
|
||||
}
|
||||
@@ -978,89 +1007,6 @@ fn load_replicated<S: Into<candle_core::Shape>>(
|
||||
.with_context(|| format!("load replicated '{}/{name}'", vb.prefix()))
|
||||
}
|
||||
|
||||
/// Load a fused QKV-style 2D weight tensor that stores three regions
|
||||
/// sequentially along dim 0: `[first key_dim, second key_dim, value_dim]`.
|
||||
/// Returns the per-rank slice formed by extracting the rank's share
|
||||
/// from each region and concatenating along dim 0.
|
||||
///
|
||||
/// The full tensor materialises briefly on the device before the
|
||||
/// slices are extracted (`narrow` views + `contiguous` copy). Memory
|
||||
/// peak is one full-tensor load per layer during construction; only
|
||||
/// the per-rank concatenation stays after `full` drops.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn load_fused_qkv_slice_2d(
|
||||
vb: &ShardedVarBuilder,
|
||||
name: &str,
|
||||
conv_dim: usize,
|
||||
hidden_size: usize,
|
||||
key_dim: usize,
|
||||
value_dim: usize,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
) -> Result<Tensor> {
|
||||
let ws = world_size as usize;
|
||||
let r = rank as usize;
|
||||
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
|
||||
bail!(
|
||||
"fused qkv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
|
||||
must each be divisible by world_size ({ws})"
|
||||
);
|
||||
}
|
||||
let per_rank_key = key_dim / ws;
|
||||
let per_rank_value = value_dim / ws;
|
||||
|
||||
// Force full-tensor load via `vb.get`, which defaults to
|
||||
// `Shard { world_size: 1 }` and falls through to SimpleBackend.
|
||||
let full = vb
|
||||
.pp(name)
|
||||
.get((conv_dim, hidden_size), "weight")
|
||||
.with_context(|| format!("load fused qkv '{}/{}/weight'", vb.prefix(), name))?;
|
||||
|
||||
let q = full.narrow(0, r * per_rank_key, per_rank_key)?;
|
||||
let k = full.narrow(0, key_dim + r * per_rank_key, per_rank_key)?;
|
||||
let v = full.narrow(0, 2 * key_dim + r * per_rank_value, per_rank_value)?;
|
||||
|
||||
Tensor::cat(&[&q, &k, &v], 0)?
|
||||
.contiguous()
|
||||
.with_context(|| format!("materialise fused qkv slice for rank {r}"))
|
||||
}
|
||||
|
||||
/// Same per-region slicing pattern for a 3D fused tensor (the depthwise
|
||||
/// `conv1d.weight` of the linear-attention block: shape
|
||||
/// `(conv_dim, 1, kernel_size)`).
|
||||
fn load_fused_qkv_slice_3d(
|
||||
vb: &ShardedVarBuilder,
|
||||
shape: (usize, usize, usize),
|
||||
key_dim: usize,
|
||||
value_dim: usize,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
) -> Result<Tensor> {
|
||||
let (conv_dim, mid, kernel_size) = shape;
|
||||
let ws = world_size as usize;
|
||||
let r = rank as usize;
|
||||
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
|
||||
bail!(
|
||||
"fused conv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
|
||||
must each be divisible by world_size ({ws})"
|
||||
);
|
||||
}
|
||||
let per_rank_key = key_dim / ws;
|
||||
let per_rank_value = value_dim / ws;
|
||||
|
||||
let full = vb
|
||||
.get((conv_dim, mid, kernel_size), "weight")
|
||||
.with_context(|| format!("load fused conv '{}/weight'", vb.prefix()))?;
|
||||
|
||||
let q = full.narrow(0, r * per_rank_key, per_rank_key)?;
|
||||
let k = full.narrow(0, key_dim + r * per_rank_key, per_rank_key)?;
|
||||
let v = full.narrow(0, 2 * key_dim + r * per_rank_value, per_rank_value)?;
|
||||
|
||||
Tensor::cat(&[&q, &k, &v], 0)?
|
||||
.contiguous()
|
||||
.with_context(|| format!("materialise fused conv slice for rank {r}"))
|
||||
}
|
||||
|
||||
/// Query the cuda driver for free/total VRAM on the current device.
|
||||
/// Returns `(free_mb, total_mb)`. Returns `(0, 0)` if the query fails
|
||||
/// (so logging never crashes the load path).
|
||||
|
||||
@@ -232,6 +232,19 @@ impl WorkerState {
|
||||
};
|
||||
}
|
||||
};
|
||||
// Separate mmap of the same paths for the direct fused-region
|
||||
// loader in `fused_load`. Linux's page cache shares the
|
||||
// underlying pages between the two mmaps; the cost is one
|
||||
// extra set of safetensors-header parses.
|
||||
let mmap = match unsafe { candle_core::safetensors::MmapedSafetensors::multi(&paths) } {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "load_failed".into(),
|
||||
message: format!("MmapedSafetensors::multi: {e}"),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let loaded = match model_type.as_str() {
|
||||
"qwen3" => {
|
||||
@@ -273,6 +286,7 @@ impl WorkerState {
|
||||
match TpQwen3_5ForCausalLM::load(
|
||||
cfg,
|
||||
&vb,
|
||||
&mmap,
|
||||
self.config.rank,
|
||||
self.config.world_size,
|
||||
comm,
|
||||
|
||||
Reference in New Issue
Block a user