feat(neuron): M-RoPE Stage 4 — wire interleaved M-RoPE into the TP path
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / CUDA type-check (push) Successful in 31s
CI / Format (push) Successful in 42s
build-prerelease / Build cortex binary (push) Successful in 5m9s
build-prerelease / Build neuron-blackwell (push) Successful in 6m4s
build-prerelease / Package cortex RPM (push) Successful in 1m32s
CI / Test (push) Successful in 7m19s
build-prerelease / Build neuron-ampere (push) Successful in 8m40s
build-prerelease / Build neuron-ada (push) Successful in 5m17s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m0s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m1s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m53s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m14s
CI / Clippy (push) Successful in 2m29s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped

Mirror Stage 3 into the tensor-parallel Qwen3.6 model:

- TpQwen3_5Attention / DecoderLayer take (cos, sin) instead of a scalar
  offset and apply via apply_cos_sin.
- TpQwen3_5Model gains the replicated rotary + rope_delta (reset in
  clear_kv_cache, settable). forward_inner builds the cos/sin once —
  interleaved M-RoPE from explicit position_ids (vision) or plain at
  offset+rope_delta (text/decode). forward() and forward_with_positions()
  delegate; the old single-shot forward_with_vision is gone.
- prefill_with_images_chunked now computes get_rope_index over the whole
  prompt once, stores rope_delta on the base model, and slices the
  (3, prompt_len) position tensor per chunk — so every rank assigns image
  tokens their 14×14 grid coordinates and steps in lockstep (every chunk,
  text or image, carries the M-RoPE slice because the image shifts the
  surrounding text positions).

Also build the position-id tensor as f32 directly (positions are small
integers, exact in f32) to avoid an i64→f32 cast on the GPU.

The TP forward is cuda-gated — CI CUDA type-check is the compile gate.
Non-cuda build + clippy + full workspace tests green; rope math + the
plain-RoPE-reduction invariant covered by unit tests.

Completes the interleaved-M-RoPE work for the vision spatial misread.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-04 18:46:27 +03:00
parent 4c12c7e2f0
commit 825bf4e905
2 changed files with 136 additions and 83 deletions

View File

@@ -221,21 +221,6 @@ impl RotaryEmbedding {
Ok((q_embed, k_embed))
}
}
/// Text/decode convenience: build plain cos/sin for a scalar offset
/// and apply in one call. The current call sites use this; Stages 34
/// move cos/sin construction up into the model forward (computed once
/// per forward) and call [`Self::apply_cos_sin`] directly.
pub fn apply(
&self,
q: &Tensor,
k: &Tensor,
offset: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let (_, _, seq_len, _) = q.dims4()?;
let (cos, sin) = self.plain_cos_sin(offset, seq_len)?;
self.apply_cos_sin(q, k, &cos, &sin)
}
}
/// Compute interleaved-M-RoPE 3D position ids for a full prompt that may
@@ -310,6 +295,10 @@ pub(crate) type MRopeIndex = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
/// Build the `(3, seq)` position-id tensor consumed by
/// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors.
///
/// Built directly as **f32** (positions are small integers, exact in
/// f32 well past any context length): the freqs matmul needs float
/// anyway, and this avoids an i64 tensor / i64→f32 cast on the GPU.
pub(crate) fn mrope_position_tensor(
text: &[i64],
height: &[i64],
@@ -318,9 +307,9 @@ pub(crate) fn mrope_position_tensor(
) -> candle_core::Result<Tensor> {
let seq = text.len();
let mut flat = Vec::with_capacity(3 * seq);
flat.extend_from_slice(text);
flat.extend_from_slice(height);
flat.extend_from_slice(width);
flat.extend(text.iter().map(|&x| x as f32));
flat.extend(height.iter().map(|&x| x as f32));
flat.extend(width.iter().map(|&x| x as f32));
Tensor::from_vec(flat, (3, seq), dev)
}

View File

@@ -526,7 +526,8 @@ impl TpQwen3_5Attention {
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
cos: &Tensor,
sin: &Tensor,
) -> candle_core::Result<Tensor> {
let (b, l, _) = x.dims3()?;
@@ -559,7 +560,7 @@ impl TpQwen3_5Attention {
.transpose(1, 2)?
.contiguous()?;
let (q, k) = self.rotary.apply(&q, &k, offset)?;
let (q, k) = self.rotary.apply_cos_sin(&q, &k, cos, sin)?;
let (k, v) = self.kv_cache.append(&k, &v)?;
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
@@ -807,11 +808,12 @@ impl TpQwen3_5DecoderLayer {
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
cos: &Tensor,
sin: &Tensor,
) -> candle_core::Result<Tensor> {
let h = self.input_layernorm.forward(x)?;
let attn_out = match &mut self.attention {
TpAttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
TpAttentionKind::Full(attn) => attn.forward(&h, attn_mask, cos, sin)?,
TpAttentionKind::Linear(net) => net.forward(&h)?,
};
let x = (x + attn_out)?;
@@ -834,6 +836,15 @@ pub struct TpQwen3_5Model {
embed_tokens: Embedding,
layers: Vec<TpQwen3_5DecoderLayer>,
norm: Qwen3_5RmsNorm,
/// Replicated rotary, shared with every full-attention layer. The
/// model builds the per-forward cos/sin (interleaved M-RoPE for image
/// tokens, plain for text) once and the layers apply it. Identical on
/// every rank, so per-rank position ids stay consistent.
rotary: Arc<RotaryEmbedding>,
/// `offset + rope_delta` is the text-axis decode position; set from
/// `get_rope_index` during a vision prefill, reset in `clear_kv_cache`.
/// See `Qwen3_5Model::rope_delta`.
rope_delta: i64,
device: Device,
dtype: DType,
}
@@ -900,6 +911,8 @@ impl TpQwen3_5Model {
embed_tokens,
layers,
norm,
rotary,
rope_delta: 0,
device,
dtype,
})
@@ -956,6 +969,8 @@ impl TpQwen3_5Model {
embed_tokens,
layers,
norm,
rotary,
rope_delta: 0,
device,
dtype,
})
@@ -969,6 +984,14 @@ impl TpQwen3_5Model {
for l in &mut self.layers {
l.clear_kv_cache();
}
self.rope_delta = 0;
}
/// Set the decode `rope_delta` computed by `get_rope_index` during a
/// vision prefill, so decode after the image resumes text positions
/// from the image-compressed counter.
pub fn set_rope_delta(&mut self, delta: i64) {
self.rope_delta = delta;
}
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
@@ -980,64 +1003,80 @@ impl TpQwen3_5Model {
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
}
self.norm.forward(&h)
self.forward_inner(input, offset, None, None, None)
}
/// Forward with image-embedding splice (TP, replicated tower).
///
/// Mirrors the single-GPU `Qwen3_5Model::forward_inner` splice:
/// embed locally, replace the rows at `image_token_id` positions
/// with the image patch embeddings, then run the sharded decoder
/// stack. The TP invariant is that every rank holds an identical
/// hidden state (only the attention/MLP matmuls shard, with a
/// trailing `AllReduce`). That holds here because every rank
/// encodes the *same* pixels through its *replicated* vision tower
/// and so produces identical `image_embeds` — no broadcast needed.
pub fn forward_with_vision(
/// Forward for a vision-prefill chunk: optional image-embedding
/// splice plus explicit interleaved-M-RoPE `position_ids` (the
/// chunk's slice of the full prompt's 3D positions). Used by
/// `TpQwen3_5ForCausalLM::prefill_with_images_chunked`, which
/// computes the positions once over the whole prompt and slices them
/// per chunk so every rank steps in lockstep.
pub fn forward_with_positions(
&mut self,
input: &Tensor,
offset: usize,
image_embeds: &Tensor,
image_token_id: u32,
position_ids: &Tensor,
image_embeds: Option<&Tensor>,
image_token_id: Option<u32>,
) -> candle_core::Result<Tensor> {
self.forward_inner(
input,
offset,
image_embeds,
image_token_id,
Some(position_ids),
)
}
/// Shared forward. Splices image embeddings at `image_token_id`
/// positions when present, then builds the rotary cos/sin — from the
/// explicit `position_ids` (interleaved M-RoPE, vision) when given,
/// else plain positions at `offset + rope_delta` (text / decode) —
/// and runs the sharded decoder stack. The TP replicated-hidden-state
/// invariant holds because every rank encodes the same pixels and
/// computes the same positions.
fn forward_inner(
&mut self,
input: &Tensor,
offset: usize,
image_embeds: Option<&Tensor>,
image_token_id: Option<u32>,
position_ids: Option<&Tensor>,
) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
// Locate the image-token positions in the (pre-expanded) input
// ids and splice the patch rows in. Same CPU-side scan as the
// single-GPU path; the count must match the patch dimension or
// the prompt expansion is wrong.
if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
let mut positions: Vec<u32> = Vec::with_capacity(image_embeds.dim(0)?);
let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?);
for (idx, id) in ids.iter().enumerate() {
if *id == image_token_id {
if *id == tok_id {
positions.push(idx as u32);
}
}
let n_img_tokens = image_embeds.dim(0)?;
let n_img_tokens = img.dim(0)?;
if positions.len() != n_img_tokens {
candle_core::bail!(
"TP forward_with_vision: prompt has {} image-token positions but \
image_embeds carries {} tokens — ensure the per-image patch-count \
expansion has been applied",
"TP forward: chunk has {} image-token positions but image_embeds carries \
{} tokens — patch-count expansion / chunk slicing mismatch",
positions.len(),
n_img_tokens,
);
}
if !positions.is_empty() {
let img = image_embeds.to_dtype(self.dtype)?;
let img = img.to_dtype(self.dtype)?;
h = splice_runs(&h, &img, &positions)?;
}
}
let (cos, sin) = match position_ids {
Some(pos) => self.rotary.mrope_cos_sin(pos)?,
None => {
let base = (offset as i64 + self.rope_delta).max(0) as usize;
self.rotary.plain_cos_sin(base, l)?
}
};
let causal = if l == 1 {
None
@@ -1045,7 +1084,7 @@ impl TpQwen3_5Model {
Some(self.causal_mask(b, l, offset)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
h = layer.forward(&h, causal.as_ref(), &cos, &sin)?;
}
self.norm.forward(&h)
}
@@ -1174,21 +1213,25 @@ impl TpQwen3_5ForCausalLM {
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
/// Forward with image-embedding splice (TP). Mirrors `forward` but
/// routes through `TpQwen3_5Model::forward_with_vision` so the
/// per-rank input embeddings get the image patches spliced in at
/// `image_token_id` positions before the sharded decoder stack.
pub fn forward_with_vision(
/// Forward for a vision-prefill chunk (optional image splice +
/// explicit interleaved-M-RoPE `position_ids`). Mirrors `forward`
/// but routes through `TpQwen3_5Model::forward_with_positions`.
pub fn forward_with_positions(
&mut self,
input: &Tensor,
offset: usize,
image_embeds: &Tensor,
image_token_id: u32,
position_ids: &Tensor,
image_embeds: Option<&Tensor>,
image_token_id: Option<u32>,
) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?;
let hidden = self
.base
.forward_with_vision(input, offset, image_embeds, image_token_id)?;
let hidden = self.base.forward_with_positions(
input,
offset,
position_ids,
image_embeds,
image_token_id,
)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
@@ -1245,6 +1288,21 @@ impl TpQwen3_5ForCausalLM {
let device = self.device().clone();
let image_embeds = self.encode_images_concat(image_pixels)?;
// Interleaved-M-RoPE 3D position ids for the whole prompt,
// computed once and sliced per chunk so every rank assigns image
// tokens their 14×14 grid coordinates (and text after the image
// resumes from the compressed counter). `rope_delta` is stored on
// the base model for the decode that follows this prefill. Every
// chunk — text or image — uses the M-RoPE slice, because the image
// shifts the positions of the text around it.
let (text, height, width, delta) =
crate::harness::arch::qwen3_5::rope::get_rope_index(tokens, image_token_id)
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
self.base.set_rope_delta(delta);
let full_pos = crate::harness::arch::qwen3_5::rope::mrope_position_tensor(
&text, &height, &width, &device,
)?;
let mut last_logits: Option<Tensor> = None;
// Rows of `image_embeds` already spliced by earlier chunks. The
// `<|image_pad|>` run is contiguous, so chunks consume embedding
@@ -1255,16 +1313,22 @@ impl TpQwen3_5ForCausalLM {
let end = (start + chunk_size).min(tokens.len());
let chunk = &tokens[start..end];
let input = Tensor::new(chunk, &device)?.unsqueeze(0)?;
let pos_slice = full_pos.narrow(1, start, end - start)?;
let n_here = chunk.iter().filter(|&&t| t == image_token_id).count();
let logits = if n_here == 0 {
// Pure-text chunk — same forward the text prefill runs.
self.forward(&input, base_offset + start)?
self.forward_with_positions(&input, base_offset + start, &pos_slice, None, None)?
} else {
// Splice the next `n_here` patch rows at this chunk's
// local image-pad positions.
let rows = image_embeds.narrow(0, img_off, n_here)?;
img_off += n_here;
self.forward_with_vision(&input, base_offset + start, &rows, image_token_id)?
self.forward_with_positions(
&input,
base_offset + start,
&pos_slice,
Some(&rows),
Some(image_token_id),
)?
};
last_logits = Some(logits);
start = end;