feat(neuron): M-RoPE Stage 3 — wire interleaved M-RoPE into single-GPU

Qwen3_5Model now builds the rotary cos/sin once per forward and threads
(cos, sin) through the decoder → full-attention → rope, replacing the
scalar offset that reached RotaryEmbedding:

- vision forward computes get_rope_index over the (single-shot) prompt,
  sets rope_delta, and builds interleaved-M-RoPE cos/sin so image tokens
  carry their 14×14 grid (height/width) positions;
- text / decode take plain_cos_sin at offset + rope_delta — with
  rope_delta == 0 (no image) this is bit-for-bit the old plain RoPE, and
  the device→host id copy is skipped on the text decode hot path.

rope_delta is stored on the model and reset in clear_kv_cache, so decode
after a vision prefill resumes text positions from the image-compressed
counter. decoder.rs / full_attn.rs take (cos, sin) instead of offset;
linear-attention layers are unchanged (no RoPE). The TP path still uses
the retained apply(offset) — wired in Stage 4.

Full workspace tests green; the load-bearing invariant (M-RoPE == plain
for equal axes) keeps text unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-04 18:39:52 +03:00
parent ba1b5ba408
commit 4c12c7e2f0
4 changed files with 47 additions and 33 deletions

View File

@@ -93,12 +93,13 @@ impl Qwen3_5DecoderLayer {
&mut self, &mut self,
x: &Tensor, x: &Tensor,
attn_mask: Option<&Tensor>, attn_mask: Option<&Tensor>,
offset: usize, cos: &Tensor,
sin: &Tensor,
) -> candle_core::Result<Tensor> { ) -> candle_core::Result<Tensor> {
let h = self.input_layernorm.forward(x)?; let h = self.input_layernorm.forward(x)?;
let attn_out = match &mut self.attention { let attn_out = match &mut self.attention {
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?, AttentionKind::Full(attn) => attn.forward(&h, attn_mask, cos, sin)?,
// Linear attention ignores attn_mask + offset; its causal // Linear attention ignores attn_mask + rope; its causal
// structure is baked into the recurrent state lifecycle. // structure is baked into the recurrent state lifecycle.
AttentionKind::Linear(net) => net.forward(&h)?, AttentionKind::Linear(net) => net.forward(&h)?,
}; };

View File

@@ -96,7 +96,8 @@ impl Qwen3_5Attention {
&mut self, &mut self,
x: &Tensor, x: &Tensor,
attn_mask: Option<&Tensor>, attn_mask: Option<&Tensor>,
offset: usize, cos: &Tensor,
sin: &Tensor,
) -> candle_core::Result<Tensor> { ) -> candle_core::Result<Tensor> {
let (b, l, _) = x.dims3()?; let (b, l, _) = x.dims3()?;
@@ -131,8 +132,9 @@ impl Qwen3_5Attention {
.transpose(1, 2)? .transpose(1, 2)?
.contiguous()?; .contiguous()?;
// 3. RoPE on q, k. // 3. RoPE on q, k (cos/sin built once per forward by the model —
let (q, k) = self.rotary.apply(&q, &k, offset)?; // interleaved M-RoPE for image tokens, plain for text).
let (q, k) = self.rotary.apply_cos_sin(&q, &k, cos, sin)?;
// 4. KV cache. // 4. KV cache.
let (k, v) = self.kv_cache.append(&k, &v)?; let (k, v) = self.kv_cache.append(&k, &v)?;

View File

@@ -314,6 +314,16 @@ pub struct Qwen3_5Model {
embed_tokens: Embedding, embed_tokens: Embedding,
layers: Vec<Qwen3_5DecoderLayer>, layers: Vec<Qwen3_5DecoderLayer>,
norm: Qwen3_5RmsNorm, norm: Qwen3_5RmsNorm,
/// Shared with every full-attention layer; the model uses it to
/// build the per-forward cos/sin (interleaved M-RoPE for image
/// tokens, plain for text) once, which the layers then apply.
rotary: Arc<RotaryEmbedding>,
/// `offset + rope_delta` is the text-axis position during decode.
/// 0 for text-only; set from `get_rope_index` during a vision
/// prefill (image tokens compress the position space, so text after
/// the image resumes from a smaller counter than the sequence
/// index). Reset in `clear_kv_cache`.
rope_delta: i64,
device: Device, device: Device,
dtype: DType, dtype: DType,
} }
@@ -365,6 +375,8 @@ impl Qwen3_5Model {
embed_tokens, embed_tokens,
layers, layers,
norm, norm,
rotary,
rope_delta: 0,
device, device,
dtype, dtype,
}) })
@@ -378,6 +390,9 @@ impl Qwen3_5Model {
for l in &mut self.layers { for l in &mut self.layers {
l.clear_kv_cache(); l.clear_kv_cache();
} }
// New request → no image-compressed position offset until the
// next vision prefill sets one.
self.rope_delta = 0;
} }
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> { fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
@@ -435,16 +450,16 @@ impl Qwen3_5Model {
) -> candle_core::Result<Tensor> { ) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?; let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?; let mut h = self.embed_tokens.forward(input)?;
// Splice image embeddings at `image_token_id` positions. The
// caller pre-expanded the prompt so every patch token in the // Vision path: splice image embeddings at `image_token_id`
// image_embeds tensor has a matching position in `input`. We // positions and build interleaved M-RoPE cos/sin so image tokens
// index_put the rows in place. // carry their 14×14 grid coordinates. Text / decode skip the
if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) { // device→host id copy entirely and take the plain-RoPE fast path
// Locate image-token positions in input_ids. Operate on // — bit-for-bit the pre-M-RoPE behaviour when `rope_delta == 0`.
// CPU since the input ids are tiny (max ~10k entries let (cos, sin) = if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
// including the patch expansion) and the comparison is // Token ids on CPU — reused for the splice + position ids.
// not in the per-step hot path.
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?; let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?); let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?);
for (idx, id) in ids.iter().enumerate() { for (idx, id) in ids.iter().enumerate() {
if *id == tok_id { if *id == tok_id {
@@ -462,22 +477,22 @@ impl Qwen3_5Model {
); );
} }
if !positions.is_empty() { if !positions.is_empty() {
// Cast image_embeds to the LM's dtype so the splice // Cast image_embeds to the LM's dtype, then splice the
// produces a uniform tensor for the decoder stack. // contiguous `<|image_pad|>` runs in place.
let img = img.to_dtype(self.dtype)?; let img = img.to_dtype(self.dtype)?;
// index_select would return the rows; we want to put.
// candle's slice_assign with explicit positions ranges
// doesn't exist; use scatter via index_select + an
// accumulator: build a `(B, L, hidden)` zero tensor,
// scatter the image rows in, then add to a masked
// version of `h`. Simpler approach: walk positions
// and use `slice_assign` for contiguous runs. Since
// image_pad runs are contiguous (template emits
// `<|vision_start|><|image_pad|>×N<|vision_end|>`),
// we group positions and assign per run.
h = splice_runs(&h, &img, &positions)?; h = splice_runs(&h, &img, &positions)?;
} }
}
let (text, height, width, delta) = rope::get_rope_index(&ids, tok_id)
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
self.rope_delta = delta;
let pos = rope::mrope_position_tensor(&text, &height, &width, &self.device)?;
self.rotary.mrope_cos_sin(&pos)?
} else {
let base = (offset as i64 + self.rope_delta).max(0) as usize;
self.rotary.plain_cos_sin(base, l)?
};
// Causal mask only needed for L > 1 prefill; full-attention // Causal mask only needed for L > 1 prefill; full-attention
// layers consume it via broadcast_add. Linear-attention layers // layers consume it via broadcast_add. Linear-attention layers
// ignore the mask. // ignore the mask.
@@ -487,7 +502,7 @@ impl Qwen3_5Model {
Some(self.causal_mask(b, l, offset)?) Some(self.causal_mask(b, l, offset)?)
}; };
for layer in &mut self.layers { for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?; h = layer.forward(&h, causal.as_ref(), &cos, &sin)?;
} }
self.norm.forward(&h) self.norm.forward(&h)
} }

View File

@@ -258,9 +258,6 @@ impl RotaryEmbedding {
/// square with `grid_t = 1` (still image) and `grid_h = grid_w = /// square with `grid_t = 1` (still image) and `grid_h = grid_w =
/// isqrt(run_len)` — 196 → 14×14. Dynamic resolution (#14) would thread /// isqrt(run_len)` — 196 → 14×14. Dynamic resolution (#14) would thread
/// real per-image grids instead. /// real per-image grids instead.
// Wired into the forward path in Stage 3 (single-GPU) / Stage 4 (TP);
// exercised by unit tests until then.
#[allow(dead_code)]
pub(crate) fn get_rope_index(input_ids: &[u32], image_token_id: u32) -> Result<MRopeIndex> { pub(crate) fn get_rope_index(input_ids: &[u32], image_token_id: u32) -> Result<MRopeIndex> {
let n = input_ids.len(); let n = input_ids.len();
let mut text = Vec::with_capacity(n); let mut text = Vec::with_capacity(n);
@@ -313,7 +310,6 @@ pub(crate) type MRopeIndex = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
/// Build the `(3, seq)` position-id tensor consumed by /// Build the `(3, seq)` position-id tensor consumed by
/// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors. /// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors.
#[allow(dead_code)] // wired in Stage 3/4
pub(crate) fn mrope_position_tensor( pub(crate) fn mrope_position_tensor(
text: &[i64], text: &[i64],
height: &[i64], height: &[i64],