fix(neuron): vision-tower 2D positions + M-RoPE default on
All checks were successful
CI / CUDA type-check (push) Successful in 32s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Format (push) Successful in 33s
CI / Clippy (push) Successful in 2m36s
build-prerelease / Build cortex binary (push) Successful in 4m48s
build-prerelease / Build neuron-blackwell (push) Successful in 5m59s
CI / Test (push) Successful in 6m35s
build-prerelease / Build neuron-ampere (push) Successful in 7m51s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ada (push) Successful in 5m13s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m0s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m5s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m49s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m6s
All checks were successful
CI / CUDA type-check (push) Successful in 32s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Format (push) Successful in 33s
CI / Clippy (push) Successful in 2m36s
build-prerelease / Build cortex binary (push) Successful in 4m48s
build-prerelease / Build neuron-blackwell (push) Successful in 5m59s
CI / Test (push) Successful in 6m35s
build-prerelease / Build neuron-ampere (push) Successful in 7m51s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ada (push) Successful in 5m13s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m0s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m5s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m49s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m6s
Two fixes to the spatial handling of images, validated against the HF transformers 4.57.1 qwen3_vl reference on beast. **Vision tower (the real cause of poor spatial vision).** The Stage-A tower encoded position two ways wrong, so the model saw image *content* but not *layout* (a row of 5 people read as "a line of 23", sky inverted), regardless of the LM-side rope: - Learned pos-embed was a naive sequential lookup of the first `n_patches` rows of the 48×48 (`num_position_embeddings=2304`) grid — wrong stride for a 28×28 patch grid. Now bilinearly interpolates the grid to `gh×gw` (port of HF `fast_pos_embed_interpolate`), row-major. - The 2D vision rotary was absent entirely. Added `VisionRotaryEmbedding` (θ=10000, dim=head_dim/2) applying per-patch `(row, col)` rotary to q/k in every ViT block via rope_slow, matching HF `apply_rotary_pos_emb_vision`. Both default on; `NEURON_VISION_LEGACY_POS=1` / `NEURON_VISION_LEGACY_ROPE=1` revert each for A/B (no rebuild). New unit tests: interpolation reduces to the sequential lookup at the native grid; rotary row/col structure. **M-RoPE default on.** The interleaved M-RoPE matches HF apply_interleaved_mrope / get_rope_index exactly and A/B'd strictly ≥ plain. `NEURON_MROPE` is now a kill switch (`=0` for plain), not opt-in — defaults should encode the model's trained behaviour, not freeze the broken state. Vision tower is plain candle (CPU-testable): built, clippy-clean, full workspace tests green locally. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -240,21 +240,20 @@ impl RotaryEmbedding {
|
||||
/// (position-compressed) image blocks.
|
||||
///
|
||||
/// Whether interleaved M-RoPE for image tokens is enabled. Default
|
||||
/// **off** — the initial implementation degraded image understanding
|
||||
/// (the model misread spatial layout and rambled), so until the
|
||||
/// interleave is validated against a numerical reference it is opt-in
|
||||
/// via `NEURON_MROPE=1`. When off, image tokens get plain sequential
|
||||
/// positions (the pre-M-RoPE behaviour: content recognition works,
|
||||
/// spatial layout is approximate).
|
||||
/// **on** — Qwen3.6 was trained with interleaved M-RoPE, and this
|
||||
/// implementation matches the HF `apply_interleaved_mrope` /
|
||||
/// `get_rope_index` reference exactly (verified column-for-column). The
|
||||
/// env var is a **kill switch**: `NEURON_MROPE=0` falls back to plain
|
||||
/// sequential positions for image tokens (the pre-M-RoPE behaviour).
|
||||
pub(crate) fn mrope_enabled() -> bool {
|
||||
std::env::var("NEURON_MROPE")
|
||||
.map(|v| {
|
||||
matches!(
|
||||
!matches!(
|
||||
v.trim().to_ascii_lowercase().as_str(),
|
||||
"1" | "true" | "yes" | "on"
|
||||
"0" | "false" | "no" | "off"
|
||||
)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Position ids for the forward path. Gated by [`mrope_enabled`]: when
|
||||
@@ -474,16 +473,17 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_rope_index_gated_off_by_default() {
|
||||
// With NEURON_MROPE unset (default), the runtime path returns
|
||||
// plain sequential identity positions → mrope_cos_sin reduces to
|
||||
// plain RoPE. (compute_mrope_index, tested above, is the real
|
||||
// computation used when the flag is on.)
|
||||
let (t, h, w, delta) = get_rope_index(&[1, 99, 99, 99, 99, 2], 99).unwrap();
|
||||
assert_eq!(t, vec![0, 1, 2, 3, 4, 5]);
|
||||
assert_eq!(h, t);
|
||||
assert_eq!(w, t);
|
||||
assert_eq!(delta, 0);
|
||||
fn get_rope_index_on_by_default() {
|
||||
// With NEURON_MROPE unset (default ON), the runtime path returns
|
||||
// the real interleaved-M-RoPE positions, so image tokens carry
|
||||
// their 2D grid coords (height differs from the text counter).
|
||||
// (NEURON_MROPE=0 would fall back to identity; not asserted here
|
||||
// since it depends on env.)
|
||||
let (t, h, w, _delta) = get_rope_index(&[1, 99, 99, 99, 99, 2], 99).unwrap();
|
||||
// Same as compute_mrope_index: 2x2 image after one text token.
|
||||
assert_eq!(t, vec![0, 1, 1, 1, 1, 3]);
|
||||
assert_eq!(h, vec![0, 1, 1, 2, 2, 3]);
|
||||
assert_eq!(w, vec![0, 1, 2, 1, 2, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -48,6 +48,31 @@ use candle_nn::var_builder::ShardedVarBuilder;
|
||||
use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear};
|
||||
use serde::Deserialize;
|
||||
|
||||
fn env_truthy(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.map(|v| {
|
||||
matches!(
|
||||
v.trim().to_ascii_lowercase().as_str(),
|
||||
"1" | "true" | "yes" | "on"
|
||||
)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Legacy escape hatch: when set, use the original Stage-A sequential
|
||||
/// `pos_embed` lookup instead of the bilinear grid interpolation.
|
||||
/// Default off (interpolation on) — for A/B comparison only.
|
||||
fn vision_legacy_pos() -> bool {
|
||||
env_truthy("NEURON_VISION_LEGACY_POS")
|
||||
}
|
||||
|
||||
/// Legacy escape hatch: when set, skip the 2D vision rotary in the ViT
|
||||
/// attention (the original Stage-A behaviour). Default off (rotary on)
|
||||
/// — for A/B comparison only.
|
||||
fn vision_legacy_rope() -> bool {
|
||||
env_truthy("NEURON_VISION_LEGACY_ROPE")
|
||||
}
|
||||
|
||||
/// Qwen3.6 vision tower hyperparameters. Mirrors the `vision_config`
|
||||
/// block of `config.json`. Only the fields we actually need are
|
||||
/// captured; serde tolerates the rest.
|
||||
@@ -118,10 +143,12 @@ impl VisionBlock {
|
||||
})
|
||||
}
|
||||
|
||||
/// `x`: `(N, hidden_size)` un-batched. Returns same shape.
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
/// `x`: `(N, hidden_size)` un-batched. `rotary`: optional
|
||||
/// `(cos, sin)` each `(N, head_dim/2)` — the 2D vision rotary applied
|
||||
/// to q/k. Returns same shape.
|
||||
fn forward(&self, x: &Tensor, rotary: Option<&(Tensor, Tensor)>) -> Result<Tensor> {
|
||||
let attn_in = self.norm1.forward(x)?;
|
||||
let attn_out = self.attention(&attn_in)?;
|
||||
let attn_out = self.attention(&attn_in, rotary)?;
|
||||
let x = x.add(&attn_out)?;
|
||||
let mlp_in = self.norm2.forward(&x)?;
|
||||
let mlp_out = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&mlp_in)?)?)?;
|
||||
@@ -129,8 +156,11 @@ impl VisionBlock {
|
||||
}
|
||||
|
||||
/// Multi-head self-attention over the patch sequence. No causal
|
||||
/// mask — every patch attends to every other patch.
|
||||
fn attention(&self, x: &Tensor) -> Result<Tensor> {
|
||||
/// mask — every patch attends to every other patch. When `rotary` is
|
||||
/// given, the 2D vision rotary (row/col position) is applied to q, k
|
||||
/// before the scores, matching HF `apply_rotary_pos_emb_vision`
|
||||
/// (`rope_slow` is the same rotate-half form).
|
||||
fn attention(&self, x: &Tensor, rotary: Option<&(Tensor, Tensor)>) -> Result<Tensor> {
|
||||
let (n, hidden) = x.dims2()?;
|
||||
// qkv: (N, 3*hidden). Split into Q, K, V each (N, hidden).
|
||||
let qkv = self.qkv.forward(x)?;
|
||||
@@ -140,6 +170,15 @@ impl VisionBlock {
|
||||
let q = qkv.i(0)?;
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
// 2D vision rotary on q, k (full head_dim; rotate-half form).
|
||||
let (q, k) = match rotary {
|
||||
Some((cos, sin)) => {
|
||||
let q = candle_nn::rotary_emb::rope_slow(&q.unsqueeze(0)?, cos, sin)?.squeeze(0)?;
|
||||
let k = candle_nn::rotary_emb::rope_slow(&k.unsqueeze(0)?, cos, sin)?.squeeze(0)?;
|
||||
(q, k)
|
||||
}
|
||||
None => (q, k),
|
||||
};
|
||||
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||
// (num_heads, N, head_dim) @ (num_heads, head_dim, N) -> (num_heads, N, N)
|
||||
let scores = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
|
||||
@@ -210,11 +249,65 @@ impl VisionMerger {
|
||||
}
|
||||
}
|
||||
|
||||
/// 2D rotary position embedding for the vision tower. Each patch's
|
||||
/// `head_dim` rotates by its `(row, col)` grid coordinates: the first
|
||||
/// half of the rotary freqs are driven by the row position, the second
|
||||
/// half by the column. Mirrors HF `Qwen3VLVisionRotaryEmbedding` +
|
||||
/// `rot_pos_emb` (θ = 10000, `dim = head_dim/2`).
|
||||
struct VisionRotaryEmbedding {
|
||||
/// `(half,)` f32, `half = head_dim/4` freqs per spatial axis.
|
||||
inv_freq: Vec<f32>,
|
||||
}
|
||||
|
||||
impl VisionRotaryEmbedding {
|
||||
fn new(head_dim: usize) -> Self {
|
||||
// HF: Qwen3VLVisionRotaryEmbedding(head_dim // 2), theta 10000.
|
||||
let dim = head_dim / 2;
|
||||
let theta = 10000f32;
|
||||
let inv_freq = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / theta.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
Self { inv_freq }
|
||||
}
|
||||
|
||||
/// cos/sin for a `gh×gw` patch grid in **row-major** order. Returns
|
||||
/// `(cos, sin)` each `(gh*gw, head_dim/2)`: per patch, the row-axis
|
||||
/// freqs `row·inv_freq` followed by the col-axis freqs `col·inv_freq`
|
||||
/// (then `rope_slow` duplicates them across the full head_dim).
|
||||
fn cos_sin(
|
||||
&self,
|
||||
gh: usize,
|
||||
gw: usize,
|
||||
dev: &Device,
|
||||
dtype: DType,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let half = self.inv_freq.len();
|
||||
let n = gh * gw;
|
||||
let mut data = Vec::with_capacity(n * 2 * half);
|
||||
for hi in 0..gh {
|
||||
for wi in 0..gw {
|
||||
for &f in &self.inv_freq {
|
||||
data.push(hi as f32 * f);
|
||||
}
|
||||
for &f in &self.inv_freq {
|
||||
data.push(wi as f32 * f);
|
||||
}
|
||||
}
|
||||
}
|
||||
let freqs = Tensor::from_vec(data, (n, 2 * half), dev)?;
|
||||
let cos = freqs.cos()?.to_dtype(dtype)?;
|
||||
let sin = freqs.sin()?.to_dtype(dtype)?;
|
||||
Ok((cos, sin))
|
||||
}
|
||||
}
|
||||
|
||||
/// The vision tower itself.
|
||||
pub struct VisionTower {
|
||||
/// Sum-collapsed temporal kernel (Conv2d, see module doc).
|
||||
patch_embed: Conv2d,
|
||||
pos_embed: Embedding,
|
||||
rotary: VisionRotaryEmbedding,
|
||||
blocks: Vec<VisionBlock>,
|
||||
merger: VisionMerger,
|
||||
config: VisionConfig,
|
||||
@@ -265,6 +358,7 @@ impl VisionTower {
|
||||
.get((cfg.num_position_embeddings, cfg.hidden_size), "weight")
|
||||
.context("load model.visual.pos_embed.weight")?;
|
||||
let pos_embed = Embedding::new(pos_embed_weight, cfg.hidden_size);
|
||||
let rotary = VisionRotaryEmbedding::new(cfg.hidden_size / cfg.num_heads);
|
||||
|
||||
let blocks_vb = vb.pp("blocks");
|
||||
let mut blocks = Vec::with_capacity(cfg.depth);
|
||||
@@ -279,6 +373,7 @@ impl VisionTower {
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
pos_embed,
|
||||
rotary,
|
||||
blocks,
|
||||
merger,
|
||||
config: cfg,
|
||||
@@ -302,6 +397,81 @@ impl VisionTower {
|
||||
gh * gw * LM_TOKENS_PER_MERGE_GROUP
|
||||
}
|
||||
|
||||
/// Bilinearly interpolate the learned `pos_embed` grid (a
|
||||
/// `num_grid_per_side × num_grid_per_side` table, 48×48 for Qwen3.6)
|
||||
/// onto the actual `gh × gw` patch grid, in **row-major** patch
|
||||
/// order. Port of the HF `fast_pos_embed_interpolate`: for each patch
|
||||
/// at fractional grid coord `(linspace(0, ngrid-1, gh)[hi],
|
||||
/// linspace(0, ngrid-1, gw)[wi])`, blend the 4 surrounding grid
|
||||
/// entries by bilinear weights. Returns `(gh*gw, hidden)` in
|
||||
/// `self.dtype`.
|
||||
fn interpolated_pos_embed(&self, gh: usize, gw: usize) -> Result<Tensor> {
|
||||
let ngrid = (self.config.num_position_embeddings as f64).sqrt().round() as usize;
|
||||
anyhow::ensure!(
|
||||
ngrid * ngrid == self.config.num_position_embeddings,
|
||||
"num_position_embeddings {} is not a perfect square",
|
||||
self.config.num_position_embeddings
|
||||
);
|
||||
// Evenly-spaced fractional indices into the [0, ngrid-1] grid.
|
||||
let lin = |n: usize| -> Vec<f64> {
|
||||
if n <= 1 {
|
||||
vec![0.0]
|
||||
} else {
|
||||
let step = (ngrid - 1) as f64 / (n - 1) as f64;
|
||||
(0..n).map(|i| i as f64 * step).collect()
|
||||
}
|
||||
};
|
||||
let hs = lin(gh);
|
||||
let ws = lin(gw);
|
||||
let n = gh * gw;
|
||||
|
||||
// Four corner index sets + bilinear weight sets, row-major.
|
||||
let mut idx: [Vec<u32>; 4] = [
|
||||
Vec::with_capacity(n),
|
||||
Vec::with_capacity(n),
|
||||
Vec::with_capacity(n),
|
||||
Vec::with_capacity(n),
|
||||
];
|
||||
let mut wts: [Vec<f32>; 4] = [
|
||||
Vec::with_capacity(n),
|
||||
Vec::with_capacity(n),
|
||||
Vec::with_capacity(n),
|
||||
Vec::with_capacity(n),
|
||||
];
|
||||
for &hv in &hs {
|
||||
let hf = hv as usize; // floor (hv >= 0)
|
||||
let hc = (hf + 1).min(ngrid - 1);
|
||||
let dh = (hv - hf as f64) as f32;
|
||||
for &wv in &ws {
|
||||
let wf = wv as usize;
|
||||
let wc = (wf + 1).min(ngrid - 1);
|
||||
let dw = (wv - wf as f64) as f32;
|
||||
idx[0].push((hf * ngrid + wf) as u32);
|
||||
wts[0].push((1.0 - dh) * (1.0 - dw));
|
||||
idx[1].push((hf * ngrid + wc) as u32);
|
||||
wts[1].push((1.0 - dh) * dw);
|
||||
idx[2].push((hc * ngrid + wf) as u32);
|
||||
wts[2].push(dh * (1.0 - dw));
|
||||
idx[3].push((hc * ngrid + wc) as u32);
|
||||
wts[3].push(dh * dw);
|
||||
}
|
||||
}
|
||||
|
||||
let mut acc: Option<Tensor> = None;
|
||||
for corner in 0..4 {
|
||||
let idx_t = Tensor::from_vec(std::mem::take(&mut idx[corner]), (n,), &self.device)?;
|
||||
let emb = self.pos_embed.forward(&idx_t)?; // (n, hidden), pos_embed dtype
|
||||
let wt = Tensor::from_vec(std::mem::take(&mut wts[corner]), (n, 1), &self.device)?
|
||||
.to_dtype(self.dtype)?;
|
||||
let term = emb.broadcast_mul(&wt)?;
|
||||
acc = Some(match acc {
|
||||
Some(a) => a.add(&term)?,
|
||||
None => term,
|
||||
});
|
||||
}
|
||||
Ok(acc.expect("4 corners accumulated"))
|
||||
}
|
||||
|
||||
/// Encode one image.
|
||||
///
|
||||
/// `image`: row-major `(3, H, W)` f32 tensor on `self.device`,
|
||||
@@ -339,16 +509,34 @@ impl VisionTower {
|
||||
let x = x.permute((1, 2, 0))?.contiguous()?;
|
||||
let x = x.reshape((n_patches, self.config.hidden_size))?;
|
||||
|
||||
// Add learned positional embeddings (sequential indices for
|
||||
// Stage A's fixed-resolution path; full 2D positional logic
|
||||
// lands with variable resolution, issue #14).
|
||||
// Learned absolute position embeddings. The `pos_embed` table is
|
||||
// a `num_position_embeddings = num_grid_per_side²` learned grid
|
||||
// (48×48 for Qwen3.6); for a `gh×gw` patch grid the reference
|
||||
// (`fast_pos_embed_interpolate`) bilinearly interpolates that
|
||||
// grid to `gh×gw`. The legacy path (a naive sequential lookup of
|
||||
// the first `n_patches` rows) mis-maps the grid stride and
|
||||
// scrambles spatial structure — kept only behind
|
||||
// `NEURON_VISION_LEGACY_POS=1` for A/B comparison.
|
||||
let pos = if vision_legacy_pos() {
|
||||
let positions = Tensor::arange(0u32, n_patches as u32, &self.device)?;
|
||||
let pos = self.pos_embed.forward(&positions)?;
|
||||
self.pos_embed.forward(&positions)?
|
||||
} else {
|
||||
self.interpolated_pos_embed(gh, gw)?
|
||||
};
|
||||
let mut x = x.add(&pos)?;
|
||||
|
||||
// 2D vision rotary (row/col per patch), computed once and applied
|
||||
// in every block's attention. Legacy escape hatch skips it.
|
||||
let rotary = if vision_legacy_rope() {
|
||||
None
|
||||
} else {
|
||||
Some(self.rotary.cos_sin(gh, gw, &self.device, self.dtype)?)
|
||||
};
|
||||
let rotary_ref = rotary.as_ref();
|
||||
|
||||
for (i, block) in self.blocks.iter().enumerate() {
|
||||
x = block
|
||||
.forward(&x)
|
||||
.forward(&x, rotary_ref)
|
||||
.with_context(|| format!("vision block {i}"))?;
|
||||
}
|
||||
|
||||
@@ -516,9 +704,11 @@ mod tests {
|
||||
spatial_merge_size: cfg.spatial_merge_size,
|
||||
};
|
||||
|
||||
let rotary = VisionRotaryEmbedding::new(cfg.hidden_size / cfg.num_heads);
|
||||
VisionTower {
|
||||
patch_embed,
|
||||
pos_embed,
|
||||
rotary,
|
||||
blocks,
|
||||
merger,
|
||||
config: cfg.clone(),
|
||||
@@ -548,6 +738,51 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolated_pos_embed_reduces_to_sequential_at_native_grid() {
|
||||
// When the patch grid equals the pos_embed grid (gh=gw=ngrid),
|
||||
// linspace(0,ngrid-1,ngrid) is the integer ladder, so every patch
|
||||
// lands exactly on a grid node (dh=dw=0, corner-0 weight 1) and
|
||||
// the bilinear result is the raw pos_embed rows in row-major
|
||||
// order — i.e. identical to the legacy sequential lookup.
|
||||
let cfg = tiny_config();
|
||||
let tower = tiny_tower(&cfg);
|
||||
let ngrid = (cfg.num_position_embeddings as f64).sqrt() as usize; // 8
|
||||
let interp = tower.interpolated_pos_embed(ngrid, ngrid).unwrap();
|
||||
let seq = tower
|
||||
.pos_embed
|
||||
.forward(&Tensor::arange(0u32, (ngrid * ngrid) as u32, &Device::Cpu).unwrap())
|
||||
.unwrap();
|
||||
let a: Vec<f32> = interp.flatten_all().unwrap().to_vec1().unwrap();
|
||||
let b: Vec<f32> = seq.flatten_all().unwrap().to_vec1().unwrap();
|
||||
assert_eq!(a.len(), b.len());
|
||||
for (x, y) in a.iter().zip(b.iter()) {
|
||||
assert!((x - y).abs() < 1e-5, "interp {x} vs seq {y}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vision_rotary_row_col_structure() {
|
||||
// head_dim 8 → rotary dim 4 → inv_freq over [0,2] → 2 freqs/axis.
|
||||
let rot = VisionRotaryEmbedding::new(8);
|
||||
assert_eq!(rot.inv_freq.len(), 2);
|
||||
let (cos, sin) = rot.cos_sin(2, 2, &Device::Cpu, DType::F32).unwrap();
|
||||
assert_eq!(cos.dims(), &[4, 4]); // 4 patches, head_dim/2 = 4 cols
|
||||
|
||||
// Patch (0,0): all freqs 0 → cos 1, sin 0.
|
||||
let s0: Vec<f32> = sin.i(0).unwrap().to_vec1().unwrap();
|
||||
assert!(s0.iter().all(|&s| s.abs() < 1e-6));
|
||||
|
||||
// Patch index 2 = grid (1,0): row=1 drives the first half, col=0
|
||||
// leaves the second half at zero.
|
||||
let s2: Vec<f32> = sin.i(2).unwrap().to_vec1().unwrap();
|
||||
assert!(s2[0].abs() > 1e-6, "row half must be non-zero");
|
||||
assert!(
|
||||
s2[2].abs() < 1e-6 && s2[3].abs() < 1e-6,
|
||||
"col half must be zero"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lm_token_count_matches_grid() {
|
||||
let cfg = tiny_config();
|
||||
|
||||
Reference in New Issue
Block a user