feat(neuron): M-RoPE Stage 4 — wire interleaved M-RoPE into the TP path
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / CUDA type-check (push) Successful in 31s
CI / Format (push) Successful in 42s
build-prerelease / Build cortex binary (push) Successful in 5m9s
build-prerelease / Build neuron-blackwell (push) Successful in 6m4s
build-prerelease / Package cortex RPM (push) Successful in 1m32s
CI / Test (push) Successful in 7m19s
build-prerelease / Build neuron-ampere (push) Successful in 8m40s
build-prerelease / Build neuron-ada (push) Successful in 5m17s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m0s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m1s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m53s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m14s
CI / Clippy (push) Successful in 2m29s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / CUDA type-check (push) Successful in 31s
CI / Format (push) Successful in 42s
build-prerelease / Build cortex binary (push) Successful in 5m9s
build-prerelease / Build neuron-blackwell (push) Successful in 6m4s
build-prerelease / Package cortex RPM (push) Successful in 1m32s
CI / Test (push) Successful in 7m19s
build-prerelease / Build neuron-ampere (push) Successful in 8m40s
build-prerelease / Build neuron-ada (push) Successful in 5m17s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m0s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m1s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m53s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m14s
CI / Clippy (push) Successful in 2m29s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
Mirror Stage 3 into the tensor-parallel Qwen3.6 model: - TpQwen3_5Attention / DecoderLayer take (cos, sin) instead of a scalar offset and apply via apply_cos_sin. - TpQwen3_5Model gains the replicated rotary + rope_delta (reset in clear_kv_cache, settable). forward_inner builds the cos/sin once — interleaved M-RoPE from explicit position_ids (vision) or plain at offset+rope_delta (text/decode). forward() and forward_with_positions() delegate; the old single-shot forward_with_vision is gone. - prefill_with_images_chunked now computes get_rope_index over the whole prompt once, stores rope_delta on the base model, and slices the (3, prompt_len) position tensor per chunk — so every rank assigns image tokens their 14×14 grid coordinates and steps in lockstep (every chunk, text or image, carries the M-RoPE slice because the image shifts the surrounding text positions). Also build the position-id tensor as f32 directly (positions are small integers, exact in f32) to avoid an i64→f32 cast on the GPU. The TP forward is cuda-gated — CI CUDA type-check is the compile gate. Non-cuda build + clippy + full workspace tests green; rope math + the plain-RoPE-reduction invariant covered by unit tests. Completes the interleaved-M-RoPE work for the vision spatial misread. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -221,21 +221,6 @@ impl RotaryEmbedding {
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
/// Text/decode convenience: build plain cos/sin for a scalar offset
|
||||
/// and apply in one call. The current call sites use this; Stages 3–4
|
||||
/// move cos/sin construction up into the model forward (computed once
|
||||
/// per forward) and call [`Self::apply_cos_sin`] directly.
|
||||
pub fn apply(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
offset: usize,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let (_, _, seq_len, _) = q.dims4()?;
|
||||
let (cos, sin) = self.plain_cos_sin(offset, seq_len)?;
|
||||
self.apply_cos_sin(q, k, &cos, &sin)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute interleaved-M-RoPE 3D position ids for a full prompt that may
|
||||
@@ -310,6 +295,10 @@ pub(crate) type MRopeIndex = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
|
||||
|
||||
/// Build the `(3, seq)` position-id tensor consumed by
|
||||
/// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors.
|
||||
///
|
||||
/// Built directly as **f32** (positions are small integers, exact in
|
||||
/// f32 well past any context length): the freqs matmul needs float
|
||||
/// anyway, and this avoids an i64 tensor / i64→f32 cast on the GPU.
|
||||
pub(crate) fn mrope_position_tensor(
|
||||
text: &[i64],
|
||||
height: &[i64],
|
||||
@@ -318,9 +307,9 @@ pub(crate) fn mrope_position_tensor(
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let seq = text.len();
|
||||
let mut flat = Vec::with_capacity(3 * seq);
|
||||
flat.extend_from_slice(text);
|
||||
flat.extend_from_slice(height);
|
||||
flat.extend_from_slice(width);
|
||||
flat.extend(text.iter().map(|&x| x as f32));
|
||||
flat.extend(height.iter().map(|&x| x as f32));
|
||||
flat.extend(width.iter().map(|&x| x as f32));
|
||||
Tensor::from_vec(flat, (3, seq), dev)
|
||||
}
|
||||
|
||||
|
||||
@@ -526,7 +526,8 @@ impl TpQwen3_5Attention {
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
attn_mask: Option<&Tensor>,
|
||||
offset: usize,
|
||||
cos: &Tensor,
|
||||
sin: &Tensor,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (b, l, _) = x.dims3()?;
|
||||
|
||||
@@ -559,7 +560,7 @@ impl TpQwen3_5Attention {
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
let (q, k) = self.rotary.apply(&q, &k, offset)?;
|
||||
let (q, k) = self.rotary.apply_cos_sin(&q, &k, cos, sin)?;
|
||||
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
||||
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
||||
@@ -807,11 +808,12 @@ impl TpQwen3_5DecoderLayer {
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
attn_mask: Option<&Tensor>,
|
||||
offset: usize,
|
||||
cos: &Tensor,
|
||||
sin: &Tensor,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let h = self.input_layernorm.forward(x)?;
|
||||
let attn_out = match &mut self.attention {
|
||||
TpAttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
|
||||
TpAttentionKind::Full(attn) => attn.forward(&h, attn_mask, cos, sin)?,
|
||||
TpAttentionKind::Linear(net) => net.forward(&h)?,
|
||||
};
|
||||
let x = (x + attn_out)?;
|
||||
@@ -834,6 +836,15 @@ pub struct TpQwen3_5Model {
|
||||
embed_tokens: Embedding,
|
||||
layers: Vec<TpQwen3_5DecoderLayer>,
|
||||
norm: Qwen3_5RmsNorm,
|
||||
/// Replicated rotary, shared with every full-attention layer. The
|
||||
/// model builds the per-forward cos/sin (interleaved M-RoPE for image
|
||||
/// tokens, plain for text) once and the layers apply it. Identical on
|
||||
/// every rank, so per-rank position ids stay consistent.
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
/// `offset + rope_delta` is the text-axis decode position; set from
|
||||
/// `get_rope_index` during a vision prefill, reset in `clear_kv_cache`.
|
||||
/// See `Qwen3_5Model::rope_delta`.
|
||||
rope_delta: i64,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
@@ -900,6 +911,8 @@ impl TpQwen3_5Model {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
rotary,
|
||||
rope_delta: 0,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
@@ -956,6 +969,8 @@ impl TpQwen3_5Model {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
rotary,
|
||||
rope_delta: 0,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
@@ -969,6 +984,14 @@ impl TpQwen3_5Model {
|
||||
for l in &mut self.layers {
|
||||
l.clear_kv_cache();
|
||||
}
|
||||
self.rope_delta = 0;
|
||||
}
|
||||
|
||||
/// Set the decode `rope_delta` computed by `get_rope_index` during a
|
||||
/// vision prefill, so decode after the image resumes text positions
|
||||
/// from the image-compressed counter.
|
||||
pub fn set_rope_delta(&mut self, delta: i64) {
|
||||
self.rope_delta = delta;
|
||||
}
|
||||
|
||||
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||
@@ -980,64 +1003,80 @@ impl TpQwen3_5Model {
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||
let (b, l) = input.dims2()?;
|
||||
let mut h = self.embed_tokens.forward(input)?;
|
||||
let causal = if l == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.causal_mask(b, l, offset)?)
|
||||
};
|
||||
for layer in &mut self.layers {
|
||||
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||
}
|
||||
self.norm.forward(&h)
|
||||
self.forward_inner(input, offset, None, None, None)
|
||||
}
|
||||
|
||||
/// Forward with image-embedding splice (TP, replicated tower).
|
||||
///
|
||||
/// Mirrors the single-GPU `Qwen3_5Model::forward_inner` splice:
|
||||
/// embed locally, replace the rows at `image_token_id` positions
|
||||
/// with the image patch embeddings, then run the sharded decoder
|
||||
/// stack. The TP invariant is that every rank holds an identical
|
||||
/// hidden state (only the attention/MLP matmuls shard, with a
|
||||
/// trailing `AllReduce`). That holds here because every rank
|
||||
/// encodes the *same* pixels through its *replicated* vision tower
|
||||
/// and so produces identical `image_embeds` — no broadcast needed.
|
||||
pub fn forward_with_vision(
|
||||
/// Forward for a vision-prefill chunk: optional image-embedding
|
||||
/// splice plus explicit interleaved-M-RoPE `position_ids` (the
|
||||
/// chunk's slice of the full prompt's 3D positions). Used by
|
||||
/// `TpQwen3_5ForCausalLM::prefill_with_images_chunked`, which
|
||||
/// computes the positions once over the whole prompt and slices them
|
||||
/// per chunk so every rank steps in lockstep.
|
||||
pub fn forward_with_positions(
|
||||
&mut self,
|
||||
input: &Tensor,
|
||||
offset: usize,
|
||||
image_embeds: &Tensor,
|
||||
image_token_id: u32,
|
||||
position_ids: &Tensor,
|
||||
image_embeds: Option<&Tensor>,
|
||||
image_token_id: Option<u32>,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
self.forward_inner(
|
||||
input,
|
||||
offset,
|
||||
image_embeds,
|
||||
image_token_id,
|
||||
Some(position_ids),
|
||||
)
|
||||
}
|
||||
|
||||
/// Shared forward. Splices image embeddings at `image_token_id`
|
||||
/// positions when present, then builds the rotary cos/sin — from the
|
||||
/// explicit `position_ids` (interleaved M-RoPE, vision) when given,
|
||||
/// else plain positions at `offset + rope_delta` (text / decode) —
|
||||
/// and runs the sharded decoder stack. The TP replicated-hidden-state
|
||||
/// invariant holds because every rank encodes the same pixels and
|
||||
/// computes the same positions.
|
||||
fn forward_inner(
|
||||
&mut self,
|
||||
input: &Tensor,
|
||||
offset: usize,
|
||||
image_embeds: Option<&Tensor>,
|
||||
image_token_id: Option<u32>,
|
||||
position_ids: Option<&Tensor>,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (b, l) = input.dims2()?;
|
||||
let mut h = self.embed_tokens.forward(input)?;
|
||||
|
||||
// Locate the image-token positions in the (pre-expanded) input
|
||||
// ids and splice the patch rows in. Same CPU-side scan as the
|
||||
// single-GPU path; the count must match the patch dimension or
|
||||
// the prompt expansion is wrong.
|
||||
if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
|
||||
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
|
||||
let mut positions: Vec<u32> = Vec::with_capacity(image_embeds.dim(0)?);
|
||||
let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?);
|
||||
for (idx, id) in ids.iter().enumerate() {
|
||||
if *id == image_token_id {
|
||||
if *id == tok_id {
|
||||
positions.push(idx as u32);
|
||||
}
|
||||
}
|
||||
let n_img_tokens = image_embeds.dim(0)?;
|
||||
let n_img_tokens = img.dim(0)?;
|
||||
if positions.len() != n_img_tokens {
|
||||
candle_core::bail!(
|
||||
"TP forward_with_vision: prompt has {} image-token positions but \
|
||||
image_embeds carries {} tokens — ensure the per-image patch-count \
|
||||
expansion has been applied",
|
||||
"TP forward: chunk has {} image-token positions but image_embeds carries \
|
||||
{} tokens — patch-count expansion / chunk slicing mismatch",
|
||||
positions.len(),
|
||||
n_img_tokens,
|
||||
);
|
||||
}
|
||||
if !positions.is_empty() {
|
||||
let img = image_embeds.to_dtype(self.dtype)?;
|
||||
let img = img.to_dtype(self.dtype)?;
|
||||
h = splice_runs(&h, &img, &positions)?;
|
||||
}
|
||||
}
|
||||
|
||||
let (cos, sin) = match position_ids {
|
||||
Some(pos) => self.rotary.mrope_cos_sin(pos)?,
|
||||
None => {
|
||||
let base = (offset as i64 + self.rope_delta).max(0) as usize;
|
||||
self.rotary.plain_cos_sin(base, l)?
|
||||
}
|
||||
};
|
||||
|
||||
let causal = if l == 1 {
|
||||
None
|
||||
@@ -1045,7 +1084,7 @@ impl TpQwen3_5Model {
|
||||
Some(self.causal_mask(b, l, offset)?)
|
||||
};
|
||||
for layer in &mut self.layers {
|
||||
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||
h = layer.forward(&h, causal.as_ref(), &cos, &sin)?;
|
||||
}
|
||||
self.norm.forward(&h)
|
||||
}
|
||||
@@ -1174,21 +1213,25 @@ impl TpQwen3_5ForCausalLM {
|
||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
/// Forward with image-embedding splice (TP). Mirrors `forward` but
|
||||
/// routes through `TpQwen3_5Model::forward_with_vision` so the
|
||||
/// per-rank input embeddings get the image patches spliced in at
|
||||
/// `image_token_id` positions before the sharded decoder stack.
|
||||
pub fn forward_with_vision(
|
||||
/// Forward for a vision-prefill chunk (optional image splice +
|
||||
/// explicit interleaved-M-RoPE `position_ids`). Mirrors `forward`
|
||||
/// but routes through `TpQwen3_5Model::forward_with_positions`.
|
||||
pub fn forward_with_positions(
|
||||
&mut self,
|
||||
input: &Tensor,
|
||||
offset: usize,
|
||||
image_embeds: &Tensor,
|
||||
image_token_id: u32,
|
||||
position_ids: &Tensor,
|
||||
image_embeds: Option<&Tensor>,
|
||||
image_token_id: Option<u32>,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (_, l) = input.dims2()?;
|
||||
let hidden = self
|
||||
.base
|
||||
.forward_with_vision(input, offset, image_embeds, image_token_id)?;
|
||||
let hidden = self.base.forward_with_positions(
|
||||
input,
|
||||
offset,
|
||||
position_ids,
|
||||
image_embeds,
|
||||
image_token_id,
|
||||
)?;
|
||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
@@ -1245,6 +1288,21 @@ impl TpQwen3_5ForCausalLM {
|
||||
let device = self.device().clone();
|
||||
let image_embeds = self.encode_images_concat(image_pixels)?;
|
||||
|
||||
// Interleaved-M-RoPE 3D position ids for the whole prompt,
|
||||
// computed once and sliced per chunk so every rank assigns image
|
||||
// tokens their 14×14 grid coordinates (and text after the image
|
||||
// resumes from the compressed counter). `rope_delta` is stored on
|
||||
// the base model for the decode that follows this prefill. Every
|
||||
// chunk — text or image — uses the M-RoPE slice, because the image
|
||||
// shifts the positions of the text around it.
|
||||
let (text, height, width, delta) =
|
||||
crate::harness::arch::qwen3_5::rope::get_rope_index(tokens, image_token_id)
|
||||
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
|
||||
self.base.set_rope_delta(delta);
|
||||
let full_pos = crate::harness::arch::qwen3_5::rope::mrope_position_tensor(
|
||||
&text, &height, &width, &device,
|
||||
)?;
|
||||
|
||||
let mut last_logits: Option<Tensor> = None;
|
||||
// Rows of `image_embeds` already spliced by earlier chunks. The
|
||||
// `<|image_pad|>` run is contiguous, so chunks consume embedding
|
||||
@@ -1255,16 +1313,22 @@ impl TpQwen3_5ForCausalLM {
|
||||
let end = (start + chunk_size).min(tokens.len());
|
||||
let chunk = &tokens[start..end];
|
||||
let input = Tensor::new(chunk, &device)?.unsqueeze(0)?;
|
||||
let pos_slice = full_pos.narrow(1, start, end - start)?;
|
||||
let n_here = chunk.iter().filter(|&&t| t == image_token_id).count();
|
||||
let logits = if n_here == 0 {
|
||||
// Pure-text chunk — same forward the text prefill runs.
|
||||
self.forward(&input, base_offset + start)?
|
||||
self.forward_with_positions(&input, base_offset + start, &pos_slice, None, None)?
|
||||
} else {
|
||||
// Splice the next `n_here` patch rows at this chunk's
|
||||
// local image-pad positions.
|
||||
let rows = image_embeds.narrow(0, img_off, n_here)?;
|
||||
img_off += n_here;
|
||||
self.forward_with_vision(&input, base_offset + start, &rows, image_token_id)?
|
||||
self.forward_with_positions(
|
||||
&input,
|
||||
base_offset + start,
|
||||
&pos_slice,
|
||||
Some(&rows),
|
||||
Some(image_token_id),
|
||||
)?
|
||||
};
|
||||
last_logits = Some(logits);
|
||||
start = end;
|
||||
|
||||
Reference in New Issue
Block a user