feat(stage-8c): linear-attention layer (Qwen3-Next GatedDeltaNet)
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Format (push) Successful in 38s
CI / Clippy (push) Successful in 2m17s
build-prerelease / Build neuron-blackwell (push) Successful in 3m48s
CI / Test (push) Successful in 5m1s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m36s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 5m13s
build-prerelease / Build neuron-ada (push) Successful in 4m39s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m55s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m4s

Implements the recurrent-path Gated DeltaNet block that occupies 48 of
Qwen3.6's 64 decoder layers (`layer_types[i] == "linear_attention"`).
Ported from `huggingface/transformers/models/qwen3_5/modeling_qwen3_5.py`
(`Qwen3_5GatedDeltaNet`, `torch_recurrent_gated_delta_rule`,
`Qwen3_5RMSNormGated`, `l2norm`).

Layout: `arch/qwen3_5.rs` becomes `arch/qwen3_5/` with submodules
- `mod.rs`         — Config + (still-stub) ForCausalLM
- `linear_attn.rs` — GatedDeltaNet + GatedDeltaNetState
- `rmsnorm.rs`     — Qwen3_5RmsNorm `(1+w)*x`, Qwen3_5RmsNormGated, l2norm

Architecture pieces in this commit:
- Block: in_proj_qkv + in_proj_z + in_proj_b + in_proj_a + out_proj
  (all bias=False); depthwise causal Conv1d (k=4) with state-aware
  prepend; SiLU; per-head reshape; L2norm on q,k.
- Discretisation: g = -exp(A_log) * softplus(a + dt_bias); beta = σ(b).
  All computed in f32 to avoid the -inf underflow in fp16 that the
  reference notes.
- Delta rule (recurrent, per-token):
    state *= exp(g_t)
    kv_mem = state^T · k_t
    delta  = (v_t - kv_mem) * beta_t
    state += outer(k_t, delta)
    out_t  = state^T · q_t
- Output: RMSNormGated(core_attn_out, z) reshape out_proj.

State (`GatedDeltaNetState`) lives inline on the layer:
- conv_state: (B, conv_dim, conv_kernel_size) — left-padded tail.
- recurrent_state: (B, num_v_heads, head_k_dim, head_v_dim) — the
  delta-rule outer-product memory.
Cleared via `clear_kv_cache` at the start of every new request.

Config extended with the qwen3_5-specific fields:
- linear_num_value_heads (48 in Qwen3.6-27B)
- linear_num_key_heads   (16)
- linear_key_head_dim    (128)
- linear_value_head_dim  (128)
- linear_conv_kernel_dim (4)
- hidden_act             ("silu")

Performance note: this is the **recurrent** delta-rule (PyTorch's
`torch_recurrent_gated_delta_rule`), correct for any seq_len but O(L)
prefill. The chunked algorithm (`torch_chunk_gated_delta_rule`,
chunk_size=64) is a follow-up perf optimisation; surface stays the
same.

8 unit tests:
- softplus small/large branches
- l2norm hand-calc + zero-vector stability
- repeat_interleave round-trip
- forward_smoke on tiny dims (4-head fixture) — verifies shape +
  no NaN/Inf propagation through the f32-promotion pipeline. Doesn't
  validate numerical correctness against the Python reference; that
  requires a fixed-weight fixture and is the next step.

cargo clippy CPU + --features cuda both clean; 32 lib tests pass.
The ForCausalLM stub still bails on forward — wrapping
attention/MLP/decoder layer + lm_head is the next sub-stage.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-20 09:29:52 +03:00
parent a70f317729
commit 180274548d
3 changed files with 727 additions and 0 deletions

View File

@@ -0,0 +1,531 @@
//! Qwen3-Next's `linear_attention` layer: Gated DeltaNet.
//!
//! The recurrent linear-attention block that occupies 3 out of every 4
//! decoder layers in Qwen3.6 (`layer_types[i] == "linear_attention"`).
//! Implemented against the reference Python in
//! `huggingface/transformers/src/transformers/models/qwen3_5/modeling_qwen3_5.py`
//! (class `Qwen3_5GatedDeltaNet`).
//!
//! ## Block structure
//!
//! ```text
//! x ── in_proj_qkv ── transpose ─► (B, conv_dim, L)
//! │
//! ┌──────────────── conv_state ──┤ prepend cached state (decode)
//! ▼
//! depthwise causal Conv1d (k=4) → SiLU
//! │
//! └─ split → q (k_dim), k (k_dim), v (v_dim) ─► per-head reshape
//!
//! x ── in_proj_z ────────────────► z (gate for the output RMSNorm)
//! x ── in_proj_b ── sigmoid ─────► beta (per-head per-token update rate)
//! x ── in_proj_a ── softplus ────► g (decay; see eqn below)
//!
//! g = -exp(A_log) * softplus(a + dt_bias) # discretisation
//! beta = sigmoid(b)
//!
//! (q, k) ─── L2norm ─── delta rule loop ──── core_attn_out
//! (per-token, per-head):
//! state *= exp(g_t)
//! mem = state^T · k_t
//! delta = (v_t - mem) * beta_t
//! state += outer(k_t, delta)
//! out_t = state^T · q_t
//!
//! core_attn_out ── RMSNormGated(z) ── reshape ── out_proj ── y
//! ```
//!
//! ## State
//!
//! Two tensors persist across decode steps:
//! - `conv_state`: `(B, conv_dim, conv_kernel_size)` — left-padded
//! tail of the input to the depthwise conv, so the next causal
//! window has the right left-context.
//! - `recurrent_state`: `(B, num_v_heads, head_k_dim, head_v_dim)` —
//! the delta-rule outer-product memory.
//!
//! Both are cleared via [`GatedDeltaNet::clear_kv_cache`] at the start
//! of every new request.
//!
//! ## Performance note
//!
//! This impl is the **recurrent** delta-rule for both prefill and
//! decode — i.e. the algorithm in `torch_recurrent_gated_delta_rule`.
//! Correctness-first. The chunked algorithm (chunk_size=64) in
//! `torch_chunk_gated_delta_rule` is a perf optimisation for long
//! prefill; can be added later without changing the surface.
use anyhow::{Context, Result};
use candle_core::{IndexOp, Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
use super::TextConfig;
use super::rmsnorm::{Qwen3_5RmsNormGated, l2norm};
/// Per-rank, per-layer state for the linear-attention block.
///
/// `conv_state` is left-padded with zeros on first use; `recurrent_state`
/// is initialised lazily to zeros once we know the batch size.
#[derive(Default)]
pub struct GatedDeltaNetState {
pub conv_state: Option<Tensor>,
pub recurrent_state: Option<Tensor>,
}
pub struct GatedDeltaNet {
// Projections.
in_proj_qkv: Linear,
in_proj_z: Linear,
in_proj_b: Linear,
in_proj_a: Linear,
out_proj: Linear,
// Depthwise causal Conv1d weight; shape (conv_dim, 1, kernel_size).
// No bias (Python sets bias=False).
conv1d_weight: Tensor,
// Per-head discretisation params.
dt_bias: Tensor,
a_log: Tensor,
// Output norm + gate.
norm: Qwen3_5RmsNormGated,
// Shape hyperparams (cached for forward).
num_v_heads: usize,
num_k_heads: usize,
head_k_dim: usize,
head_v_dim: usize,
key_dim: usize,
value_dim: usize,
conv_dim: usize,
conv_kernel_size: usize,
// Recurrent state held inline. Each request resets via
// `clear_kv_cache`; otherwise the state persists across forwards
// and the per-token offset advances naturally.
state: GatedDeltaNetState,
}
impl GatedDeltaNet {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let num_v_heads = cfg.linear_num_value_heads;
let num_k_heads = cfg.linear_num_key_heads;
let head_k_dim = cfg.linear_key_head_dim;
let head_v_dim = cfg.linear_value_head_dim;
let conv_kernel_size = cfg.linear_conv_kernel_dim;
if num_v_heads == 0 || num_k_heads == 0 {
anyhow::bail!(
"Qwen3-Next linear_num_*_heads must be set; got v={num_v_heads}, k={num_k_heads}"
);
}
if !num_v_heads.is_multiple_of(num_k_heads) {
anyhow::bail!(
"linear_num_value_heads ({num_v_heads}) must be a multiple of \
linear_num_key_heads ({num_k_heads}) for GQA-style head expansion"
);
}
let key_dim = head_k_dim * num_k_heads;
let value_dim = head_v_dim * num_v_heads;
let conv_dim = key_dim * 2 + value_dim;
// ----- Linear projections (all `bias=False` in the reference). -----
let in_proj_qkv = load_linear_no_bias(vb, "in_proj_qkv", cfg.hidden_size, conv_dim)?;
let in_proj_z = load_linear_no_bias(vb, "in_proj_z", cfg.hidden_size, value_dim)?;
let in_proj_b = load_linear_no_bias(vb, "in_proj_b", cfg.hidden_size, num_v_heads)?;
let in_proj_a = load_linear_no_bias(vb, "in_proj_a", cfg.hidden_size, num_v_heads)?;
let out_proj = load_linear_no_bias(vb, "out_proj", value_dim, cfg.hidden_size)?;
// ----- Conv1d weight (depthwise, bias=False). -----
let conv1d_weight = vb
.pp("conv1d")
.get((conv_dim, 1, conv_kernel_size), "weight")
.with_context(|| format!("load '{}/conv1d/weight'", vb.prefix()))?;
// ----- dt_bias + A_log: per-head 1D params. -----
let dt_bias = vb
.get(num_v_heads, "dt_bias")
.with_context(|| format!("load '{}/dt_bias'", vb.prefix()))?;
let a_log = vb
.get(num_v_heads, "A_log")
.with_context(|| format!("load '{}/A_log'", vb.prefix()))?;
// ----- Output gated RMSNorm (per-head_v_dim). -----
let norm = Qwen3_5RmsNormGated::load(&vb.pp("norm"), head_v_dim, cfg.rms_norm_eps)?;
Ok(Self {
in_proj_qkv,
in_proj_z,
in_proj_b,
in_proj_a,
out_proj,
conv1d_weight,
dt_bias,
a_log,
norm,
num_v_heads,
num_k_heads,
head_k_dim,
head_v_dim,
key_dim,
value_dim,
conv_dim,
conv_kernel_size,
state: GatedDeltaNetState::default(),
})
}
pub fn clear_kv_cache(&mut self) {
self.state = GatedDeltaNetState::default();
}
/// `x` shape: `(B, L, hidden_size)`. Returns the same shape.
pub fn forward(&mut self, x: &Tensor) -> candle_core::Result<Tensor> {
let (batch_size, seq_len, _) = x.dims3()?;
let dtype = x.dtype();
let device = x.device().clone();
// ----- Projections. -----
// mixed_qkv: (B, L, conv_dim)
let mixed_qkv = self.in_proj_qkv.forward(x)?;
// (B, conv_dim, L) for the conv1d.
let mixed_qkv_chw = mixed_qkv.transpose(1, 2)?.contiguous()?;
// z: (B, L, value_dim) → (B, L, num_v_heads, head_v_dim)
let z = self.in_proj_z.forward(x)?.reshape((
batch_size,
seq_len,
self.num_v_heads,
self.head_v_dim,
))?;
// b, a: (B, L, num_v_heads)
let b = self.in_proj_b.forward(x)?;
let a = self.in_proj_a.forward(x)?;
// ----- Depthwise causal Conv1d + SiLU (with state continuation). -----
// If the previous step left a `conv_state`, prepend it so the
// causal kernel window sees the correct left-context.
let prepended = match &self.state.conv_state {
Some(prev) => Tensor::cat(&[prev, &mixed_qkv_chw], 2)?,
None => mixed_qkv_chw.clone(),
};
let prep_len = prepended.dims()[2];
// Update conv_state: keep the last `conv_kernel_size` columns
// of the (possibly prepended) sequence. If the sequence is
// shorter than `conv_kernel_size` (very-short prefill or first
// decode step before warmup), left-pad with zeros.
let new_state = if prep_len >= self.conv_kernel_size {
prepended.narrow(2, prep_len - self.conv_kernel_size, self.conv_kernel_size)?
} else {
let pad = Tensor::zeros(
(batch_size, self.conv_dim, self.conv_kernel_size - prep_len),
dtype,
&device,
)?;
Tensor::cat(&[&pad, &prepended], 2)?
};
self.state.conv_state = Some(new_state);
// Apply the depthwise conv with padding=kernel-1 (so output
// length = input + kernel - 1), then trim back to `prep_len`.
// Matches the reference Python which calls the same nn.Conv1d
// with its baked-in padding and slices `[..., :input_len]`.
let conv_out = prepended.conv1d(
&self.conv1d_weight,
self.conv_kernel_size - 1,
1,
1,
self.conv_dim,
)?;
let conv_out = conv_out.narrow(2, 0, prep_len)?;
let conv_out = candle_nn::ops::silu(&conv_out)?;
// Keep only the last L outputs (drop the prepended-state contribution).
let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?;
// Back to (B, L, conv_dim).
let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?;
// ----- Split into q, k, v. -----
let q = mixed_qkv.narrow(2, 0, self.key_dim)?;
let k = mixed_qkv.narrow(2, self.key_dim, self.key_dim)?;
let v = mixed_qkv.narrow(2, 2 * self.key_dim, self.value_dim)?;
let q = q.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
let k = k.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
let v = v.reshape((batch_size, seq_len, self.num_v_heads, self.head_v_dim))?;
// ----- beta + g (per-head, per-token gates). -----
// beta = sigmoid(b)
let beta = candle_nn::ops::sigmoid(&b)?;
// g = -exp(A_log) * softplus(a + dt_bias)
// Promote everything to f32 — the Python does the same to
// avoid underflow on the -exp path.
let a_log_f32 = self.a_log.to_dtype(candle_core::DType::F32)?;
let neg_a_exp = a_log_f32.exp()?.neg()?; // (num_v_heads,)
let dt_b_f32 = self.dt_bias.to_dtype(candle_core::DType::F32)?;
let a_f32 = a.to_dtype(candle_core::DType::F32)?;
// a is (B, L, num_v_heads); broadcast-add dt_bias.
let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?;
let softplus = softplus(&a_plus_dt)?;
// (1, 1, num_v_heads) × (B, L, num_v_heads).
let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?;
let g = neg_a_exp_b.broadcast_mul(&softplus)?;
// ----- GQA-style key expansion if num_v_heads > num_k_heads. -----
let (q, k) = if self.num_v_heads > self.num_k_heads {
let rep = self.num_v_heads / self.num_k_heads;
(
repeat_interleave(&q, rep, 2)?,
repeat_interleave(&k, rep, 2)?,
)
} else {
(q, k)
};
// ----- L2-norm on q, k (use_qk_l2norm_in_kernel=True in ref). -----
let q = l2norm(&q, 1e-6)?;
let k = l2norm(&k, 1e-6)?;
// ----- Recurrent delta rule. -----
// Inputs: q, k (B, L, H, D_k); v (B, L, H, D_v); g (B, L, H); beta (B, L, H).
// The reference transposes to (B, H, L, D) before the loop. We
// do the same — it makes per-token indexing trivial.
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D_k)
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?; // (B, H, L, D_v)
let g = g.transpose(1, 2)?.contiguous()?; // (B, H, L)
let beta = beta.transpose(1, 2)?.contiguous()?; // (B, H, L)
// Pre-scale q by 1/sqrt(D_k) once.
let scale = 1.0_f64 / (self.head_k_dim as f64).sqrt();
let q = (q.to_dtype(candle_core::DType::F32)? * scale)?;
let k = k.to_dtype(candle_core::DType::F32)?;
let v = v.to_dtype(candle_core::DType::F32)?;
// Initialise the recurrent state from cache or zeros.
let mut state = match self.state.recurrent_state.take() {
Some(s) => s.to_dtype(candle_core::DType::F32)?,
None => Tensor::zeros(
(
batch_size,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
),
candle_core::DType::F32,
&device,
)?,
};
// Per-token delta-rule loop. Slow-but-correct path; chunked
// optimisation is for later.
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
// (B, H, D_k) and (B, H, D_v) for token t.
let q_t = q.i((.., .., t, ..))?; // (B, H, D_k)
let k_t = k.i((.., .., t, ..))?;
let v_t = v.i((.., .., t, ..))?;
let g_t = g.i((.., .., t))?; // (B, H)
let beta_t = beta.i((.., .., t))?; // (B, H)
// Decay: state *= exp(g_t). exp(g_t) shape (B, H) → broadcast to (B, H, 1, 1).
let decay = g_t
.exp()?
.unsqueeze(candle_core::D::Minus1)?
.unsqueeze(candle_core::D::Minus1)?; // (B, H, 1, 1)
state = state.broadcast_mul(&decay)?;
// Memory readout: sum_{d_k} state[d_k, d_v] * k_t[d_k] → (B, H, D_v).
// state: (B, H, D_k, D_v); k_t.unsqueeze(-1): (B, H, D_k, 1).
let k_col = k_t.unsqueeze(candle_core::D::Minus1)?; // (B, H, D_k, 1)
let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?; // sum over D_k → (B, H, D_v)
// delta = (v_t - kv_mem) * beta_t (broadcast beta on last dim).
let beta_col = beta_t.unsqueeze(candle_core::D::Minus1)?; // (B, H, 1)
let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?; // (B, H, D_v)
// state += outer(k_t, delta) = k_col * delta_row, broadcast to (B, H, D_k, D_v).
let delta_row = delta.unsqueeze(2)?; // (B, H, 1, D_v)
let outer = k_col.broadcast_mul(&delta_row)?; // (B, H, D_k, D_v)
state = (state + outer)?;
// out_t = sum_{d_k} state[d_k, d_v] * q_t[d_k] → (B, H, D_v).
let q_col = q_t.unsqueeze(candle_core::D::Minus1)?; // (B, H, D_k, 1)
let out_t = state.broadcast_mul(&q_col)?.sum(2)?; // (B, H, D_v)
outputs.push(out_t.unsqueeze(2)?); // (B, H, 1, D_v)
}
// Stash the updated recurrent state for the next call.
self.state.recurrent_state = Some(state.to_dtype(dtype)?);
// core_attn_out: (B, H, L, D_v) → (B, L, H, D_v) → (B*L*H, D_v).
let core_attn_out = Tensor::cat(&outputs, 2)?; // (B, H, L, D_v)
let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; // (B, L, H, D_v)
let core_attn_out = core_attn_out.to_dtype(dtype)?;
let core_attn_flat =
core_attn_out.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
let z_flat = z.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
// RMSNormGated: (out * silu(z) * weight) with the norm.
let normed = self.norm.forward(&core_attn_flat, &z_flat)?;
let normed = normed.reshape((batch_size, seq_len, self.num_v_heads * self.head_v_dim))?;
// Output projection: (B, L, value_dim) → (B, L, hidden_size).
self.out_proj.forward(&normed)
}
}
/// Load a no-bias linear from the ShardedVarBuilder. Weight shape is
/// the standard `[out, in]` order.
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}
/// Numerically-stable `softplus(x) = ln(1 + exp(x))`. Matches PyTorch's
/// `F.softplus` default (beta=1, threshold=20: for large positive x,
/// returns x as-is to avoid overflow in the exp).
fn softplus(x: &Tensor) -> candle_core::Result<Tensor> {
let threshold = 20.0_f64;
let big = x.ge(threshold)?; // Tensor<u8> mask
let safe = x.minimum(&x.affine(0.0, 0.0)?.affine(0.0, threshold)?)?; // min(x, threshold)
let small = ((safe.exp()? + 1.0_f64)?).log()?;
// Select x where big, else small.
big.where_cond(x, &small)
}
/// `repeat_interleave` along a single dim. Candle has no built-in for
/// this; emulate with unsqueeze + expand + reshape.
fn repeat_interleave(x: &Tensor, repeats: usize, dim: usize) -> candle_core::Result<Tensor> {
if repeats == 1 {
return Ok(x.clone());
}
let mut shape = x.dims().to_vec();
let orig = shape[dim];
shape.insert(dim + 1, repeats);
let mut expanded_shape = shape.clone();
expanded_shape[dim + 1] = repeats;
let x = x.unsqueeze(dim + 1)?;
let x = x.expand(expanded_shape)?;
let mut out_shape = x.dims().to_vec();
out_shape.remove(dim + 1);
out_shape[dim] = orig * repeats;
x.contiguous()?.reshape(out_shape)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device};
#[test]
fn softplus_small_x() {
// softplus(0) = ln(2) ≈ 0.6931
let x = Tensor::new(&[0.0_f32], &Device::Cpu).unwrap();
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
assert!((out[0] - 2.0_f32.ln()).abs() < 1e-4);
}
#[test]
fn softplus_large_x_returns_x() {
// For x = 30, softplus(x) ≈ x (the threshold branch).
let x = Tensor::new(&[30.0_f32], &Device::Cpu).unwrap();
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
assert!((out[0] - 30.0).abs() < 1e-4);
}
#[test]
fn repeat_interleave_doubles_dim() {
let x = Tensor::new(&[[1.0_f32, 2.0], [3.0, 4.0]], &Device::Cpu).unwrap(); // shape (2, 2)
let out = repeat_interleave(&x, 2, 1).unwrap(); // each col duplicated
let v: Vec<Vec<f32>> = out.to_vec2().unwrap();
// Row 0: 1, 1, 2, 2
// Row 1: 3, 3, 4, 4
assert_eq!(v[0], vec![1.0, 1.0, 2.0, 2.0]);
assert_eq!(v[1], vec![3.0, 3.0, 4.0, 4.0]);
}
/// Sanity: the recurrent path produces a finite tensor of the right
/// shape on tiny dimensions. Doesn't validate numerical correctness
/// against the Python reference — that would need a fixed-weight
/// fixture to compare against. Catches structural mistakes
/// (broadcasting shapes, off-by-one slices) early.
#[test]
fn forward_smoke_with_tiny_dimensions() {
let dev = Device::Cpu;
let dtype = DType::F32;
let (b, l) = (1, 3);
let cfg = TextConfig {
vocab_size: 100,
hidden_size: 16,
intermediate_size: 32,
num_hidden_layers: 1,
num_attention_heads: 4,
num_key_value_heads: 1,
head_dim: 4,
max_position_embeddings: 32,
rope_theta: 10000.0,
rms_norm_eps: 1e-6,
tie_word_embeddings: false,
attn_output_gate: true,
layer_types: vec!["linear_attention".into()],
full_attention_interval: Some(4),
hidden_act: "silu".into(),
linear_num_value_heads: 4,
linear_num_key_heads: 2,
linear_key_head_dim: 4,
linear_value_head_dim: 4,
linear_conv_kernel_dim: 4,
};
// Build a synthetic VarBuilder with all-zeros weights.
// Easier path: skip the load and construct GatedDeltaNet
// manually by hand-rolling the Linear/Tensor inputs.
let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &dev).unwrap();
let key_dim = cfg.linear_key_head_dim * cfg.linear_num_key_heads;
let value_dim = cfg.linear_value_head_dim * cfg.linear_num_value_heads;
let conv_dim = key_dim * 2 + value_dim;
let mut net = GatedDeltaNet {
in_proj_qkv: Linear::new(zeros(&[conv_dim, cfg.hidden_size]), None),
in_proj_z: Linear::new(zeros(&[value_dim, cfg.hidden_size]), None),
in_proj_b: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
in_proj_a: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
out_proj: Linear::new(zeros(&[cfg.hidden_size, value_dim]), None),
conv1d_weight: zeros(&[conv_dim, 1, cfg.linear_conv_kernel_dim]),
dt_bias: zeros(&[cfg.linear_num_value_heads]),
a_log: zeros(&[cfg.linear_num_value_heads]),
norm: {
let weight = Tensor::ones(&[cfg.linear_value_head_dim], dtype, &dev).unwrap();
Qwen3_5RmsNormGated::from_weight(weight, cfg.rms_norm_eps)
},
num_v_heads: cfg.linear_num_value_heads,
num_k_heads: cfg.linear_num_key_heads,
head_k_dim: cfg.linear_key_head_dim,
head_v_dim: cfg.linear_value_head_dim,
key_dim,
value_dim,
conv_dim,
conv_kernel_size: cfg.linear_conv_kernel_dim,
state: GatedDeltaNetState::default(),
};
let x = Tensor::ones(&[b, l, cfg.hidden_size], dtype, &dev).unwrap();
let y = net.forward(&x).unwrap();
assert_eq!(y.dims(), &[b, l, cfg.hidden_size]);
// All zero weights → output should be zero. Confirms no NaN/Inf
// poisoning from the f32 promotions.
let v: Vec<f32> = y.flatten_all().unwrap().to_vec1().unwrap();
assert!(v.iter().all(|x| x.is_finite()));
}
}

View File

@@ -69,6 +69,9 @@ use anyhow::Result;
use candle_core::Tensor;
use serde::Deserialize;
pub mod linear_attn;
pub mod rmsnorm;
/// `model_type` we deserialise from `config.json`. Const so the
/// dispatch in `candle.rs::load_arch_dense` can pattern-match without
/// magic strings.
@@ -122,6 +125,38 @@ pub struct TextConfig {
/// logging / validation; the forward dispatches on `layer_types`.
#[serde(default)]
pub full_attention_interval: Option<usize>,
/// Hidden activation (`"silu"` for Qwen3-Next). Used by the MLP
/// and the linear-attention conv1d.
#[serde(default = "default_hidden_act")]
pub hidden_act: String,
// --- Gated DeltaNet (linear-attention) hyperparams -----------------
/// Per-layer linear-attention V-head count (Qwen3.6-27B: 48).
/// More V-heads than K-heads is fine — query/key get
/// `repeat_interleave`'d to match before the delta rule.
#[serde(default)]
pub linear_num_value_heads: usize,
/// Per-layer linear-attention K-head count (Qwen3.6-27B: 16).
#[serde(default)]
pub linear_num_key_heads: usize,
/// Per-head key dimension for the linear-attention path
/// (Qwen3.6-27B: 128). Separate from `head_dim` which the
/// full-attention layers use.
#[serde(default)]
pub linear_key_head_dim: usize,
/// Per-head value dimension for the linear-attention path
/// (Qwen3.6-27B: 128).
#[serde(default)]
pub linear_value_head_dim: usize,
/// Causal Conv1d kernel size used before the delta rule
/// (Qwen3.6-27B: 4).
#[serde(default)]
pub linear_conv_kernel_dim: usize,
}
fn default_hidden_act() -> String {
"silu".into()
}
/// Stub model. Fields are intentionally empty — filling in the

View File

@@ -0,0 +1,161 @@
//! Norm primitives for Qwen3-Next.
//!
//! Two reasons we can't reuse `candle_nn::RmsNorm` directly:
//!
//! 1. **`(1.0 + weight)` scaling.** Qwen3-Next's `Qwen3_5RMSNorm`
//! initialises `weight` to zeros and applies `(1.0 + weight)` to
//! the normalised vector. `candle_nn::RmsNorm` applies `weight`
//! directly. The two are equivalent only when the operator has
//! pre-shifted the weights — the upstream checkpoints have not. See
//! `huggingface/transformers#29402` for the upstream PR that
//! introduced the `(1 + w)` form to recover from the zero-init.
//!
//! 2. **Gated variant.** The linear-attention layer post-normalises
//! its output by an RMSNorm *gated* with a per-element SiLU on
//! a sibling `z` projection — fused for numerical reasons (the
//! norm's float32 promotion has to happen before the SiLU
//! multiply). Not a single existing candle op.
//!
//! Both ops accept inputs in any compute dtype; promotion to f32 for
//! the variance calculation matches the Python reference.
use anyhow::{Context, Result};
use candle_core::{D, Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
/// L2-normalise along the last dim with a small epsilon. Matches the
/// `l2norm` helper in `transformers/models/qwen3_5/modeling_qwen3_5.py`
/// — `x * rsqrt(sum(x*x) + eps)`. The linear-attention path uses this
/// on Q and K before the delta rule when
/// `use_qk_l2norm_in_kernel=True` (which Qwen3-Next always sets).
pub fn l2norm(x: &Tensor, eps: f32) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let sq = x_f32.sqr()?;
let sum = sq.sum_keepdim(D::Minus1)?;
let inv = (sum + eps as f64)?.sqrt()?.recip()?;
x_f32.broadcast_mul(&inv)?.to_dtype(dtype)
}
/// Qwen3-Next's RMSNorm. Stores the raw weight tensor; forward applies
/// `(1.0 + weight) * x_normed`.
pub struct Qwen3_5RmsNorm {
weight: Tensor,
eps: f32,
size: usize,
}
impl Qwen3_5RmsNorm {
/// Load `weight` from the ShardedVarBuilder. `vb` should already be
/// `.pp(...)`-ed to the norm's tensor prefix.
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
let weight = vb
.get(size, "weight")
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
Ok(Self {
weight,
eps: eps as f32,
size,
})
}
pub fn size(&self) -> usize {
self.size
}
}
impl Module for Qwen3_5RmsNorm {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
// Promote weight to f32 and shift by 1.0 *before* multiplying.
// Doing the (1 + w) operation in fp16 lands at -inf for the
// bottom-of-range weights at load time.
let w_f32 = self.weight.to_dtype(candle_core::DType::F32)?;
let scale = (w_f32 + 1.0_f64)?;
normed.broadcast_mul(&scale)?.to_dtype(dtype)
}
}
/// Gated RMSNorm used at the tail of `Qwen3_5GatedDeltaNet`. Equivalent
/// to `x_normed * weight * silu(gate)` but with both the norm and the
/// gate evaluated in float32 to avoid mid-pipeline underflow.
///
/// Note: unlike `Qwen3_5RmsNorm`, this variant matches the Python
/// reference's `Qwen3_5RMSNormGated` which uses `weight` directly (not
/// `1.0 + weight`).
pub struct Qwen3_5RmsNormGated {
weight: Tensor,
eps: f32,
size: usize,
}
impl Qwen3_5RmsNormGated {
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
let weight = vb
.get(size, "weight")
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
Ok(Self {
weight,
eps: eps as f32,
size,
})
}
/// Direct constructor — used by unit tests that build a layer
/// without going through a VarBuilder.
#[cfg(test)]
pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
let size = weight.dims()[0];
Self {
weight,
eps: eps as f32,
size,
}
}
pub fn size(&self) -> usize {
self.size
}
/// `x` and `gate` share the same last-dim shape (`size`).
pub fn forward(&self, x: &Tensor, gate: &Tensor) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
let w = self.weight.to_dtype(candle_core::DType::F32)?;
let out = normed.broadcast_mul(&w)?;
// SiLU on the float32 gate, multiply back into the normed
// tensor, then cast to the model dtype.
let g = gate.to_dtype(candle_core::DType::F32)?;
let silu_gate = candle_nn::ops::silu(&g)?;
(out * silu_gate)?.to_dtype(dtype)
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn l2norm_matches_hand_calc() {
let x = Tensor::new(&[3.0_f32, 4.0_f32], &Device::Cpu).unwrap();
let out = l2norm(&x, 1e-6).unwrap();
let v: Vec<f32> = out.to_vec1().unwrap();
// |x| = 5, so x/|x| = [0.6, 0.8] (eps is tiny).
assert!((v[0] - 0.6).abs() < 1e-4);
assert!((v[1] - 0.8).abs() < 1e-4);
}
#[test]
fn l2norm_zero_vector_is_safe_via_epsilon() {
let x = Tensor::new(&[0.0_f32, 0.0_f32], &Device::Cpu).unwrap();
let out = l2norm(&x, 1e-6).unwrap();
let v: Vec<f32> = out.to_vec1().unwrap();
assert!(v.iter().all(|x| x.is_finite()));
}
}