feat(neuron): M-RoPE Stage 3 — wire interleaved M-RoPE into single-GPU
Qwen3_5Model now builds the rotary cos/sin once per forward and threads (cos, sin) through the decoder → full-attention → rope, replacing the scalar offset that reached RotaryEmbedding: - vision forward computes get_rope_index over the (single-shot) prompt, sets rope_delta, and builds interleaved-M-RoPE cos/sin so image tokens carry their 14×14 grid (height/width) positions; - text / decode take plain_cos_sin at offset + rope_delta — with rope_delta == 0 (no image) this is bit-for-bit the old plain RoPE, and the device→host id copy is skipped on the text decode hot path. rope_delta is stored on the model and reset in clear_kv_cache, so decode after a vision prefill resumes text positions from the image-compressed counter. decoder.rs / full_attn.rs take (cos, sin) instead of offset; linear-attention layers are unchanged (no RoPE). The TP path still uses the retained apply(offset) — wired in Stage 4. Full workspace tests green; the load-bearing invariant (M-RoPE == plain for equal axes) keeps text unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -93,12 +93,13 @@ impl Qwen3_5DecoderLayer {
|
|||||||
&mut self,
|
&mut self,
|
||||||
x: &Tensor,
|
x: &Tensor,
|
||||||
attn_mask: Option<&Tensor>,
|
attn_mask: Option<&Tensor>,
|
||||||
offset: usize,
|
cos: &Tensor,
|
||||||
|
sin: &Tensor,
|
||||||
) -> candle_core::Result<Tensor> {
|
) -> candle_core::Result<Tensor> {
|
||||||
let h = self.input_layernorm.forward(x)?;
|
let h = self.input_layernorm.forward(x)?;
|
||||||
let attn_out = match &mut self.attention {
|
let attn_out = match &mut self.attention {
|
||||||
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
|
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, cos, sin)?,
|
||||||
// Linear attention ignores attn_mask + offset; its causal
|
// Linear attention ignores attn_mask + rope; its causal
|
||||||
// structure is baked into the recurrent state lifecycle.
|
// structure is baked into the recurrent state lifecycle.
|
||||||
AttentionKind::Linear(net) => net.forward(&h)?,
|
AttentionKind::Linear(net) => net.forward(&h)?,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -96,7 +96,8 @@ impl Qwen3_5Attention {
|
|||||||
&mut self,
|
&mut self,
|
||||||
x: &Tensor,
|
x: &Tensor,
|
||||||
attn_mask: Option<&Tensor>,
|
attn_mask: Option<&Tensor>,
|
||||||
offset: usize,
|
cos: &Tensor,
|
||||||
|
sin: &Tensor,
|
||||||
) -> candle_core::Result<Tensor> {
|
) -> candle_core::Result<Tensor> {
|
||||||
let (b, l, _) = x.dims3()?;
|
let (b, l, _) = x.dims3()?;
|
||||||
|
|
||||||
@@ -131,8 +132,9 @@ impl Qwen3_5Attention {
|
|||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.contiguous()?;
|
.contiguous()?;
|
||||||
|
|
||||||
// 3. RoPE on q, k.
|
// 3. RoPE on q, k (cos/sin built once per forward by the model —
|
||||||
let (q, k) = self.rotary.apply(&q, &k, offset)?;
|
// interleaved M-RoPE for image tokens, plain for text).
|
||||||
|
let (q, k) = self.rotary.apply_cos_sin(&q, &k, cos, sin)?;
|
||||||
|
|
||||||
// 4. KV cache.
|
// 4. KV cache.
|
||||||
let (k, v) = self.kv_cache.append(&k, &v)?;
|
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||||
|
|||||||
@@ -314,6 +314,16 @@ pub struct Qwen3_5Model {
|
|||||||
embed_tokens: Embedding,
|
embed_tokens: Embedding,
|
||||||
layers: Vec<Qwen3_5DecoderLayer>,
|
layers: Vec<Qwen3_5DecoderLayer>,
|
||||||
norm: Qwen3_5RmsNorm,
|
norm: Qwen3_5RmsNorm,
|
||||||
|
/// Shared with every full-attention layer; the model uses it to
|
||||||
|
/// build the per-forward cos/sin (interleaved M-RoPE for image
|
||||||
|
/// tokens, plain for text) once, which the layers then apply.
|
||||||
|
rotary: Arc<RotaryEmbedding>,
|
||||||
|
/// `offset + rope_delta` is the text-axis position during decode.
|
||||||
|
/// 0 for text-only; set from `get_rope_index` during a vision
|
||||||
|
/// prefill (image tokens compress the position space, so text after
|
||||||
|
/// the image resumes from a smaller counter than the sequence
|
||||||
|
/// index). Reset in `clear_kv_cache`.
|
||||||
|
rope_delta: i64,
|
||||||
device: Device,
|
device: Device,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
@@ -365,6 +375,8 @@ impl Qwen3_5Model {
|
|||||||
embed_tokens,
|
embed_tokens,
|
||||||
layers,
|
layers,
|
||||||
norm,
|
norm,
|
||||||
|
rotary,
|
||||||
|
rope_delta: 0,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
})
|
})
|
||||||
@@ -378,6 +390,9 @@ impl Qwen3_5Model {
|
|||||||
for l in &mut self.layers {
|
for l in &mut self.layers {
|
||||||
l.clear_kv_cache();
|
l.clear_kv_cache();
|
||||||
}
|
}
|
||||||
|
// New request → no image-compressed position offset until the
|
||||||
|
// next vision prefill sets one.
|
||||||
|
self.rope_delta = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
@@ -435,16 +450,16 @@ impl Qwen3_5Model {
|
|||||||
) -> candle_core::Result<Tensor> {
|
) -> candle_core::Result<Tensor> {
|
||||||
let (b, l) = input.dims2()?;
|
let (b, l) = input.dims2()?;
|
||||||
let mut h = self.embed_tokens.forward(input)?;
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
// Splice image embeddings at `image_token_id` positions. The
|
|
||||||
// caller pre-expanded the prompt so every patch token in the
|
// Vision path: splice image embeddings at `image_token_id`
|
||||||
// image_embeds tensor has a matching position in `input`. We
|
// positions and build interleaved M-RoPE cos/sin so image tokens
|
||||||
// index_put the rows in place.
|
// carry their 14×14 grid coordinates. Text / decode skip the
|
||||||
if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
|
// device→host id copy entirely and take the plain-RoPE fast path
|
||||||
// Locate image-token positions in input_ids. Operate on
|
// — bit-for-bit the pre-M-RoPE behaviour when `rope_delta == 0`.
|
||||||
// CPU since the input ids are tiny (max ~10k entries
|
let (cos, sin) = if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
|
||||||
// including the patch expansion) and the comparison is
|
// Token ids on CPU — reused for the splice + position ids.
|
||||||
// not in the per-step hot path.
|
|
||||||
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
|
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
|
||||||
|
|
||||||
let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?);
|
let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?);
|
||||||
for (idx, id) in ids.iter().enumerate() {
|
for (idx, id) in ids.iter().enumerate() {
|
||||||
if *id == tok_id {
|
if *id == tok_id {
|
||||||
@@ -462,22 +477,22 @@ impl Qwen3_5Model {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
if !positions.is_empty() {
|
if !positions.is_empty() {
|
||||||
// Cast image_embeds to the LM's dtype so the splice
|
// Cast image_embeds to the LM's dtype, then splice the
|
||||||
// produces a uniform tensor for the decoder stack.
|
// contiguous `<|image_pad|>` runs in place.
|
||||||
let img = img.to_dtype(self.dtype)?;
|
let img = img.to_dtype(self.dtype)?;
|
||||||
// index_select would return the rows; we want to put.
|
|
||||||
// candle's slice_assign with explicit positions ranges
|
|
||||||
// doesn't exist; use scatter via index_select + an
|
|
||||||
// accumulator: build a `(B, L, hidden)` zero tensor,
|
|
||||||
// scatter the image rows in, then add to a masked
|
|
||||||
// version of `h`. Simpler approach: walk positions
|
|
||||||
// and use `slice_assign` for contiguous runs. Since
|
|
||||||
// image_pad runs are contiguous (template emits
|
|
||||||
// `<|vision_start|><|image_pad|>×N<|vision_end|>`),
|
|
||||||
// we group positions and assign per run.
|
|
||||||
h = splice_runs(&h, &img, &positions)?;
|
h = splice_runs(&h, &img, &positions)?;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
let (text, height, width, delta) = rope::get_rope_index(&ids, tok_id)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
|
||||||
|
self.rope_delta = delta;
|
||||||
|
let pos = rope::mrope_position_tensor(&text, &height, &width, &self.device)?;
|
||||||
|
self.rotary.mrope_cos_sin(&pos)?
|
||||||
|
} else {
|
||||||
|
let base = (offset as i64 + self.rope_delta).max(0) as usize;
|
||||||
|
self.rotary.plain_cos_sin(base, l)?
|
||||||
|
};
|
||||||
|
|
||||||
// Causal mask only needed for L > 1 prefill; full-attention
|
// Causal mask only needed for L > 1 prefill; full-attention
|
||||||
// layers consume it via broadcast_add. Linear-attention layers
|
// layers consume it via broadcast_add. Linear-attention layers
|
||||||
// ignore the mask.
|
// ignore the mask.
|
||||||
@@ -487,7 +502,7 @@ impl Qwen3_5Model {
|
|||||||
Some(self.causal_mask(b, l, offset)?)
|
Some(self.causal_mask(b, l, offset)?)
|
||||||
};
|
};
|
||||||
for layer in &mut self.layers {
|
for layer in &mut self.layers {
|
||||||
h = layer.forward(&h, causal.as_ref(), offset)?;
|
h = layer.forward(&h, causal.as_ref(), &cos, &sin)?;
|
||||||
}
|
}
|
||||||
self.norm.forward(&h)
|
self.norm.forward(&h)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -258,9 +258,6 @@ impl RotaryEmbedding {
|
|||||||
/// square with `grid_t = 1` (still image) and `grid_h = grid_w =
|
/// square with `grid_t = 1` (still image) and `grid_h = grid_w =
|
||||||
/// isqrt(run_len)` — 196 → 14×14. Dynamic resolution (#14) would thread
|
/// isqrt(run_len)` — 196 → 14×14. Dynamic resolution (#14) would thread
|
||||||
/// real per-image grids instead.
|
/// real per-image grids instead.
|
||||||
// Wired into the forward path in Stage 3 (single-GPU) / Stage 4 (TP);
|
|
||||||
// exercised by unit tests until then.
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub(crate) fn get_rope_index(input_ids: &[u32], image_token_id: u32) -> Result<MRopeIndex> {
|
pub(crate) fn get_rope_index(input_ids: &[u32], image_token_id: u32) -> Result<MRopeIndex> {
|
||||||
let n = input_ids.len();
|
let n = input_ids.len();
|
||||||
let mut text = Vec::with_capacity(n);
|
let mut text = Vec::with_capacity(n);
|
||||||
@@ -313,7 +310,6 @@ pub(crate) type MRopeIndex = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
|
|||||||
|
|
||||||
/// Build the `(3, seq)` position-id tensor consumed by
|
/// Build the `(3, seq)` position-id tensor consumed by
|
||||||
/// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors.
|
/// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors.
|
||||||
#[allow(dead_code)] // wired in Stage 3/4
|
|
||||||
pub(crate) fn mrope_position_tensor(
|
pub(crate) fn mrope_position_tensor(
|
||||||
text: &[i64],
|
text: &[i64],
|
||||||
height: &[i64],
|
height: &[i64],
|
||||||
|
|||||||
Reference in New Issue
Block a user