feat(stage-8d-7): direct safetensors fused-region loader
Some checks failed
build-prerelease / Package cortex RPM (push) Blocked by required conditions
CI / Format (push) Successful in 35s
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Clippy (push) Successful in 2m18s
CI / Test (push) Successful in 4m28s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m51s
build-prerelease / Build cortex binary (push) Successful in 4m13s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled

Replaces load_fused_qkv_slice_2d/_3d with reads from a separate
MmapedSafetensors handle. Each per-rank fused tensor is built by
reading the three region byte-slices directly from the mmap,
concatenating them host-side, and uploading as one device
allocation — no full-fused-tensor device materialisation.

The prior approach allocated a ~100 MB transient device tensor
per linear-attention layer; on Qwen3.6-27B with 48 linear-attn
layers that's ~4.8 GB of allocator churn during load — enough
to fragment the cuda caching allocator on a tight-VRAM 32 GB
consumer GPU, which is what triggered the layer-22 up_proj
OOM seen on beast.

Threading: MmapedSafetensors flows worker → ForCausalLM →
Model → DecoderLayer → GatedDeltaNet::load. Both leader (mod.rs)
and worker (worker.rs) construct their own mmap; Linux's page
cache shares the underlying pages.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 17:49:35 +03:00
parent 89d98d1fb2
commit 8d7b099b36
6 changed files with 282 additions and 97 deletions

1
Cargo.lock generated
View File

@@ -2120,6 +2120,7 @@ dependencies = [
"half", "half",
"hf-hub", "hf-hub",
"reqwest", "reqwest",
"safetensors 0.7.0",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 2.0.18", "thiserror 2.0.18",

View File

@@ -76,6 +76,11 @@ cudarc = { version = "0.19", optional = true, default-features = false, features
half = { version = "2.5", optional = true } half = { version = "2.5", optional = true }
tokenizers = { version = "0.22", default-features = false, features = ["onig"] } tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
hf-hub = { version = "0.4", features = ["tokio"] } hf-hub = { version = "0.4", features = ["tokio"] }
# Direct dep on `safetensors` (re-exported by candle but its `TensorView`
# / `slice::IndexOp` types are public-but-not-re-exported). Used by the
# tp `fused_load` module to read per-rank slices of fused QKV tensors
# without materialising the full tensor on device.
safetensors = "0.7"
[dev-dependencies] [dev-dependencies]
tokio = { workspace = true, features = ["test-util"] } tokio = { workspace = true, features = ["test-util"] }

View File

@@ -0,0 +1,213 @@
//! Direct safetensors readers for fused-region weight tensors.
//!
//! Qwen3-Next's `in_proj_qkv` and `conv1d` weights are *fused* —
//! three regions stored sequentially along dim 0 (`[key_q, key_k,
//! value]`). The per-rank shard for each region has unequal size
//! (`key_dim/ws` vs `value_dim/ws`), so candle's `ShardedSafeTensors`
//! built-in `Shard { dim, rank, world_size }` (uniform split) doesn't
//! map to the right slices.
//!
//! The previous approach loaded the full fused tensor onto the device,
//! `narrow`ed the three regions, and `Tensor::cat(...).contiguous()`'d
//! the per-rank slice. That left ~100 MB of transient device memory
//! per linear-attention layer — 48 layers × 100 MB = ~4.8 GB of
//! allocator pressure during load, enough to trigger fragmentation
//! OOM on tight-VRAM consumer GPUs.
//!
//! This module reads the three per-rank byte ranges *directly from
//! the safetensors mmap* (host-side), concatenates them into a single
//! contiguous byte buffer, and uploads as one device allocation. No
//! full-tensor device materialisation.
use anyhow::{Context, Result, bail};
use candle_core::safetensors::MmapedSafetensors;
use candle_core::{DType, Device, Tensor};
/// Read a 2D fused-QKV tensor `[conv_dim, hidden_size]` and return
/// this rank's per-region slice as a `[per_rank_conv_dim, hidden_size]`
/// device tensor.
///
/// `tensor_name` must be the fully-qualified safetensors key (e.g.
/// `"model.language_model.layers.5.linear_attn.in_proj_qkv.weight"`).
#[allow(clippy::too_many_arguments)]
pub fn load_fused_qkv_2d(
mmap: &MmapedSafetensors,
tensor_name: &str,
hidden_size: usize,
key_dim: usize,
value_dim: usize,
rank: u32,
world_size: u32,
target_dtype: DType,
device: &Device,
) -> Result<Tensor> {
let ws = world_size as usize;
let r = rank as usize;
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
bail!(
"fused qkv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
must each be divisible by world_size ({ws})"
);
}
let per_rank_key = key_dim / ws;
let per_rank_value = value_dim / ws;
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
let view = mmap
.get(tensor_name)
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 2D"))?;
let view_dtype: DType = view
.dtype()
.try_into()
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
let shape = view.shape();
if shape.len() != 2 {
bail!(
"fused qkv tensor '{tensor_name}' has shape {shape:?}, expected 2D \
[conv_dim, hidden_size]"
);
}
let conv_dim = key_dim * 2 + value_dim;
if shape[0] != conv_dim || shape[1] != hidden_size {
bail!(
"fused qkv tensor '{tensor_name}' shape {shape:?} \
doesn't match expected [{conv_dim}, {hidden_size}]"
);
}
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
let k_bytes = slice_dim0_bytes(
&view,
key_dim + r * per_rank_key,
per_rank_key,
tensor_name,
"k",
)?;
let v_bytes = slice_dim0_bytes(
&view,
2 * key_dim + r * per_rank_value,
per_rank_value,
tensor_name,
"v",
)?;
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
bytes.extend_from_slice(&q_bytes);
bytes.extend_from_slice(&k_bytes);
bytes.extend_from_slice(&v_bytes);
let tensor = Tensor::from_raw_buffer(
&bytes,
view_dtype,
&[per_rank_conv_dim, hidden_size],
device,
)
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused qkv '{tensor_name}'"))?;
tensor
.to_dtype(target_dtype)
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
}
/// Read a 3D fused-QKV tensor `[conv_dim, 1, kernel_size]` (the
/// depthwise conv1d weight) and return this rank's per-region slice
/// as a `[per_rank_conv_dim, 1, kernel_size]` device tensor.
#[allow(clippy::too_many_arguments)]
pub fn load_fused_qkv_3d(
mmap: &MmapedSafetensors,
tensor_name: &str,
mid: usize,
kernel_size: usize,
key_dim: usize,
value_dim: usize,
rank: u32,
world_size: u32,
target_dtype: DType,
device: &Device,
) -> Result<Tensor> {
let ws = world_size as usize;
let r = rank as usize;
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
bail!(
"fused conv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
must each be divisible by world_size ({ws})"
);
}
let per_rank_key = key_dim / ws;
let per_rank_value = value_dim / ws;
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
let view = mmap
.get(tensor_name)
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 3D"))?;
let view_dtype: DType = view
.dtype()
.try_into()
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
let shape = view.shape();
if shape.len() != 3 {
bail!(
"fused conv tensor '{tensor_name}' has shape {shape:?}, expected 3D \
[conv_dim, mid, kernel_size]"
);
}
let conv_dim = key_dim * 2 + value_dim;
if shape[0] != conv_dim || shape[1] != mid || shape[2] != kernel_size {
bail!(
"fused conv tensor '{tensor_name}' shape {shape:?} \
doesn't match expected [{conv_dim}, {mid}, {kernel_size}]"
);
}
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
let k_bytes = slice_dim0_bytes(
&view,
key_dim + r * per_rank_key,
per_rank_key,
tensor_name,
"k",
)?;
let v_bytes = slice_dim0_bytes(
&view,
2 * key_dim + r * per_rank_value,
per_rank_value,
tensor_name,
"v",
)?;
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
bytes.extend_from_slice(&q_bytes);
bytes.extend_from_slice(&k_bytes);
bytes.extend_from_slice(&v_bytes);
let tensor = Tensor::from_raw_buffer(
&bytes,
view_dtype,
&[per_rank_conv_dim, mid, kernel_size],
device,
)
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused conv '{tensor_name}'"))?;
tensor
.to_dtype(target_dtype)
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
}
/// Read `len` consecutive rows along dim 0 starting at `start` from
/// the safetensors view, returning the raw bytes. Wraps the same
/// `view.slice(start..stop)` machinery that candle's
/// `ShardedSafeTensors::get` uses internally.
fn slice_dim0_bytes(
view: &safetensors::tensor::TensorView<'_>,
start: usize,
len: usize,
tensor_name: &str,
region: &str,
) -> Result<Vec<u8>> {
use safetensors::slice::IndexOp;
let stop = start + len;
let iter = view.slice(start..stop).map_err(|e| {
anyhow::anyhow!("slice '{tensor_name}' region {region} ({start}..{stop}): {e:?}")
})?;
Ok(iter.into_iter().flatten().copied().collect())
}

View File

@@ -18,6 +18,7 @@
//! - **7c:** crash detection, streaming SSE, graceful unload. //! - **7c:** crash detection, streaming SSE, graceful unload.
pub mod all_reduce; pub mod all_reduce;
pub mod fused_load;
pub mod nccl_state; pub mod nccl_state;
pub mod rpc; pub mod rpc;
pub mod tp_linear; pub mod tp_linear;
@@ -539,6 +540,11 @@ impl WorkerPool {
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader) ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
.context("build ShardedVarBuilder over safetensors")? .context("build ShardedVarBuilder over safetensors")?
}; };
// SAFETY: as above — the HF cache files are immutable.
let mmap = unsafe {
candle_core::safetensors::MmapedSafetensors::multi(&paths_for_leader)
.context("build MmapedSafetensors for leader load")?
};
let comm = comm_for_leader.into_inner(); let comm = comm_for_leader.into_inner();
let loaded = match model_type.as_str() { let loaded = match model_type.as_str() {
"qwen3" => { "qwen3" => {
@@ -553,7 +559,7 @@ impl WorkerPool {
serde_json::from_str(&config_json_for_leader) serde_json::from_str(&config_json_for_leader)
.context("parse Qwen3-Next Config JSON for leader load")?; .context("parse Qwen3-Next Config JSON for leader load")?;
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load( TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
cfg, &vb, 0, world_size, comm, cfg, &vb, &mmap, 0, world_size, comm,
)?) )?)
} }
other => anyhow::bail!( other => anyhow::bail!(

View File

@@ -31,6 +31,7 @@
//! linear-attention block, lm_head, the rotary table. //! linear-attention block, lm_head, the rotary table.
use anyhow::{Context, Result, bail}; use anyhow::{Context, Result, bail};
use candle_core::safetensors::MmapedSafetensors;
use candle_core::{DType, Device, IndexOp, Module, Tensor}; use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder; use candle_nn::var_builder::ShardedVarBuilder;
use candle_nn::{Embedding, Linear, kv_cache::ConcatKvCache}; use candle_nn::{Embedding, Linear, kv_cache::ConcatKvCache};
@@ -94,26 +95,29 @@ impl TpQwen3_5GatedDeltaNet {
pub fn load( pub fn load(
cfg: &TextConfig, cfg: &TextConfig,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
mmap: &MmapedSafetensors,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
comm: Arc<Comm>, comm: Arc<Comm>,
) -> Result<Self> { ) -> Result<Self> {
Self::load_inner(cfg, vb, rank, world_size, comm) Self::load_inner(cfg, vb, mmap, rank, world_size, comm)
} }
#[cfg(not(feature = "cuda"))] #[cfg(not(feature = "cuda"))]
pub fn load( pub fn load(
cfg: &TextConfig, cfg: &TextConfig,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
mmap: &MmapedSafetensors,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
) -> Result<Self> { ) -> Result<Self> {
Self::load_inner(cfg, vb, rank, world_size) Self::load_inner(cfg, vb, mmap, rank, world_size)
} }
fn load_inner( fn load_inner(
cfg: &TextConfig, cfg: &TextConfig,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
mmap: &MmapedSafetensors,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
#[cfg(feature = "cuda")] comm: Arc<Comm>, #[cfg(feature = "cuda")] comm: Arc<Comm>,
@@ -150,29 +154,43 @@ impl TpQwen3_5GatedDeltaNet {
let key_dim = head_k_dim * num_k_heads; let key_dim = head_k_dim * num_k_heads;
let value_dim = head_v_dim * num_v_heads; let value_dim = head_v_dim * num_v_heads;
let conv_dim = key_dim * 2 + value_dim; let _conv_dim = key_dim * 2 + value_dim;
let hidden_size = cfg.hidden_size; let hidden_size = cfg.hidden_size;
// ----- Fused `in_proj_qkv` and `conv1d` (per-region slicing). ----- // ----- Fused `in_proj_qkv` and `conv1d` (direct safetensors slicing).
let in_proj_qkv_weight = load_fused_qkv_slice_2d( // Reads only this rank's per-region byte slices from the mmap
vb, // and uploads as one device allocation per fused tensor — no
"in_proj_qkv", // full-fused-tensor device materialisation, which on the prior
conv_dim, // narrow+cat approach was the main allocator-fragmentation
// source on consumer GPUs near their VRAM ceiling.
let dtype = vb.dtype();
let device = vb.device().clone();
let in_proj_qkv_name = format!("{}.in_proj_qkv.weight", vb.prefix());
let in_proj_qkv_weight = super::fused_load::load_fused_qkv_2d(
mmap,
&in_proj_qkv_name,
hidden_size, hidden_size,
key_dim, key_dim,
value_dim, value_dim,
rank, rank,
world_size, world_size,
dtype,
&device,
)?; )?;
let in_proj_qkv = Linear::new(in_proj_qkv_weight, None); let in_proj_qkv = Linear::new(in_proj_qkv_weight, None);
let conv1d_weight = load_fused_qkv_slice_3d( let conv1d_name = format!("{}.conv1d.weight", vb.prefix());
&vb.pp("conv1d"), let conv1d_weight = super::fused_load::load_fused_qkv_3d(
(conv_dim, 1, conv_kernel_size), mmap,
&conv1d_name,
1,
conv_kernel_size,
key_dim, key_dim,
value_dim, value_dim,
rank, rank,
world_size, world_size,
dtype,
&device,
)?; )?;
// ----- Uniformly-sharded projections (along output dim 0). ----- // ----- Uniformly-sharded projections (along output dim 0). -----
@@ -621,11 +639,13 @@ pub struct TpQwen3_5DecoderLayer {
impl TpQwen3_5DecoderLayer { impl TpQwen3_5DecoderLayer {
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn load( pub fn load(
cfg: &TextConfig, cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>, rotary: Arc<RotaryEmbedding>,
layer_idx: usize, layer_idx: usize,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
mmap: &MmapedSafetensors,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
comm: Arc<Comm>, comm: Arc<Comm>,
@@ -647,6 +667,7 @@ impl TpQwen3_5DecoderLayer {
"linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load( "linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load(
cfg, cfg,
&vb.pp("linear_attn"), &vb.pp("linear_attn"),
mmap,
rank, rank,
world_size, world_size,
comm.clone(), comm.clone(),
@@ -675,6 +696,7 @@ impl TpQwen3_5DecoderLayer {
rotary: Arc<RotaryEmbedding>, rotary: Arc<RotaryEmbedding>,
layer_idx: usize, layer_idx: usize,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
mmap: &MmapedSafetensors,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
) -> Result<Self> { ) -> Result<Self> {
@@ -694,6 +716,7 @@ impl TpQwen3_5DecoderLayer {
"linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load( "linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load(
cfg, cfg,
&vb.pp("linear_attn"), &vb.pp("linear_attn"),
mmap,
rank, rank,
world_size, world_size,
)?), )?),
@@ -755,6 +778,7 @@ impl TpQwen3_5Model {
pub fn load( pub fn load(
cfg: &TextConfig, cfg: &TextConfig,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
mmap: &MmapedSafetensors,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
comm: Arc<Comm>, comm: Arc<Comm>,
@@ -789,6 +813,7 @@ impl TpQwen3_5Model {
rotary.clone(), rotary.clone(),
i, i,
&vb_l.pp(i), &vb_l.pp(i),
mmap,
rank, rank,
world_size, world_size,
comm.clone(), comm.clone(),
@@ -816,6 +841,7 @@ impl TpQwen3_5Model {
pub fn load( pub fn load(
cfg: &TextConfig, cfg: &TextConfig,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
mmap: &MmapedSafetensors,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
) -> Result<Self> { ) -> Result<Self> {
@@ -848,6 +874,7 @@ impl TpQwen3_5Model {
rotary.clone(), rotary.clone(),
i, i,
&vb_l.pp(i), &vb_l.pp(i),
mmap,
rank, rank,
world_size, world_size,
)?); )?);
@@ -907,12 +934,13 @@ impl TpQwen3_5ForCausalLM {
pub fn load( pub fn load(
config: Config, config: Config,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
mmap: &MmapedSafetensors,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
comm: Arc<Comm>, comm: Arc<Comm>,
) -> Result<Self> { ) -> Result<Self> {
let cfg = &config.text_config; let cfg = &config.text_config;
let base = TpQwen3_5Model::load(cfg, vb, rank, world_size, comm)?; let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm)?;
let lm_head = build_lm_head(cfg, vb, &base)?; let lm_head = build_lm_head(cfg, vb, &base)?;
Ok(Self { base, lm_head }) Ok(Self { base, lm_head })
} }
@@ -921,11 +949,12 @@ impl TpQwen3_5ForCausalLM {
pub fn load( pub fn load(
config: Config, config: Config,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
mmap: &MmapedSafetensors,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
) -> Result<Self> { ) -> Result<Self> {
let cfg = &config.text_config; let cfg = &config.text_config;
let base = TpQwen3_5Model::load(cfg, vb, rank, world_size)?; let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size)?;
let lm_head = build_lm_head(cfg, vb, &base)?; let lm_head = build_lm_head(cfg, vb, &base)?;
Ok(Self { base, lm_head }) Ok(Self { base, lm_head })
} }
@@ -978,89 +1007,6 @@ fn load_replicated<S: Into<candle_core::Shape>>(
.with_context(|| format!("load replicated '{}/{name}'", vb.prefix())) .with_context(|| format!("load replicated '{}/{name}'", vb.prefix()))
} }
/// Load a fused QKV-style 2D weight tensor that stores three regions
/// sequentially along dim 0: `[first key_dim, second key_dim, value_dim]`.
/// Returns the per-rank slice formed by extracting the rank's share
/// from each region and concatenating along dim 0.
///
/// The full tensor materialises briefly on the device before the
/// slices are extracted (`narrow` views + `contiguous` copy). Memory
/// peak is one full-tensor load per layer during construction; only
/// the per-rank concatenation stays after `full` drops.
#[allow(clippy::too_many_arguments)]
fn load_fused_qkv_slice_2d(
vb: &ShardedVarBuilder,
name: &str,
conv_dim: usize,
hidden_size: usize,
key_dim: usize,
value_dim: usize,
rank: u32,
world_size: u32,
) -> Result<Tensor> {
let ws = world_size as usize;
let r = rank as usize;
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
bail!(
"fused qkv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
must each be divisible by world_size ({ws})"
);
}
let per_rank_key = key_dim / ws;
let per_rank_value = value_dim / ws;
// Force full-tensor load via `vb.get`, which defaults to
// `Shard { world_size: 1 }` and falls through to SimpleBackend.
let full = vb
.pp(name)
.get((conv_dim, hidden_size), "weight")
.with_context(|| format!("load fused qkv '{}/{}/weight'", vb.prefix(), name))?;
let q = full.narrow(0, r * per_rank_key, per_rank_key)?;
let k = full.narrow(0, key_dim + r * per_rank_key, per_rank_key)?;
let v = full.narrow(0, 2 * key_dim + r * per_rank_value, per_rank_value)?;
Tensor::cat(&[&q, &k, &v], 0)?
.contiguous()
.with_context(|| format!("materialise fused qkv slice for rank {r}"))
}
/// Same per-region slicing pattern for a 3D fused tensor (the depthwise
/// `conv1d.weight` of the linear-attention block: shape
/// `(conv_dim, 1, kernel_size)`).
fn load_fused_qkv_slice_3d(
vb: &ShardedVarBuilder,
shape: (usize, usize, usize),
key_dim: usize,
value_dim: usize,
rank: u32,
world_size: u32,
) -> Result<Tensor> {
let (conv_dim, mid, kernel_size) = shape;
let ws = world_size as usize;
let r = rank as usize;
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
bail!(
"fused conv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
must each be divisible by world_size ({ws})"
);
}
let per_rank_key = key_dim / ws;
let per_rank_value = value_dim / ws;
let full = vb
.get((conv_dim, mid, kernel_size), "weight")
.with_context(|| format!("load fused conv '{}/weight'", vb.prefix()))?;
let q = full.narrow(0, r * per_rank_key, per_rank_key)?;
let k = full.narrow(0, key_dim + r * per_rank_key, per_rank_key)?;
let v = full.narrow(0, 2 * key_dim + r * per_rank_value, per_rank_value)?;
Tensor::cat(&[&q, &k, &v], 0)?
.contiguous()
.with_context(|| format!("materialise fused conv slice for rank {r}"))
}
/// Query the cuda driver for free/total VRAM on the current device. /// Query the cuda driver for free/total VRAM on the current device.
/// Returns `(free_mb, total_mb)`. Returns `(0, 0)` if the query fails /// Returns `(free_mb, total_mb)`. Returns `(0, 0)` if the query fails
/// (so logging never crashes the load path). /// (so logging never crashes the load path).

View File

@@ -232,6 +232,19 @@ impl WorkerState {
}; };
} }
}; };
// Separate mmap of the same paths for the direct fused-region
// loader in `fused_load`. Linux's page cache shares the
// underlying pages between the two mmaps; the cost is one
// extra set of safetensors-header parses.
let mmap = match unsafe { candle_core::safetensors::MmapedSafetensors::multi(&paths) } {
Ok(m) => m,
Err(e) => {
return WorkerResponse::Error {
kind: "load_failed".into(),
message: format!("MmapedSafetensors::multi: {e}"),
};
}
};
let loaded = match model_type.as_str() { let loaded = match model_type.as_str() {
"qwen3" => { "qwen3" => {
@@ -273,6 +286,7 @@ impl WorkerState {
match TpQwen3_5ForCausalLM::load( match TpQwen3_5ForCausalLM::load(
cfg, cfg,
&vb, &vb,
&mmap,
self.config.rank, self.config.rank,
self.config.world_size, self.config.world_size,
comm, comm,