diff --git a/crates/neuron/src/harness/arch/qwen3_5/rope.rs b/crates/neuron/src/harness/arch/qwen3_5/rope.rs index f1d1b16..1d547b1 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/rope.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/rope.rs @@ -240,21 +240,20 @@ impl RotaryEmbedding { /// (position-compressed) image blocks. /// /// Whether interleaved M-RoPE for image tokens is enabled. Default -/// **off** — the initial implementation degraded image understanding -/// (the model misread spatial layout and rambled), so until the -/// interleave is validated against a numerical reference it is opt-in -/// via `NEURON_MROPE=1`. When off, image tokens get plain sequential -/// positions (the pre-M-RoPE behaviour: content recognition works, -/// spatial layout is approximate). +/// **on** — Qwen3.6 was trained with interleaved M-RoPE, and this +/// implementation matches the HF `apply_interleaved_mrope` / +/// `get_rope_index` reference exactly (verified column-for-column). The +/// env var is a **kill switch**: `NEURON_MROPE=0` falls back to plain +/// sequential positions for image tokens (the pre-M-RoPE behaviour). pub(crate) fn mrope_enabled() -> bool { std::env::var("NEURON_MROPE") .map(|v| { - matches!( + !matches!( v.trim().to_ascii_lowercase().as_str(), - "1" | "true" | "yes" | "on" + "0" | "false" | "no" | "off" ) }) - .unwrap_or(false) + .unwrap_or(true) } /// Position ids for the forward path. Gated by [`mrope_enabled`]: when @@ -474,16 +473,17 @@ mod tests { } #[test] - fn get_rope_index_gated_off_by_default() { - // With NEURON_MROPE unset (default), the runtime path returns - // plain sequential identity positions → mrope_cos_sin reduces to - // plain RoPE. (compute_mrope_index, tested above, is the real - // computation used when the flag is on.) - let (t, h, w, delta) = get_rope_index(&[1, 99, 99, 99, 99, 2], 99).unwrap(); - assert_eq!(t, vec![0, 1, 2, 3, 4, 5]); - assert_eq!(h, t); - assert_eq!(w, t); - assert_eq!(delta, 0); + fn get_rope_index_on_by_default() { + // With NEURON_MROPE unset (default ON), the runtime path returns + // the real interleaved-M-RoPE positions, so image tokens carry + // their 2D grid coords (height differs from the text counter). + // (NEURON_MROPE=0 would fall back to identity; not asserted here + // since it depends on env.) + let (t, h, w, _delta) = get_rope_index(&[1, 99, 99, 99, 99, 2], 99).unwrap(); + // Same as compute_mrope_index: 2x2 image after one text token. + assert_eq!(t, vec![0, 1, 1, 1, 1, 3]); + assert_eq!(h, vec![0, 1, 1, 2, 2, 3]); + assert_eq!(w, vec![0, 1, 2, 1, 2, 3]); } #[test] diff --git a/crates/neuron/src/harness/arch/qwen3_5/vision.rs b/crates/neuron/src/harness/arch/qwen3_5/vision.rs index f644e24..857b675 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/vision.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/vision.rs @@ -48,6 +48,31 @@ use candle_nn::var_builder::ShardedVarBuilder; use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear}; use serde::Deserialize; +fn env_truthy(name: &str) -> bool { + std::env::var(name) + .map(|v| { + matches!( + v.trim().to_ascii_lowercase().as_str(), + "1" | "true" | "yes" | "on" + ) + }) + .unwrap_or(false) +} + +/// Legacy escape hatch: when set, use the original Stage-A sequential +/// `pos_embed` lookup instead of the bilinear grid interpolation. +/// Default off (interpolation on) — for A/B comparison only. +fn vision_legacy_pos() -> bool { + env_truthy("NEURON_VISION_LEGACY_POS") +} + +/// Legacy escape hatch: when set, skip the 2D vision rotary in the ViT +/// attention (the original Stage-A behaviour). Default off (rotary on) +/// — for A/B comparison only. +fn vision_legacy_rope() -> bool { + env_truthy("NEURON_VISION_LEGACY_ROPE") +} + /// Qwen3.6 vision tower hyperparameters. Mirrors the `vision_config` /// block of `config.json`. Only the fields we actually need are /// captured; serde tolerates the rest. @@ -118,10 +143,12 @@ impl VisionBlock { }) } - /// `x`: `(N, hidden_size)` un-batched. Returns same shape. - fn forward(&self, x: &Tensor) -> Result { + /// `x`: `(N, hidden_size)` un-batched. `rotary`: optional + /// `(cos, sin)` each `(N, head_dim/2)` — the 2D vision rotary applied + /// to q/k. Returns same shape. + fn forward(&self, x: &Tensor, rotary: Option<&(Tensor, Tensor)>) -> Result { let attn_in = self.norm1.forward(x)?; - let attn_out = self.attention(&attn_in)?; + let attn_out = self.attention(&attn_in, rotary)?; let x = x.add(&attn_out)?; let mlp_in = self.norm2.forward(&x)?; let mlp_out = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&mlp_in)?)?)?; @@ -129,8 +156,11 @@ impl VisionBlock { } /// Multi-head self-attention over the patch sequence. No causal - /// mask — every patch attends to every other patch. - fn attention(&self, x: &Tensor) -> Result { + /// mask — every patch attends to every other patch. When `rotary` is + /// given, the 2D vision rotary (row/col position) is applied to q, k + /// before the scores, matching HF `apply_rotary_pos_emb_vision` + /// (`rope_slow` is the same rotate-half form). + fn attention(&self, x: &Tensor, rotary: Option<&(Tensor, Tensor)>) -> Result { let (n, hidden) = x.dims2()?; // qkv: (N, 3*hidden). Split into Q, K, V each (N, hidden). let qkv = self.qkv.forward(x)?; @@ -140,6 +170,15 @@ impl VisionBlock { let q = qkv.i(0)?; let k = qkv.i(1)?; let v = qkv.i(2)?; + // 2D vision rotary on q, k (full head_dim; rotate-half form). + let (q, k) = match rotary { + Some((cos, sin)) => { + let q = candle_nn::rotary_emb::rope_slow(&q.unsqueeze(0)?, cos, sin)?.squeeze(0)?; + let k = candle_nn::rotary_emb::rope_slow(&k.unsqueeze(0)?, cos, sin)?.squeeze(0)?; + (q, k) + } + None => (q, k), + }; let scale = 1.0 / (self.head_dim as f64).sqrt(); // (num_heads, N, head_dim) @ (num_heads, head_dim, N) -> (num_heads, N, N) let scores = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?; @@ -210,11 +249,65 @@ impl VisionMerger { } } +/// 2D rotary position embedding for the vision tower. Each patch's +/// `head_dim` rotates by its `(row, col)` grid coordinates: the first +/// half of the rotary freqs are driven by the row position, the second +/// half by the column. Mirrors HF `Qwen3VLVisionRotaryEmbedding` + +/// `rot_pos_emb` (θ = 10000, `dim = head_dim/2`). +struct VisionRotaryEmbedding { + /// `(half,)` f32, `half = head_dim/4` freqs per spatial axis. + inv_freq: Vec, +} + +impl VisionRotaryEmbedding { + fn new(head_dim: usize) -> Self { + // HF: Qwen3VLVisionRotaryEmbedding(head_dim // 2), theta 10000. + let dim = head_dim / 2; + let theta = 10000f32; + let inv_freq = (0..dim) + .step_by(2) + .map(|i| 1f32 / theta.powf(i as f32 / dim as f32)) + .collect(); + Self { inv_freq } + } + + /// cos/sin for a `gh×gw` patch grid in **row-major** order. Returns + /// `(cos, sin)` each `(gh*gw, head_dim/2)`: per patch, the row-axis + /// freqs `row·inv_freq` followed by the col-axis freqs `col·inv_freq` + /// (then `rope_slow` duplicates them across the full head_dim). + fn cos_sin( + &self, + gh: usize, + gw: usize, + dev: &Device, + dtype: DType, + ) -> candle_core::Result<(Tensor, Tensor)> { + let half = self.inv_freq.len(); + let n = gh * gw; + let mut data = Vec::with_capacity(n * 2 * half); + for hi in 0..gh { + for wi in 0..gw { + for &f in &self.inv_freq { + data.push(hi as f32 * f); + } + for &f in &self.inv_freq { + data.push(wi as f32 * f); + } + } + } + let freqs = Tensor::from_vec(data, (n, 2 * half), dev)?; + let cos = freqs.cos()?.to_dtype(dtype)?; + let sin = freqs.sin()?.to_dtype(dtype)?; + Ok((cos, sin)) + } +} + /// The vision tower itself. pub struct VisionTower { /// Sum-collapsed temporal kernel (Conv2d, see module doc). patch_embed: Conv2d, pos_embed: Embedding, + rotary: VisionRotaryEmbedding, blocks: Vec, merger: VisionMerger, config: VisionConfig, @@ -265,6 +358,7 @@ impl VisionTower { .get((cfg.num_position_embeddings, cfg.hidden_size), "weight") .context("load model.visual.pos_embed.weight")?; let pos_embed = Embedding::new(pos_embed_weight, cfg.hidden_size); + let rotary = VisionRotaryEmbedding::new(cfg.hidden_size / cfg.num_heads); let blocks_vb = vb.pp("blocks"); let mut blocks = Vec::with_capacity(cfg.depth); @@ -279,6 +373,7 @@ impl VisionTower { Ok(Self { patch_embed, pos_embed, + rotary, blocks, merger, config: cfg, @@ -302,6 +397,81 @@ impl VisionTower { gh * gw * LM_TOKENS_PER_MERGE_GROUP } + /// Bilinearly interpolate the learned `pos_embed` grid (a + /// `num_grid_per_side × num_grid_per_side` table, 48×48 for Qwen3.6) + /// onto the actual `gh × gw` patch grid, in **row-major** patch + /// order. Port of the HF `fast_pos_embed_interpolate`: for each patch + /// at fractional grid coord `(linspace(0, ngrid-1, gh)[hi], + /// linspace(0, ngrid-1, gw)[wi])`, blend the 4 surrounding grid + /// entries by bilinear weights. Returns `(gh*gw, hidden)` in + /// `self.dtype`. + fn interpolated_pos_embed(&self, gh: usize, gw: usize) -> Result { + let ngrid = (self.config.num_position_embeddings as f64).sqrt().round() as usize; + anyhow::ensure!( + ngrid * ngrid == self.config.num_position_embeddings, + "num_position_embeddings {} is not a perfect square", + self.config.num_position_embeddings + ); + // Evenly-spaced fractional indices into the [0, ngrid-1] grid. + let lin = |n: usize| -> Vec { + if n <= 1 { + vec![0.0] + } else { + let step = (ngrid - 1) as f64 / (n - 1) as f64; + (0..n).map(|i| i as f64 * step).collect() + } + }; + let hs = lin(gh); + let ws = lin(gw); + let n = gh * gw; + + // Four corner index sets + bilinear weight sets, row-major. + let mut idx: [Vec; 4] = [ + Vec::with_capacity(n), + Vec::with_capacity(n), + Vec::with_capacity(n), + Vec::with_capacity(n), + ]; + let mut wts: [Vec; 4] = [ + Vec::with_capacity(n), + Vec::with_capacity(n), + Vec::with_capacity(n), + Vec::with_capacity(n), + ]; + for &hv in &hs { + let hf = hv as usize; // floor (hv >= 0) + let hc = (hf + 1).min(ngrid - 1); + let dh = (hv - hf as f64) as f32; + for &wv in &ws { + let wf = wv as usize; + let wc = (wf + 1).min(ngrid - 1); + let dw = (wv - wf as f64) as f32; + idx[0].push((hf * ngrid + wf) as u32); + wts[0].push((1.0 - dh) * (1.0 - dw)); + idx[1].push((hf * ngrid + wc) as u32); + wts[1].push((1.0 - dh) * dw); + idx[2].push((hc * ngrid + wf) as u32); + wts[2].push(dh * (1.0 - dw)); + idx[3].push((hc * ngrid + wc) as u32); + wts[3].push(dh * dw); + } + } + + let mut acc: Option = None; + for corner in 0..4 { + let idx_t = Tensor::from_vec(std::mem::take(&mut idx[corner]), (n,), &self.device)?; + let emb = self.pos_embed.forward(&idx_t)?; // (n, hidden), pos_embed dtype + let wt = Tensor::from_vec(std::mem::take(&mut wts[corner]), (n, 1), &self.device)? + .to_dtype(self.dtype)?; + let term = emb.broadcast_mul(&wt)?; + acc = Some(match acc { + Some(a) => a.add(&term)?, + None => term, + }); + } + Ok(acc.expect("4 corners accumulated")) + } + /// Encode one image. /// /// `image`: row-major `(3, H, W)` f32 tensor on `self.device`, @@ -339,16 +509,34 @@ impl VisionTower { let x = x.permute((1, 2, 0))?.contiguous()?; let x = x.reshape((n_patches, self.config.hidden_size))?; - // Add learned positional embeddings (sequential indices for - // Stage A's fixed-resolution path; full 2D positional logic - // lands with variable resolution, issue #14). - let positions = Tensor::arange(0u32, n_patches as u32, &self.device)?; - let pos = self.pos_embed.forward(&positions)?; + // Learned absolute position embeddings. The `pos_embed` table is + // a `num_position_embeddings = num_grid_per_side²` learned grid + // (48×48 for Qwen3.6); for a `gh×gw` patch grid the reference + // (`fast_pos_embed_interpolate`) bilinearly interpolates that + // grid to `gh×gw`. The legacy path (a naive sequential lookup of + // the first `n_patches` rows) mis-maps the grid stride and + // scrambles spatial structure — kept only behind + // `NEURON_VISION_LEGACY_POS=1` for A/B comparison. + let pos = if vision_legacy_pos() { + let positions = Tensor::arange(0u32, n_patches as u32, &self.device)?; + self.pos_embed.forward(&positions)? + } else { + self.interpolated_pos_embed(gh, gw)? + }; let mut x = x.add(&pos)?; + // 2D vision rotary (row/col per patch), computed once and applied + // in every block's attention. Legacy escape hatch skips it. + let rotary = if vision_legacy_rope() { + None + } else { + Some(self.rotary.cos_sin(gh, gw, &self.device, self.dtype)?) + }; + let rotary_ref = rotary.as_ref(); + for (i, block) in self.blocks.iter().enumerate() { x = block - .forward(&x) + .forward(&x, rotary_ref) .with_context(|| format!("vision block {i}"))?; } @@ -516,9 +704,11 @@ mod tests { spatial_merge_size: cfg.spatial_merge_size, }; + let rotary = VisionRotaryEmbedding::new(cfg.hidden_size / cfg.num_heads); VisionTower { patch_embed, pos_embed, + rotary, blocks, merger, config: cfg.clone(), @@ -548,6 +738,51 @@ mod tests { ); } + #[test] + fn interpolated_pos_embed_reduces_to_sequential_at_native_grid() { + // When the patch grid equals the pos_embed grid (gh=gw=ngrid), + // linspace(0,ngrid-1,ngrid) is the integer ladder, so every patch + // lands exactly on a grid node (dh=dw=0, corner-0 weight 1) and + // the bilinear result is the raw pos_embed rows in row-major + // order — i.e. identical to the legacy sequential lookup. + let cfg = tiny_config(); + let tower = tiny_tower(&cfg); + let ngrid = (cfg.num_position_embeddings as f64).sqrt() as usize; // 8 + let interp = tower.interpolated_pos_embed(ngrid, ngrid).unwrap(); + let seq = tower + .pos_embed + .forward(&Tensor::arange(0u32, (ngrid * ngrid) as u32, &Device::Cpu).unwrap()) + .unwrap(); + let a: Vec = interp.flatten_all().unwrap().to_vec1().unwrap(); + let b: Vec = seq.flatten_all().unwrap().to_vec1().unwrap(); + assert_eq!(a.len(), b.len()); + for (x, y) in a.iter().zip(b.iter()) { + assert!((x - y).abs() < 1e-5, "interp {x} vs seq {y}"); + } + } + + #[test] + fn vision_rotary_row_col_structure() { + // head_dim 8 → rotary dim 4 → inv_freq over [0,2] → 2 freqs/axis. + let rot = VisionRotaryEmbedding::new(8); + assert_eq!(rot.inv_freq.len(), 2); + let (cos, sin) = rot.cos_sin(2, 2, &Device::Cpu, DType::F32).unwrap(); + assert_eq!(cos.dims(), &[4, 4]); // 4 patches, head_dim/2 = 4 cols + + // Patch (0,0): all freqs 0 → cos 1, sin 0. + let s0: Vec = sin.i(0).unwrap().to_vec1().unwrap(); + assert!(s0.iter().all(|&s| s.abs() < 1e-6)); + + // Patch index 2 = grid (1,0): row=1 drives the first half, col=0 + // leaves the second half at zero. + let s2: Vec = sin.i(2).unwrap().to_vec1().unwrap(); + assert!(s2[0].abs() > 1e-6, "row half must be non-zero"); + assert!( + s2[2].abs() < 1e-6 && s2[3].abs() < 1e-6, + "col half must be zero" + ); + } + #[test] fn lm_token_count_matches_grid() { let cfg = tiny_config();