feat(neuron): M-RoPE Stage 1 — interleaved rope machinery + config
Parse + store mrope_section / mrope_interleaved in RopeParameters (previously accepted-but-ignored). RotaryEmbedding gains: - inv_freq + per-axis column masks (mask_t/h/w) built from mrope_section; - plain_cos_sin(pos, seq_len): narrow the precomputed tables (text/decode); - mrope_cos_sin(position_ids (3,seq)): per-axis freqs blended at the interleave columns (vision); - apply_cos_sin(q,k,cos,sin): the rope_slow application, factored out. The existing apply(q,k,offset) is retained (delegates to plain_cos_sin + apply_cos_sin) so current callers are unchanged; Stages 3–4 move cos/sin construction into the model forward and thread the 3D position ids for image tokens. Tests: masks partition the half-dim; interleave drives the right axis per column; and the load-bearing invariant — mrope_cos_sin reduces bit-for-bit to plain_cos_sin when the three axes are equal (so text inference is unchanged). Refs the MRoPE-gap diagnosis (vision spatial misread). Pure non-cuda; no behaviour change until wired. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -737,6 +737,8 @@ mod tests {
|
|||||||
rope_theta: 10000.0,
|
rope_theta: 10000.0,
|
||||||
partial_rotary_factor: 1.0,
|
partial_rotary_factor: 1.0,
|
||||||
rope_type: None,
|
rope_type: None,
|
||||||
|
mrope_section: Vec::new(),
|
||||||
|
mrope_interleaved: false,
|
||||||
},
|
},
|
||||||
rms_norm_eps: 1e-6,
|
rms_norm_eps: 1e-6,
|
||||||
tie_word_embeddings: false,
|
tie_word_embeddings: false,
|
||||||
|
|||||||
@@ -191,11 +191,12 @@ fn default_hidden_act() -> String {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
|
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
|
||||||
/// `mrope_section` and `mrope_interleaved` are accepted via the
|
///
|
||||||
/// `#[serde(default)]` flatten-tolerance below but ignored — we treat
|
/// For text-only inference the three MRoPE position grids carry
|
||||||
/// MRoPE as plain RoPE for text-only inference (the three position
|
/// identical ids, so the interleave is a no-op and plain RoPE applies.
|
||||||
/// grids carry identical ids when there's no vision input, so the
|
/// For vision inputs `mrope_section` + `mrope_interleaved` drive the
|
||||||
/// interleaving is a no-op).
|
/// per-axis (text/height/width) rotary used by image tokens — see
|
||||||
|
/// `rope.rs`.
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub struct RopeParameters {
|
pub struct RopeParameters {
|
||||||
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
|
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
|
||||||
@@ -211,6 +212,16 @@ pub struct RopeParameters {
|
|||||||
/// implemented here.
|
/// implemented here.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub rope_type: Option<String>,
|
pub rope_type: Option<String>,
|
||||||
|
/// MRoPE per-axis section sizes `[text, height, width]` — e.g.
|
||||||
|
/// `[11, 11, 10]` for Qwen3.6, summing to the rotary half-dim.
|
||||||
|
/// Empty for models that don't declare MRoPE (→ plain RoPE).
|
||||||
|
#[serde(default)]
|
||||||
|
pub mrope_section: Vec<usize>,
|
||||||
|
/// Whether the three MRoPE axes are interleaved per-frequency
|
||||||
|
/// (Qwen3-VL / Qwen3.6 style, `true`) rather than block-concatenated
|
||||||
|
/// (Qwen2-VL style, `false`).
|
||||||
|
#[serde(default)]
|
||||||
|
pub mrope_interleaved: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_rope_theta() -> f64 {
|
fn default_rope_theta() -> f64 {
|
||||||
|
|||||||
@@ -1,19 +1,27 @@
|
|||||||
//! Rotary position embedding for Qwen3-Next's full-attention layers.
|
//! Rotary position embedding for Qwen3-Next's full-attention layers.
|
||||||
//!
|
//!
|
||||||
//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the
|
//! Qwen3.6 declares **interleaved M-RoPE** (multimodal RoPE): the
|
||||||
//! reference Python — three position grids interleaved per
|
//! rotary half-dimension is split across three position axes —
|
||||||
//! `mrope_section`. For text-only inference all three grids carry the
|
//! `[text, height, width]` per `mrope_section` (`[11,11,10]` for
|
||||||
//! same position ids and the interleave is a no-op, so this module
|
//! Qwen3.6) — interleaved per-frequency. For **text** every token's
|
||||||
//! implements the plain (non-mrope) flavour: the standard inv_freq
|
//! three axes carry the same position id, so the interleave is a no-op
|
||||||
//! cosine/sine tables driven by `rope_theta` and `head_dim`.
|
//! and this reduces exactly to plain RoPE. For **image** tokens the
|
||||||
|
//! height/width axes carry the patch's 2D grid coordinates, which is
|
||||||
|
//! how the model reads the 14×14 patch layout (without it, all patches
|
||||||
|
//! share a height position and the image reads as vertical repetition).
|
||||||
//!
|
//!
|
||||||
//! Rotation flavour: **GLM-style** rotate-half (the second half of the
|
//! Two cos/sin builders feed a shared [`RotaryEmbedding::apply`]:
|
||||||
//! head dim is negated and swapped into the first). The reference
|
//! - [`RotaryEmbedding::plain_cos_sin`] narrows the precomputed tables
|
||||||
//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's
|
//! at a scalar position — the text / decode fast path.
|
||||||
//! `rope_slow` is the matching helper.
|
//! - [`RotaryEmbedding::mrope_cos_sin`] builds per-token cos/sin from a
|
||||||
|
//! `(3, seq)` position-id tensor, blending the three axes' frequencies
|
||||||
|
//! at the interleave index sets — the vision-prefill path.
|
||||||
|
//!
|
||||||
|
//! Rotation flavour: **GLM-style** rotate-half (candle's `rope_slow`),
|
||||||
|
//! matching the reference Python's `apply_rotary_pos_emb` + `rotate_half`.
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle_core::{DType, Device, Tensor};
|
use candle_core::{DType, Device, IndexOp, Tensor};
|
||||||
|
|
||||||
use super::TextConfig;
|
use super::TextConfig;
|
||||||
|
|
||||||
@@ -21,6 +29,18 @@ use super::TextConfig;
|
|||||||
pub struct RotaryEmbedding {
|
pub struct RotaryEmbedding {
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
|
/// Inverse frequencies, shape `(1, rotary_dim/2)`. Retained (beyond
|
||||||
|
/// the precomputed `sin`/`cos` tables) so [`Self::mrope_cos_sin`] can
|
||||||
|
/// build cos/sin from arbitrary per-axis position ids.
|
||||||
|
inv_freq: Tensor,
|
||||||
|
/// Per-axis column masks over the rotary half-dim, shape `(1, half)`,
|
||||||
|
/// f32 0/1. `mask_t + mask_h + mask_w` partitions the columns; a
|
||||||
|
/// column belongs to exactly one axis. For a non-MRoPE config
|
||||||
|
/// `mask_t` is all-ones and the others all-zero (→ plain RoPE).
|
||||||
|
mask_t: Tensor,
|
||||||
|
mask_h: Tensor,
|
||||||
|
mask_w: Tensor,
|
||||||
|
dtype: DType,
|
||||||
/// Number of dims at the head's leading edge that the rotation
|
/// Number of dims at the head's leading edge that the rotation
|
||||||
/// covers. The remaining `head_dim - rotary_dim` dims pass through
|
/// covers. The remaining `head_dim - rotary_dim` dims pass through
|
||||||
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
|
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
|
||||||
@@ -29,6 +49,52 @@ pub struct RotaryEmbedding {
|
|||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build the per-axis 0/1 column masks over the rotary half-dim from
|
||||||
|
/// `mrope_section`. Returns `(temporal, height, width)` each length
|
||||||
|
/// `half`. Temporal is the complement of height ∪ width, so the three
|
||||||
|
/// masks always partition `0..half` and reduce to all-temporal (plain
|
||||||
|
/// RoPE) when no usable section is given.
|
||||||
|
fn mrope_masks(
|
||||||
|
half: usize,
|
||||||
|
section: &[usize],
|
||||||
|
interleaved: bool,
|
||||||
|
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||||
|
let mut mh = vec![0f32; half];
|
||||||
|
let mut mw = vec![0f32; half];
|
||||||
|
if section.len() == 3 {
|
||||||
|
if interleaved {
|
||||||
|
// Qwen3-VL: height at columns 1,4,7,… ; width at 2,5,8,… ;
|
||||||
|
// temporal keeps 0,3,6,… — each `take`n from `mrope_section`.
|
||||||
|
for i in (1..half).step_by(3).take(section[1]) {
|
||||||
|
mh[i] = 1.0;
|
||||||
|
}
|
||||||
|
for i in (2..half).step_by(3).take(section[2]) {
|
||||||
|
mw[i] = 1.0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Qwen2-VL: contiguous blocks [text | height | width].
|
||||||
|
let h_start = section[0].min(half);
|
||||||
|
let h_end = (section[0] + section[1]).min(half);
|
||||||
|
for m in mh.iter_mut().take(h_end).skip(h_start) {
|
||||||
|
*m = 1.0;
|
||||||
|
}
|
||||||
|
for m in mw.iter_mut().take(half).skip(h_end) {
|
||||||
|
*m = 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mt: Vec<f32> = (0..half)
|
||||||
|
.map(|i| {
|
||||||
|
if mh[i] == 0.0 && mw[i] == 0.0 {
|
||||||
|
1.0
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
(mt, mh, mw)
|
||||||
|
}
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
|
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
|
||||||
let head_dim = cfg.head_dim;
|
let head_dim = cfg.head_dim;
|
||||||
@@ -52,44 +118,88 @@ impl RotaryEmbedding {
|
|||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
|
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
|
||||||
.collect();
|
.collect();
|
||||||
let n = inv_freq.len();
|
let half = inv_freq.len();
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?;
|
let inv_freq = Tensor::from_vec(inv_freq, (1, half), dev)?.to_dtype(DType::F32)?;
|
||||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.reshape((max_seq_len, 1))?;
|
.reshape((max_seq_len, 1))?;
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
|
||||||
|
// MRoPE axis masks. `sum(mrope_section)` should equal `half`;
|
||||||
|
// warn-tolerant: any shortfall just stays on the temporal axis.
|
||||||
|
let (mt, mh, mw) = mrope_masks(half, &rope.mrope_section, rope.mrope_interleaved);
|
||||||
|
let mask_t = Tensor::from_vec(mt, (1, half), dev)?;
|
||||||
|
let mask_h = Tensor::from_vec(mh, (1, half), dev)?;
|
||||||
|
let mask_w = Tensor::from_vec(mw, (1, half), dev)?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
sin: freqs.sin()?.to_dtype(dtype)?,
|
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||||
cos: freqs.cos()?.to_dtype(dtype)?,
|
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||||
|
inv_freq,
|
||||||
|
mask_t,
|
||||||
|
mask_h,
|
||||||
|
mask_w,
|
||||||
|
dtype,
|
||||||
rotary_dim,
|
rotary_dim,
|
||||||
head_dim,
|
head_dim,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply RoPE to q, k.
|
/// cos/sin for a contiguous run of `seq_len` positions starting at
|
||||||
///
|
/// `pos`, by narrowing the precomputed tables. The text / decode
|
||||||
/// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index
|
/// path (all three MRoPE axes equal → plain RoPE). Shape
|
||||||
/// into the cached cos/sin table — the position of the first token
|
/// `(seq_len, rotary_dim/2)`.
|
||||||
/// in the current step.
|
pub fn plain_cos_sin(
|
||||||
///
|
&self,
|
||||||
/// When `rotary_dim < head_dim` the rotation is applied only to the
|
pos: usize,
|
||||||
/// first `rotary_dim` dims of each head; the tail passes through
|
seq_len: usize,
|
||||||
/// unchanged (matches the reference Python's
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
/// `apply_rotary_pos_emb` with non-trivial `partial_rotary_factor`).
|
let cos = self.cos.narrow(0, pos, seq_len)?;
|
||||||
pub fn apply(
|
let sin = self.sin.narrow(0, pos, seq_len)?;
|
||||||
|
Ok((cos, sin))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// cos/sin from explicit per-token 3D position ids, shape
|
||||||
|
/// `(3, seq_len)` (axes: text, height, width). Builds each axis's
|
||||||
|
/// frequencies and blends them at the interleave index sets, so
|
||||||
|
/// every rotary frequency slot is driven by exactly one axis.
|
||||||
|
/// Reduces exactly to [`Self::plain_cos_sin`] when the three axes are
|
||||||
|
/// equal. Returns cos/sin of shape `(seq_len, rotary_dim/2)`.
|
||||||
|
pub fn mrope_cos_sin(&self, position_ids: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let pos = position_ids.to_dtype(DType::F32)?;
|
||||||
|
let (axes, seq_len) = pos.dims2()?;
|
||||||
|
debug_assert_eq!(axes, 3, "mrope position_ids must have 3 axes");
|
||||||
|
// Per-axis freqs: pos[a] (seq,1) @ inv_freq (1,half) → (seq,half).
|
||||||
|
let ft = pos.i(0)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
|
||||||
|
let fh = pos.i(1)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
|
||||||
|
let fw = pos.i(2)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
|
||||||
|
// Blend: each column belongs to exactly one axis (masks partition
|
||||||
|
// the half-dim), so this picks the right axis per frequency slot.
|
||||||
|
let blended = ft
|
||||||
|
.broadcast_mul(&self.mask_t)?
|
||||||
|
.add(&fh.broadcast_mul(&self.mask_h)?)?
|
||||||
|
.add(&fw.broadcast_mul(&self.mask_w)?)?;
|
||||||
|
let cos = blended.cos()?.to_dtype(self.dtype)?;
|
||||||
|
let sin = blended.sin()?.to_dtype(self.dtype)?;
|
||||||
|
Ok((cos, sin))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply rotary to `q`, `k` (shape `(B, H, L, head_dim)`) using
|
||||||
|
/// precomputed `cos`/`sin` of shape `(L, rotary_dim/2)`. Partial
|
||||||
|
/// rotary: only the first `rotary_dim` dims rotate; the tail passes
|
||||||
|
/// through unchanged.
|
||||||
|
pub fn apply_cos_sin(
|
||||||
&self,
|
&self,
|
||||||
q: &Tensor,
|
q: &Tensor,
|
||||||
k: &Tensor,
|
k: &Tensor,
|
||||||
offset: usize,
|
cos: &Tensor,
|
||||||
|
sin: &Tensor,
|
||||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
let (_, _, seq_len, head_dim_in) = q.dims4()?;
|
let (_, _, _seq_len, head_dim_in) = q.dims4()?;
|
||||||
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
|
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
|
||||||
let cos = self.cos.narrow(0, offset, seq_len)?;
|
|
||||||
let sin = self.sin.narrow(0, offset, seq_len)?;
|
|
||||||
if self.rotary_dim == self.head_dim {
|
if self.rotary_dim == self.head_dim {
|
||||||
// Full rotation.
|
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, cos, sin)?;
|
||||||
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?;
|
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, cos, sin)?;
|
||||||
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?;
|
|
||||||
Ok((q_embed, k_embed))
|
Ok((q_embed, k_embed))
|
||||||
} else {
|
} else {
|
||||||
// Partial rotation: narrow → rotate → cat the untouched tail.
|
// Partial rotation: narrow → rotate → cat the untouched tail.
|
||||||
@@ -102,8 +212,8 @@ impl RotaryEmbedding {
|
|||||||
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
|
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
|
||||||
.contiguous()?;
|
.contiguous()?;
|
||||||
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
|
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
|
||||||
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, &cos, &sin)?;
|
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, cos, sin)?;
|
||||||
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, &cos, &sin)?;
|
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, cos, sin)?;
|
||||||
let q_embed =
|
let q_embed =
|
||||||
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
|
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
|
||||||
let k_embed =
|
let k_embed =
|
||||||
@@ -111,4 +221,120 @@ impl RotaryEmbedding {
|
|||||||
Ok((q_embed, k_embed))
|
Ok((q_embed, k_embed))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Text/decode convenience: build plain cos/sin for a scalar offset
|
||||||
|
/// and apply in one call. The current call sites use this; Stages 3–4
|
||||||
|
/// move cos/sin construction up into the model forward (computed once
|
||||||
|
/// per forward) and call [`Self::apply_cos_sin`] directly.
|
||||||
|
pub fn apply(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let (_, _, seq_len, _) = q.dims4()?;
|
||||||
|
let (cos, sin) = self.plain_cos_sin(offset, seq_len)?;
|
||||||
|
self.apply_cos_sin(q, k, &cos, &sin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use candle_core::IndexOp;
|
||||||
|
|
||||||
|
/// A TextConfig stub with Qwen3.6's rope params (head_dim 256,
|
||||||
|
/// partial 0.25 → rotary_dim 64 → half 32; section [11,11,10]).
|
||||||
|
fn qwen36_cfg() -> TextConfig {
|
||||||
|
serde_json::from_value(serde_json::json!({
|
||||||
|
"hidden_size": 5120,
|
||||||
|
"num_hidden_layers": 1,
|
||||||
|
"num_attention_heads": 64,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"head_dim": 256,
|
||||||
|
"intermediate_size": 1,
|
||||||
|
"vocab_size": 10,
|
||||||
|
"rms_norm_eps": 1e-6,
|
||||||
|
"max_position_embeddings": 64,
|
||||||
|
"layer_types": ["full_attention"],
|
||||||
|
"rope_parameters": {
|
||||||
|
"rope_theta": 10000000.0,
|
||||||
|
"partial_rotary_factor": 0.25,
|
||||||
|
"mrope_section": [11, 11, 10],
|
||||||
|
"mrope_interleaved": true
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("cfg")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn mrope_masks_partition_the_half_dim() {
|
||||||
|
let (mt, mh, mw) = mrope_masks(32, &[11, 11, 10], true);
|
||||||
|
// Each column belongs to exactly one axis.
|
||||||
|
for i in 0..32 {
|
||||||
|
let s = mt[i] + mh[i] + mw[i];
|
||||||
|
assert_eq!(s, 1.0, "column {i} covered {s} times");
|
||||||
|
}
|
||||||
|
assert_eq!(mt.iter().sum::<f32>(), 11.0);
|
||||||
|
assert_eq!(mh.iter().sum::<f32>(), 11.0);
|
||||||
|
assert_eq!(mw.iter().sum::<f32>(), 10.0);
|
||||||
|
// Interleave: temporal 0,3,…; height 1,4,…; width 2,5,…
|
||||||
|
assert_eq!(mt[0], 1.0);
|
||||||
|
assert_eq!(mh[1], 1.0);
|
||||||
|
assert_eq!(mw[2], 1.0);
|
||||||
|
assert_eq!(mt[3], 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The load-bearing invariant: when all three position axes are
|
||||||
|
/// equal (text), `mrope_cos_sin` must reproduce `plain_cos_sin`
|
||||||
|
/// bit-for-bit — i.e. M-RoPE is a no-op for text, so text inference
|
||||||
|
/// is unchanged.
|
||||||
|
#[test]
|
||||||
|
fn mrope_reduces_to_plain_for_equal_axes() {
|
||||||
|
let dev = Device::Cpu;
|
||||||
|
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
|
||||||
|
|
||||||
|
// positions 5,6,7 on all three axes.
|
||||||
|
let base: Vec<i64> = vec![5, 6, 7];
|
||||||
|
let pos =
|
||||||
|
Tensor::from_vec([base.clone(), base.clone(), base].concat(), (3, 3), &dev).unwrap();
|
||||||
|
|
||||||
|
let (mc, ms) = rope.mrope_cos_sin(&pos).unwrap();
|
||||||
|
let (pc, ps) = rope.plain_cos_sin(5, 3).unwrap();
|
||||||
|
|
||||||
|
let dcos = (mc - pc).unwrap().abs().unwrap().max_all().unwrap();
|
||||||
|
let dsin = (ms - ps).unwrap().abs().unwrap().max_all().unwrap();
|
||||||
|
assert!(
|
||||||
|
dcos.to_scalar::<f32>().unwrap() < 1e-6,
|
||||||
|
"cos mismatch {dcos:?}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
dsin.to_scalar::<f32>().unwrap() < 1e-6,
|
||||||
|
"sin mismatch {dsin:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hand-checked interleave: a width-axis column (index 2) must track
|
||||||
|
/// the WIDTH position, while a temporal column (index 0) tracks the
|
||||||
|
/// TEXT position, even when the axes differ.
|
||||||
|
#[test]
|
||||||
|
fn mrope_blends_axes_at_interleave_columns() {
|
||||||
|
let dev = Device::Cpu;
|
||||||
|
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
|
||||||
|
let half = rope.inv_freq.dim(1).unwrap();
|
||||||
|
let inv: Vec<f32> = rope.inv_freq.i(0).unwrap().to_vec1().unwrap();
|
||||||
|
|
||||||
|
// One token: text=10, height=3, width=7 — all distinct.
|
||||||
|
let pos = Tensor::from_vec(vec![10i64, 3, 7], (3, 1), &dev).unwrap();
|
||||||
|
let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap();
|
||||||
|
let cos_row: Vec<f32> = cos.i(0).unwrap().to_vec1().unwrap();
|
||||||
|
assert_eq!(cos_row.len(), half);
|
||||||
|
|
||||||
|
// Column 0 (temporal) → text pos 10. Column 1 (height) → 3.
|
||||||
|
// Column 2 (width) → 7.
|
||||||
|
assert!((cos_row[0] - (10.0 * inv[0]).cos()).abs() < 1e-5);
|
||||||
|
assert!((cos_row[1] - (3.0 * inv[1]).cos()).abs() < 1e-5);
|
||||||
|
assert!((cos_row[2] - (7.0 * inv[2]).cos()).abs() < 1e-5);
|
||||||
|
assert!((cos_row[3] - (10.0 * inv[3]).cos()).abs() < 1e-5);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user