feat(neuron): M-RoPE Stage 1 — interleaved rope machinery + config
Parse + store mrope_section / mrope_interleaved in RopeParameters (previously accepted-but-ignored). RotaryEmbedding gains: - inv_freq + per-axis column masks (mask_t/h/w) built from mrope_section; - plain_cos_sin(pos, seq_len): narrow the precomputed tables (text/decode); - mrope_cos_sin(position_ids (3,seq)): per-axis freqs blended at the interleave columns (vision); - apply_cos_sin(q,k,cos,sin): the rope_slow application, factored out. The existing apply(q,k,offset) is retained (delegates to plain_cos_sin + apply_cos_sin) so current callers are unchanged; Stages 3–4 move cos/sin construction into the model forward and thread the 3D position ids for image tokens. Tests: masks partition the half-dim; interleave drives the right axis per column; and the load-bearing invariant — mrope_cos_sin reduces bit-for-bit to plain_cos_sin when the three axes are equal (so text inference is unchanged). Refs the MRoPE-gap diagnosis (vision spatial misread). Pure non-cuda; no behaviour change until wired. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -737,6 +737,8 @@ mod tests {
|
||||
rope_theta: 10000.0,
|
||||
partial_rotary_factor: 1.0,
|
||||
rope_type: None,
|
||||
mrope_section: Vec::new(),
|
||||
mrope_interleaved: false,
|
||||
},
|
||||
rms_norm_eps: 1e-6,
|
||||
tie_word_embeddings: false,
|
||||
|
||||
@@ -191,11 +191,12 @@ fn default_hidden_act() -> String {
|
||||
}
|
||||
|
||||
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
|
||||
/// `mrope_section` and `mrope_interleaved` are accepted via the
|
||||
/// `#[serde(default)]` flatten-tolerance below but ignored — we treat
|
||||
/// MRoPE as plain RoPE for text-only inference (the three position
|
||||
/// grids carry identical ids when there's no vision input, so the
|
||||
/// interleaving is a no-op).
|
||||
///
|
||||
/// For text-only inference the three MRoPE position grids carry
|
||||
/// identical ids, so the interleave is a no-op and plain RoPE applies.
|
||||
/// For vision inputs `mrope_section` + `mrope_interleaved` drive the
|
||||
/// per-axis (text/height/width) rotary used by image tokens — see
|
||||
/// `rope.rs`.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RopeParameters {
|
||||
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
|
||||
@@ -211,6 +212,16 @@ pub struct RopeParameters {
|
||||
/// implemented here.
|
||||
#[serde(default)]
|
||||
pub rope_type: Option<String>,
|
||||
/// MRoPE per-axis section sizes `[text, height, width]` — e.g.
|
||||
/// `[11, 11, 10]` for Qwen3.6, summing to the rotary half-dim.
|
||||
/// Empty for models that don't declare MRoPE (→ plain RoPE).
|
||||
#[serde(default)]
|
||||
pub mrope_section: Vec<usize>,
|
||||
/// Whether the three MRoPE axes are interleaved per-frequency
|
||||
/// (Qwen3-VL / Qwen3.6 style, `true`) rather than block-concatenated
|
||||
/// (Qwen2-VL style, `false`).
|
||||
#[serde(default)]
|
||||
pub mrope_interleaved: bool,
|
||||
}
|
||||
|
||||
fn default_rope_theta() -> f64 {
|
||||
|
||||
@@ -1,19 +1,27 @@
|
||||
//! Rotary position embedding for Qwen3-Next's full-attention layers.
|
||||
//!
|
||||
//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the
|
||||
//! reference Python — three position grids interleaved per
|
||||
//! `mrope_section`. For text-only inference all three grids carry the
|
||||
//! same position ids and the interleave is a no-op, so this module
|
||||
//! implements the plain (non-mrope) flavour: the standard inv_freq
|
||||
//! cosine/sine tables driven by `rope_theta` and `head_dim`.
|
||||
//! Qwen3.6 declares **interleaved M-RoPE** (multimodal RoPE): the
|
||||
//! rotary half-dimension is split across three position axes —
|
||||
//! `[text, height, width]` per `mrope_section` (`[11,11,10]` for
|
||||
//! Qwen3.6) — interleaved per-frequency. For **text** every token's
|
||||
//! three axes carry the same position id, so the interleave is a no-op
|
||||
//! and this reduces exactly to plain RoPE. For **image** tokens the
|
||||
//! height/width axes carry the patch's 2D grid coordinates, which is
|
||||
//! how the model reads the 14×14 patch layout (without it, all patches
|
||||
//! share a height position and the image reads as vertical repetition).
|
||||
//!
|
||||
//! Rotation flavour: **GLM-style** rotate-half (the second half of the
|
||||
//! head dim is negated and swapped into the first). The reference
|
||||
//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's
|
||||
//! `rope_slow` is the matching helper.
|
||||
//! Two cos/sin builders feed a shared [`RotaryEmbedding::apply`]:
|
||||
//! - [`RotaryEmbedding::plain_cos_sin`] narrows the precomputed tables
|
||||
//! at a scalar position — the text / decode fast path.
|
||||
//! - [`RotaryEmbedding::mrope_cos_sin`] builds per-token cos/sin from a
|
||||
//! `(3, seq)` position-id tensor, blending the three axes' frequencies
|
||||
//! at the interleave index sets — the vision-prefill path.
|
||||
//!
|
||||
//! Rotation flavour: **GLM-style** rotate-half (candle's `rope_slow`),
|
||||
//! matching the reference Python's `apply_rotary_pos_emb` + `rotate_half`.
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use candle_core::{DType, Device, IndexOp, Tensor};
|
||||
|
||||
use super::TextConfig;
|
||||
|
||||
@@ -21,6 +29,18 @@ use super::TextConfig;
|
||||
pub struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
/// Inverse frequencies, shape `(1, rotary_dim/2)`. Retained (beyond
|
||||
/// the precomputed `sin`/`cos` tables) so [`Self::mrope_cos_sin`] can
|
||||
/// build cos/sin from arbitrary per-axis position ids.
|
||||
inv_freq: Tensor,
|
||||
/// Per-axis column masks over the rotary half-dim, shape `(1, half)`,
|
||||
/// f32 0/1. `mask_t + mask_h + mask_w` partitions the columns; a
|
||||
/// column belongs to exactly one axis. For a non-MRoPE config
|
||||
/// `mask_t` is all-ones and the others all-zero (→ plain RoPE).
|
||||
mask_t: Tensor,
|
||||
mask_h: Tensor,
|
||||
mask_w: Tensor,
|
||||
dtype: DType,
|
||||
/// Number of dims at the head's leading edge that the rotation
|
||||
/// covers. The remaining `head_dim - rotary_dim` dims pass through
|
||||
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
|
||||
@@ -29,6 +49,52 @@ pub struct RotaryEmbedding {
|
||||
head_dim: usize,
|
||||
}
|
||||
|
||||
/// Build the per-axis 0/1 column masks over the rotary half-dim from
|
||||
/// `mrope_section`. Returns `(temporal, height, width)` each length
|
||||
/// `half`. Temporal is the complement of height ∪ width, so the three
|
||||
/// masks always partition `0..half` and reduce to all-temporal (plain
|
||||
/// RoPE) when no usable section is given.
|
||||
fn mrope_masks(
|
||||
half: usize,
|
||||
section: &[usize],
|
||||
interleaved: bool,
|
||||
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||
let mut mh = vec![0f32; half];
|
||||
let mut mw = vec![0f32; half];
|
||||
if section.len() == 3 {
|
||||
if interleaved {
|
||||
// Qwen3-VL: height at columns 1,4,7,… ; width at 2,5,8,… ;
|
||||
// temporal keeps 0,3,6,… — each `take`n from `mrope_section`.
|
||||
for i in (1..half).step_by(3).take(section[1]) {
|
||||
mh[i] = 1.0;
|
||||
}
|
||||
for i in (2..half).step_by(3).take(section[2]) {
|
||||
mw[i] = 1.0;
|
||||
}
|
||||
} else {
|
||||
// Qwen2-VL: contiguous blocks [text | height | width].
|
||||
let h_start = section[0].min(half);
|
||||
let h_end = (section[0] + section[1]).min(half);
|
||||
for m in mh.iter_mut().take(h_end).skip(h_start) {
|
||||
*m = 1.0;
|
||||
}
|
||||
for m in mw.iter_mut().take(half).skip(h_end) {
|
||||
*m = 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
let mt: Vec<f32> = (0..half)
|
||||
.map(|i| {
|
||||
if mh[i] == 0.0 && mw[i] == 0.0 {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
(mt, mh, mw)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
|
||||
let head_dim = cfg.head_dim;
|
||||
@@ -52,44 +118,88 @@ impl RotaryEmbedding {
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
|
||||
.collect();
|
||||
let n = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?;
|
||||
let half = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, half), dev)?.to_dtype(DType::F32)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
|
||||
// MRoPE axis masks. `sum(mrope_section)` should equal `half`;
|
||||
// warn-tolerant: any shortfall just stays on the temporal axis.
|
||||
let (mt, mh, mw) = mrope_masks(half, &rope.mrope_section, rope.mrope_interleaved);
|
||||
let mask_t = Tensor::from_vec(mt, (1, half), dev)?;
|
||||
let mask_h = Tensor::from_vec(mh, (1, half), dev)?;
|
||||
let mask_w = Tensor::from_vec(mw, (1, half), dev)?;
|
||||
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||
inv_freq,
|
||||
mask_t,
|
||||
mask_h,
|
||||
mask_w,
|
||||
dtype,
|
||||
rotary_dim,
|
||||
head_dim,
|
||||
})
|
||||
}
|
||||
|
||||
/// Apply RoPE to q, k.
|
||||
///
|
||||
/// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index
|
||||
/// into the cached cos/sin table — the position of the first token
|
||||
/// in the current step.
|
||||
///
|
||||
/// When `rotary_dim < head_dim` the rotation is applied only to the
|
||||
/// first `rotary_dim` dims of each head; the tail passes through
|
||||
/// unchanged (matches the reference Python's
|
||||
/// `apply_rotary_pos_emb` with non-trivial `partial_rotary_factor`).
|
||||
pub fn apply(
|
||||
/// cos/sin for a contiguous run of `seq_len` positions starting at
|
||||
/// `pos`, by narrowing the precomputed tables. The text / decode
|
||||
/// path (all three MRoPE axes equal → plain RoPE). Shape
|
||||
/// `(seq_len, rotary_dim/2)`.
|
||||
pub fn plain_cos_sin(
|
||||
&self,
|
||||
pos: usize,
|
||||
seq_len: usize,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let cos = self.cos.narrow(0, pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, pos, seq_len)?;
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
/// cos/sin from explicit per-token 3D position ids, shape
|
||||
/// `(3, seq_len)` (axes: text, height, width). Builds each axis's
|
||||
/// frequencies and blends them at the interleave index sets, so
|
||||
/// every rotary frequency slot is driven by exactly one axis.
|
||||
/// Reduces exactly to [`Self::plain_cos_sin`] when the three axes are
|
||||
/// equal. Returns cos/sin of shape `(seq_len, rotary_dim/2)`.
|
||||
pub fn mrope_cos_sin(&self, position_ids: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let pos = position_ids.to_dtype(DType::F32)?;
|
||||
let (axes, seq_len) = pos.dims2()?;
|
||||
debug_assert_eq!(axes, 3, "mrope position_ids must have 3 axes");
|
||||
// Per-axis freqs: pos[a] (seq,1) @ inv_freq (1,half) → (seq,half).
|
||||
let ft = pos.i(0)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
|
||||
let fh = pos.i(1)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
|
||||
let fw = pos.i(2)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
|
||||
// Blend: each column belongs to exactly one axis (masks partition
|
||||
// the half-dim), so this picks the right axis per frequency slot.
|
||||
let blended = ft
|
||||
.broadcast_mul(&self.mask_t)?
|
||||
.add(&fh.broadcast_mul(&self.mask_h)?)?
|
||||
.add(&fw.broadcast_mul(&self.mask_w)?)?;
|
||||
let cos = blended.cos()?.to_dtype(self.dtype)?;
|
||||
let sin = blended.sin()?.to_dtype(self.dtype)?;
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
/// Apply rotary to `q`, `k` (shape `(B, H, L, head_dim)`) using
|
||||
/// precomputed `cos`/`sin` of shape `(L, rotary_dim/2)`. Partial
|
||||
/// rotary: only the first `rotary_dim` dims rotate; the tail passes
|
||||
/// through unchanged.
|
||||
pub fn apply_cos_sin(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
offset: usize,
|
||||
cos: &Tensor,
|
||||
sin: &Tensor,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let (_, _, seq_len, head_dim_in) = q.dims4()?;
|
||||
let (_, _, _seq_len, head_dim_in) = q.dims4()?;
|
||||
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
|
||||
let cos = self.cos.narrow(0, offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, offset, seq_len)?;
|
||||
if self.rotary_dim == self.head_dim {
|
||||
// Full rotation.
|
||||
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, cos, sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, cos, sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
} else {
|
||||
// Partial rotation: narrow → rotate → cat the untouched tail.
|
||||
@@ -102,8 +212,8 @@ impl RotaryEmbedding {
|
||||
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
|
||||
.contiguous()?;
|
||||
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
|
||||
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, &cos, &sin)?;
|
||||
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, &cos, &sin)?;
|
||||
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, cos, sin)?;
|
||||
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, cos, sin)?;
|
||||
let q_embed =
|
||||
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
|
||||
let k_embed =
|
||||
@@ -111,4 +221,120 @@ impl RotaryEmbedding {
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
/// Text/decode convenience: build plain cos/sin for a scalar offset
|
||||
/// and apply in one call. The current call sites use this; Stages 3–4
|
||||
/// move cos/sin construction up into the model forward (computed once
|
||||
/// per forward) and call [`Self::apply_cos_sin`] directly.
|
||||
pub fn apply(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
offset: usize,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let (_, _, seq_len, _) = q.dims4()?;
|
||||
let (cos, sin) = self.plain_cos_sin(offset, seq_len)?;
|
||||
self.apply_cos_sin(q, k, &cos, &sin)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle_core::IndexOp;
|
||||
|
||||
/// A TextConfig stub with Qwen3.6's rope params (head_dim 256,
|
||||
/// partial 0.25 → rotary_dim 64 → half 32; section [11,11,10]).
|
||||
fn qwen36_cfg() -> TextConfig {
|
||||
serde_json::from_value(serde_json::json!({
|
||||
"hidden_size": 5120,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 64,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 256,
|
||||
"intermediate_size": 1,
|
||||
"vocab_size": 10,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 64,
|
||||
"layer_types": ["full_attention"],
|
||||
"rope_parameters": {
|
||||
"rope_theta": 10000000.0,
|
||||
"partial_rotary_factor": 0.25,
|
||||
"mrope_section": [11, 11, 10],
|
||||
"mrope_interleaved": true
|
||||
}
|
||||
}))
|
||||
.expect("cfg")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mrope_masks_partition_the_half_dim() {
|
||||
let (mt, mh, mw) = mrope_masks(32, &[11, 11, 10], true);
|
||||
// Each column belongs to exactly one axis.
|
||||
for i in 0..32 {
|
||||
let s = mt[i] + mh[i] + mw[i];
|
||||
assert_eq!(s, 1.0, "column {i} covered {s} times");
|
||||
}
|
||||
assert_eq!(mt.iter().sum::<f32>(), 11.0);
|
||||
assert_eq!(mh.iter().sum::<f32>(), 11.0);
|
||||
assert_eq!(mw.iter().sum::<f32>(), 10.0);
|
||||
// Interleave: temporal 0,3,…; height 1,4,…; width 2,5,…
|
||||
assert_eq!(mt[0], 1.0);
|
||||
assert_eq!(mh[1], 1.0);
|
||||
assert_eq!(mw[2], 1.0);
|
||||
assert_eq!(mt[3], 1.0);
|
||||
}
|
||||
|
||||
/// The load-bearing invariant: when all three position axes are
|
||||
/// equal (text), `mrope_cos_sin` must reproduce `plain_cos_sin`
|
||||
/// bit-for-bit — i.e. M-RoPE is a no-op for text, so text inference
|
||||
/// is unchanged.
|
||||
#[test]
|
||||
fn mrope_reduces_to_plain_for_equal_axes() {
|
||||
let dev = Device::Cpu;
|
||||
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
|
||||
|
||||
// positions 5,6,7 on all three axes.
|
||||
let base: Vec<i64> = vec![5, 6, 7];
|
||||
let pos =
|
||||
Tensor::from_vec([base.clone(), base.clone(), base].concat(), (3, 3), &dev).unwrap();
|
||||
|
||||
let (mc, ms) = rope.mrope_cos_sin(&pos).unwrap();
|
||||
let (pc, ps) = rope.plain_cos_sin(5, 3).unwrap();
|
||||
|
||||
let dcos = (mc - pc).unwrap().abs().unwrap().max_all().unwrap();
|
||||
let dsin = (ms - ps).unwrap().abs().unwrap().max_all().unwrap();
|
||||
assert!(
|
||||
dcos.to_scalar::<f32>().unwrap() < 1e-6,
|
||||
"cos mismatch {dcos:?}"
|
||||
);
|
||||
assert!(
|
||||
dsin.to_scalar::<f32>().unwrap() < 1e-6,
|
||||
"sin mismatch {dsin:?}"
|
||||
);
|
||||
}
|
||||
|
||||
/// Hand-checked interleave: a width-axis column (index 2) must track
|
||||
/// the WIDTH position, while a temporal column (index 0) tracks the
|
||||
/// TEXT position, even when the axes differ.
|
||||
#[test]
|
||||
fn mrope_blends_axes_at_interleave_columns() {
|
||||
let dev = Device::Cpu;
|
||||
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
|
||||
let half = rope.inv_freq.dim(1).unwrap();
|
||||
let inv: Vec<f32> = rope.inv_freq.i(0).unwrap().to_vec1().unwrap();
|
||||
|
||||
// One token: text=10, height=3, width=7 — all distinct.
|
||||
let pos = Tensor::from_vec(vec![10i64, 3, 7], (3, 1), &dev).unwrap();
|
||||
let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap();
|
||||
let cos_row: Vec<f32> = cos.i(0).unwrap().to_vec1().unwrap();
|
||||
assert_eq!(cos_row.len(), half);
|
||||
|
||||
// Column 0 (temporal) → text pos 10. Column 1 (height) → 3.
|
||||
// Column 2 (width) → 7.
|
||||
assert!((cos_row[0] - (10.0 * inv[0]).cos()).abs() < 1e-5);
|
||||
assert!((cos_row[1] - (3.0 * inv[1]).cos()).abs() < 1e-5);
|
||||
assert!((cos_row[2] - (7.0 * inv[2]).cos()).abs() < 1e-5);
|
||||
assert!((cos_row[3] - (10.0 * inv[3]).cos()).abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user