feat(neuron): M-RoPE Stage 3 — wire interleaved M-RoPE into single-GPU

Qwen3_5Model now builds the rotary cos/sin once per forward and threads
(cos, sin) through the decoder → full-attention → rope, replacing the
scalar offset that reached RotaryEmbedding:

- vision forward computes get_rope_index over the (single-shot) prompt,
  sets rope_delta, and builds interleaved-M-RoPE cos/sin so image tokens
  carry their 14×14 grid (height/width) positions;
- text / decode take plain_cos_sin at offset + rope_delta — with
  rope_delta == 0 (no image) this is bit-for-bit the old plain RoPE, and
  the device→host id copy is skipped on the text decode hot path.

rope_delta is stored on the model and reset in clear_kv_cache, so decode
after a vision prefill resumes text positions from the image-compressed
counter. decoder.rs / full_attn.rs take (cos, sin) instead of offset;
linear-attention layers are unchanged (no RoPE). The TP path still uses
the retained apply(offset) — wired in Stage 4.

Full workspace tests green; the load-bearing invariant (M-RoPE == plain
for equal axes) keeps text unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-04 18:39:52 +03:00
parent ba1b5ba408
commit 4c12c7e2f0
4 changed files with 47 additions and 33 deletions

View File

@@ -93,12 +93,13 @@ impl Qwen3_5DecoderLayer {
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
cos: &Tensor,
sin: &Tensor,
) -> candle_core::Result<Tensor> {
let h = self.input_layernorm.forward(x)?;
let attn_out = match &mut self.attention {
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
// Linear attention ignores attn_mask + offset; its causal
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, cos, sin)?,
// Linear attention ignores attn_mask + rope; its causal
// structure is baked into the recurrent state lifecycle.
AttentionKind::Linear(net) => net.forward(&h)?,
};

View File

@@ -96,7 +96,8 @@ impl Qwen3_5Attention {
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
cos: &Tensor,
sin: &Tensor,
) -> candle_core::Result<Tensor> {
let (b, l, _) = x.dims3()?;
@@ -131,8 +132,9 @@ impl Qwen3_5Attention {
.transpose(1, 2)?
.contiguous()?;
// 3. RoPE on q, k.
let (q, k) = self.rotary.apply(&q, &k, offset)?;
// 3. RoPE on q, k (cos/sin built once per forward by the model —
// interleaved M-RoPE for image tokens, plain for text).
let (q, k) = self.rotary.apply_cos_sin(&q, &k, cos, sin)?;
// 4. KV cache.
let (k, v) = self.kv_cache.append(&k, &v)?;

View File

@@ -314,6 +314,16 @@ pub struct Qwen3_5Model {
embed_tokens: Embedding,
layers: Vec<Qwen3_5DecoderLayer>,
norm: Qwen3_5RmsNorm,
/// Shared with every full-attention layer; the model uses it to
/// build the per-forward cos/sin (interleaved M-RoPE for image
/// tokens, plain for text) once, which the layers then apply.
rotary: Arc<RotaryEmbedding>,
/// `offset + rope_delta` is the text-axis position during decode.
/// 0 for text-only; set from `get_rope_index` during a vision
/// prefill (image tokens compress the position space, so text after
/// the image resumes from a smaller counter than the sequence
/// index). Reset in `clear_kv_cache`.
rope_delta: i64,
device: Device,
dtype: DType,
}
@@ -365,6 +375,8 @@ impl Qwen3_5Model {
embed_tokens,
layers,
norm,
rotary,
rope_delta: 0,
device,
dtype,
})
@@ -378,6 +390,9 @@ impl Qwen3_5Model {
for l in &mut self.layers {
l.clear_kv_cache();
}
// New request → no image-compressed position offset until the
// next vision prefill sets one.
self.rope_delta = 0;
}
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
@@ -435,16 +450,16 @@ impl Qwen3_5Model {
) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
// Splice image embeddings at `image_token_id` positions. The
// caller pre-expanded the prompt so every patch token in the
// image_embeds tensor has a matching position in `input`. We
// index_put the rows in place.
if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
// Locate image-token positions in input_ids. Operate on
// CPU since the input ids are tiny (max ~10k entries
// including the patch expansion) and the comparison is
// not in the per-step hot path.
// Vision path: splice image embeddings at `image_token_id`
// positions and build interleaved M-RoPE cos/sin so image tokens
// carry their 14×14 grid coordinates. Text / decode skip the
// device→host id copy entirely and take the plain-RoPE fast path
// — bit-for-bit the pre-M-RoPE behaviour when `rope_delta == 0`.
let (cos, sin) = if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
// Token ids on CPU — reused for the splice + position ids.
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?);
for (idx, id) in ids.iter().enumerate() {
if *id == tok_id {
@@ -462,22 +477,22 @@ impl Qwen3_5Model {
);
}
if !positions.is_empty() {
// Cast image_embeds to the LM's dtype so the splice
// produces a uniform tensor for the decoder stack.
// Cast image_embeds to the LM's dtype, then splice the
// contiguous `<|image_pad|>` runs in place.
let img = img.to_dtype(self.dtype)?;
// index_select would return the rows; we want to put.
// candle's slice_assign with explicit positions ranges
// doesn't exist; use scatter via index_select + an
// accumulator: build a `(B, L, hidden)` zero tensor,
// scatter the image rows in, then add to a masked
// version of `h`. Simpler approach: walk positions
// and use `slice_assign` for contiguous runs. Since
// image_pad runs are contiguous (template emits
// `<|vision_start|><|image_pad|>×N<|vision_end|>`),
// we group positions and assign per run.
h = splice_runs(&h, &img, &positions)?;
}
}
let (text, height, width, delta) = rope::get_rope_index(&ids, tok_id)
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
self.rope_delta = delta;
let pos = rope::mrope_position_tensor(&text, &height, &width, &self.device)?;
self.rotary.mrope_cos_sin(&pos)?
} else {
let base = (offset as i64 + self.rope_delta).max(0) as usize;
self.rotary.plain_cos_sin(base, l)?
};
// Causal mask only needed for L > 1 prefill; full-attention
// layers consume it via broadcast_add. Linear-attention layers
// ignore the mask.
@@ -487,7 +502,7 @@ impl Qwen3_5Model {
Some(self.causal_mask(b, l, offset)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
h = layer.forward(&h, causal.as_ref(), &cos, &sin)?;
}
self.norm.forward(&h)
}

View File

@@ -258,9 +258,6 @@ impl RotaryEmbedding {
/// square with `grid_t = 1` (still image) and `grid_h = grid_w =
/// isqrt(run_len)` — 196 → 14×14. Dynamic resolution (#14) would thread
/// real per-image grids instead.
// Wired into the forward path in Stage 3 (single-GPU) / Stage 4 (TP);
// exercised by unit tests until then.
#[allow(dead_code)]
pub(crate) fn get_rope_index(input_ids: &[u32], image_token_id: u32) -> Result<MRopeIndex> {
let n = input_ids.len();
let mut text = Vec::with_capacity(n);
@@ -313,7 +310,6 @@ pub(crate) type MRopeIndex = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
/// Build the `(3, seq)` position-id tensor consumed by
/// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors.
#[allow(dead_code)] // wired in Stage 3/4
pub(crate) fn mrope_position_tensor(
text: &[i64],
height: &[i64],