diff --git a/crates/neuron/src/harness/arch/qwen3_5/decoder.rs b/crates/neuron/src/harness/arch/qwen3_5/decoder.rs index df10bb4..e553d7a 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/decoder.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/decoder.rs @@ -93,12 +93,13 @@ impl Qwen3_5DecoderLayer { &mut self, x: &Tensor, attn_mask: Option<&Tensor>, - offset: usize, + cos: &Tensor, + sin: &Tensor, ) -> candle_core::Result { let h = self.input_layernorm.forward(x)?; let attn_out = match &mut self.attention { - AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?, - // Linear attention ignores attn_mask + offset; its causal + AttentionKind::Full(attn) => attn.forward(&h, attn_mask, cos, sin)?, + // Linear attention ignores attn_mask + rope; its causal // structure is baked into the recurrent state lifecycle. AttentionKind::Linear(net) => net.forward(&h)?, }; diff --git a/crates/neuron/src/harness/arch/qwen3_5/full_attn.rs b/crates/neuron/src/harness/arch/qwen3_5/full_attn.rs index 28b378b..91119bc 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/full_attn.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/full_attn.rs @@ -96,7 +96,8 @@ impl Qwen3_5Attention { &mut self, x: &Tensor, attn_mask: Option<&Tensor>, - offset: usize, + cos: &Tensor, + sin: &Tensor, ) -> candle_core::Result { let (b, l, _) = x.dims3()?; @@ -131,8 +132,9 @@ impl Qwen3_5Attention { .transpose(1, 2)? .contiguous()?; - // 3. RoPE on q, k. - let (q, k) = self.rotary.apply(&q, &k, offset)?; + // 3. RoPE on q, k (cos/sin built once per forward by the model — + // interleaved M-RoPE for image tokens, plain for text). + let (q, k) = self.rotary.apply_cos_sin(&q, &k, cos, sin)?; // 4. KV cache. let (k, v) = self.kv_cache.append(&k, &v)?; diff --git a/crates/neuron/src/harness/arch/qwen3_5/mod.rs b/crates/neuron/src/harness/arch/qwen3_5/mod.rs index 4a58d2f..95b5b64 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/mod.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/mod.rs @@ -314,6 +314,16 @@ pub struct Qwen3_5Model { embed_tokens: Embedding, layers: Vec, norm: Qwen3_5RmsNorm, + /// Shared with every full-attention layer; the model uses it to + /// build the per-forward cos/sin (interleaved M-RoPE for image + /// tokens, plain for text) once, which the layers then apply. + rotary: Arc, + /// `offset + rope_delta` is the text-axis position during decode. + /// 0 for text-only; set from `get_rope_index` during a vision + /// prefill (image tokens compress the position space, so text after + /// the image resumes from a smaller counter than the sequence + /// index). Reset in `clear_kv_cache`. + rope_delta: i64, device: Device, dtype: DType, } @@ -365,6 +375,8 @@ impl Qwen3_5Model { embed_tokens, layers, norm, + rotary, + rope_delta: 0, device, dtype, }) @@ -378,6 +390,9 @@ impl Qwen3_5Model { for l in &mut self.layers { l.clear_kv_cache(); } + // New request → no image-compressed position offset until the + // next vision prefill sets one. + self.rope_delta = 0; } fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result { @@ -435,16 +450,16 @@ impl Qwen3_5Model { ) -> candle_core::Result { let (b, l) = input.dims2()?; let mut h = self.embed_tokens.forward(input)?; - // Splice image embeddings at `image_token_id` positions. The - // caller pre-expanded the prompt so every patch token in the - // image_embeds tensor has a matching position in `input`. We - // index_put the rows in place. - if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) { - // Locate image-token positions in input_ids. Operate on - // CPU since the input ids are tiny (max ~10k entries - // including the patch expansion) and the comparison is - // not in the per-step hot path. + + // Vision path: splice image embeddings at `image_token_id` + // positions and build interleaved M-RoPE cos/sin so image tokens + // carry their 14×14 grid coordinates. Text / decode skip the + // device→host id copy entirely and take the plain-RoPE fast path + // — bit-for-bit the pre-M-RoPE behaviour when `rope_delta == 0`. + let (cos, sin) = if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) { + // Token ids on CPU — reused for the splice + position ids. let ids: Vec = input.flatten_all()?.to_vec1()?; + let mut positions: Vec = Vec::with_capacity(img.dim(0)?); for (idx, id) in ids.iter().enumerate() { if *id == tok_id { @@ -462,22 +477,22 @@ impl Qwen3_5Model { ); } if !positions.is_empty() { - // Cast image_embeds to the LM's dtype so the splice - // produces a uniform tensor for the decoder stack. + // Cast image_embeds to the LM's dtype, then splice the + // contiguous `<|image_pad|>` runs in place. let img = img.to_dtype(self.dtype)?; - // index_select would return the rows; we want to put. - // candle's slice_assign with explicit positions ranges - // doesn't exist; use scatter via index_select + an - // accumulator: build a `(B, L, hidden)` zero tensor, - // scatter the image rows in, then add to a masked - // version of `h`. Simpler approach: walk positions - // and use `slice_assign` for contiguous runs. Since - // image_pad runs are contiguous (template emits - // `<|vision_start|><|image_pad|>×N<|vision_end|>`), - // we group positions and assign per run. h = splice_runs(&h, &img, &positions)?; } - } + + let (text, height, width, delta) = rope::get_rope_index(&ids, tok_id) + .map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?; + self.rope_delta = delta; + let pos = rope::mrope_position_tensor(&text, &height, &width, &self.device)?; + self.rotary.mrope_cos_sin(&pos)? + } else { + let base = (offset as i64 + self.rope_delta).max(0) as usize; + self.rotary.plain_cos_sin(base, l)? + }; + // Causal mask only needed for L > 1 prefill; full-attention // layers consume it via broadcast_add. Linear-attention layers // ignore the mask. @@ -487,7 +502,7 @@ impl Qwen3_5Model { Some(self.causal_mask(b, l, offset)?) }; for layer in &mut self.layers { - h = layer.forward(&h, causal.as_ref(), offset)?; + h = layer.forward(&h, causal.as_ref(), &cos, &sin)?; } self.norm.forward(&h) } diff --git a/crates/neuron/src/harness/arch/qwen3_5/rope.rs b/crates/neuron/src/harness/arch/qwen3_5/rope.rs index cb4fc4d..0c2c63f 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/rope.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/rope.rs @@ -258,9 +258,6 @@ impl RotaryEmbedding { /// square with `grid_t = 1` (still image) and `grid_h = grid_w = /// isqrt(run_len)` — 196 → 14×14. Dynamic resolution (#14) would thread /// real per-image grids instead. -// Wired into the forward path in Stage 3 (single-GPU) / Stage 4 (TP); -// exercised by unit tests until then. -#[allow(dead_code)] pub(crate) fn get_rope_index(input_ids: &[u32], image_token_id: u32) -> Result { let n = input_ids.len(); let mut text = Vec::with_capacity(n); @@ -313,7 +310,6 @@ pub(crate) type MRopeIndex = (Vec, Vec, Vec, i64); /// Build the `(3, seq)` position-id tensor consumed by /// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors. -#[allow(dead_code)] // wired in Stage 3/4 pub(crate) fn mrope_position_tensor( text: &[i64], height: &[i64],