diff --git a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs index 58e73b3..e3d9620 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs @@ -737,6 +737,8 @@ mod tests { rope_theta: 10000.0, partial_rotary_factor: 1.0, rope_type: None, + mrope_section: Vec::new(), + mrope_interleaved: false, }, rms_norm_eps: 1e-6, tie_word_embeddings: false, diff --git a/crates/neuron/src/harness/arch/qwen3_5/mod.rs b/crates/neuron/src/harness/arch/qwen3_5/mod.rs index 3431077..4a58d2f 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/mod.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/mod.rs @@ -191,11 +191,12 @@ fn default_hidden_act() -> String { } /// Nested `rope_parameters` block from a Qwen3-Next `config.json`. -/// `mrope_section` and `mrope_interleaved` are accepted via the -/// `#[serde(default)]` flatten-tolerance below but ignored — we treat -/// MRoPE as plain RoPE for text-only inference (the three position -/// grids carry identical ids when there's no vision input, so the -/// interleaving is a no-op). +/// +/// For text-only inference the three MRoPE position grids carry +/// identical ids, so the interleave is a no-op and plain RoPE applies. +/// For vision inputs `mrope_section` + `mrope_interleaved` drive the +/// per-axis (text/height/width) rotary used by image tokens — see +/// `rope.rs`. #[derive(Debug, Clone, Deserialize)] pub struct RopeParameters { /// Base for the inverse-frequency computation. Qwen3.6: 10_000_000. @@ -211,6 +212,16 @@ pub struct RopeParameters { /// implemented here. #[serde(default)] pub rope_type: Option, + /// MRoPE per-axis section sizes `[text, height, width]` — e.g. + /// `[11, 11, 10]` for Qwen3.6, summing to the rotary half-dim. + /// Empty for models that don't declare MRoPE (→ plain RoPE). + #[serde(default)] + pub mrope_section: Vec, + /// Whether the three MRoPE axes are interleaved per-frequency + /// (Qwen3-VL / Qwen3.6 style, `true`) rather than block-concatenated + /// (Qwen2-VL style, `false`). + #[serde(default)] + pub mrope_interleaved: bool, } fn default_rope_theta() -> f64 { diff --git a/crates/neuron/src/harness/arch/qwen3_5/rope.rs b/crates/neuron/src/harness/arch/qwen3_5/rope.rs index 3bc9cfc..67928c1 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/rope.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/rope.rs @@ -1,19 +1,27 @@ //! Rotary position embedding for Qwen3-Next's full-attention layers. //! -//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the -//! reference Python — three position grids interleaved per -//! `mrope_section`. For text-only inference all three grids carry the -//! same position ids and the interleave is a no-op, so this module -//! implements the plain (non-mrope) flavour: the standard inv_freq -//! cosine/sine tables driven by `rope_theta` and `head_dim`. +//! Qwen3.6 declares **interleaved M-RoPE** (multimodal RoPE): the +//! rotary half-dimension is split across three position axes — +//! `[text, height, width]` per `mrope_section` (`[11,11,10]` for +//! Qwen3.6) — interleaved per-frequency. For **text** every token's +//! three axes carry the same position id, so the interleave is a no-op +//! and this reduces exactly to plain RoPE. For **image** tokens the +//! height/width axes carry the patch's 2D grid coordinates, which is +//! how the model reads the 14×14 patch layout (without it, all patches +//! share a height position and the image reads as vertical repetition). //! -//! Rotation flavour: **GLM-style** rotate-half (the second half of the -//! head dim is negated and swapped into the first). The reference -//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's -//! `rope_slow` is the matching helper. +//! Two cos/sin builders feed a shared [`RotaryEmbedding::apply`]: +//! - [`RotaryEmbedding::plain_cos_sin`] narrows the precomputed tables +//! at a scalar position — the text / decode fast path. +//! - [`RotaryEmbedding::mrope_cos_sin`] builds per-token cos/sin from a +//! `(3, seq)` position-id tensor, blending the three axes' frequencies +//! at the interleave index sets — the vision-prefill path. +//! +//! Rotation flavour: **GLM-style** rotate-half (candle's `rope_slow`), +//! matching the reference Python's `apply_rotary_pos_emb` + `rotate_half`. use anyhow::Result; -use candle_core::{DType, Device, Tensor}; +use candle_core::{DType, Device, IndexOp, Tensor}; use super::TextConfig; @@ -21,6 +29,18 @@ use super::TextConfig; pub struct RotaryEmbedding { sin: Tensor, cos: Tensor, + /// Inverse frequencies, shape `(1, rotary_dim/2)`. Retained (beyond + /// the precomputed `sin`/`cos` tables) so [`Self::mrope_cos_sin`] can + /// build cos/sin from arbitrary per-axis position ids. + inv_freq: Tensor, + /// Per-axis column masks over the rotary half-dim, shape `(1, half)`, + /// f32 0/1. `mask_t + mask_h + mask_w` partitions the columns; a + /// column belongs to exactly one axis. For a non-MRoPE config + /// `mask_t` is all-ones and the others all-zero (→ plain RoPE). + mask_t: Tensor, + mask_h: Tensor, + mask_w: Tensor, + dtype: DType, /// Number of dims at the head's leading edge that the rotation /// covers. The remaining `head_dim - rotary_dim` dims pass through /// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so @@ -29,6 +49,52 @@ pub struct RotaryEmbedding { head_dim: usize, } +/// Build the per-axis 0/1 column masks over the rotary half-dim from +/// `mrope_section`. Returns `(temporal, height, width)` each length +/// `half`. Temporal is the complement of height ∪ width, so the three +/// masks always partition `0..half` and reduce to all-temporal (plain +/// RoPE) when no usable section is given. +fn mrope_masks( + half: usize, + section: &[usize], + interleaved: bool, +) -> (Vec, Vec, Vec) { + let mut mh = vec![0f32; half]; + let mut mw = vec![0f32; half]; + if section.len() == 3 { + if interleaved { + // Qwen3-VL: height at columns 1,4,7,… ; width at 2,5,8,… ; + // temporal keeps 0,3,6,… — each `take`n from `mrope_section`. + for i in (1..half).step_by(3).take(section[1]) { + mh[i] = 1.0; + } + for i in (2..half).step_by(3).take(section[2]) { + mw[i] = 1.0; + } + } else { + // Qwen2-VL: contiguous blocks [text | height | width]. + let h_start = section[0].min(half); + let h_end = (section[0] + section[1]).min(half); + for m in mh.iter_mut().take(h_end).skip(h_start) { + *m = 1.0; + } + for m in mw.iter_mut().take(half).skip(h_end) { + *m = 1.0; + } + } + } + let mt: Vec = (0..half) + .map(|i| { + if mh[i] == 0.0 && mw[i] == 0.0 { + 1.0 + } else { + 0.0 + } + }) + .collect(); + (mt, mh, mw) +} + impl RotaryEmbedding { pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result { let head_dim = cfg.head_dim; @@ -52,44 +118,88 @@ impl RotaryEmbedding { .step_by(2) .map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32) .collect(); - let n = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?; + let half = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, half), dev)?.to_dtype(DType::F32)?; let t = Tensor::arange(0u32, max_seq_len as u32, dev)? .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; + + // MRoPE axis masks. `sum(mrope_section)` should equal `half`; + // warn-tolerant: any shortfall just stays on the temporal axis. + let (mt, mh, mw) = mrope_masks(half, &rope.mrope_section, rope.mrope_interleaved); + let mask_t = Tensor::from_vec(mt, (1, half), dev)?; + let mask_h = Tensor::from_vec(mh, (1, half), dev)?; + let mask_w = Tensor::from_vec(mw, (1, half), dev)?; + Ok(Self { sin: freqs.sin()?.to_dtype(dtype)?, cos: freqs.cos()?.to_dtype(dtype)?, + inv_freq, + mask_t, + mask_h, + mask_w, + dtype, rotary_dim, head_dim, }) } - /// Apply RoPE to q, k. - /// - /// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index - /// into the cached cos/sin table — the position of the first token - /// in the current step. - /// - /// When `rotary_dim < head_dim` the rotation is applied only to the - /// first `rotary_dim` dims of each head; the tail passes through - /// unchanged (matches the reference Python's - /// `apply_rotary_pos_emb` with non-trivial `partial_rotary_factor`). - pub fn apply( + /// cos/sin for a contiguous run of `seq_len` positions starting at + /// `pos`, by narrowing the precomputed tables. The text / decode + /// path (all three MRoPE axes equal → plain RoPE). Shape + /// `(seq_len, rotary_dim/2)`. + pub fn plain_cos_sin( + &self, + pos: usize, + seq_len: usize, + ) -> candle_core::Result<(Tensor, Tensor)> { + let cos = self.cos.narrow(0, pos, seq_len)?; + let sin = self.sin.narrow(0, pos, seq_len)?; + Ok((cos, sin)) + } + + /// cos/sin from explicit per-token 3D position ids, shape + /// `(3, seq_len)` (axes: text, height, width). Builds each axis's + /// frequencies and blends them at the interleave index sets, so + /// every rotary frequency slot is driven by exactly one axis. + /// Reduces exactly to [`Self::plain_cos_sin`] when the three axes are + /// equal. Returns cos/sin of shape `(seq_len, rotary_dim/2)`. + pub fn mrope_cos_sin(&self, position_ids: &Tensor) -> candle_core::Result<(Tensor, Tensor)> { + let pos = position_ids.to_dtype(DType::F32)?; + let (axes, seq_len) = pos.dims2()?; + debug_assert_eq!(axes, 3, "mrope position_ids must have 3 axes"); + // Per-axis freqs: pos[a] (seq,1) @ inv_freq (1,half) → (seq,half). + let ft = pos.i(0)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?; + let fh = pos.i(1)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?; + let fw = pos.i(2)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?; + // Blend: each column belongs to exactly one axis (masks partition + // the half-dim), so this picks the right axis per frequency slot. + let blended = ft + .broadcast_mul(&self.mask_t)? + .add(&fh.broadcast_mul(&self.mask_h)?)? + .add(&fw.broadcast_mul(&self.mask_w)?)?; + let cos = blended.cos()?.to_dtype(self.dtype)?; + let sin = blended.sin()?.to_dtype(self.dtype)?; + Ok((cos, sin)) + } + + /// Apply rotary to `q`, `k` (shape `(B, H, L, head_dim)`) using + /// precomputed `cos`/`sin` of shape `(L, rotary_dim/2)`. Partial + /// rotary: only the first `rotary_dim` dims rotate; the tail passes + /// through unchanged. + pub fn apply_cos_sin( &self, q: &Tensor, k: &Tensor, - offset: usize, + cos: &Tensor, + sin: &Tensor, ) -> candle_core::Result<(Tensor, Tensor)> { - let (_, _, seq_len, head_dim_in) = q.dims4()?; + let (_, _, _seq_len, head_dim_in) = q.dims4()?; debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch"); - let cos = self.cos.narrow(0, offset, seq_len)?; - let sin = self.sin.narrow(0, offset, seq_len)?; if self.rotary_dim == self.head_dim { - // Full rotation. - let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?; - let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?; + let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, cos, sin)?; + let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, cos, sin)?; Ok((q_embed, k_embed)) } else { // Partial rotation: narrow → rotate → cat the untouched tail. @@ -102,8 +212,8 @@ impl RotaryEmbedding { .narrow(candle_core::D::Minus1, 0, self.rotary_dim)? .contiguous()?; let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?; - let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, &cos, &sin)?; - let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, &cos, &sin)?; + let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, cos, sin)?; + let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, cos, sin)?; let q_embed = Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?; let k_embed = @@ -111,4 +221,120 @@ impl RotaryEmbedding { Ok((q_embed, k_embed)) } } + + /// Text/decode convenience: build plain cos/sin for a scalar offset + /// and apply in one call. The current call sites use this; Stages 3–4 + /// move cos/sin construction up into the model forward (computed once + /// per forward) and call [`Self::apply_cos_sin`] directly. + pub fn apply( + &self, + q: &Tensor, + k: &Tensor, + offset: usize, + ) -> candle_core::Result<(Tensor, Tensor)> { + let (_, _, seq_len, _) = q.dims4()?; + let (cos, sin) = self.plain_cos_sin(offset, seq_len)?; + self.apply_cos_sin(q, k, &cos, &sin) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use candle_core::IndexOp; + + /// A TextConfig stub with Qwen3.6's rope params (head_dim 256, + /// partial 0.25 → rotary_dim 64 → half 32; section [11,11,10]). + fn qwen36_cfg() -> TextConfig { + serde_json::from_value(serde_json::json!({ + "hidden_size": 5120, + "num_hidden_layers": 1, + "num_attention_heads": 64, + "num_key_value_heads": 8, + "head_dim": 256, + "intermediate_size": 1, + "vocab_size": 10, + "rms_norm_eps": 1e-6, + "max_position_embeddings": 64, + "layer_types": ["full_attention"], + "rope_parameters": { + "rope_theta": 10000000.0, + "partial_rotary_factor": 0.25, + "mrope_section": [11, 11, 10], + "mrope_interleaved": true + } + })) + .expect("cfg") + } + + #[test] + fn mrope_masks_partition_the_half_dim() { + let (mt, mh, mw) = mrope_masks(32, &[11, 11, 10], true); + // Each column belongs to exactly one axis. + for i in 0..32 { + let s = mt[i] + mh[i] + mw[i]; + assert_eq!(s, 1.0, "column {i} covered {s} times"); + } + assert_eq!(mt.iter().sum::(), 11.0); + assert_eq!(mh.iter().sum::(), 11.0); + assert_eq!(mw.iter().sum::(), 10.0); + // Interleave: temporal 0,3,…; height 1,4,…; width 2,5,… + assert_eq!(mt[0], 1.0); + assert_eq!(mh[1], 1.0); + assert_eq!(mw[2], 1.0); + assert_eq!(mt[3], 1.0); + } + + /// The load-bearing invariant: when all three position axes are + /// equal (text), `mrope_cos_sin` must reproduce `plain_cos_sin` + /// bit-for-bit — i.e. M-RoPE is a no-op for text, so text inference + /// is unchanged. + #[test] + fn mrope_reduces_to_plain_for_equal_axes() { + let dev = Device::Cpu; + let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap(); + + // positions 5,6,7 on all three axes. + let base: Vec = vec![5, 6, 7]; + let pos = + Tensor::from_vec([base.clone(), base.clone(), base].concat(), (3, 3), &dev).unwrap(); + + let (mc, ms) = rope.mrope_cos_sin(&pos).unwrap(); + let (pc, ps) = rope.plain_cos_sin(5, 3).unwrap(); + + let dcos = (mc - pc).unwrap().abs().unwrap().max_all().unwrap(); + let dsin = (ms - ps).unwrap().abs().unwrap().max_all().unwrap(); + assert!( + dcos.to_scalar::().unwrap() < 1e-6, + "cos mismatch {dcos:?}" + ); + assert!( + dsin.to_scalar::().unwrap() < 1e-6, + "sin mismatch {dsin:?}" + ); + } + + /// Hand-checked interleave: a width-axis column (index 2) must track + /// the WIDTH position, while a temporal column (index 0) tracks the + /// TEXT position, even when the axes differ. + #[test] + fn mrope_blends_axes_at_interleave_columns() { + let dev = Device::Cpu; + let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap(); + let half = rope.inv_freq.dim(1).unwrap(); + let inv: Vec = rope.inv_freq.i(0).unwrap().to_vec1().unwrap(); + + // One token: text=10, height=3, width=7 — all distinct. + let pos = Tensor::from_vec(vec![10i64, 3, 7], (3, 1), &dev).unwrap(); + let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap(); + let cos_row: Vec = cos.i(0).unwrap().to_vec1().unwrap(); + assert_eq!(cos_row.len(), half); + + // Column 0 (temporal) → text pos 10. Column 1 (height) → 3. + // Column 2 (width) → 7. + assert!((cos_row[0] - (10.0 * inv[0]).cos()).abs() < 1e-5); + assert!((cos_row[1] - (3.0 * inv[1]).cos()).abs() < 1e-5); + assert!((cos_row[2] - (7.0 * inv[2]).cos()).abs() < 1e-5); + assert!((cos_row[3] - (10.0 * inv[3]).cos()).abs() < 1e-5); + } }