From 95dc8745eb4c028294115ce886275b51bb2a3a6c Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Wed, 20 May 2026 22:02:42 +0300 Subject: [PATCH] feat(stage-8c): TP-aware Qwen3-Next (tp_qwen3_5) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `harness/tp/tp_qwen3_5.rs` — the tensor-parallel variant of the Qwen3-Next architecture — plus the dispatch wiring needed to route a load through it on both the leader and the workers. Architecture pieces (all per-rank, follow `tp_qwen3.rs` patterns for the full-attention layers + a new pattern for linear-attention): - TpQwen3_5GatedDeltaNet: V-head-dim sharded. `num_v_heads / world_size` V-heads per rank, `num_k_heads / world_size` K-heads. `in_proj_z`, `in_proj_b`, `in_proj_a`, `A_log`, `dt_bias` shard uniformly along the V-head dim. `out_proj` is row-parallel + AllReduce (the only collective inside the block). The recurrent state shards 1:1 with V-heads — no cross-rank sync inside the delta-rule loop. `in_proj_qkv` and `conv1d.weight` are FUSED tensors with three regions along dim 0 (`[first key_dim, second key_dim, value_dim]`). Standard uniform-slicing doesn't align with the head boundaries — rank 0 would end up with `[first half of K_0, full K_1, first half of V]`. New `load_fused_qkv_slice_{2d,3d}` helpers load the full tensor, narrow per-region per-rank, and `Tensor::cat` the three slices into a per-rank fused weight. Transient peak of one full tensor per layer during construction; net memory is properly per- rank after the full drops. - TpQwen3_5Attention: column-parallel `q_proj` (the widened `2 * num_heads * head_dim` output, including the gate half — shards along the head axis so both query AND gate halves stay consistent per rank), `k_proj`, `v_proj`; row-parallel `o_proj` with AllReduce. Otherwise mirrors `tp_qwen3.rs`'s attention. - TpQwen3_5MLP, TpQwen3_5DecoderLayer (dispatches on layer_types), TpQwen3_5Model (with `model.language_model.*` prefix), and TpQwen3_5ForCausalLM (with tied or separate `lm_head` at top level). Dispatch wiring: - New `tp::TpLeaderModel` enum holds either Qwen3 or Qwen3_5 variant. `WorkerPool::load_dense_shard` now dispatches on `model_type` from the config JSON and returns `Arc>`. The two downstream methods (`generate_step`, `clear_kv_cache`) thread this enum through — the inner forward+clear_kv_cache dispatch happens via the enum's pub methods. Adding another TP architecture later is one more enum variant + match arms. - Worker side gets a parallel `WorkerModel` enum + dispatch in `handle_load_dense_shard`, branching on the same `model_type`. - Harness gate `TP_SUPPORTED_MODEL_TYPES` now `["qwen3", "qwen3_5"]`. `TpLoadedModel.leader_model` retyped to the enum. Helpers in `arch/qwen3_5/linear_attn.rs`: - `softplus` and `repeat_interleave` made `pub(crate)` so the TP module reuses them rather than duplicating. Reuses unchanged: `Qwen3_5RmsNorm` (replicated weight), the gated `Qwen3_5RmsNormGated` tail, `l2norm`, the `RotaryEmbedding` (partial RoPE with `partial_rotary_factor` already correct). CPU build + clippy + 32 lib tests pass; `cargo clippy --features cuda` also clean inside the patched runner container. Single inflight risk to call out: tensor names. For full-attention layers the per-layer prefix is `model.language_model.layers..self_attn.*` and for linear-attention layers `model.language_model.layers..linear_attn.*` — the same as the single-GPU path. lm_head sits at the top level (not under `language_model`) — consistent with the single-GPU path that validated against Qwen3.5-0.8B. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/harness/arch/qwen3_5/linear_attn.rs | 8 +- crates/neuron/src/harness/candle.rs | 4 +- crates/neuron/src/harness/tp/mod.rs | 124 +- crates/neuron/src/harness/tp/tp_qwen3_5.rs | 1082 +++++++++++++++++ crates/neuron/src/harness/tp/worker.rs | 150 ++- 5 files changed, 1305 insertions(+), 63 deletions(-) create mode 100644 crates/neuron/src/harness/tp/tp_qwen3_5.rs diff --git a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs index 102277b..7ced2da 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs @@ -404,7 +404,7 @@ fn load_linear_no_bias( /// Numerically-stable `softplus(x) = ln(1 + exp(x))`. Matches PyTorch's /// `F.softplus` default (beta=1, threshold=20: for large positive x, /// returns x as-is to avoid overflow in the exp). -fn softplus(x: &Tensor) -> candle_core::Result { +pub(crate) fn softplus(x: &Tensor) -> candle_core::Result { let threshold = 20.0_f64; let big = x.ge(threshold)?; // Tensor mask let safe = x.minimum(&x.affine(0.0, 0.0)?.affine(0.0, threshold)?)?; // min(x, threshold) @@ -415,7 +415,11 @@ fn softplus(x: &Tensor) -> candle_core::Result { /// `repeat_interleave` along a single dim. Candle has no built-in for /// this; emulate with unsqueeze + expand + reshape. -fn repeat_interleave(x: &Tensor, repeats: usize, dim: usize) -> candle_core::Result { +pub(crate) fn repeat_interleave( + x: &Tensor, + repeats: usize, + dim: usize, +) -> candle_core::Result { if repeats == 1 { return Ok(x.clone()); } diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 342df69..75323a4 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -101,7 +101,7 @@ pub struct TpLoadedModel { /// step. The same Mutex covers both for the simplest correctness /// story. pub pool: tokio::sync::Mutex, - pub leader_model: Arc>, + pub leader_model: Arc>, } /// Architecture-specific weights. Each variant covers one (family, @@ -291,7 +291,7 @@ fn check_dense_config_supported(config_json: &str, model_id: &str) -> Result<()> /// families than the TP path because each TP-aware module is a real /// chunk of work (`tp_qwen3.rs` is the only one shipped today). #[cfg(feature = "cuda")] -const TP_SUPPORTED_MODEL_TYPES: &[&str] = &["qwen3"]; +const TP_SUPPORTED_MODEL_TYPES: &[&str] = &["qwen3", "qwen3_5"]; /// TP-side counterpart to `check_dense_config_supported`. Gates the /// `load_tp` path on a narrower architecture set: even though the diff --git a/crates/neuron/src/harness/tp/mod.rs b/crates/neuron/src/harness/tp/mod.rs index 451904c..c05522f 100644 --- a/crates/neuron/src/harness/tp/mod.rs +++ b/crates/neuron/src/harness/tp/mod.rs @@ -22,6 +22,7 @@ pub mod nccl_state; pub mod rpc; pub mod tp_linear; pub mod tp_qwen3; +pub mod tp_qwen3_5; pub mod worker; use anyhow::{Context, Result}; @@ -32,6 +33,49 @@ use tokio::process::{Child, ChildStdin, ChildStdout, Command}; use rpc::{WorkerRequest, WorkerResponse}; +/// Leader-side handle for any TP-loaded model. The pool's +/// `load_dense_shard` dispatches on `config.json#/model_type` to build +/// the right variant; downstream callers (the harness's +/// `chat_completion_tp` path, `generate_step`, `clear_kv_cache`, +/// `unload_model`) all hold this enum and let the variant dispatch +/// determine the concrete forward. +/// +/// Variants gated on `cuda` because the underlying TP models hold +/// `Arc` references — irrelevant on CPU builds. +#[cfg(feature = "cuda")] +pub enum TpLeaderModel { + Qwen3(tp_qwen3::TpQwen3ForCausalLM), + Qwen3_5(tp_qwen3_5::TpQwen3_5ForCausalLM), +} + +#[cfg(feature = "cuda")] +impl TpLeaderModel { + pub fn forward( + &mut self, + input: &candle_core::Tensor, + offset: usize, + ) -> candle_core::Result { + match self { + TpLeaderModel::Qwen3(m) => m.forward(input, offset), + TpLeaderModel::Qwen3_5(m) => m.forward(input, offset), + } + } + + pub fn clear_kv_cache(&mut self) { + match self { + TpLeaderModel::Qwen3(m) => m.clear_kv_cache(), + TpLeaderModel::Qwen3_5(m) => m.clear_kv_cache(), + } + } + + pub fn device(&self) -> &candle_core::Device { + match self { + TpLeaderModel::Qwen3(m) => m.device(), + TpLeaderModel::Qwen3_5(m) => m.device(), + } + } +} + /// One worker subprocess plus its bidirectional stdio handles. struct Worker { rank: u32, @@ -363,7 +407,7 @@ impl WorkerPool { safetensors_paths: &[std::path::PathBuf], leader_device: &candle_core::Device, dtype: candle_core::DType, - ) -> Result>> { + ) -> Result>> { use candle_nn::var_builder::ShardedSafeTensors; use std::sync::Arc; use tokio::sync::Mutex; @@ -396,36 +440,56 @@ impl WorkerPool { .await?; } - // 2. Build rank 0's shard on the leader. ShardedVarBuilder reads - // only the rank's slice from safetensors — no full-tensor - // materialisation. Runs in spawn_blocking because the - // file-mmap + slice + copy-to-device work is synchronous. - let cfg: super::tp::tp_qwen3::Config = - serde_json::from_str(config_json).context("parse Qwen3 Config JSON for leader load")?; + // 2. Build rank 0's shard on the leader. Dispatch on model_type + // — for `qwen3` we build a `TpQwen3ForCausalLM`, for + // `qwen3_5` (Qwen3-Next, Qwen3.6's architecture) we build + // `TpQwen3_5ForCausalLM`. Both end up wrapped in the + // `TpLeaderModel` enum so downstream callers don't care. + let model_type = serde_json::from_str::(config_json) + .ok() + .as_ref() + .and_then(|v| v.get("model_type")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); let paths_for_leader: Vec = safetensors_paths.to_vec(); let device_for_leader = leader_device.clone(); let comm_for_leader = leader_comm; let model_id_for_log = model_id.to_string(); - let leader_model = tokio::task::spawn_blocking( - move || -> Result { - // SAFETY: same invariant as the single-GPU dense path — - // the HF cache files are treated as immutable while the - // mmap is held. - let vb = unsafe { - ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader) - .context("build ShardedVarBuilder over safetensors")? - }; - let model = super::tp::tp_qwen3::TpQwen3ForCausalLM::load( - &cfg, - &vb, - 0, - world_size, - comm_for_leader.into_inner(), - )?; - tracing::info!(rank = 0, model = %model_id_for_log, "loaded TP shard (leader)"); - Ok(model) - }, - ) + let config_json_for_leader = config_json.to_string(); + + let leader_model = tokio::task::spawn_blocking(move || -> Result { + // SAFETY: same invariant as the single-GPU dense path — + // the HF cache files are treated as immutable while the + // mmap is held. + let vb = unsafe { + ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader) + .context("build ShardedVarBuilder over safetensors")? + }; + let comm = comm_for_leader.into_inner(); + let loaded = match model_type.as_str() { + "qwen3" => { + let cfg: super::tp::tp_qwen3::Config = serde_json::from_str(&config_json_for_leader) + .context("parse Qwen3 Config JSON for leader load")?; + TpLeaderModel::Qwen3(super::tp::tp_qwen3::TpQwen3ForCausalLM::load( + &cfg, &vb, 0, world_size, comm, + )?) + } + "qwen3_5" => { + let cfg: super::tp::tp_qwen3_5::Config = + serde_json::from_str(&config_json_for_leader) + .context("parse Qwen3-Next Config JSON for leader load")?; + TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load( + cfg, &vb, 0, world_size, comm, + )?) + } + other => anyhow::bail!( + "TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)" + ), + }; + tracing::info!(rank = 0, model = %model_id_for_log, model_type = %model_type, "loaded TP shard (leader)"); + Ok(loaded) + }) .await .context("leader load task panicked")??; @@ -463,7 +527,7 @@ impl WorkerPool { pub async fn generate_step( &mut self, model_id: &str, - leader_model: std::sync::Arc>, + leader_model: std::sync::Arc>, tokens: Vec, offset: usize, ) -> Result { @@ -516,9 +580,7 @@ impl WorkerPool { pub async fn clear_kv_cache( &mut self, model_id: &str, - #[cfg(feature = "cuda")] leader_model: std::sync::Arc< - tokio::sync::Mutex, - >, + #[cfg(feature = "cuda")] leader_model: std::sync::Arc>, ) -> Result<()> { for w in &mut self.workers { w.send_only(&WorkerRequest::ClearKvCache { diff --git a/crates/neuron/src/harness/tp/tp_qwen3_5.rs b/crates/neuron/src/harness/tp/tp_qwen3_5.rs new file mode 100644 index 0000000..1b45235 --- /dev/null +++ b/crates/neuron/src/harness/tp/tp_qwen3_5.rs @@ -0,0 +1,1082 @@ +//! Tensor-parallel Qwen3-Next (`qwen3_5`) model. +//! +//! Two distinct sharding strategies coexist in the same model because +//! `layer_types[i]` dispatches per layer: +//! +//! - **Full-attention layers** (`Qwen3_5Attention`): column-parallel +//! `q_proj` (the doubled `2 * num_heads * head_dim` output sharded +//! on the head axis, including the gate half), `k_proj`, `v_proj`; +//! row-parallel `o_proj` with the trailing `AllReduce`. Same shape +//! of work as `tp_qwen3.rs` apart from the gate. +//! +//! - **Linear-attention layers** (`Qwen3_5GatedDeltaNet`): V-head-dim +//! sharding. Per rank: `num_v_heads / world_size` value heads and +//! `num_k_heads / world_size` key heads. The recurrent state shards +//! 1:1 with the V-heads; no cross-rank sync inside the delta-rule +//! loop. `out_proj` is row-parallel + AllReduce — the only +//! collective inside the block. +//! +//! The `in_proj_qkv` and `conv1d` weights are *fused* tensors with +//! three regions sequentially along dim 0: +//! `[first key_dim, second key_dim, value_dim]`. Uniform +//! slicing-along-dim-0 (the standard `ShardedSafeTensors` behaviour) +//! does **not** align with these head boundaries — rank 0 would end +//! up with `[first half of key_dim_0, full key_dim_1, first half of +//! value_dim]`, garbage. So we load the full tensor and re-slice it +//! per-region per-rank, dropping the unused portion. Net memory is +//! the same as proper per-rank loading; transient peak is one +//! full-tensor allocation per layer during construction. +//! +//! Replicated: embedding, all RmsNorms, the gated RMSNorm tail of the +//! linear-attention block, lm_head, the rotary table. + +use anyhow::{Context, Result, bail}; +use candle_core::{D, DType, Device, IndexOp, Module, Tensor}; +use candle_nn::var_builder::ShardedVarBuilder; +use candle_nn::{Embedding, Linear, kv_cache::ConcatKvCache}; +use candle_transformers::utils::repeat_kv; +use std::sync::Arc; + +#[cfg(feature = "cuda")] +use cudarc::nccl::Comm; + +use super::tp_linear::{ColumnParallelLinear, RowParallelLinear}; +use crate::harness::arch::qwen3_5::linear_attn::{repeat_interleave, softplus}; +use crate::harness::arch::qwen3_5::rmsnorm::{Qwen3_5RmsNorm, Qwen3_5RmsNormGated, l2norm}; +use crate::harness::arch::qwen3_5::rope::RotaryEmbedding; +pub use crate::harness::arch::qwen3_5::{Config, TextConfig}; + +// ─── linear-attention (Gated DeltaNet) ────────────────────────────── + +/// Per-rank, per-layer state for the TP linear-attention block. +/// Identical shape to the single-GPU `GatedDeltaNetState` but with +/// `num_v_heads` replaced by `per_rank_num_v_heads`. +#[derive(Default)] +pub struct TpGatedDeltaNetState { + pub conv_state: Option, + pub recurrent_state: Option, +} + +pub(crate) struct TpQwen3_5GatedDeltaNet { + in_proj_qkv: Linear, + in_proj_z: ColumnParallelLinear, + in_proj_b: ColumnParallelLinear, + in_proj_a: ColumnParallelLinear, + out_proj: RowParallelLinear, + + /// Depthwise causal Conv1d weight, sharded per-region by V-head. + /// Shape: `(per_rank_conv_dim, 1, conv_kernel_size)`. + conv1d_weight: Tensor, + + /// Per-V-head discretisation params, sharded along `num_v_heads`. + a_log: Tensor, + dt_bias: Tensor, + + /// Output gated RMSNorm (replicated; the norm dim is `head_v_dim` + /// which doesn't change with sharding). + norm: Qwen3_5RmsNormGated, + + // Per-rank shape hyperparams. + per_rank_num_v_heads: usize, + per_rank_num_k_heads: usize, + head_k_dim: usize, + head_v_dim: usize, + per_rank_key_dim: usize, + per_rank_value_dim: usize, + per_rank_conv_dim: usize, + conv_kernel_size: usize, + + state: TpGatedDeltaNetState, +} + +impl TpQwen3_5GatedDeltaNet { + #[cfg(feature = "cuda")] + pub fn load( + cfg: &TextConfig, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + comm: Arc, + ) -> Result { + Self::load_inner(cfg, vb, rank, world_size, comm) + } + + #[cfg(not(feature = "cuda"))] + pub fn load( + cfg: &TextConfig, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + ) -> Result { + Self::load_inner(cfg, vb, rank, world_size) + } + + fn load_inner( + cfg: &TextConfig, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + #[cfg(feature = "cuda")] comm: Arc, + ) -> Result { + let ws = world_size as usize; + let num_v_heads = cfg.linear_num_value_heads; + let num_k_heads = cfg.linear_num_key_heads; + if num_v_heads == 0 || num_k_heads == 0 { + bail!( + "Qwen3-Next linear_num_*_heads must be set; got v={num_v_heads}, k={num_k_heads}" + ); + } + if !num_v_heads.is_multiple_of(num_k_heads) { + bail!( + "linear_num_value_heads ({num_v_heads}) must be a multiple of \ + linear_num_key_heads ({num_k_heads}) for GQA-style head expansion" + ); + } + if !num_v_heads.is_multiple_of(ws) { + bail!("linear_num_value_heads ({num_v_heads}) not divisible by world_size {ws}"); + } + if !num_k_heads.is_multiple_of(ws) { + bail!("linear_num_key_heads ({num_k_heads}) not divisible by world_size {ws}"); + } + + let head_k_dim = cfg.linear_key_head_dim; + let head_v_dim = cfg.linear_value_head_dim; + let conv_kernel_size = cfg.linear_conv_kernel_dim; + let per_rank_num_v_heads = num_v_heads / ws; + let per_rank_num_k_heads = num_k_heads / ws; + let per_rank_key_dim = head_k_dim * per_rank_num_k_heads; + let per_rank_value_dim = head_v_dim * per_rank_num_v_heads; + let per_rank_conv_dim = per_rank_key_dim * 2 + per_rank_value_dim; + + let key_dim = head_k_dim * num_k_heads; + let value_dim = head_v_dim * num_v_heads; + let conv_dim = key_dim * 2 + value_dim; + let hidden_size = cfg.hidden_size; + + // ----- Fused `in_proj_qkv` and `conv1d` (per-region slicing). ----- + let in_proj_qkv_weight = load_fused_qkv_slice_2d( + vb, + "in_proj_qkv", + conv_dim, + hidden_size, + key_dim, + value_dim, + rank, + world_size, + )?; + let in_proj_qkv = Linear::new(in_proj_qkv_weight, None); + + let conv1d_weight = load_fused_qkv_slice_3d( + &vb.pp("conv1d"), + (conv_dim, 1, conv_kernel_size), + key_dim, + value_dim, + rank, + world_size, + )?; + + // ----- Uniformly-sharded projections (along output dim 0). ----- + // in_proj_z: hidden → value_dim, sharded along value_dim (V-head). + let in_proj_z = ColumnParallelLinear::load(&vb.pp("in_proj_z"), rank, world_size)?; + // in_proj_b, in_proj_a: hidden → num_v_heads, sharded along output. + let in_proj_b = ColumnParallelLinear::load(&vb.pp("in_proj_b"), rank, world_size)?; + let in_proj_a = ColumnParallelLinear::load(&vb.pp("in_proj_a"), rank, world_size)?; + + // ----- Per-V-head 1D params (sharded uniformly). ----- + let a_log = vb + .get_with_hints((), "A_log", super::tp_linear::shard(0, rank, world_size)) + .with_context(|| format!("load '{}/A_log'", vb.prefix()))?; + let dt_bias = vb + .get_with_hints((), "dt_bias", super::tp_linear::shard(0, rank, world_size)) + .with_context(|| format!("load '{}/dt_bias'", vb.prefix()))?; + + // ----- Output gated RMSNorm (replicated, norm dim is head_v_dim). ----- + let norm = Qwen3_5RmsNormGated::load(&vb.pp("norm"), head_v_dim, cfg.rms_norm_eps)?; + + // ----- Output projection: row-parallel + AllReduce. ----- + #[cfg(feature = "cuda")] + let out_proj = RowParallelLinear::load(&vb.pp("out_proj"), rank, world_size, comm)?; + #[cfg(not(feature = "cuda"))] + let out_proj = RowParallelLinear::load(&vb.pp("out_proj"), rank, world_size)?; + + Ok(Self { + in_proj_qkv, + in_proj_z, + in_proj_b, + in_proj_a, + out_proj, + conv1d_weight, + a_log, + dt_bias, + norm, + per_rank_num_v_heads, + per_rank_num_k_heads, + head_k_dim, + head_v_dim, + per_rank_key_dim, + per_rank_value_dim, + per_rank_conv_dim, + conv_kernel_size, + state: TpGatedDeltaNetState::default(), + }) + } + + pub fn clear_kv_cache(&mut self) { + self.state = TpGatedDeltaNetState::default(); + } + + /// `x` shape: `(B, L, hidden_size)`. Returns `(B, L, hidden_size)` + /// after the row-parallel AllReduce. + pub fn forward(&mut self, x: &Tensor) -> candle_core::Result { + let (batch_size, seq_len, _) = x.dims3()?; + let dtype = x.dtype(); + let device = x.device().clone(); + + // ----- Projections (per-rank). ----- + let mixed_qkv = self.in_proj_qkv.forward(x)?; // (B, L, per_rank_conv_dim) + let mixed_qkv_chw = mixed_qkv.transpose(1, 2)?.contiguous()?; + + let z = self.in_proj_z.forward(x)?.reshape(( + batch_size, + seq_len, + self.per_rank_num_v_heads, + self.head_v_dim, + ))?; + + let b = self.in_proj_b.forward(x)?; // (B, L, per_rank_num_v_heads) + let a = self.in_proj_a.forward(x)?; + + // ----- State-aware causal Conv1d + SiLU. ----- + let prepended = match &self.state.conv_state { + Some(prev) => Tensor::cat(&[prev, &mixed_qkv_chw], 2)?, + None => mixed_qkv_chw.clone(), + }; + let prep_len = prepended.dims()[2]; + let new_state = if prep_len >= self.conv_kernel_size { + prepended.narrow(2, prep_len - self.conv_kernel_size, self.conv_kernel_size)? + } else { + let pad = Tensor::zeros( + ( + batch_size, + self.per_rank_conv_dim, + self.conv_kernel_size - prep_len, + ), + dtype, + &device, + )?; + Tensor::cat(&[&pad, &prepended], 2)? + }; + self.state.conv_state = Some(new_state); + + let conv_out = prepended.conv1d( + &self.conv1d_weight, + self.conv_kernel_size - 1, + 1, + 1, + self.per_rank_conv_dim, + )?; + let conv_out = conv_out.narrow(2, 0, prep_len)?; + let conv_out = candle_nn::ops::silu(&conv_out)?; + let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?; + let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?; + + // ----- Split into q, k, v (per-rank head counts). ----- + let q = mixed_qkv.narrow(2, 0, self.per_rank_key_dim)?; + let k = mixed_qkv.narrow(2, self.per_rank_key_dim, self.per_rank_key_dim)?; + let v = mixed_qkv.narrow(2, 2 * self.per_rank_key_dim, self.per_rank_value_dim)?; + + let q = q.reshape(( + batch_size, + seq_len, + self.per_rank_num_k_heads, + self.head_k_dim, + ))?; + let k = k.reshape(( + batch_size, + seq_len, + self.per_rank_num_k_heads, + self.head_k_dim, + ))?; + let v = v.reshape(( + batch_size, + seq_len, + self.per_rank_num_v_heads, + self.head_v_dim, + ))?; + + // ----- beta + g (per-V-head, per-token). ----- + let beta = candle_nn::ops::sigmoid(&b)?; + let a_log_f32 = self.a_log.to_dtype(DType::F32)?; + let neg_a_exp = a_log_f32.exp()?.neg()?; // (per_rank_num_v_heads,) + let dt_b_f32 = self.dt_bias.to_dtype(DType::F32)?; + let a_f32 = a.to_dtype(DType::F32)?; + let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?; + let softplus_a = softplus(&a_plus_dt)?; + let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?; + let g = neg_a_exp_b.broadcast_mul(&softplus_a)?; // F32 + + // ----- GQA expansion if per-rank ratio > 1. ----- + let (q, k) = if self.per_rank_num_v_heads > self.per_rank_num_k_heads { + let rep = self.per_rank_num_v_heads / self.per_rank_num_k_heads; + ( + repeat_interleave(&q, rep, 2)?, + repeat_interleave(&k, rep, 2)?, + ) + } else { + (q, k) + }; + + // ----- L2norm on q, k. ----- + let q = l2norm(&q, 1e-6)?; + let k = l2norm(&k, 1e-6)?; + + // ----- Transpose to (B, H, L, D) for delta-rule loop. ----- + let q = q.transpose(1, 2)?.contiguous()?; + let k = k.transpose(1, 2)?.contiguous()?; + let v = v.transpose(1, 2)?.contiguous()?; + let g = g.transpose(1, 2)?.contiguous()?; + let beta = beta.transpose(1, 2)?.contiguous()?; + + let scale = 1.0_f64 / (self.head_k_dim as f64).sqrt(); + let q = (q.to_dtype(DType::F32)? * scale)?; + let k = k.to_dtype(DType::F32)?; + let v = v.to_dtype(DType::F32)?; + let beta = beta.to_dtype(DType::F32)?; + + let mut state = match self.state.recurrent_state.take() { + Some(s) => s.to_dtype(DType::F32)?, + None => Tensor::zeros( + ( + batch_size, + self.per_rank_num_v_heads, + self.head_k_dim, + self.head_v_dim, + ), + DType::F32, + &device, + )?, + }; + + let mut outputs: Vec = Vec::with_capacity(seq_len); + for t in 0..seq_len { + let q_t = q.i((.., .., t, ..))?; + let k_t = k.i((.., .., t, ..))?; + let v_t = v.i((.., .., t, ..))?; + let g_t = g.i((.., .., t))?; + let beta_t = beta.i((.., .., t))?; + + let decay = g_t.exp()?.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?; + state = state.broadcast_mul(&decay)?; + + let k_col = k_t.unsqueeze(D::Minus1)?; + let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?; + + let beta_col = beta_t.unsqueeze(D::Minus1)?; + let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?; + + let delta_row = delta.unsqueeze(2)?; + let outer = k_col.broadcast_mul(&delta_row)?; + state = (state + outer)?; + + let q_col = q_t.unsqueeze(D::Minus1)?; + let out_t = state.broadcast_mul(&q_col)?.sum(2)?; + outputs.push(out_t.unsqueeze(2)?); + } + self.state.recurrent_state = Some(state.to_dtype(dtype)?); + + let core_attn_out = Tensor::cat(&outputs, 2)?; + let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; + let core_attn_out = core_attn_out.to_dtype(dtype)?; + let core_attn_flat = core_attn_out.reshape(( + batch_size * seq_len * self.per_rank_num_v_heads, + self.head_v_dim, + ))?; + let z_flat = z.reshape(( + batch_size * seq_len * self.per_rank_num_v_heads, + self.head_v_dim, + ))?; + let normed = self.norm.forward(&core_attn_flat, &z_flat)?; + let normed = normed.reshape(( + batch_size, + seq_len, + self.per_rank_num_v_heads * self.head_v_dim, + ))?; + + // Row-parallel out_proj + AllReduce. + self.out_proj.forward(&normed) + } +} + +// ─── full-attention layer ─────────────────────────────────────────── + +pub(crate) struct TpQwen3_5Attention { + q_proj: ColumnParallelLinear, // output = 2 * num_heads * head_dim + k_proj: ColumnParallelLinear, + v_proj: ColumnParallelLinear, + o_proj: RowParallelLinear, + q_norm: Qwen3_5RmsNorm, + k_norm: Qwen3_5RmsNorm, + per_rank_num_heads: usize, + per_rank_num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + per_rank_hidden_size: usize, + rotary: Arc, + kv_cache: ConcatKvCache, +} + +impl TpQwen3_5Attention { + #[cfg(feature = "cuda")] + pub fn load( + cfg: &TextConfig, + rotary: Arc, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + comm: Arc, + ) -> Result { + Self::load_inner(cfg, rotary, vb, rank, world_size, comm) + } + + #[cfg(not(feature = "cuda"))] + pub fn load( + cfg: &TextConfig, + rotary: Arc, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + ) -> Result { + Self::load_inner(cfg, rotary, vb, rank, world_size) + } + + fn load_inner( + cfg: &TextConfig, + rotary: Arc, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + #[cfg(feature = "cuda")] comm: Arc, + ) -> Result { + let ws = world_size as usize; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + if !num_heads.is_multiple_of(ws) { + bail!("num_attention_heads ({num_heads}) not divisible by world_size {ws}"); + } + if !num_kv_heads.is_multiple_of(ws) { + bail!("num_key_value_heads ({num_kv_heads}) not divisible by world_size {ws}"); + } + let per_rank_num_heads = num_heads / ws; + let per_rank_num_kv_heads = num_kv_heads / ws; + let num_kv_groups = per_rank_num_heads / per_rank_num_kv_heads; + let head_dim = cfg.head_dim; + let per_rank_hidden_size = head_dim * per_rank_num_heads; + + // q_proj has 2x output width (query + gate halves). Column-parallel + // sharding along the output (head) axis splits both halves + // consistently — rank R holds heads `[R*per_rank, (R+1)*per_rank)` + // for both query AND gate, so the post-attention `gate.sigmoid()` + // multiply against the per-rank attention output matches up. + let q_proj = ColumnParallelLinear::load(&vb.pp("q_proj"), rank, world_size)?; + let k_proj = ColumnParallelLinear::load(&vb.pp("k_proj"), rank, world_size)?; + let v_proj = ColumnParallelLinear::load(&vb.pp("v_proj"), rank, world_size)?; + #[cfg(feature = "cuda")] + let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size, comm)?; + #[cfg(not(feature = "cuda"))] + let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size)?; + + let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?; + let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?; + + let kv_cache = ConcatKvCache::new(2); + + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + q_norm, + k_norm, + per_rank_num_heads, + per_rank_num_kv_heads, + num_kv_groups, + head_dim, + per_rank_hidden_size, + rotary, + kv_cache, + }) + } + + pub fn forward( + &mut self, + x: &Tensor, + attn_mask: Option<&Tensor>, + offset: usize, + ) -> candle_core::Result { + let (b, l, _) = x.dims3()?; + + // 1. q_proj outputs (B, L, per_rank_num_heads * head_dim * 2) + // — split into (query, gate) per rank. + let q_raw = + self.q_proj + .forward(x)? + .reshape((b, l, self.per_rank_num_heads, self.head_dim * 2))?; + let q = q_raw.narrow(3, 0, self.head_dim)?; + let gate = q_raw.narrow(3, self.head_dim, self.head_dim)?; + let gate = gate + .contiguous()? + .reshape((b, l, self.per_rank_num_heads * self.head_dim))?; + + let q = self.q_norm.forward(&q.contiguous()?)?; + let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D) + + let k = + self.k_proj + .forward(x)? + .reshape((b, l, self.per_rank_num_kv_heads, self.head_dim))?; + let k = self.k_norm.forward(&k.contiguous()?)?; + let k = k.transpose(1, 2)?.contiguous()?; + + let v = self + .v_proj + .forward(x)? + .reshape((b, l, self.per_rank_num_kv_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + + let (q, k) = self.rotary.apply(&q, &k, offset)?; + let (k, v) = self.kv_cache.append(&k, &v)?; + let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; + let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?; + + let scale = 1.0_f64 / (self.head_dim as f64).sqrt(); + let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + if let Some(m) = attn_mask { + scores = scores.broadcast_add(m)?; + } + let probs = candle_nn::ops::softmax_last_dim(&scores)?; + let ctx = probs.matmul(&v)?; + + let ctx = ctx + .transpose(1, 2)? + .contiguous()? + .reshape((b, l, self.per_rank_hidden_size))?; + let gate_sig = candle_nn::ops::sigmoid(&gate)?; + let gated = (ctx * gate_sig)?; + self.o_proj.forward(&gated) + } + + pub fn clear_kv_cache(&mut self) { + self.kv_cache.reset(); + } +} + +// ─── MLP ──────────────────────────────────────────────────────────── + +pub(crate) struct TpQwen3_5MLP { + gate_proj: ColumnParallelLinear, + up_proj: ColumnParallelLinear, + down_proj: RowParallelLinear, +} + +impl TpQwen3_5MLP { + #[cfg(feature = "cuda")] + pub fn load( + cfg: &TextConfig, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + comm: Arc, + ) -> Result { + if !cfg.intermediate_size.is_multiple_of(world_size as usize) { + bail!( + "intermediate_size {} not divisible by world_size {}", + cfg.intermediate_size, + world_size + ); + } + Ok(Self { + gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?, + up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?, + down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size, comm)?, + }) + } + + #[cfg(not(feature = "cuda"))] + pub fn load( + cfg: &TextConfig, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + ) -> Result { + if !cfg.intermediate_size.is_multiple_of(world_size as usize) { + bail!( + "intermediate_size {} not divisible by world_size {}", + cfg.intermediate_size, + world_size + ); + } + Ok(Self { + gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?, + up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?, + down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size)?, + }) + } +} + +impl Module for TpQwen3_5MLP { + fn forward(&self, x: &Tensor) -> candle_core::Result { + let lhs = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?; + let rhs = self.up_proj.forward(x)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +// ─── decoder layer ────────────────────────────────────────────────── + +enum TpAttentionKind { + Full(TpQwen3_5Attention), + Linear(TpQwen3_5GatedDeltaNet), +} + +pub struct TpQwen3_5DecoderLayer { + input_layernorm: Qwen3_5RmsNorm, + post_attention_layernorm: Qwen3_5RmsNorm, + mlp: TpQwen3_5MLP, + attention: TpAttentionKind, +} + +impl TpQwen3_5DecoderLayer { + #[cfg(feature = "cuda")] + pub fn load( + cfg: &TextConfig, + rotary: Arc, + layer_idx: usize, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + comm: Arc, + ) -> Result { + let layer_type = cfg + .layer_types + .get(layer_idx) + .map(String::as_str) + .ok_or_else(|| anyhow::anyhow!("layer_types[{layer_idx}] missing"))?; + let attention = match layer_type { + "full_attention" => TpAttentionKind::Full(TpQwen3_5Attention::load( + cfg, + rotary, + &vb.pp("self_attn"), + rank, + world_size, + comm.clone(), + )?), + "linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load( + cfg, + &vb.pp("linear_attn"), + rank, + world_size, + comm.clone(), + )?), + other => bail!("unknown layer_type '{other}' for layer {layer_idx}"), + }; + let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm)?; + let input_layernorm = + Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?; + let post_attention_layernorm = Qwen3_5RmsNorm::load( + &vb.pp("post_attention_layernorm"), + cfg.hidden_size, + cfg.rms_norm_eps, + )?; + Ok(Self { + input_layernorm, + post_attention_layernorm, + mlp, + attention, + }) + } + + #[cfg(not(feature = "cuda"))] + pub fn load( + cfg: &TextConfig, + rotary: Arc, + layer_idx: usize, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + ) -> Result { + let layer_type = cfg + .layer_types + .get(layer_idx) + .map(String::as_str) + .ok_or_else(|| anyhow::anyhow!("layer_types[{layer_idx}] missing"))?; + let attention = match layer_type { + "full_attention" => TpAttentionKind::Full(TpQwen3_5Attention::load( + cfg, + rotary, + &vb.pp("self_attn"), + rank, + world_size, + )?), + "linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load( + cfg, + &vb.pp("linear_attn"), + rank, + world_size, + )?), + other => bail!("unknown layer_type '{other}' for layer {layer_idx}"), + }; + let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size)?; + let input_layernorm = + Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?; + let post_attention_layernorm = Qwen3_5RmsNorm::load( + &vb.pp("post_attention_layernorm"), + cfg.hidden_size, + cfg.rms_norm_eps, + )?; + Ok(Self { + input_layernorm, + post_attention_layernorm, + mlp, + attention, + }) + } + + pub fn forward( + &mut self, + x: &Tensor, + attn_mask: Option<&Tensor>, + offset: usize, + ) -> candle_core::Result { + let h = self.input_layernorm.forward(x)?; + let attn_out = match &mut self.attention { + TpAttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?, + TpAttentionKind::Linear(net) => net.forward(&h)?, + }; + let x = (x + attn_out)?; + let h2 = self.post_attention_layernorm.forward(&x)?; + let h2 = self.mlp.forward(&h2)?; + x + h2 + } + + pub fn clear_kv_cache(&mut self) { + match &mut self.attention { + TpAttentionKind::Full(a) => a.clear_kv_cache(), + TpAttentionKind::Linear(n) => n.clear_kv_cache(), + } + } +} + +// ─── base Model ───────────────────────────────────────────────────── + +pub struct TpQwen3_5Model { + embed_tokens: Embedding, + layers: Vec, + norm: Qwen3_5RmsNorm, + device: Device, + dtype: DType, +} + +impl TpQwen3_5Model { + #[cfg(feature = "cuda")] + pub fn load( + cfg: &TextConfig, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + comm: Arc, + ) -> Result { + let dtype = vb.dtype(); + let device = vb.device().clone(); + let text_vb = vb.pp("model.language_model"); + + let embed_weight = load_replicated( + &text_vb.pp("embed_tokens"), + (cfg.vocab_size, cfg.hidden_size), + "weight", + )?; + let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size); + + let rotary = Arc::new(RotaryEmbedding::new(dtype, cfg, &device)?); + + if cfg.layer_types.len() != cfg.num_hidden_layers { + bail!( + "layer_types must have num_hidden_layers ({}) entries; got {}", + cfg.num_hidden_layers, + cfg.layer_types.len() + ); + } + + let vb_l = text_vb.pp("layers"); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + for i in 0..cfg.num_hidden_layers { + layers.push(TpQwen3_5DecoderLayer::load( + cfg, + rotary.clone(), + i, + &vb_l.pp(i), + rank, + world_size, + comm.clone(), + )?); + } + + let norm = Qwen3_5RmsNorm::load(&text_vb.pp("norm"), cfg.hidden_size, cfg.rms_norm_eps)?; + + Ok(Self { + embed_tokens, + layers, + norm, + device, + dtype, + }) + } + + #[cfg(not(feature = "cuda"))] + pub fn load( + cfg: &TextConfig, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + ) -> Result { + let dtype = vb.dtype(); + let device = vb.device().clone(); + let text_vb = vb.pp("model.language_model"); + + let embed_weight = load_replicated( + &text_vb.pp("embed_tokens"), + (cfg.vocab_size, cfg.hidden_size), + "weight", + )?; + let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size); + + let rotary = Arc::new(RotaryEmbedding::new(dtype, cfg, &device)?); + + if cfg.layer_types.len() != cfg.num_hidden_layers { + bail!( + "layer_types must have num_hidden_layers ({}) entries; got {}", + cfg.num_hidden_layers, + cfg.layer_types.len() + ); + } + + let vb_l = text_vb.pp("layers"); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + for i in 0..cfg.num_hidden_layers { + layers.push(TpQwen3_5DecoderLayer::load( + cfg, + rotary.clone(), + i, + &vb_l.pp(i), + rank, + world_size, + )?); + } + + let norm = Qwen3_5RmsNorm::load(&text_vb.pp("norm"), cfg.hidden_size, cfg.rms_norm_eps)?; + + Ok(Self { + embed_tokens, + layers, + norm, + device, + dtype, + }) + } + + pub fn embed_weight(&self) -> &Tensor { + self.embed_tokens.embeddings() + } + + pub fn clear_kv_cache(&mut self) { + for l in &mut self.layers { + l.clear_kv_cache(); + } + } + + fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result { + let minf = f32::NEG_INFINITY; + let mask: Vec<_> = (0..tgt) + .flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf })) + .collect(); + Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result { + let (b, l) = input.dims2()?; + let mut h = self.embed_tokens.forward(input)?; + let causal = if l == 1 { + None + } else { + Some(self.causal_mask(b, l, offset)?) + }; + for layer in &mut self.layers { + h = layer.forward(&h, causal.as_ref(), offset)?; + } + self.norm.forward(&h) + } +} + +pub struct TpQwen3_5ForCausalLM { + base: TpQwen3_5Model, + lm_head: Linear, +} + +impl TpQwen3_5ForCausalLM { + #[cfg(feature = "cuda")] + pub fn load( + config: Config, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + comm: Arc, + ) -> Result { + let cfg = &config.text_config; + let base = TpQwen3_5Model::load(cfg, vb, rank, world_size, comm)?; + let lm_head = build_lm_head(cfg, vb, &base)?; + Ok(Self { base, lm_head }) + } + + #[cfg(not(feature = "cuda"))] + pub fn load( + config: Config, + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + ) -> Result { + let cfg = &config.text_config; + let base = TpQwen3_5Model::load(cfg, vb, rank, world_size)?; + let lm_head = build_lm_head(cfg, vb, &base)?; + Ok(Self { base, lm_head }) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result { + let (_, l) = input.dims2()?; + let hidden = self.base.forward(input, offset)?; + hidden.i((.., l - 1.., ..))?.apply(&self.lm_head) + } + + pub fn clear_kv_cache(&mut self) { + self.base.clear_kv_cache(); + } + + pub fn device(&self) -> &Device { + &self.base.device + } +} + +fn build_lm_head( + cfg: &TextConfig, + vb: &ShardedVarBuilder, + base: &TpQwen3_5Model, +) -> Result { + if cfg.tie_word_embeddings { + Ok(Linear::new(base.embed_weight().clone(), None)) + } else { + // lm_head sits at the top level (sibling of `model.*`), NOT + // under `model.language_model`. + let weight = load_replicated( + &vb.pp("lm_head"), + (cfg.vocab_size, cfg.hidden_size), + "weight", + )?; + Ok(Linear::new(weight, None)) + } +} + +// ─── load helpers ─────────────────────────────────────────────────── + +/// Load a tensor that's the SAME on every rank by asking the +/// ShardedVarBuilder with the default `Shard { world_size: 1 }` hint +/// (which falls through to the unsharded backend). +fn load_replicated>( + vb: &ShardedVarBuilder, + shape: S, + name: &str, +) -> Result { + vb.get(shape, name) + .with_context(|| format!("load replicated '{}/{name}'", vb.prefix())) +} + +/// Load a fused QKV-style 2D weight tensor that stores three regions +/// sequentially along dim 0: `[first key_dim, second key_dim, value_dim]`. +/// Returns the per-rank slice formed by extracting the rank's share +/// from each region and concatenating along dim 0. +/// +/// The full tensor materialises briefly on the device before the +/// slices are extracted (`narrow` views + `contiguous` copy). Memory +/// peak is one full-tensor load per layer during construction; only +/// the per-rank concatenation stays after `full` drops. +#[allow(clippy::too_many_arguments)] +fn load_fused_qkv_slice_2d( + vb: &ShardedVarBuilder, + name: &str, + conv_dim: usize, + hidden_size: usize, + key_dim: usize, + value_dim: usize, + rank: u32, + world_size: u32, +) -> Result { + let ws = world_size as usize; + let r = rank as usize; + if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) { + bail!( + "fused qkv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \ + must each be divisible by world_size ({ws})" + ); + } + let per_rank_key = key_dim / ws; + let per_rank_value = value_dim / ws; + + // Force full-tensor load via `vb.get`, which defaults to + // `Shard { world_size: 1 }` and falls through to SimpleBackend. + let full = vb + .pp(name) + .get((conv_dim, hidden_size), "weight") + .with_context(|| format!("load fused qkv '{}/{}/weight'", vb.prefix(), name))?; + + let q = full.narrow(0, r * per_rank_key, per_rank_key)?; + let k = full.narrow(0, key_dim + r * per_rank_key, per_rank_key)?; + let v = full.narrow(0, 2 * key_dim + r * per_rank_value, per_rank_value)?; + + Tensor::cat(&[&q, &k, &v], 0)? + .contiguous() + .with_context(|| format!("materialise fused qkv slice for rank {r}")) +} + +/// Same per-region slicing pattern for a 3D fused tensor (the depthwise +/// `conv1d.weight` of the linear-attention block: shape +/// `(conv_dim, 1, kernel_size)`). +fn load_fused_qkv_slice_3d( + vb: &ShardedVarBuilder, + shape: (usize, usize, usize), + key_dim: usize, + value_dim: usize, + rank: u32, + world_size: u32, +) -> Result { + let (conv_dim, mid, kernel_size) = shape; + let ws = world_size as usize; + let r = rank as usize; + if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) { + bail!( + "fused conv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \ + must each be divisible by world_size ({ws})" + ); + } + let per_rank_key = key_dim / ws; + let per_rank_value = value_dim / ws; + + let full = vb + .get((conv_dim, mid, kernel_size), "weight") + .with_context(|| format!("load fused conv '{}/weight'", vb.prefix()))?; + + let q = full.narrow(0, r * per_rank_key, per_rank_key)?; + let k = full.narrow(0, key_dim + r * per_rank_key, per_rank_key)?; + let v = full.narrow(0, 2 * key_dim + r * per_rank_value, per_rank_value)?; + + Tensor::cat(&[&q, &k, &v], 0)? + .contiguous() + .with_context(|| format!("materialise fused conv slice for rank {r}")) +} diff --git a/crates/neuron/src/harness/tp/worker.rs b/crates/neuron/src/harness/tp/worker.rs index d7d0ca7..82267c5 100644 --- a/crates/neuron/src/harness/tp/worker.rs +++ b/crates/neuron/src/harness/tp/worker.rs @@ -21,6 +21,46 @@ use super::rpc::{WorkerRequest, WorkerResponse}; #[cfg(feature = "cuda")] use super::tp_qwen3::TpQwen3ForCausalLM; +#[cfg(feature = "cuda")] +use super::tp_qwen3_5::TpQwen3_5ForCausalLM; + +/// Worker-side discriminator over the architectures we can load via +/// `LoadDenseShard`. Mirrors `super::TpLeaderModel` on the leader +/// side — the dispatch happens on the `model_type` extracted from the +/// config JSON. +#[cfg(feature = "cuda")] +enum WorkerModel { + Qwen3(TpQwen3ForCausalLM), + Qwen3_5(TpQwen3_5ForCausalLM), +} + +#[cfg(feature = "cuda")] +impl WorkerModel { + fn forward( + &mut self, + input: &candle_core::Tensor, + offset: usize, + ) -> candle_core::Result { + match self { + WorkerModel::Qwen3(m) => m.forward(input, offset), + WorkerModel::Qwen3_5(m) => m.forward(input, offset), + } + } + + fn clear_kv_cache(&mut self) { + match self { + WorkerModel::Qwen3(m) => m.clear_kv_cache(), + WorkerModel::Qwen3_5(m) => m.clear_kv_cache(), + } + } + + fn device(&self) -> &candle_core::Device { + match self { + WorkerModel::Qwen3(m) => m.device(), + WorkerModel::Qwen3_5(m) => m.device(), + } + } +} #[derive(Debug, Clone, Copy)] pub struct WorkerConfig { @@ -84,12 +124,13 @@ async fn write_response(stdout: &mut tokio::io::Stdout, resp: &WorkerResponse) - struct WorkerState { config: WorkerConfig, nccl: NcclState, - /// Loaded model shards keyed by `model_id`. Each entry holds this - /// rank's `TpQwen3ForCausalLM` — the column/row-parallel layers - /// hold an `Arc` cloned from `nccl`. Cuda-only: there is no - /// TpQwen3ForCausalLM type without the cuda feature in scope. + /// Loaded model shards keyed by `model_id`. Each entry wraps the + /// rank's TP architecture handle (Qwen3 or Qwen3-Next) — the + /// column/row-parallel layers hold an `Arc` cloned from + /// `nccl`. Cuda-only: the underlying types reference cudarc types + /// that don't exist without the cuda feature. #[cfg(feature = "cuda")] - models: HashMap, + models: HashMap, /// Placeholder so the non-cuda build keeps the same field name set /// and `WorkerState::new` reads the same on both. #[cfg(not(feature = "cuda"))] @@ -138,6 +179,7 @@ impl WorkerState { config_json: String, safetensors_paths: Vec, ) -> WorkerResponse { + use crate::harness::arch::qwen3_5 as qwen3_5_arch; use candle_core::{DType, Device}; use candle_nn::var_builder::ShardedSafeTensors; use candle_transformers::models::qwen3 as qwen3_dense; @@ -159,15 +201,14 @@ impl WorkerState { } }; - let cfg: qwen3_dense::Config = match serde_json::from_str(&config_json) { - Ok(c) => c, - Err(e) => { - return WorkerResponse::Error { - kind: "bad_request".into(), - message: format!("parse Qwen3 Config JSON: {e}"), - }; - } - }; + // Peek at model_type so we know which architecture to build. + let model_type = serde_json::from_str::(&config_json) + .ok() + .as_ref() + .and_then(|v| v.get("model_type")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); let device = match Device::new_cuda(self.config.cuda_device as usize) { Ok(d) => d, @@ -191,24 +232,77 @@ impl WorkerState { }; } }; - let model = match TpQwen3ForCausalLM::load( - &cfg, - &vb, - self.config.rank, - self.config.world_size, - comm, - ) { - Ok(m) => m, - Err(e) => { + + let loaded = match model_type.as_str() { + "qwen3" => { + let cfg: qwen3_dense::Config = match serde_json::from_str(&config_json) { + Ok(c) => c, + Err(e) => { + return WorkerResponse::Error { + kind: "bad_request".into(), + message: format!("parse Qwen3 Config JSON: {e}"), + }; + } + }; + match TpQwen3ForCausalLM::load( + &cfg, + &vb, + self.config.rank, + self.config.world_size, + comm, + ) { + Ok(m) => WorkerModel::Qwen3(m), + Err(e) => { + return WorkerResponse::Error { + kind: "load_failed".into(), + message: format!("TpQwen3ForCausalLM::load: {e:#}"), + }; + } + } + } + "qwen3_5" => { + let cfg: qwen3_5_arch::Config = match serde_json::from_str(&config_json) { + Ok(c) => c, + Err(e) => { + return WorkerResponse::Error { + kind: "bad_request".into(), + message: format!("parse Qwen3-Next Config JSON: {e}"), + }; + } + }; + match TpQwen3_5ForCausalLM::load( + cfg, + &vb, + self.config.rank, + self.config.world_size, + comm, + ) { + Ok(m) => WorkerModel::Qwen3_5(m), + Err(e) => { + return WorkerResponse::Error { + kind: "load_failed".into(), + message: format!("TpQwen3_5ForCausalLM::load: {e:#}"), + }; + } + } + } + other => { return WorkerResponse::Error { - kind: "load_failed".into(), - message: format!("TpQwen3ForCausalLM::load: {e:#}"), + kind: "unsupported_arch".into(), + message: format!( + "worker: unsupported model_type '{other}' (supported: qwen3, qwen3_5)" + ), }; } }; - self.models.insert(model_id.clone(), model); - tracing::info!(rank = self.config.rank, model = %model_id, "loaded TP shard"); + self.models.insert(model_id.clone(), loaded); + tracing::info!( + rank = self.config.rank, + model = %model_id, + model_type = %model_type, + "loaded TP shard" + ); WorkerResponse::LoadDenseShardOk } @@ -256,7 +350,7 @@ impl WorkerState { if let Err(e) = model.forward(&input, offset) { return WorkerResponse::Error { kind: "forward_failed".into(), - message: format!("TpQwen3ForCausalLM::forward: {e}"), + message: format!("TP forward: {e}"), }; } WorkerResponse::GenerateStepOk