feat(neuron): M-RoPE Stage 3 — wire interleaved M-RoPE into single-GPU
Qwen3_5Model now builds the rotary cos/sin once per forward and threads (cos, sin) through the decoder → full-attention → rope, replacing the scalar offset that reached RotaryEmbedding: - vision forward computes get_rope_index over the (single-shot) prompt, sets rope_delta, and builds interleaved-M-RoPE cos/sin so image tokens carry their 14×14 grid (height/width) positions; - text / decode take plain_cos_sin at offset + rope_delta — with rope_delta == 0 (no image) this is bit-for-bit the old plain RoPE, and the device→host id copy is skipped on the text decode hot path. rope_delta is stored on the model and reset in clear_kv_cache, so decode after a vision prefill resumes text positions from the image-compressed counter. decoder.rs / full_attn.rs take (cos, sin) instead of offset; linear-attention layers are unchanged (no RoPE). The TP path still uses the retained apply(offset) — wired in Stage 4. Full workspace tests green; the load-bearing invariant (M-RoPE == plain for equal axes) keeps text unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -93,12 +93,13 @@ impl Qwen3_5DecoderLayer {
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
attn_mask: Option<&Tensor>,
|
||||
offset: usize,
|
||||
cos: &Tensor,
|
||||
sin: &Tensor,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let h = self.input_layernorm.forward(x)?;
|
||||
let attn_out = match &mut self.attention {
|
||||
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
|
||||
// Linear attention ignores attn_mask + offset; its causal
|
||||
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, cos, sin)?,
|
||||
// Linear attention ignores attn_mask + rope; its causal
|
||||
// structure is baked into the recurrent state lifecycle.
|
||||
AttentionKind::Linear(net) => net.forward(&h)?,
|
||||
};
|
||||
|
||||
@@ -96,7 +96,8 @@ impl Qwen3_5Attention {
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
attn_mask: Option<&Tensor>,
|
||||
offset: usize,
|
||||
cos: &Tensor,
|
||||
sin: &Tensor,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (b, l, _) = x.dims3()?;
|
||||
|
||||
@@ -131,8 +132,9 @@ impl Qwen3_5Attention {
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
// 3. RoPE on q, k.
|
||||
let (q, k) = self.rotary.apply(&q, &k, offset)?;
|
||||
// 3. RoPE on q, k (cos/sin built once per forward by the model —
|
||||
// interleaved M-RoPE for image tokens, plain for text).
|
||||
let (q, k) = self.rotary.apply_cos_sin(&q, &k, cos, sin)?;
|
||||
|
||||
// 4. KV cache.
|
||||
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||
|
||||
@@ -314,6 +314,16 @@ pub struct Qwen3_5Model {
|
||||
embed_tokens: Embedding,
|
||||
layers: Vec<Qwen3_5DecoderLayer>,
|
||||
norm: Qwen3_5RmsNorm,
|
||||
/// Shared with every full-attention layer; the model uses it to
|
||||
/// build the per-forward cos/sin (interleaved M-RoPE for image
|
||||
/// tokens, plain for text) once, which the layers then apply.
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
/// `offset + rope_delta` is the text-axis position during decode.
|
||||
/// 0 for text-only; set from `get_rope_index` during a vision
|
||||
/// prefill (image tokens compress the position space, so text after
|
||||
/// the image resumes from a smaller counter than the sequence
|
||||
/// index). Reset in `clear_kv_cache`.
|
||||
rope_delta: i64,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
@@ -365,6 +375,8 @@ impl Qwen3_5Model {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
rotary,
|
||||
rope_delta: 0,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
@@ -378,6 +390,9 @@ impl Qwen3_5Model {
|
||||
for l in &mut self.layers {
|
||||
l.clear_kv_cache();
|
||||
}
|
||||
// New request → no image-compressed position offset until the
|
||||
// next vision prefill sets one.
|
||||
self.rope_delta = 0;
|
||||
}
|
||||
|
||||
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||
@@ -435,16 +450,16 @@ impl Qwen3_5Model {
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (b, l) = input.dims2()?;
|
||||
let mut h = self.embed_tokens.forward(input)?;
|
||||
// Splice image embeddings at `image_token_id` positions. The
|
||||
// caller pre-expanded the prompt so every patch token in the
|
||||
// image_embeds tensor has a matching position in `input`. We
|
||||
// index_put the rows in place.
|
||||
if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
|
||||
// Locate image-token positions in input_ids. Operate on
|
||||
// CPU since the input ids are tiny (max ~10k entries
|
||||
// including the patch expansion) and the comparison is
|
||||
// not in the per-step hot path.
|
||||
|
||||
// Vision path: splice image embeddings at `image_token_id`
|
||||
// positions and build interleaved M-RoPE cos/sin so image tokens
|
||||
// carry their 14×14 grid coordinates. Text / decode skip the
|
||||
// device→host id copy entirely and take the plain-RoPE fast path
|
||||
// — bit-for-bit the pre-M-RoPE behaviour when `rope_delta == 0`.
|
||||
let (cos, sin) = if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
|
||||
// Token ids on CPU — reused for the splice + position ids.
|
||||
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
|
||||
|
||||
let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?);
|
||||
for (idx, id) in ids.iter().enumerate() {
|
||||
if *id == tok_id {
|
||||
@@ -462,22 +477,22 @@ impl Qwen3_5Model {
|
||||
);
|
||||
}
|
||||
if !positions.is_empty() {
|
||||
// Cast image_embeds to the LM's dtype so the splice
|
||||
// produces a uniform tensor for the decoder stack.
|
||||
// Cast image_embeds to the LM's dtype, then splice the
|
||||
// contiguous `<|image_pad|>` runs in place.
|
||||
let img = img.to_dtype(self.dtype)?;
|
||||
// index_select would return the rows; we want to put.
|
||||
// candle's slice_assign with explicit positions ranges
|
||||
// doesn't exist; use scatter via index_select + an
|
||||
// accumulator: build a `(B, L, hidden)` zero tensor,
|
||||
// scatter the image rows in, then add to a masked
|
||||
// version of `h`. Simpler approach: walk positions
|
||||
// and use `slice_assign` for contiguous runs. Since
|
||||
// image_pad runs are contiguous (template emits
|
||||
// `<|vision_start|><|image_pad|>×N<|vision_end|>`),
|
||||
// we group positions and assign per run.
|
||||
h = splice_runs(&h, &img, &positions)?;
|
||||
}
|
||||
}
|
||||
|
||||
let (text, height, width, delta) = rope::get_rope_index(&ids, tok_id)
|
||||
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
|
||||
self.rope_delta = delta;
|
||||
let pos = rope::mrope_position_tensor(&text, &height, &width, &self.device)?;
|
||||
self.rotary.mrope_cos_sin(&pos)?
|
||||
} else {
|
||||
let base = (offset as i64 + self.rope_delta).max(0) as usize;
|
||||
self.rotary.plain_cos_sin(base, l)?
|
||||
};
|
||||
|
||||
// Causal mask only needed for L > 1 prefill; full-attention
|
||||
// layers consume it via broadcast_add. Linear-attention layers
|
||||
// ignore the mask.
|
||||
@@ -487,7 +502,7 @@ impl Qwen3_5Model {
|
||||
Some(self.causal_mask(b, l, offset)?)
|
||||
};
|
||||
for layer in &mut self.layers {
|
||||
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||
h = layer.forward(&h, causal.as_ref(), &cos, &sin)?;
|
||||
}
|
||||
self.norm.forward(&h)
|
||||
}
|
||||
|
||||
@@ -258,9 +258,6 @@ impl RotaryEmbedding {
|
||||
/// square with `grid_t = 1` (still image) and `grid_h = grid_w =
|
||||
/// isqrt(run_len)` — 196 → 14×14. Dynamic resolution (#14) would thread
|
||||
/// real per-image grids instead.
|
||||
// Wired into the forward path in Stage 3 (single-GPU) / Stage 4 (TP);
|
||||
// exercised by unit tests until then.
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn get_rope_index(input_ids: &[u32], image_token_id: u32) -> Result<MRopeIndex> {
|
||||
let n = input_ids.len();
|
||||
let mut text = Vec::with_capacity(n);
|
||||
@@ -313,7 +310,6 @@ pub(crate) type MRopeIndex = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
|
||||
|
||||
/// Build the `(3, seq)` position-id tensor consumed by
|
||||
/// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors.
|
||||
#[allow(dead_code)] // wired in Stage 3/4
|
||||
pub(crate) fn mrope_position_tensor(
|
||||
text: &[i64],
|
||||
height: &[i64],
|
||||
|
||||
Reference in New Issue
Block a user