fix(neuron): vision-tower 2D positions + M-RoPE default on
All checks were successful
CI / CUDA type-check (push) Successful in 32s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Format (push) Successful in 33s
CI / Clippy (push) Successful in 2m36s
build-prerelease / Build cortex binary (push) Successful in 4m48s
build-prerelease / Build neuron-blackwell (push) Successful in 5m59s
CI / Test (push) Successful in 6m35s
build-prerelease / Build neuron-ampere (push) Successful in 7m51s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ada (push) Successful in 5m13s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m0s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m5s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m49s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m6s

Two fixes to the spatial handling of images, validated against the HF
transformers 4.57.1 qwen3_vl reference on beast.

**Vision tower (the real cause of poor spatial vision).** The Stage-A
tower encoded position two ways wrong, so the model saw image *content*
but not *layout* (a row of 5 people read as "a line of 23", sky
inverted), regardless of the LM-side rope:

- Learned pos-embed was a naive sequential lookup of the first
  `n_patches` rows of the 48×48 (`num_position_embeddings=2304`) grid —
  wrong stride for a 28×28 patch grid. Now bilinearly interpolates the
  grid to `gh×gw` (port of HF `fast_pos_embed_interpolate`), row-major.
- The 2D vision rotary was absent entirely. Added
  `VisionRotaryEmbedding` (θ=10000, dim=head_dim/2) applying per-patch
  `(row, col)` rotary to q/k in every ViT block via rope_slow, matching
  HF `apply_rotary_pos_emb_vision`.

Both default on; `NEURON_VISION_LEGACY_POS=1` / `NEURON_VISION_LEGACY_ROPE=1`
revert each for A/B (no rebuild). New unit tests: interpolation reduces
to the sequential lookup at the native grid; rotary row/col structure.

**M-RoPE default on.** The interleaved M-RoPE matches HF
apply_interleaved_mrope / get_rope_index exactly and A/B'd strictly ≥
plain. `NEURON_MROPE` is now a kill switch (`=0` for plain), not opt-in
— defaults should encode the model's trained behaviour, not freeze the
broken state.

Vision tower is plain candle (CPU-testable): built, clippy-clean, full
workspace tests green locally.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-04 20:52:44 +03:00
parent 7ebcfba5ca
commit dc048ffcc9
2 changed files with 265 additions and 30 deletions

View File

@@ -240,21 +240,20 @@ impl RotaryEmbedding {
/// (position-compressed) image blocks. /// (position-compressed) image blocks.
/// ///
/// Whether interleaved M-RoPE for image tokens is enabled. Default /// Whether interleaved M-RoPE for image tokens is enabled. Default
/// **off** — the initial implementation degraded image understanding /// **on** — Qwen3.6 was trained with interleaved M-RoPE, and this
/// (the model misread spatial layout and rambled), so until the /// implementation matches the HF `apply_interleaved_mrope` /
/// interleave is validated against a numerical reference it is opt-in /// `get_rope_index` reference exactly (verified column-for-column). The
/// via `NEURON_MROPE=1`. When off, image tokens get plain sequential /// env var is a **kill switch**: `NEURON_MROPE=0` falls back to plain
/// positions (the pre-M-RoPE behaviour: content recognition works, /// sequential positions for image tokens (the pre-M-RoPE behaviour).
/// spatial layout is approximate).
pub(crate) fn mrope_enabled() -> bool { pub(crate) fn mrope_enabled() -> bool {
std::env::var("NEURON_MROPE") std::env::var("NEURON_MROPE")
.map(|v| { .map(|v| {
matches!( !matches!(
v.trim().to_ascii_lowercase().as_str(), v.trim().to_ascii_lowercase().as_str(),
"1" | "true" | "yes" | "on" "0" | "false" | "no" | "off"
) )
}) })
.unwrap_or(false) .unwrap_or(true)
} }
/// Position ids for the forward path. Gated by [`mrope_enabled`]: when /// Position ids for the forward path. Gated by [`mrope_enabled`]: when
@@ -474,16 +473,17 @@ mod tests {
} }
#[test] #[test]
fn get_rope_index_gated_off_by_default() { fn get_rope_index_on_by_default() {
// With NEURON_MROPE unset (default), the runtime path returns // With NEURON_MROPE unset (default ON), the runtime path returns
// plain sequential identity positions → mrope_cos_sin reduces to // the real interleaved-M-RoPE positions, so image tokens carry
// plain RoPE. (compute_mrope_index, tested above, is the real // their 2D grid coords (height differs from the text counter).
// computation used when the flag is on.) // (NEURON_MROPE=0 would fall back to identity; not asserted here
let (t, h, w, delta) = get_rope_index(&[1, 99, 99, 99, 99, 2], 99).unwrap(); // since it depends on env.)
assert_eq!(t, vec![0, 1, 2, 3, 4, 5]); let (t, h, w, _delta) = get_rope_index(&[1, 99, 99, 99, 99, 2], 99).unwrap();
assert_eq!(h, t); // Same as compute_mrope_index: 2x2 image after one text token.
assert_eq!(w, t); assert_eq!(t, vec![0, 1, 1, 1, 1, 3]);
assert_eq!(delta, 0); assert_eq!(h, vec![0, 1, 1, 2, 2, 3]);
assert_eq!(w, vec![0, 1, 2, 1, 2, 3]);
} }
#[test] #[test]

View File

@@ -48,6 +48,31 @@ use candle_nn::var_builder::ShardedVarBuilder;
use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear}; use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear};
use serde::Deserialize; use serde::Deserialize;
fn env_truthy(name: &str) -> bool {
std::env::var(name)
.map(|v| {
matches!(
v.trim().to_ascii_lowercase().as_str(),
"1" | "true" | "yes" | "on"
)
})
.unwrap_or(false)
}
/// Legacy escape hatch: when set, use the original Stage-A sequential
/// `pos_embed` lookup instead of the bilinear grid interpolation.
/// Default off (interpolation on) — for A/B comparison only.
fn vision_legacy_pos() -> bool {
env_truthy("NEURON_VISION_LEGACY_POS")
}
/// Legacy escape hatch: when set, skip the 2D vision rotary in the ViT
/// attention (the original Stage-A behaviour). Default off (rotary on)
/// — for A/B comparison only.
fn vision_legacy_rope() -> bool {
env_truthy("NEURON_VISION_LEGACY_ROPE")
}
/// Qwen3.6 vision tower hyperparameters. Mirrors the `vision_config` /// Qwen3.6 vision tower hyperparameters. Mirrors the `vision_config`
/// block of `config.json`. Only the fields we actually need are /// block of `config.json`. Only the fields we actually need are
/// captured; serde tolerates the rest. /// captured; serde tolerates the rest.
@@ -118,10 +143,12 @@ impl VisionBlock {
}) })
} }
/// `x`: `(N, hidden_size)` un-batched. Returns same shape. /// `x`: `(N, hidden_size)` un-batched. `rotary`: optional
fn forward(&self, x: &Tensor) -> Result<Tensor> { /// `(cos, sin)` each `(N, head_dim/2)` — the 2D vision rotary applied
/// to q/k. Returns same shape.
fn forward(&self, x: &Tensor, rotary: Option<&(Tensor, Tensor)>) -> Result<Tensor> {
let attn_in = self.norm1.forward(x)?; let attn_in = self.norm1.forward(x)?;
let attn_out = self.attention(&attn_in)?; let attn_out = self.attention(&attn_in, rotary)?;
let x = x.add(&attn_out)?; let x = x.add(&attn_out)?;
let mlp_in = self.norm2.forward(&x)?; let mlp_in = self.norm2.forward(&x)?;
let mlp_out = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&mlp_in)?)?)?; let mlp_out = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&mlp_in)?)?)?;
@@ -129,8 +156,11 @@ impl VisionBlock {
} }
/// Multi-head self-attention over the patch sequence. No causal /// Multi-head self-attention over the patch sequence. No causal
/// mask — every patch attends to every other patch. /// mask — every patch attends to every other patch. When `rotary` is
fn attention(&self, x: &Tensor) -> Result<Tensor> { /// given, the 2D vision rotary (row/col position) is applied to q, k
/// before the scores, matching HF `apply_rotary_pos_emb_vision`
/// (`rope_slow` is the same rotate-half form).
fn attention(&self, x: &Tensor, rotary: Option<&(Tensor, Tensor)>) -> Result<Tensor> {
let (n, hidden) = x.dims2()?; let (n, hidden) = x.dims2()?;
// qkv: (N, 3*hidden). Split into Q, K, V each (N, hidden). // qkv: (N, 3*hidden). Split into Q, K, V each (N, hidden).
let qkv = self.qkv.forward(x)?; let qkv = self.qkv.forward(x)?;
@@ -140,6 +170,15 @@ impl VisionBlock {
let q = qkv.i(0)?; let q = qkv.i(0)?;
let k = qkv.i(1)?; let k = qkv.i(1)?;
let v = qkv.i(2)?; let v = qkv.i(2)?;
// 2D vision rotary on q, k (full head_dim; rotate-half form).
let (q, k) = match rotary {
Some((cos, sin)) => {
let q = candle_nn::rotary_emb::rope_slow(&q.unsqueeze(0)?, cos, sin)?.squeeze(0)?;
let k = candle_nn::rotary_emb::rope_slow(&k.unsqueeze(0)?, cos, sin)?.squeeze(0)?;
(q, k)
}
None => (q, k),
};
let scale = 1.0 / (self.head_dim as f64).sqrt(); let scale = 1.0 / (self.head_dim as f64).sqrt();
// (num_heads, N, head_dim) @ (num_heads, head_dim, N) -> (num_heads, N, N) // (num_heads, N, head_dim) @ (num_heads, head_dim, N) -> (num_heads, N, N)
let scores = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?; let scores = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
@@ -210,11 +249,65 @@ impl VisionMerger {
} }
} }
/// 2D rotary position embedding for the vision tower. Each patch's
/// `head_dim` rotates by its `(row, col)` grid coordinates: the first
/// half of the rotary freqs are driven by the row position, the second
/// half by the column. Mirrors HF `Qwen3VLVisionRotaryEmbedding` +
/// `rot_pos_emb` (θ = 10000, `dim = head_dim/2`).
struct VisionRotaryEmbedding {
/// `(half,)` f32, `half = head_dim/4` freqs per spatial axis.
inv_freq: Vec<f32>,
}
impl VisionRotaryEmbedding {
fn new(head_dim: usize) -> Self {
// HF: Qwen3VLVisionRotaryEmbedding(head_dim // 2), theta 10000.
let dim = head_dim / 2;
let theta = 10000f32;
let inv_freq = (0..dim)
.step_by(2)
.map(|i| 1f32 / theta.powf(i as f32 / dim as f32))
.collect();
Self { inv_freq }
}
/// cos/sin for a `gh×gw` patch grid in **row-major** order. Returns
/// `(cos, sin)` each `(gh*gw, head_dim/2)`: per patch, the row-axis
/// freqs `row·inv_freq` followed by the col-axis freqs `col·inv_freq`
/// (then `rope_slow` duplicates them across the full head_dim).
fn cos_sin(
&self,
gh: usize,
gw: usize,
dev: &Device,
dtype: DType,
) -> candle_core::Result<(Tensor, Tensor)> {
let half = self.inv_freq.len();
let n = gh * gw;
let mut data = Vec::with_capacity(n * 2 * half);
for hi in 0..gh {
for wi in 0..gw {
for &f in &self.inv_freq {
data.push(hi as f32 * f);
}
for &f in &self.inv_freq {
data.push(wi as f32 * f);
}
}
}
let freqs = Tensor::from_vec(data, (n, 2 * half), dev)?;
let cos = freqs.cos()?.to_dtype(dtype)?;
let sin = freqs.sin()?.to_dtype(dtype)?;
Ok((cos, sin))
}
}
/// The vision tower itself. /// The vision tower itself.
pub struct VisionTower { pub struct VisionTower {
/// Sum-collapsed temporal kernel (Conv2d, see module doc). /// Sum-collapsed temporal kernel (Conv2d, see module doc).
patch_embed: Conv2d, patch_embed: Conv2d,
pos_embed: Embedding, pos_embed: Embedding,
rotary: VisionRotaryEmbedding,
blocks: Vec<VisionBlock>, blocks: Vec<VisionBlock>,
merger: VisionMerger, merger: VisionMerger,
config: VisionConfig, config: VisionConfig,
@@ -265,6 +358,7 @@ impl VisionTower {
.get((cfg.num_position_embeddings, cfg.hidden_size), "weight") .get((cfg.num_position_embeddings, cfg.hidden_size), "weight")
.context("load model.visual.pos_embed.weight")?; .context("load model.visual.pos_embed.weight")?;
let pos_embed = Embedding::new(pos_embed_weight, cfg.hidden_size); let pos_embed = Embedding::new(pos_embed_weight, cfg.hidden_size);
let rotary = VisionRotaryEmbedding::new(cfg.hidden_size / cfg.num_heads);
let blocks_vb = vb.pp("blocks"); let blocks_vb = vb.pp("blocks");
let mut blocks = Vec::with_capacity(cfg.depth); let mut blocks = Vec::with_capacity(cfg.depth);
@@ -279,6 +373,7 @@ impl VisionTower {
Ok(Self { Ok(Self {
patch_embed, patch_embed,
pos_embed, pos_embed,
rotary,
blocks, blocks,
merger, merger,
config: cfg, config: cfg,
@@ -302,6 +397,81 @@ impl VisionTower {
gh * gw * LM_TOKENS_PER_MERGE_GROUP gh * gw * LM_TOKENS_PER_MERGE_GROUP
} }
/// Bilinearly interpolate the learned `pos_embed` grid (a
/// `num_grid_per_side × num_grid_per_side` table, 48×48 for Qwen3.6)
/// onto the actual `gh × gw` patch grid, in **row-major** patch
/// order. Port of the HF `fast_pos_embed_interpolate`: for each patch
/// at fractional grid coord `(linspace(0, ngrid-1, gh)[hi],
/// linspace(0, ngrid-1, gw)[wi])`, blend the 4 surrounding grid
/// entries by bilinear weights. Returns `(gh*gw, hidden)` in
/// `self.dtype`.
fn interpolated_pos_embed(&self, gh: usize, gw: usize) -> Result<Tensor> {
let ngrid = (self.config.num_position_embeddings as f64).sqrt().round() as usize;
anyhow::ensure!(
ngrid * ngrid == self.config.num_position_embeddings,
"num_position_embeddings {} is not a perfect square",
self.config.num_position_embeddings
);
// Evenly-spaced fractional indices into the [0, ngrid-1] grid.
let lin = |n: usize| -> Vec<f64> {
if n <= 1 {
vec![0.0]
} else {
let step = (ngrid - 1) as f64 / (n - 1) as f64;
(0..n).map(|i| i as f64 * step).collect()
}
};
let hs = lin(gh);
let ws = lin(gw);
let n = gh * gw;
// Four corner index sets + bilinear weight sets, row-major.
let mut idx: [Vec<u32>; 4] = [
Vec::with_capacity(n),
Vec::with_capacity(n),
Vec::with_capacity(n),
Vec::with_capacity(n),
];
let mut wts: [Vec<f32>; 4] = [
Vec::with_capacity(n),
Vec::with_capacity(n),
Vec::with_capacity(n),
Vec::with_capacity(n),
];
for &hv in &hs {
let hf = hv as usize; // floor (hv >= 0)
let hc = (hf + 1).min(ngrid - 1);
let dh = (hv - hf as f64) as f32;
for &wv in &ws {
let wf = wv as usize;
let wc = (wf + 1).min(ngrid - 1);
let dw = (wv - wf as f64) as f32;
idx[0].push((hf * ngrid + wf) as u32);
wts[0].push((1.0 - dh) * (1.0 - dw));
idx[1].push((hf * ngrid + wc) as u32);
wts[1].push((1.0 - dh) * dw);
idx[2].push((hc * ngrid + wf) as u32);
wts[2].push(dh * (1.0 - dw));
idx[3].push((hc * ngrid + wc) as u32);
wts[3].push(dh * dw);
}
}
let mut acc: Option<Tensor> = None;
for corner in 0..4 {
let idx_t = Tensor::from_vec(std::mem::take(&mut idx[corner]), (n,), &self.device)?;
let emb = self.pos_embed.forward(&idx_t)?; // (n, hidden), pos_embed dtype
let wt = Tensor::from_vec(std::mem::take(&mut wts[corner]), (n, 1), &self.device)?
.to_dtype(self.dtype)?;
let term = emb.broadcast_mul(&wt)?;
acc = Some(match acc {
Some(a) => a.add(&term)?,
None => term,
});
}
Ok(acc.expect("4 corners accumulated"))
}
/// Encode one image. /// Encode one image.
/// ///
/// `image`: row-major `(3, H, W)` f32 tensor on `self.device`, /// `image`: row-major `(3, H, W)` f32 tensor on `self.device`,
@@ -339,16 +509,34 @@ impl VisionTower {
let x = x.permute((1, 2, 0))?.contiguous()?; let x = x.permute((1, 2, 0))?.contiguous()?;
let x = x.reshape((n_patches, self.config.hidden_size))?; let x = x.reshape((n_patches, self.config.hidden_size))?;
// Add learned positional embeddings (sequential indices for // Learned absolute position embeddings. The `pos_embed` table is
// Stage A's fixed-resolution path; full 2D positional logic // a `num_position_embeddings = num_grid_per_side²` learned grid
// lands with variable resolution, issue #14). // (48×48 for Qwen3.6); for a `gh×gw` patch grid the reference
let positions = Tensor::arange(0u32, n_patches as u32, &self.device)?; // (`fast_pos_embed_interpolate`) bilinearly interpolates that
let pos = self.pos_embed.forward(&positions)?; // grid to `gh×gw`. The legacy path (a naive sequential lookup of
// the first `n_patches` rows) mis-maps the grid stride and
// scrambles spatial structure — kept only behind
// `NEURON_VISION_LEGACY_POS=1` for A/B comparison.
let pos = if vision_legacy_pos() {
let positions = Tensor::arange(0u32, n_patches as u32, &self.device)?;
self.pos_embed.forward(&positions)?
} else {
self.interpolated_pos_embed(gh, gw)?
};
let mut x = x.add(&pos)?; let mut x = x.add(&pos)?;
// 2D vision rotary (row/col per patch), computed once and applied
// in every block's attention. Legacy escape hatch skips it.
let rotary = if vision_legacy_rope() {
None
} else {
Some(self.rotary.cos_sin(gh, gw, &self.device, self.dtype)?)
};
let rotary_ref = rotary.as_ref();
for (i, block) in self.blocks.iter().enumerate() { for (i, block) in self.blocks.iter().enumerate() {
x = block x = block
.forward(&x) .forward(&x, rotary_ref)
.with_context(|| format!("vision block {i}"))?; .with_context(|| format!("vision block {i}"))?;
} }
@@ -516,9 +704,11 @@ mod tests {
spatial_merge_size: cfg.spatial_merge_size, spatial_merge_size: cfg.spatial_merge_size,
}; };
let rotary = VisionRotaryEmbedding::new(cfg.hidden_size / cfg.num_heads);
VisionTower { VisionTower {
patch_embed, patch_embed,
pos_embed, pos_embed,
rotary,
blocks, blocks,
merger, merger,
config: cfg.clone(), config: cfg.clone(),
@@ -548,6 +738,51 @@ mod tests {
); );
} }
#[test]
fn interpolated_pos_embed_reduces_to_sequential_at_native_grid() {
// When the patch grid equals the pos_embed grid (gh=gw=ngrid),
// linspace(0,ngrid-1,ngrid) is the integer ladder, so every patch
// lands exactly on a grid node (dh=dw=0, corner-0 weight 1) and
// the bilinear result is the raw pos_embed rows in row-major
// order — i.e. identical to the legacy sequential lookup.
let cfg = tiny_config();
let tower = tiny_tower(&cfg);
let ngrid = (cfg.num_position_embeddings as f64).sqrt() as usize; // 8
let interp = tower.interpolated_pos_embed(ngrid, ngrid).unwrap();
let seq = tower
.pos_embed
.forward(&Tensor::arange(0u32, (ngrid * ngrid) as u32, &Device::Cpu).unwrap())
.unwrap();
let a: Vec<f32> = interp.flatten_all().unwrap().to_vec1().unwrap();
let b: Vec<f32> = seq.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(a.len(), b.len());
for (x, y) in a.iter().zip(b.iter()) {
assert!((x - y).abs() < 1e-5, "interp {x} vs seq {y}");
}
}
#[test]
fn vision_rotary_row_col_structure() {
// head_dim 8 → rotary dim 4 → inv_freq over [0,2] → 2 freqs/axis.
let rot = VisionRotaryEmbedding::new(8);
assert_eq!(rot.inv_freq.len(), 2);
let (cos, sin) = rot.cos_sin(2, 2, &Device::Cpu, DType::F32).unwrap();
assert_eq!(cos.dims(), &[4, 4]); // 4 patches, head_dim/2 = 4 cols
// Patch (0,0): all freqs 0 → cos 1, sin 0.
let s0: Vec<f32> = sin.i(0).unwrap().to_vec1().unwrap();
assert!(s0.iter().all(|&s| s.abs() < 1e-6));
// Patch index 2 = grid (1,0): row=1 drives the first half, col=0
// leaves the second half at zero.
let s2: Vec<f32> = sin.i(2).unwrap().to_vec1().unwrap();
assert!(s2[0].abs() > 1e-6, "row half must be non-zero");
assert!(
s2[2].abs() < 1e-6 && s2[3].abs() < 1e-6,
"col half must be zero"
);
}
#[test] #[test]
fn lm_token_count_matches_grid() { fn lm_token_count_matches_grid() {
let cfg = tiny_config(); let cfg = tiny_config();