diff --git a/Cargo.lock b/Cargo.lock index f8965d7..edd0d76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2120,6 +2120,7 @@ dependencies = [ "half", "hf-hub", "reqwest", + "safetensors 0.7.0", "serde", "serde_json", "thiserror 2.0.18", diff --git a/crates/neuron/Cargo.toml b/crates/neuron/Cargo.toml index b653623..7cf7cf4 100644 --- a/crates/neuron/Cargo.toml +++ b/crates/neuron/Cargo.toml @@ -76,6 +76,11 @@ cudarc = { version = "0.19", optional = true, default-features = false, features half = { version = "2.5", optional = true } tokenizers = { version = "0.22", default-features = false, features = ["onig"] } hf-hub = { version = "0.4", features = ["tokio"] } +# Direct dep on `safetensors` (re-exported by candle but its `TensorView` +# / `slice::IndexOp` types are public-but-not-re-exported). Used by the +# tp `fused_load` module to read per-rank slices of fused QKV tensors +# without materialising the full tensor on device. +safetensors = "0.7" [dev-dependencies] tokio = { workspace = true, features = ["test-util"] } diff --git a/crates/neuron/src/harness/tp/fused_load.rs b/crates/neuron/src/harness/tp/fused_load.rs new file mode 100644 index 0000000..dd155e5 --- /dev/null +++ b/crates/neuron/src/harness/tp/fused_load.rs @@ -0,0 +1,213 @@ +//! Direct safetensors readers for fused-region weight tensors. +//! +//! Qwen3-Next's `in_proj_qkv` and `conv1d` weights are *fused* — +//! three regions stored sequentially along dim 0 (`[key_q, key_k, +//! value]`). The per-rank shard for each region has unequal size +//! (`key_dim/ws` vs `value_dim/ws`), so candle's `ShardedSafeTensors` +//! built-in `Shard { dim, rank, world_size }` (uniform split) doesn't +//! map to the right slices. +//! +//! The previous approach loaded the full fused tensor onto the device, +//! `narrow`ed the three regions, and `Tensor::cat(...).contiguous()`'d +//! the per-rank slice. That left ~100 MB of transient device memory +//! per linear-attention layer — 48 layers × 100 MB = ~4.8 GB of +//! allocator pressure during load, enough to trigger fragmentation +//! OOM on tight-VRAM consumer GPUs. +//! +//! This module reads the three per-rank byte ranges *directly from +//! the safetensors mmap* (host-side), concatenates them into a single +//! contiguous byte buffer, and uploads as one device allocation. No +//! full-tensor device materialisation. + +use anyhow::{Context, Result, bail}; +use candle_core::safetensors::MmapedSafetensors; +use candle_core::{DType, Device, Tensor}; + +/// Read a 2D fused-QKV tensor `[conv_dim, hidden_size]` and return +/// this rank's per-region slice as a `[per_rank_conv_dim, hidden_size]` +/// device tensor. +/// +/// `tensor_name` must be the fully-qualified safetensors key (e.g. +/// `"model.language_model.layers.5.linear_attn.in_proj_qkv.weight"`). +#[allow(clippy::too_many_arguments)] +pub fn load_fused_qkv_2d( + mmap: &MmapedSafetensors, + tensor_name: &str, + hidden_size: usize, + key_dim: usize, + value_dim: usize, + rank: u32, + world_size: u32, + target_dtype: DType, + device: &Device, +) -> Result { + let ws = world_size as usize; + let r = rank as usize; + if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) { + bail!( + "fused qkv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \ + must each be divisible by world_size ({ws})" + ); + } + let per_rank_key = key_dim / ws; + let per_rank_value = value_dim / ws; + let per_rank_conv_dim = per_rank_key * 2 + per_rank_value; + + let view = mmap + .get(tensor_name) + .with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 2D"))?; + let view_dtype: DType = view + .dtype() + .try_into() + .with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?; + + let shape = view.shape(); + if shape.len() != 2 { + bail!( + "fused qkv tensor '{tensor_name}' has shape {shape:?}, expected 2D \ + [conv_dim, hidden_size]" + ); + } + let conv_dim = key_dim * 2 + value_dim; + if shape[0] != conv_dim || shape[1] != hidden_size { + bail!( + "fused qkv tensor '{tensor_name}' shape {shape:?} \ + doesn't match expected [{conv_dim}, {hidden_size}]" + ); + } + + let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?; + let k_bytes = slice_dim0_bytes( + &view, + key_dim + r * per_rank_key, + per_rank_key, + tensor_name, + "k", + )?; + let v_bytes = slice_dim0_bytes( + &view, + 2 * key_dim + r * per_rank_value, + per_rank_value, + tensor_name, + "v", + )?; + + let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len()); + bytes.extend_from_slice(&q_bytes); + bytes.extend_from_slice(&k_bytes); + bytes.extend_from_slice(&v_bytes); + + let tensor = Tensor::from_raw_buffer( + &bytes, + view_dtype, + &[per_rank_conv_dim, hidden_size], + device, + ) + .with_context(|| format!("Tensor::from_raw_buffer for per-rank fused qkv '{tensor_name}'"))?; + tensor + .to_dtype(target_dtype) + .with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}")) +} + +/// Read a 3D fused-QKV tensor `[conv_dim, 1, kernel_size]` (the +/// depthwise conv1d weight) and return this rank's per-region slice +/// as a `[per_rank_conv_dim, 1, kernel_size]` device tensor. +#[allow(clippy::too_many_arguments)] +pub fn load_fused_qkv_3d( + mmap: &MmapedSafetensors, + tensor_name: &str, + mid: usize, + kernel_size: usize, + key_dim: usize, + value_dim: usize, + rank: u32, + world_size: u32, + target_dtype: DType, + device: &Device, +) -> Result { + let ws = world_size as usize; + let r = rank as usize; + if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) { + bail!( + "fused conv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \ + must each be divisible by world_size ({ws})" + ); + } + let per_rank_key = key_dim / ws; + let per_rank_value = value_dim / ws; + let per_rank_conv_dim = per_rank_key * 2 + per_rank_value; + + let view = mmap + .get(tensor_name) + .with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 3D"))?; + let view_dtype: DType = view + .dtype() + .try_into() + .with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?; + + let shape = view.shape(); + if shape.len() != 3 { + bail!( + "fused conv tensor '{tensor_name}' has shape {shape:?}, expected 3D \ + [conv_dim, mid, kernel_size]" + ); + } + let conv_dim = key_dim * 2 + value_dim; + if shape[0] != conv_dim || shape[1] != mid || shape[2] != kernel_size { + bail!( + "fused conv tensor '{tensor_name}' shape {shape:?} \ + doesn't match expected [{conv_dim}, {mid}, {kernel_size}]" + ); + } + + let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?; + let k_bytes = slice_dim0_bytes( + &view, + key_dim + r * per_rank_key, + per_rank_key, + tensor_name, + "k", + )?; + let v_bytes = slice_dim0_bytes( + &view, + 2 * key_dim + r * per_rank_value, + per_rank_value, + tensor_name, + "v", + )?; + + let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len()); + bytes.extend_from_slice(&q_bytes); + bytes.extend_from_slice(&k_bytes); + bytes.extend_from_slice(&v_bytes); + + let tensor = Tensor::from_raw_buffer( + &bytes, + view_dtype, + &[per_rank_conv_dim, mid, kernel_size], + device, + ) + .with_context(|| format!("Tensor::from_raw_buffer for per-rank fused conv '{tensor_name}'"))?; + tensor + .to_dtype(target_dtype) + .with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}")) +} + +/// Read `len` consecutive rows along dim 0 starting at `start` from +/// the safetensors view, returning the raw bytes. Wraps the same +/// `view.slice(start..stop)` machinery that candle's +/// `ShardedSafeTensors::get` uses internally. +fn slice_dim0_bytes( + view: &safetensors::tensor::TensorView<'_>, + start: usize, + len: usize, + tensor_name: &str, + region: &str, +) -> Result> { + use safetensors::slice::IndexOp; + let stop = start + len; + let iter = view.slice(start..stop).map_err(|e| { + anyhow::anyhow!("slice '{tensor_name}' region {region} ({start}..{stop}): {e:?}") + })?; + Ok(iter.into_iter().flatten().copied().collect()) +} diff --git a/crates/neuron/src/harness/tp/mod.rs b/crates/neuron/src/harness/tp/mod.rs index a53212f..bc0e22d 100644 --- a/crates/neuron/src/harness/tp/mod.rs +++ b/crates/neuron/src/harness/tp/mod.rs @@ -18,6 +18,7 @@ //! - **7c:** crash detection, streaming SSE, graceful unload. pub mod all_reduce; +pub mod fused_load; pub mod nccl_state; pub mod rpc; pub mod tp_linear; @@ -539,6 +540,11 @@ impl WorkerPool { ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader) .context("build ShardedVarBuilder over safetensors")? }; + // SAFETY: as above — the HF cache files are immutable. + let mmap = unsafe { + candle_core::safetensors::MmapedSafetensors::multi(&paths_for_leader) + .context("build MmapedSafetensors for leader load")? + }; let comm = comm_for_leader.into_inner(); let loaded = match model_type.as_str() { "qwen3" => { @@ -553,7 +559,7 @@ impl WorkerPool { serde_json::from_str(&config_json_for_leader) .context("parse Qwen3-Next Config JSON for leader load")?; TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load( - cfg, &vb, 0, world_size, comm, + cfg, &vb, &mmap, 0, world_size, comm, )?) } other => anyhow::bail!( diff --git a/crates/neuron/src/harness/tp/tp_qwen3_5.rs b/crates/neuron/src/harness/tp/tp_qwen3_5.rs index 0ec4942..a023f70 100644 --- a/crates/neuron/src/harness/tp/tp_qwen3_5.rs +++ b/crates/neuron/src/harness/tp/tp_qwen3_5.rs @@ -31,6 +31,7 @@ //! linear-attention block, lm_head, the rotary table. use anyhow::{Context, Result, bail}; +use candle_core::safetensors::MmapedSafetensors; use candle_core::{DType, Device, IndexOp, Module, Tensor}; use candle_nn::var_builder::ShardedVarBuilder; use candle_nn::{Embedding, Linear, kv_cache::ConcatKvCache}; @@ -94,26 +95,29 @@ impl TpQwen3_5GatedDeltaNet { pub fn load( cfg: &TextConfig, vb: &ShardedVarBuilder, + mmap: &MmapedSafetensors, rank: u32, world_size: u32, comm: Arc, ) -> Result { - Self::load_inner(cfg, vb, rank, world_size, comm) + Self::load_inner(cfg, vb, mmap, rank, world_size, comm) } #[cfg(not(feature = "cuda"))] pub fn load( cfg: &TextConfig, vb: &ShardedVarBuilder, + mmap: &MmapedSafetensors, rank: u32, world_size: u32, ) -> Result { - Self::load_inner(cfg, vb, rank, world_size) + Self::load_inner(cfg, vb, mmap, rank, world_size) } fn load_inner( cfg: &TextConfig, vb: &ShardedVarBuilder, + mmap: &MmapedSafetensors, rank: u32, world_size: u32, #[cfg(feature = "cuda")] comm: Arc, @@ -150,29 +154,43 @@ impl TpQwen3_5GatedDeltaNet { let key_dim = head_k_dim * num_k_heads; let value_dim = head_v_dim * num_v_heads; - let conv_dim = key_dim * 2 + value_dim; + let _conv_dim = key_dim * 2 + value_dim; let hidden_size = cfg.hidden_size; - // ----- Fused `in_proj_qkv` and `conv1d` (per-region slicing). ----- - let in_proj_qkv_weight = load_fused_qkv_slice_2d( - vb, - "in_proj_qkv", - conv_dim, + // ----- Fused `in_proj_qkv` and `conv1d` (direct safetensors slicing). + // Reads only this rank's per-region byte slices from the mmap + // and uploads as one device allocation per fused tensor — no + // full-fused-tensor device materialisation, which on the prior + // narrow+cat approach was the main allocator-fragmentation + // source on consumer GPUs near their VRAM ceiling. + let dtype = vb.dtype(); + let device = vb.device().clone(); + let in_proj_qkv_name = format!("{}.in_proj_qkv.weight", vb.prefix()); + let in_proj_qkv_weight = super::fused_load::load_fused_qkv_2d( + mmap, + &in_proj_qkv_name, hidden_size, key_dim, value_dim, rank, world_size, + dtype, + &device, )?; let in_proj_qkv = Linear::new(in_proj_qkv_weight, None); - let conv1d_weight = load_fused_qkv_slice_3d( - &vb.pp("conv1d"), - (conv_dim, 1, conv_kernel_size), + let conv1d_name = format!("{}.conv1d.weight", vb.prefix()); + let conv1d_weight = super::fused_load::load_fused_qkv_3d( + mmap, + &conv1d_name, + 1, + conv_kernel_size, key_dim, value_dim, rank, world_size, + dtype, + &device, )?; // ----- Uniformly-sharded projections (along output dim 0). ----- @@ -621,11 +639,13 @@ pub struct TpQwen3_5DecoderLayer { impl TpQwen3_5DecoderLayer { #[cfg(feature = "cuda")] + #[allow(clippy::too_many_arguments)] pub fn load( cfg: &TextConfig, rotary: Arc, layer_idx: usize, vb: &ShardedVarBuilder, + mmap: &MmapedSafetensors, rank: u32, world_size: u32, comm: Arc, @@ -647,6 +667,7 @@ impl TpQwen3_5DecoderLayer { "linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load( cfg, &vb.pp("linear_attn"), + mmap, rank, world_size, comm.clone(), @@ -675,6 +696,7 @@ impl TpQwen3_5DecoderLayer { rotary: Arc, layer_idx: usize, vb: &ShardedVarBuilder, + mmap: &MmapedSafetensors, rank: u32, world_size: u32, ) -> Result { @@ -694,6 +716,7 @@ impl TpQwen3_5DecoderLayer { "linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load( cfg, &vb.pp("linear_attn"), + mmap, rank, world_size, )?), @@ -755,6 +778,7 @@ impl TpQwen3_5Model { pub fn load( cfg: &TextConfig, vb: &ShardedVarBuilder, + mmap: &MmapedSafetensors, rank: u32, world_size: u32, comm: Arc, @@ -789,6 +813,7 @@ impl TpQwen3_5Model { rotary.clone(), i, &vb_l.pp(i), + mmap, rank, world_size, comm.clone(), @@ -816,6 +841,7 @@ impl TpQwen3_5Model { pub fn load( cfg: &TextConfig, vb: &ShardedVarBuilder, + mmap: &MmapedSafetensors, rank: u32, world_size: u32, ) -> Result { @@ -848,6 +874,7 @@ impl TpQwen3_5Model { rotary.clone(), i, &vb_l.pp(i), + mmap, rank, world_size, )?); @@ -907,12 +934,13 @@ impl TpQwen3_5ForCausalLM { pub fn load( config: Config, vb: &ShardedVarBuilder, + mmap: &MmapedSafetensors, rank: u32, world_size: u32, comm: Arc, ) -> Result { let cfg = &config.text_config; - let base = TpQwen3_5Model::load(cfg, vb, rank, world_size, comm)?; + let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm)?; let lm_head = build_lm_head(cfg, vb, &base)?; Ok(Self { base, lm_head }) } @@ -921,11 +949,12 @@ impl TpQwen3_5ForCausalLM { pub fn load( config: Config, vb: &ShardedVarBuilder, + mmap: &MmapedSafetensors, rank: u32, world_size: u32, ) -> Result { let cfg = &config.text_config; - let base = TpQwen3_5Model::load(cfg, vb, rank, world_size)?; + let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size)?; let lm_head = build_lm_head(cfg, vb, &base)?; Ok(Self { base, lm_head }) } @@ -978,89 +1007,6 @@ fn load_replicated>( .with_context(|| format!("load replicated '{}/{name}'", vb.prefix())) } -/// Load a fused QKV-style 2D weight tensor that stores three regions -/// sequentially along dim 0: `[first key_dim, second key_dim, value_dim]`. -/// Returns the per-rank slice formed by extracting the rank's share -/// from each region and concatenating along dim 0. -/// -/// The full tensor materialises briefly on the device before the -/// slices are extracted (`narrow` views + `contiguous` copy). Memory -/// peak is one full-tensor load per layer during construction; only -/// the per-rank concatenation stays after `full` drops. -#[allow(clippy::too_many_arguments)] -fn load_fused_qkv_slice_2d( - vb: &ShardedVarBuilder, - name: &str, - conv_dim: usize, - hidden_size: usize, - key_dim: usize, - value_dim: usize, - rank: u32, - world_size: u32, -) -> Result { - let ws = world_size as usize; - let r = rank as usize; - if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) { - bail!( - "fused qkv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \ - must each be divisible by world_size ({ws})" - ); - } - let per_rank_key = key_dim / ws; - let per_rank_value = value_dim / ws; - - // Force full-tensor load via `vb.get`, which defaults to - // `Shard { world_size: 1 }` and falls through to SimpleBackend. - let full = vb - .pp(name) - .get((conv_dim, hidden_size), "weight") - .with_context(|| format!("load fused qkv '{}/{}/weight'", vb.prefix(), name))?; - - let q = full.narrow(0, r * per_rank_key, per_rank_key)?; - let k = full.narrow(0, key_dim + r * per_rank_key, per_rank_key)?; - let v = full.narrow(0, 2 * key_dim + r * per_rank_value, per_rank_value)?; - - Tensor::cat(&[&q, &k, &v], 0)? - .contiguous() - .with_context(|| format!("materialise fused qkv slice for rank {r}")) -} - -/// Same per-region slicing pattern for a 3D fused tensor (the depthwise -/// `conv1d.weight` of the linear-attention block: shape -/// `(conv_dim, 1, kernel_size)`). -fn load_fused_qkv_slice_3d( - vb: &ShardedVarBuilder, - shape: (usize, usize, usize), - key_dim: usize, - value_dim: usize, - rank: u32, - world_size: u32, -) -> Result { - let (conv_dim, mid, kernel_size) = shape; - let ws = world_size as usize; - let r = rank as usize; - if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) { - bail!( - "fused conv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \ - must each be divisible by world_size ({ws})" - ); - } - let per_rank_key = key_dim / ws; - let per_rank_value = value_dim / ws; - - let full = vb - .get((conv_dim, mid, kernel_size), "weight") - .with_context(|| format!("load fused conv '{}/weight'", vb.prefix()))?; - - let q = full.narrow(0, r * per_rank_key, per_rank_key)?; - let k = full.narrow(0, key_dim + r * per_rank_key, per_rank_key)?; - let v = full.narrow(0, 2 * key_dim + r * per_rank_value, per_rank_value)?; - - Tensor::cat(&[&q, &k, &v], 0)? - .contiguous() - .with_context(|| format!("materialise fused conv slice for rank {r}")) -} - /// Query the cuda driver for free/total VRAM on the current device. /// Returns `(free_mb, total_mb)`. Returns `(0, 0)` if the query fails /// (so logging never crashes the load path). diff --git a/crates/neuron/src/harness/tp/worker.rs b/crates/neuron/src/harness/tp/worker.rs index 2d3e444..98b70a5 100644 --- a/crates/neuron/src/harness/tp/worker.rs +++ b/crates/neuron/src/harness/tp/worker.rs @@ -232,6 +232,19 @@ impl WorkerState { }; } }; + // Separate mmap of the same paths for the direct fused-region + // loader in `fused_load`. Linux's page cache shares the + // underlying pages between the two mmaps; the cost is one + // extra set of safetensors-header parses. + let mmap = match unsafe { candle_core::safetensors::MmapedSafetensors::multi(&paths) } { + Ok(m) => m, + Err(e) => { + return WorkerResponse::Error { + kind: "load_failed".into(), + message: format!("MmapedSafetensors::multi: {e}"), + }; + } + }; let loaded = match model_type.as_str() { "qwen3" => { @@ -273,6 +286,7 @@ impl WorkerState { match TpQwen3_5ForCausalLM::load( cfg, &vb, + &mmap, self.config.rank, self.config.world_size, comm,