feat(neuron): M-RoPE Stage 1 — interleaved rope machinery + config

Parse + store mrope_section / mrope_interleaved in RopeParameters
(previously accepted-but-ignored). RotaryEmbedding gains:
- inv_freq + per-axis column masks (mask_t/h/w) built from mrope_section;
- plain_cos_sin(pos, seq_len): narrow the precomputed tables (text/decode);
- mrope_cos_sin(position_ids (3,seq)): per-axis freqs blended at the
  interleave columns (vision);
- apply_cos_sin(q,k,cos,sin): the rope_slow application, factored out.

The existing apply(q,k,offset) is retained (delegates to
plain_cos_sin + apply_cos_sin) so current callers are unchanged; Stages
3–4 move cos/sin construction into the model forward and thread the 3D
position ids for image tokens.

Tests: masks partition the half-dim; interleave drives the right axis
per column; and the load-bearing invariant — mrope_cos_sin reduces
bit-for-bit to plain_cos_sin when the three axes are equal (so text
inference is unchanged).

Refs the MRoPE-gap diagnosis (vision spatial misread). Pure non-cuda;
no behaviour change until wired.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-04 18:31:15 +03:00
parent fa013505d1
commit 5731f4c318
3 changed files with 277 additions and 38 deletions

View File

@@ -737,6 +737,8 @@ mod tests {
rope_theta: 10000.0, rope_theta: 10000.0,
partial_rotary_factor: 1.0, partial_rotary_factor: 1.0,
rope_type: None, rope_type: None,
mrope_section: Vec::new(),
mrope_interleaved: false,
}, },
rms_norm_eps: 1e-6, rms_norm_eps: 1e-6,
tie_word_embeddings: false, tie_word_embeddings: false,

View File

@@ -191,11 +191,12 @@ fn default_hidden_act() -> String {
} }
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`. /// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
/// `mrope_section` and `mrope_interleaved` are accepted via the ///
/// `#[serde(default)]` flatten-tolerance below but ignored — we treat /// For text-only inference the three MRoPE position grids carry
/// MRoPE as plain RoPE for text-only inference (the three position /// identical ids, so the interleave is a no-op and plain RoPE applies.
/// grids carry identical ids when there's no vision input, so the /// For vision inputs `mrope_section` + `mrope_interleaved` drive the
/// interleaving is a no-op). /// per-axis (text/height/width) rotary used by image tokens — see
/// `rope.rs`.
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct RopeParameters { pub struct RopeParameters {
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000. /// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
@@ -211,6 +212,16 @@ pub struct RopeParameters {
/// implemented here. /// implemented here.
#[serde(default)] #[serde(default)]
pub rope_type: Option<String>, pub rope_type: Option<String>,
/// MRoPE per-axis section sizes `[text, height, width]` — e.g.
/// `[11, 11, 10]` for Qwen3.6, summing to the rotary half-dim.
/// Empty for models that don't declare MRoPE (→ plain RoPE).
#[serde(default)]
pub mrope_section: Vec<usize>,
/// Whether the three MRoPE axes are interleaved per-frequency
/// (Qwen3-VL / Qwen3.6 style, `true`) rather than block-concatenated
/// (Qwen2-VL style, `false`).
#[serde(default)]
pub mrope_interleaved: bool,
} }
fn default_rope_theta() -> f64 { fn default_rope_theta() -> f64 {

View File

@@ -1,19 +1,27 @@
//! Rotary position embedding for Qwen3-Next's full-attention layers. //! Rotary position embedding for Qwen3-Next's full-attention layers.
//! //!
//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the //! Qwen3.6 declares **interleaved M-RoPE** (multimodal RoPE): the
//! reference Python — three position grids interleaved per //! rotary half-dimension is split across three position axes —
//! `mrope_section`. For text-only inference all three grids carry the //! `[text, height, width]` per `mrope_section` (`[11,11,10]` for
//! same position ids and the interleave is a no-op, so this module //! Qwen3.6) — interleaved per-frequency. For **text** every token's
//! implements the plain (non-mrope) flavour: the standard inv_freq //! three axes carry the same position id, so the interleave is a no-op
//! cosine/sine tables driven by `rope_theta` and `head_dim`. //! and this reduces exactly to plain RoPE. For **image** tokens the
//! height/width axes carry the patch's 2D grid coordinates, which is
//! how the model reads the 14×14 patch layout (without it, all patches
//! share a height position and the image reads as vertical repetition).
//! //!
//! Rotation flavour: **GLM-style** rotate-half (the second half of the //! Two cos/sin builders feed a shared [`RotaryEmbedding::apply`]:
//! head dim is negated and swapped into the first). The reference //! - [`RotaryEmbedding::plain_cos_sin`] narrows the precomputed tables
//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's //! at a scalar position — the text / decode fast path.
//! `rope_slow` is the matching helper. //! - [`RotaryEmbedding::mrope_cos_sin`] builds per-token cos/sin from a
//! `(3, seq)` position-id tensor, blending the three axes' frequencies
//! at the interleave index sets — the vision-prefill path.
//!
//! Rotation flavour: **GLM-style** rotate-half (candle's `rope_slow`),
//! matching the reference Python's `apply_rotary_pos_emb` + `rotate_half`.
use anyhow::Result; use anyhow::Result;
use candle_core::{DType, Device, Tensor}; use candle_core::{DType, Device, IndexOp, Tensor};
use super::TextConfig; use super::TextConfig;
@@ -21,6 +29,18 @@ use super::TextConfig;
pub struct RotaryEmbedding { pub struct RotaryEmbedding {
sin: Tensor, sin: Tensor,
cos: Tensor, cos: Tensor,
/// Inverse frequencies, shape `(1, rotary_dim/2)`. Retained (beyond
/// the precomputed `sin`/`cos` tables) so [`Self::mrope_cos_sin`] can
/// build cos/sin from arbitrary per-axis position ids.
inv_freq: Tensor,
/// Per-axis column masks over the rotary half-dim, shape `(1, half)`,
/// f32 0/1. `mask_t + mask_h + mask_w` partitions the columns; a
/// column belongs to exactly one axis. For a non-MRoPE config
/// `mask_t` is all-ones and the others all-zero (→ plain RoPE).
mask_t: Tensor,
mask_h: Tensor,
mask_w: Tensor,
dtype: DType,
/// Number of dims at the head's leading edge that the rotation /// Number of dims at the head's leading edge that the rotation
/// covers. The remaining `head_dim - rotary_dim` dims pass through /// covers. The remaining `head_dim - rotary_dim` dims pass through
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so /// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
@@ -29,6 +49,52 @@ pub struct RotaryEmbedding {
head_dim: usize, head_dim: usize,
} }
/// Build the per-axis 0/1 column masks over the rotary half-dim from
/// `mrope_section`. Returns `(temporal, height, width)` each length
/// `half`. Temporal is the complement of height width, so the three
/// masks always partition `0..half` and reduce to all-temporal (plain
/// RoPE) when no usable section is given.
fn mrope_masks(
half: usize,
section: &[usize],
interleaved: bool,
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let mut mh = vec![0f32; half];
let mut mw = vec![0f32; half];
if section.len() == 3 {
if interleaved {
// Qwen3-VL: height at columns 1,4,7,… ; width at 2,5,8,… ;
// temporal keeps 0,3,6,… — each `take`n from `mrope_section`.
for i in (1..half).step_by(3).take(section[1]) {
mh[i] = 1.0;
}
for i in (2..half).step_by(3).take(section[2]) {
mw[i] = 1.0;
}
} else {
// Qwen2-VL: contiguous blocks [text | height | width].
let h_start = section[0].min(half);
let h_end = (section[0] + section[1]).min(half);
for m in mh.iter_mut().take(h_end).skip(h_start) {
*m = 1.0;
}
for m in mw.iter_mut().take(half).skip(h_end) {
*m = 1.0;
}
}
}
let mt: Vec<f32> = (0..half)
.map(|i| {
if mh[i] == 0.0 && mw[i] == 0.0 {
1.0
} else {
0.0
}
})
.collect();
(mt, mh, mw)
}
impl RotaryEmbedding { impl RotaryEmbedding {
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> { pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
let head_dim = cfg.head_dim; let head_dim = cfg.head_dim;
@@ -52,44 +118,88 @@ impl RotaryEmbedding {
.step_by(2) .step_by(2)
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32) .map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
.collect(); .collect();
let n = inv_freq.len(); let half = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?; let inv_freq = Tensor::from_vec(inv_freq, (1, half), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)? let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)? .to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?; .reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?; let freqs = t.matmul(&inv_freq)?;
// MRoPE axis masks. `sum(mrope_section)` should equal `half`;
// warn-tolerant: any shortfall just stays on the temporal axis.
let (mt, mh, mw) = mrope_masks(half, &rope.mrope_section, rope.mrope_interleaved);
let mask_t = Tensor::from_vec(mt, (1, half), dev)?;
let mask_h = Tensor::from_vec(mh, (1, half), dev)?;
let mask_w = Tensor::from_vec(mw, (1, half), dev)?;
Ok(Self { Ok(Self {
sin: freqs.sin()?.to_dtype(dtype)?, sin: freqs.sin()?.to_dtype(dtype)?,
cos: freqs.cos()?.to_dtype(dtype)?, cos: freqs.cos()?.to_dtype(dtype)?,
inv_freq,
mask_t,
mask_h,
mask_w,
dtype,
rotary_dim, rotary_dim,
head_dim, head_dim,
}) })
} }
/// Apply RoPE to q, k. /// cos/sin for a contiguous run of `seq_len` positions starting at
/// /// `pos`, by narrowing the precomputed tables. The text / decode
/// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index /// path (all three MRoPE axes equal → plain RoPE). Shape
/// into the cached cos/sin table — the position of the first token /// `(seq_len, rotary_dim/2)`.
/// in the current step. pub fn plain_cos_sin(
/// &self,
/// When `rotary_dim < head_dim` the rotation is applied only to the pos: usize,
/// first `rotary_dim` dims of each head; the tail passes through seq_len: usize,
/// unchanged (matches the reference Python's ) -> candle_core::Result<(Tensor, Tensor)> {
/// `apply_rotary_pos_emb` with non-trivial `partial_rotary_factor`). let cos = self.cos.narrow(0, pos, seq_len)?;
pub fn apply( let sin = self.sin.narrow(0, pos, seq_len)?;
Ok((cos, sin))
}
/// cos/sin from explicit per-token 3D position ids, shape
/// `(3, seq_len)` (axes: text, height, width). Builds each axis's
/// frequencies and blends them at the interleave index sets, so
/// every rotary frequency slot is driven by exactly one axis.
/// Reduces exactly to [`Self::plain_cos_sin`] when the three axes are
/// equal. Returns cos/sin of shape `(seq_len, rotary_dim/2)`.
pub fn mrope_cos_sin(&self, position_ids: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
let pos = position_ids.to_dtype(DType::F32)?;
let (axes, seq_len) = pos.dims2()?;
debug_assert_eq!(axes, 3, "mrope position_ids must have 3 axes");
// Per-axis freqs: pos[a] (seq,1) @ inv_freq (1,half) → (seq,half).
let ft = pos.i(0)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
let fh = pos.i(1)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
let fw = pos.i(2)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
// Blend: each column belongs to exactly one axis (masks partition
// the half-dim), so this picks the right axis per frequency slot.
let blended = ft
.broadcast_mul(&self.mask_t)?
.add(&fh.broadcast_mul(&self.mask_h)?)?
.add(&fw.broadcast_mul(&self.mask_w)?)?;
let cos = blended.cos()?.to_dtype(self.dtype)?;
let sin = blended.sin()?.to_dtype(self.dtype)?;
Ok((cos, sin))
}
/// Apply rotary to `q`, `k` (shape `(B, H, L, head_dim)`) using
/// precomputed `cos`/`sin` of shape `(L, rotary_dim/2)`. Partial
/// rotary: only the first `rotary_dim` dims rotate; the tail passes
/// through unchanged.
pub fn apply_cos_sin(
&self, &self,
q: &Tensor, q: &Tensor,
k: &Tensor, k: &Tensor,
offset: usize, cos: &Tensor,
sin: &Tensor,
) -> candle_core::Result<(Tensor, Tensor)> { ) -> candle_core::Result<(Tensor, Tensor)> {
let (_, _, seq_len, head_dim_in) = q.dims4()?; let (_, _, _seq_len, head_dim_in) = q.dims4()?;
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch"); debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
let cos = self.cos.narrow(0, offset, seq_len)?;
let sin = self.sin.narrow(0, offset, seq_len)?;
if self.rotary_dim == self.head_dim { if self.rotary_dim == self.head_dim {
// Full rotation. let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, cos, sin)?;
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?; let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, cos, sin)?;
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed)) Ok((q_embed, k_embed))
} else { } else {
// Partial rotation: narrow → rotate → cat the untouched tail. // Partial rotation: narrow → rotate → cat the untouched tail.
@@ -102,8 +212,8 @@ impl RotaryEmbedding {
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)? .narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
.contiguous()?; .contiguous()?;
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?; let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, &cos, &sin)?; let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, cos, sin)?;
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, &cos, &sin)?; let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, cos, sin)?;
let q_embed = let q_embed =
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?; Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
let k_embed = let k_embed =
@@ -111,4 +221,120 @@ impl RotaryEmbedding {
Ok((q_embed, k_embed)) Ok((q_embed, k_embed))
} }
} }
/// Text/decode convenience: build plain cos/sin for a scalar offset
/// and apply in one call. The current call sites use this; Stages 34
/// move cos/sin construction up into the model forward (computed once
/// per forward) and call [`Self::apply_cos_sin`] directly.
pub fn apply(
&self,
q: &Tensor,
k: &Tensor,
offset: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let (_, _, seq_len, _) = q.dims4()?;
let (cos, sin) = self.plain_cos_sin(offset, seq_len)?;
self.apply_cos_sin(q, k, &cos, &sin)
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::IndexOp;
/// A TextConfig stub with Qwen3.6's rope params (head_dim 256,
/// partial 0.25 → rotary_dim 64 → half 32; section [11,11,10]).
fn qwen36_cfg() -> TextConfig {
serde_json::from_value(serde_json::json!({
"hidden_size": 5120,
"num_hidden_layers": 1,
"num_attention_heads": 64,
"num_key_value_heads": 8,
"head_dim": 256,
"intermediate_size": 1,
"vocab_size": 10,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 64,
"layer_types": ["full_attention"],
"rope_parameters": {
"rope_theta": 10000000.0,
"partial_rotary_factor": 0.25,
"mrope_section": [11, 11, 10],
"mrope_interleaved": true
}
}))
.expect("cfg")
}
#[test]
fn mrope_masks_partition_the_half_dim() {
let (mt, mh, mw) = mrope_masks(32, &[11, 11, 10], true);
// Each column belongs to exactly one axis.
for i in 0..32 {
let s = mt[i] + mh[i] + mw[i];
assert_eq!(s, 1.0, "column {i} covered {s} times");
}
assert_eq!(mt.iter().sum::<f32>(), 11.0);
assert_eq!(mh.iter().sum::<f32>(), 11.0);
assert_eq!(mw.iter().sum::<f32>(), 10.0);
// Interleave: temporal 0,3,…; height 1,4,…; width 2,5,…
assert_eq!(mt[0], 1.0);
assert_eq!(mh[1], 1.0);
assert_eq!(mw[2], 1.0);
assert_eq!(mt[3], 1.0);
}
/// The load-bearing invariant: when all three position axes are
/// equal (text), `mrope_cos_sin` must reproduce `plain_cos_sin`
/// bit-for-bit — i.e. M-RoPE is a no-op for text, so text inference
/// is unchanged.
#[test]
fn mrope_reduces_to_plain_for_equal_axes() {
let dev = Device::Cpu;
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
// positions 5,6,7 on all three axes.
let base: Vec<i64> = vec![5, 6, 7];
let pos =
Tensor::from_vec([base.clone(), base.clone(), base].concat(), (3, 3), &dev).unwrap();
let (mc, ms) = rope.mrope_cos_sin(&pos).unwrap();
let (pc, ps) = rope.plain_cos_sin(5, 3).unwrap();
let dcos = (mc - pc).unwrap().abs().unwrap().max_all().unwrap();
let dsin = (ms - ps).unwrap().abs().unwrap().max_all().unwrap();
assert!(
dcos.to_scalar::<f32>().unwrap() < 1e-6,
"cos mismatch {dcos:?}"
);
assert!(
dsin.to_scalar::<f32>().unwrap() < 1e-6,
"sin mismatch {dsin:?}"
);
}
/// Hand-checked interleave: a width-axis column (index 2) must track
/// the WIDTH position, while a temporal column (index 0) tracks the
/// TEXT position, even when the axes differ.
#[test]
fn mrope_blends_axes_at_interleave_columns() {
let dev = Device::Cpu;
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
let half = rope.inv_freq.dim(1).unwrap();
let inv: Vec<f32> = rope.inv_freq.i(0).unwrap().to_vec1().unwrap();
// One token: text=10, height=3, width=7 — all distinct.
let pos = Tensor::from_vec(vec![10i64, 3, 7], (3, 1), &dev).unwrap();
let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap();
let cos_row: Vec<f32> = cos.i(0).unwrap().to_vec1().unwrap();
assert_eq!(cos_row.len(), half);
// Column 0 (temporal) → text pos 10. Column 1 (height) → 3.
// Column 2 (width) → 7.
assert!((cos_row[0] - (10.0 * inv[0]).cos()).abs() < 1e-5);
assert!((cos_row[1] - (3.0 * inv[1]).cos()).abs() < 1e-5);
assert!((cos_row[2] - (7.0 * inv[2]).cos()).abs() < 1e-5);
assert!((cos_row[3] - (10.0 * inv[3]).cos()).abs() < 1e-5);
}
} }