feat(neuron): M-RoPE Stage 2 — get_rope_index position-id helper
Pure function computing the interleaved-M-RoPE 3D position ids for a prompt with image-placeholder runs, plus the decode rope_delta: text tokens advance a single counter (all axes equal); each image run gets [base+t, base+h, base+w] row-major over a square grid_t=1, grid_h=grid_w=isqrt(run) (196 → 14×14); the counter resumes from base + max(grid). rope_delta = final_counter - seq_len lets decode resume text positions after the position-compressed image blocks. Plus mrope_position_tensor to build the (3, seq) tensor. Unit tests: text-only is sequential (delta 0); text+image+text matches hand-computed grid ids + resume + delta; 196 → 14×14; non-square run rejected; end-to-end through mrope_cos_sin tracks the height axis. #[allow(dead_code)] until Stage 3/4 wire it into the forward. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -238,6 +238,96 @@ impl RotaryEmbedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compute interleaved-M-RoPE 3D position ids for a full prompt that may
|
||||||
|
/// contain image-placeholder runs, plus the decode `rope_delta`.
|
||||||
|
///
|
||||||
|
/// Mirrors the reference `get_rope_index`:
|
||||||
|
/// - text tokens advance a single running counter `c`, all three axes
|
||||||
|
/// equal (`[c, c, c]`);
|
||||||
|
/// - each contiguous run of `image_token_id` is one image; its tokens get
|
||||||
|
/// `[base + t, base + h, base + w]` in row-major (t outer, h, w inner),
|
||||||
|
/// where `base` is the counter at the run's start; after the run the
|
||||||
|
/// counter resumes from `base + max(grid_t, grid_h, grid_w)`.
|
||||||
|
///
|
||||||
|
/// Returns `(text_pos, height_pos, width_pos, rope_delta)`, each pos `Vec`
|
||||||
|
/// length `input_ids.len()`. `rope_delta = final_counter - seq_len`: add it
|
||||||
|
/// to a plain decode offset so text resumes from the counter after the
|
||||||
|
/// (position-compressed) image blocks.
|
||||||
|
///
|
||||||
|
/// Fixed-resolution assumption (Stage C): each image run is a perfect
|
||||||
|
/// square with `grid_t = 1` (still image) and `grid_h = grid_w =
|
||||||
|
/// isqrt(run_len)` — 196 → 14×14. Dynamic resolution (#14) would thread
|
||||||
|
/// real per-image grids instead.
|
||||||
|
// Wired into the forward path in Stage 3 (single-GPU) / Stage 4 (TP);
|
||||||
|
// exercised by unit tests until then.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn get_rope_index(input_ids: &[u32], image_token_id: u32) -> Result<MRopeIndex> {
|
||||||
|
let n = input_ids.len();
|
||||||
|
let mut text = Vec::with_capacity(n);
|
||||||
|
let mut height = Vec::with_capacity(n);
|
||||||
|
let mut width = Vec::with_capacity(n);
|
||||||
|
let mut counter: i64 = 0;
|
||||||
|
let mut i = 0;
|
||||||
|
while i < n {
|
||||||
|
if input_ids[i] == image_token_id {
|
||||||
|
let start = i;
|
||||||
|
while i < n && input_ids[i] == image_token_id {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
let run = i - start;
|
||||||
|
let g = run.isqrt();
|
||||||
|
if g * g != run {
|
||||||
|
anyhow::bail!(
|
||||||
|
"get_rope_index: image run length {run} is not a perfect square \
|
||||||
|
(fixed-resolution Stage C assumes a square grid; dynamic resolution is #14)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let (grid_t, grid_h, grid_w) = (1usize, g, g);
|
||||||
|
let base = counter;
|
||||||
|
for tt in 0..grid_t {
|
||||||
|
for hh in 0..grid_h {
|
||||||
|
for ww in 0..grid_w {
|
||||||
|
text.push(base + tt as i64);
|
||||||
|
height.push(base + hh as i64);
|
||||||
|
width.push(base + ww as i64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
counter = base + grid_t.max(grid_h).max(grid_w) as i64;
|
||||||
|
} else {
|
||||||
|
text.push(counter);
|
||||||
|
height.push(counter);
|
||||||
|
width.push(counter);
|
||||||
|
counter += 1;
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let delta = counter - n as i64;
|
||||||
|
Ok((text, height, width, delta))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `(text_pos, height_pos, width_pos, rope_delta)` returned by
|
||||||
|
/// [`get_rope_index`]; the three vectors combine into the `(3, seq)`
|
||||||
|
/// MRoPE position-id tensor.
|
||||||
|
pub(crate) type MRopeIndex = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
|
||||||
|
|
||||||
|
/// Build the `(3, seq)` position-id tensor consumed by
|
||||||
|
/// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors.
|
||||||
|
#[allow(dead_code)] // wired in Stage 3/4
|
||||||
|
pub(crate) fn mrope_position_tensor(
|
||||||
|
text: &[i64],
|
||||||
|
height: &[i64],
|
||||||
|
width: &[i64],
|
||||||
|
dev: &Device,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let seq = text.len();
|
||||||
|
let mut flat = Vec::with_capacity(3 * seq);
|
||||||
|
flat.extend_from_slice(text);
|
||||||
|
flat.extend_from_slice(height);
|
||||||
|
flat.extend_from_slice(width);
|
||||||
|
Tensor::from_vec(flat, (3, seq), dev)
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -337,4 +427,72 @@ mod tests {
|
|||||||
assert!((cos_row[2] - (7.0 * inv[2]).cos()).abs() < 1e-5);
|
assert!((cos_row[2] - (7.0 * inv[2]).cos()).abs() < 1e-5);
|
||||||
assert!((cos_row[3] - (10.0 * inv[3]).cos()).abs() < 1e-5);
|
assert!((cos_row[3] - (10.0 * inv[3]).cos()).abs() < 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_rope_index_text_only_is_sequential() {
|
||||||
|
let (t, h, w, delta) = get_rope_index(&[1, 2, 3, 4], 99).unwrap();
|
||||||
|
assert_eq!(t, vec![0, 1, 2, 3]);
|
||||||
|
assert_eq!(h, vec![0, 1, 2, 3]);
|
||||||
|
assert_eq!(w, vec![0, 1, 2, 3]);
|
||||||
|
assert_eq!(delta, 0, "no image → delta 0 → plain decode positions");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_rope_index_text_image_text() {
|
||||||
|
// [text, image(2x2 run of 4), text]. image_token = 99.
|
||||||
|
let ids = [1u32, 99, 99, 99, 99, 2];
|
||||||
|
let (t, h, w, delta) = get_rope_index(&ids, 99).unwrap();
|
||||||
|
// token 0: text → 0. image base=1, grid 1x2x2:
|
||||||
|
// t all = 1; h = base+row = [1,1,2,2]; w = base+col = [1,2,1,2].
|
||||||
|
// resume from base + max(1,2,2) = 3. trailing text → 3.
|
||||||
|
assert_eq!(t, vec![0, 1, 1, 1, 1, 3]);
|
||||||
|
assert_eq!(h, vec![0, 1, 1, 2, 2, 3]);
|
||||||
|
assert_eq!(w, vec![0, 1, 2, 1, 2, 3]);
|
||||||
|
// final counter = 4, seq_len = 6 → delta = -2 (the 4 image tokens
|
||||||
|
// advanced the counter by only 2).
|
||||||
|
assert_eq!(delta, -2);
|
||||||
|
// Decode after the prompt (offset = 6) → text position 6 + (-2) = 4.
|
||||||
|
assert_eq!(6 + delta, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_rope_index_rejects_non_square_image_run() {
|
||||||
|
// 196 is square (14x14) — ok. 195 is not.
|
||||||
|
assert!(get_rope_index(&[99u32; 196], 99).is_ok());
|
||||||
|
assert!(get_rope_index(&[99u32; 195], 99).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn position_tensor_round_trips_through_mrope_cos_sin() {
|
||||||
|
// get_rope_index → (3,seq) tensor → mrope_cos_sin, and confirm an
|
||||||
|
// image token's height column tracks its grid row (not the text
|
||||||
|
// counter), i.e. the end-to-end position plumbing is wired right.
|
||||||
|
let dev = Device::Cpu;
|
||||||
|
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
|
||||||
|
let ids = [1u32, 99, 99, 99, 99]; // text + 2x2 image
|
||||||
|
let (t, h, w, _d) = get_rope_index(&ids, 99).unwrap();
|
||||||
|
let pos = mrope_position_tensor(&t, &h, &w, &dev).unwrap();
|
||||||
|
assert_eq!(pos.dims(), &[3, 5]);
|
||||||
|
let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap();
|
||||||
|
assert_eq!(cos.dims(), &[5, rope.inv_freq.dim(1).unwrap()]);
|
||||||
|
|
||||||
|
let inv: Vec<f32> = rope.inv_freq.i(0).unwrap().to_vec1().unwrap();
|
||||||
|
// Last image token (index 4): grid (h=1, w=1) → base 1 → h=2, w=2.
|
||||||
|
// Height column (index 1) must track h-position 2, not text.
|
||||||
|
let last: Vec<f32> = cos.i(4).unwrap().to_vec1().unwrap();
|
||||||
|
assert!((last[1] - (2.0 * inv[1]).cos()).abs() < 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_rope_index_196_is_14x14() {
|
||||||
|
let mut ids = vec![1u32]; // one text token
|
||||||
|
ids.extend(std::iter::repeat_n(99u32, 196));
|
||||||
|
let (t, h, w, _delta) = get_rope_index(&ids, 99).unwrap();
|
||||||
|
// image base = 1. Last image token (index 196) is grid (h=13,w=13).
|
||||||
|
assert_eq!(*t.last().unwrap(), 1, "grid_t=1 → temporal const at base");
|
||||||
|
assert_eq!(h[1], 1, "first image row at base");
|
||||||
|
assert_eq!(w[1], 1, "first image col at base");
|
||||||
|
assert_eq!(h[196], 1 + 13, "last image row = base + 13");
|
||||||
|
assert_eq!(w[196], 1 + 13, "last image col = base + 13");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user