From 825bf4e9057951393a158c0aafd0a8fee8a38901 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Thu, 4 Jun 2026 18:46:27 +0300 Subject: [PATCH] =?UTF-8?q?feat(neuron):=20M-RoPE=20Stage=204=20=E2=80=94?= =?UTF-8?q?=20wire=20interleaved=20M-RoPE=20into=20the=20TP=20path?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirror Stage 3 into the tensor-parallel Qwen3.6 model: - TpQwen3_5Attention / DecoderLayer take (cos, sin) instead of a scalar offset and apply via apply_cos_sin. - TpQwen3_5Model gains the replicated rotary + rope_delta (reset in clear_kv_cache, settable). forward_inner builds the cos/sin once — interleaved M-RoPE from explicit position_ids (vision) or plain at offset+rope_delta (text/decode). forward() and forward_with_positions() delegate; the old single-shot forward_with_vision is gone. - prefill_with_images_chunked now computes get_rope_index over the whole prompt once, stores rope_delta on the base model, and slices the (3, prompt_len) position tensor per chunk — so every rank assigns image tokens their 14×14 grid coordinates and steps in lockstep (every chunk, text or image, carries the M-RoPE slice because the image shifts the surrounding text positions). Also build the position-id tensor as f32 directly (positions are small integers, exact in f32) to avoid an i64→f32 cast on the GPU. The TP forward is cuda-gated — CI CUDA type-check is the compile gate. Non-cuda build + clippy + full workspace tests green; rope math + the plain-RoPE-reduction invariant covered by unit tests. Completes the interleaved-M-RoPE work for the vision spatial misread. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../neuron/src/harness/arch/qwen3_5/rope.rs | 25 +-- crates/neuron/src/harness/tp/tp_qwen3_5.rs | 194 ++++++++++++------ 2 files changed, 136 insertions(+), 83 deletions(-) diff --git a/crates/neuron/src/harness/arch/qwen3_5/rope.rs b/crates/neuron/src/harness/arch/qwen3_5/rope.rs index 0c2c63f..01cfa5a 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/rope.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/rope.rs @@ -221,21 +221,6 @@ impl RotaryEmbedding { Ok((q_embed, k_embed)) } } - - /// Text/decode convenience: build plain cos/sin for a scalar offset - /// and apply in one call. The current call sites use this; Stages 3–4 - /// move cos/sin construction up into the model forward (computed once - /// per forward) and call [`Self::apply_cos_sin`] directly. - pub fn apply( - &self, - q: &Tensor, - k: &Tensor, - offset: usize, - ) -> candle_core::Result<(Tensor, Tensor)> { - let (_, _, seq_len, _) = q.dims4()?; - let (cos, sin) = self.plain_cos_sin(offset, seq_len)?; - self.apply_cos_sin(q, k, &cos, &sin) - } } /// Compute interleaved-M-RoPE 3D position ids for a full prompt that may @@ -310,6 +295,10 @@ pub(crate) type MRopeIndex = (Vec, Vec, Vec, i64); /// Build the `(3, seq)` position-id tensor consumed by /// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors. +/// +/// Built directly as **f32** (positions are small integers, exact in +/// f32 well past any context length): the freqs matmul needs float +/// anyway, and this avoids an i64 tensor / i64→f32 cast on the GPU. pub(crate) fn mrope_position_tensor( text: &[i64], height: &[i64], @@ -318,9 +307,9 @@ pub(crate) fn mrope_position_tensor( ) -> candle_core::Result { let seq = text.len(); let mut flat = Vec::with_capacity(3 * seq); - flat.extend_from_slice(text); - flat.extend_from_slice(height); - flat.extend_from_slice(width); + flat.extend(text.iter().map(|&x| x as f32)); + flat.extend(height.iter().map(|&x| x as f32)); + flat.extend(width.iter().map(|&x| x as f32)); Tensor::from_vec(flat, (3, seq), dev) } diff --git a/crates/neuron/src/harness/tp/tp_qwen3_5.rs b/crates/neuron/src/harness/tp/tp_qwen3_5.rs index d08e475..afe2713 100644 --- a/crates/neuron/src/harness/tp/tp_qwen3_5.rs +++ b/crates/neuron/src/harness/tp/tp_qwen3_5.rs @@ -526,7 +526,8 @@ impl TpQwen3_5Attention { &mut self, x: &Tensor, attn_mask: Option<&Tensor>, - offset: usize, + cos: &Tensor, + sin: &Tensor, ) -> candle_core::Result { let (b, l, _) = x.dims3()?; @@ -559,7 +560,7 @@ impl TpQwen3_5Attention { .transpose(1, 2)? .contiguous()?; - let (q, k) = self.rotary.apply(&q, &k, offset)?; + let (q, k) = self.rotary.apply_cos_sin(&q, &k, cos, sin)?; let (k, v) = self.kv_cache.append(&k, &v)?; let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?; @@ -807,11 +808,12 @@ impl TpQwen3_5DecoderLayer { &mut self, x: &Tensor, attn_mask: Option<&Tensor>, - offset: usize, + cos: &Tensor, + sin: &Tensor, ) -> candle_core::Result { let h = self.input_layernorm.forward(x)?; let attn_out = match &mut self.attention { - TpAttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?, + TpAttentionKind::Full(attn) => attn.forward(&h, attn_mask, cos, sin)?, TpAttentionKind::Linear(net) => net.forward(&h)?, }; let x = (x + attn_out)?; @@ -834,6 +836,15 @@ pub struct TpQwen3_5Model { embed_tokens: Embedding, layers: Vec, norm: Qwen3_5RmsNorm, + /// Replicated rotary, shared with every full-attention layer. The + /// model builds the per-forward cos/sin (interleaved M-RoPE for image + /// tokens, plain for text) once and the layers apply it. Identical on + /// every rank, so per-rank position ids stay consistent. + rotary: Arc, + /// `offset + rope_delta` is the text-axis decode position; set from + /// `get_rope_index` during a vision prefill, reset in `clear_kv_cache`. + /// See `Qwen3_5Model::rope_delta`. + rope_delta: i64, device: Device, dtype: DType, } @@ -900,6 +911,8 @@ impl TpQwen3_5Model { embed_tokens, layers, norm, + rotary, + rope_delta: 0, device, dtype, }) @@ -956,6 +969,8 @@ impl TpQwen3_5Model { embed_tokens, layers, norm, + rotary, + rope_delta: 0, device, dtype, }) @@ -969,6 +984,14 @@ impl TpQwen3_5Model { for l in &mut self.layers { l.clear_kv_cache(); } + self.rope_delta = 0; + } + + /// Set the decode `rope_delta` computed by `get_rope_index` during a + /// vision prefill, so decode after the image resumes text positions + /// from the image-compressed counter. + pub fn set_rope_delta(&mut self, delta: i64) { + self.rope_delta = delta; } fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result { @@ -980,64 +1003,80 @@ impl TpQwen3_5Model { } pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result { - let (b, l) = input.dims2()?; - let mut h = self.embed_tokens.forward(input)?; - let causal = if l == 1 { - None - } else { - Some(self.causal_mask(b, l, offset)?) - }; - for layer in &mut self.layers { - h = layer.forward(&h, causal.as_ref(), offset)?; - } - self.norm.forward(&h) + self.forward_inner(input, offset, None, None, None) } - /// Forward with image-embedding splice (TP, replicated tower). - /// - /// Mirrors the single-GPU `Qwen3_5Model::forward_inner` splice: - /// embed locally, replace the rows at `image_token_id` positions - /// with the image patch embeddings, then run the sharded decoder - /// stack. The TP invariant is that every rank holds an identical - /// hidden state (only the attention/MLP matmuls shard, with a - /// trailing `AllReduce`). That holds here because every rank - /// encodes the *same* pixels through its *replicated* vision tower - /// and so produces identical `image_embeds` — no broadcast needed. - pub fn forward_with_vision( + /// Forward for a vision-prefill chunk: optional image-embedding + /// splice plus explicit interleaved-M-RoPE `position_ids` (the + /// chunk's slice of the full prompt's 3D positions). Used by + /// `TpQwen3_5ForCausalLM::prefill_with_images_chunked`, which + /// computes the positions once over the whole prompt and slices them + /// per chunk so every rank steps in lockstep. + pub fn forward_with_positions( &mut self, input: &Tensor, offset: usize, - image_embeds: &Tensor, - image_token_id: u32, + position_ids: &Tensor, + image_embeds: Option<&Tensor>, + image_token_id: Option, + ) -> candle_core::Result { + self.forward_inner( + input, + offset, + image_embeds, + image_token_id, + Some(position_ids), + ) + } + + /// Shared forward. Splices image embeddings at `image_token_id` + /// positions when present, then builds the rotary cos/sin — from the + /// explicit `position_ids` (interleaved M-RoPE, vision) when given, + /// else plain positions at `offset + rope_delta` (text / decode) — + /// and runs the sharded decoder stack. The TP replicated-hidden-state + /// invariant holds because every rank encodes the same pixels and + /// computes the same positions. + fn forward_inner( + &mut self, + input: &Tensor, + offset: usize, + image_embeds: Option<&Tensor>, + image_token_id: Option, + position_ids: Option<&Tensor>, ) -> candle_core::Result { let (b, l) = input.dims2()?; let mut h = self.embed_tokens.forward(input)?; - // Locate the image-token positions in the (pre-expanded) input - // ids and splice the patch rows in. Same CPU-side scan as the - // single-GPU path; the count must match the patch dimension or - // the prompt expansion is wrong. - let ids: Vec = input.flatten_all()?.to_vec1()?; - let mut positions: Vec = Vec::with_capacity(image_embeds.dim(0)?); - for (idx, id) in ids.iter().enumerate() { - if *id == image_token_id { - positions.push(idx as u32); + if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) { + let ids: Vec = input.flatten_all()?.to_vec1()?; + let mut positions: Vec = Vec::with_capacity(img.dim(0)?); + for (idx, id) in ids.iter().enumerate() { + if *id == tok_id { + positions.push(idx as u32); + } + } + let n_img_tokens = img.dim(0)?; + if positions.len() != n_img_tokens { + candle_core::bail!( + "TP forward: chunk has {} image-token positions but image_embeds carries \ + {} tokens — patch-count expansion / chunk slicing mismatch", + positions.len(), + n_img_tokens, + ); + } + if !positions.is_empty() { + let img = img.to_dtype(self.dtype)?; + h = splice_runs(&h, &img, &positions)?; } } - let n_img_tokens = image_embeds.dim(0)?; - if positions.len() != n_img_tokens { - candle_core::bail!( - "TP forward_with_vision: prompt has {} image-token positions but \ - image_embeds carries {} tokens — ensure the per-image patch-count \ - expansion has been applied", - positions.len(), - n_img_tokens, - ); - } - if !positions.is_empty() { - let img = image_embeds.to_dtype(self.dtype)?; - h = splice_runs(&h, &img, &positions)?; - } + + let (cos, sin) = match position_ids { + Some(pos) => self.rotary.mrope_cos_sin(pos)?, + None => { + let base = (offset as i64 + self.rope_delta).max(0) as usize; + self.rotary.plain_cos_sin(base, l)? + } + }; let causal = if l == 1 { None @@ -1045,7 +1084,7 @@ impl TpQwen3_5Model { Some(self.causal_mask(b, l, offset)?) }; for layer in &mut self.layers { - h = layer.forward(&h, causal.as_ref(), offset)?; + h = layer.forward(&h, causal.as_ref(), &cos, &sin)?; } self.norm.forward(&h) } @@ -1174,21 +1213,25 @@ impl TpQwen3_5ForCausalLM { hidden.i((.., l - 1.., ..))?.apply(&self.lm_head) } - /// Forward with image-embedding splice (TP). Mirrors `forward` but - /// routes through `TpQwen3_5Model::forward_with_vision` so the - /// per-rank input embeddings get the image patches spliced in at - /// `image_token_id` positions before the sharded decoder stack. - pub fn forward_with_vision( + /// Forward for a vision-prefill chunk (optional image splice + + /// explicit interleaved-M-RoPE `position_ids`). Mirrors `forward` + /// but routes through `TpQwen3_5Model::forward_with_positions`. + pub fn forward_with_positions( &mut self, input: &Tensor, offset: usize, - image_embeds: &Tensor, - image_token_id: u32, + position_ids: &Tensor, + image_embeds: Option<&Tensor>, + image_token_id: Option, ) -> candle_core::Result { let (_, l) = input.dims2()?; - let hidden = self - .base - .forward_with_vision(input, offset, image_embeds, image_token_id)?; + let hidden = self.base.forward_with_positions( + input, + offset, + position_ids, + image_embeds, + image_token_id, + )?; hidden.i((.., l - 1.., ..))?.apply(&self.lm_head) } @@ -1245,6 +1288,21 @@ impl TpQwen3_5ForCausalLM { let device = self.device().clone(); let image_embeds = self.encode_images_concat(image_pixels)?; + // Interleaved-M-RoPE 3D position ids for the whole prompt, + // computed once and sliced per chunk so every rank assigns image + // tokens their 14×14 grid coordinates (and text after the image + // resumes from the compressed counter). `rope_delta` is stored on + // the base model for the decode that follows this prefill. Every + // chunk — text or image — uses the M-RoPE slice, because the image + // shifts the positions of the text around it. + let (text, height, width, delta) = + crate::harness::arch::qwen3_5::rope::get_rope_index(tokens, image_token_id) + .map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?; + self.base.set_rope_delta(delta); + let full_pos = crate::harness::arch::qwen3_5::rope::mrope_position_tensor( + &text, &height, &width, &device, + )?; + let mut last_logits: Option = None; // Rows of `image_embeds` already spliced by earlier chunks. The // `<|image_pad|>` run is contiguous, so chunks consume embedding @@ -1255,16 +1313,22 @@ impl TpQwen3_5ForCausalLM { let end = (start + chunk_size).min(tokens.len()); let chunk = &tokens[start..end]; let input = Tensor::new(chunk, &device)?.unsqueeze(0)?; + let pos_slice = full_pos.narrow(1, start, end - start)?; let n_here = chunk.iter().filter(|&&t| t == image_token_id).count(); let logits = if n_here == 0 { - // Pure-text chunk — same forward the text prefill runs. - self.forward(&input, base_offset + start)? + self.forward_with_positions(&input, base_offset + start, &pos_slice, None, None)? } else { // Splice the next `n_here` patch rows at this chunk's // local image-pad positions. let rows = image_embeds.narrow(0, img_off, n_here)?; img_off += n_here; - self.forward_with_vision(&input, base_offset + start, &rows, image_token_id)? + self.forward_with_positions( + &input, + base_offset + start, + &pos_slice, + Some(&rows), + Some(image_token_id), + )? }; last_logits = Some(logits); start = end;