diff --git a/crates/neuron/src/harness/arch/qwen3_5/rope.rs b/crates/neuron/src/harness/arch/qwen3_5/rope.rs index 01cfa5a..f1d1b16 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/rope.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/rope.rs @@ -239,11 +239,44 @@ impl RotaryEmbedding { /// to a plain decode offset so text resumes from the counter after the /// (position-compressed) image blocks. /// +/// Whether interleaved M-RoPE for image tokens is enabled. Default +/// **off** — the initial implementation degraded image understanding +/// (the model misread spatial layout and rambled), so until the +/// interleave is validated against a numerical reference it is opt-in +/// via `NEURON_MROPE=1`. When off, image tokens get plain sequential +/// positions (the pre-M-RoPE behaviour: content recognition works, +/// spatial layout is approximate). +pub(crate) fn mrope_enabled() -> bool { + std::env::var("NEURON_MROPE") + .map(|v| { + matches!( + v.trim().to_ascii_lowercase().as_str(), + "1" | "true" | "yes" | "on" + ) + }) + .unwrap_or(false) +} + +/// Position ids for the forward path. Gated by [`mrope_enabled`]: when +/// off, returns plain sequential identity positions on all three axes +/// (`mrope_cos_sin` then reduces exactly to plain RoPE), restoring the +/// pre-M-RoPE behaviour without touching the rest of the forward. +pub(crate) fn get_rope_index(input_ids: &[u32], image_token_id: u32) -> Result { + if !mrope_enabled() { + let seq: Vec = (0..input_ids.len() as i64).collect(); + return Ok((seq.clone(), seq.clone(), seq, 0)); + } + compute_mrope_index(input_ids, image_token_id) +} + +/// The real interleaved-M-RoPE position-id computation (always active in +/// unit tests; gated behind [`get_rope_index`] at runtime). +/// /// Fixed-resolution assumption (Stage C): each image run is a perfect /// square with `grid_t = 1` (still image) and `grid_h = grid_w = /// isqrt(run_len)` — 196 → 14×14. Dynamic resolution (#14) would thread /// real per-image grids instead. -pub(crate) fn get_rope_index(input_ids: &[u32], image_token_id: u32) -> Result { +pub(crate) fn compute_mrope_index(input_ids: &[u32], image_token_id: u32) -> Result { let n = input_ids.len(); let mut text = Vec::with_capacity(n); let mut height = Vec::with_capacity(n); @@ -415,7 +448,7 @@ mod tests { #[test] fn get_rope_index_text_only_is_sequential() { - let (t, h, w, delta) = get_rope_index(&[1, 2, 3, 4], 99).unwrap(); + let (t, h, w, delta) = compute_mrope_index(&[1, 2, 3, 4], 99).unwrap(); assert_eq!(t, vec![0, 1, 2, 3]); assert_eq!(h, vec![0, 1, 2, 3]); assert_eq!(w, vec![0, 1, 2, 3]); @@ -426,7 +459,7 @@ mod tests { fn get_rope_index_text_image_text() { // [text, image(2x2 run of 4), text]. image_token = 99. let ids = [1u32, 99, 99, 99, 99, 2]; - let (t, h, w, delta) = get_rope_index(&ids, 99).unwrap(); + let (t, h, w, delta) = compute_mrope_index(&ids, 99).unwrap(); // token 0: text → 0. image base=1, grid 1x2x2: // t all = 1; h = base+row = [1,1,2,2]; w = base+col = [1,2,1,2]. // resume from base + max(1,2,2) = 3. trailing text → 3. @@ -440,11 +473,24 @@ mod tests { assert_eq!(6 + delta, 4); } + #[test] + fn get_rope_index_gated_off_by_default() { + // With NEURON_MROPE unset (default), the runtime path returns + // plain sequential identity positions → mrope_cos_sin reduces to + // plain RoPE. (compute_mrope_index, tested above, is the real + // computation used when the flag is on.) + let (t, h, w, delta) = get_rope_index(&[1, 99, 99, 99, 99, 2], 99).unwrap(); + assert_eq!(t, vec![0, 1, 2, 3, 4, 5]); + assert_eq!(h, t); + assert_eq!(w, t); + assert_eq!(delta, 0); + } + #[test] fn get_rope_index_rejects_non_square_image_run() { // 196 is square (14x14) — ok. 195 is not. - assert!(get_rope_index(&[99u32; 196], 99).is_ok()); - assert!(get_rope_index(&[99u32; 195], 99).is_err()); + assert!(compute_mrope_index(&[99u32; 196], 99).is_ok()); + assert!(compute_mrope_index(&[99u32; 195], 99).is_err()); } #[test] @@ -455,7 +501,7 @@ mod tests { let dev = Device::Cpu; let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap(); let ids = [1u32, 99, 99, 99, 99]; // text + 2x2 image - let (t, h, w, _d) = get_rope_index(&ids, 99).unwrap(); + let (t, h, w, _d) = compute_mrope_index(&ids, 99).unwrap(); let pos = mrope_position_tensor(&t, &h, &w, &dev).unwrap(); assert_eq!(pos.dims(), &[3, 5]); let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap(); @@ -472,7 +518,7 @@ mod tests { fn get_rope_index_196_is_14x14() { let mut ids = vec![1u32]; // one text token ids.extend(std::iter::repeat_n(99u32, 196)); - let (t, h, w, _delta) = get_rope_index(&ids, 99).unwrap(); + let (t, h, w, _delta) = compute_mrope_index(&ids, 99).unwrap(); // image base = 1. Last image token (index 196) is grid (h=13,w=13). assert_eq!(*t.last().unwrap(), 1, "grid_t=1 → temporal const at base"); assert_eq!(h[1], 1, "first image row at base");