feat(neuron): TP-vision Stage 1 — replicated vision tower on the TP model

Load the full, unsharded model.visual.* vision tower on every TP rank
(leader + each subprocess worker mmaps the same local safetensors) when
config.vision_config is present. VisionTower::load already takes a
ShardedVarBuilder whose plain .get() returns the full replicated tensor,
so the tower loads identically regardless of world_size — no sharding,
no NCCL broadcast.

- TpQwen3_5ForCausalLM gains vision: Option<VisionTower> + image_token_id,
  plus has_vision/image_token_id/encode_image/forward_with_vision,
  mirroring the single-GPU Qwen3_5ForCausalLM wrapper.
- TpQwen3_5Model::forward_with_vision mirrors the single-GPU
  forward_inner splice: embed locally, replace rows at image_token_id
  positions, run the sharded decoder stack. Because every rank encodes
  the same pixels through its replicated tower, the spliced input
  embeddings are identical across ranks — preserving the TP
  replicated-hidden-state invariant the row-parallel AllReduce relies on.
- splice_runs is now pub(crate) and shared with the TP model.

No caller yet — Stage 2 wires the RPC/worker path that invokes
encode_image + forward_with_vision per rank. Most of this compiles on
the non-cuda build (only the cuda load variant's tower line is gated);
CI's CUDA type-check covers the rest.

Refs TP-vision plan Stage 1.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-04 15:00:05 +03:00
parent 7bb033b4ed
commit 9a24b05866
2 changed files with 159 additions and 3 deletions

View File

@@ -236,7 +236,11 @@ fn default_partial_rotary_factor() -> f32 {
/// `slice_assign` per run. For typical Qwen3.6 requests this is one /// `slice_assign` per run. For typical Qwen3.6 requests this is one
/// or two runs per image; `slice_assign` does one tensor copy per /// or two runs per image; `slice_assign` does one tensor copy per
/// run, which is cheap relative to the decoder forward pass. /// run, which is cheap relative to the decoder forward pass.
fn splice_runs(h: &Tensor, img: &Tensor, positions: &[u32]) -> candle_core::Result<Tensor> { pub(crate) fn splice_runs(
h: &Tensor,
img: &Tensor,
positions: &[u32],
) -> candle_core::Result<Tensor> {
debug_assert!( debug_assert!(
!positions.is_empty(), !positions.is_empty(),
"splice_runs precondition: non-empty positions" "splice_runs precondition: non-empty positions"

View File

@@ -46,6 +46,8 @@ use super::tp_linear::{ColumnParallelLinear, RowParallelLinear};
use crate::harness::arch::qwen3_5::linear_attn::repeat_interleave; use crate::harness::arch::qwen3_5::linear_attn::repeat_interleave;
use crate::harness::arch::qwen3_5::rmsnorm::{Qwen3_5RmsNorm, Qwen3_5RmsNormGated, l2norm}; use crate::harness::arch::qwen3_5::rmsnorm::{Qwen3_5RmsNorm, Qwen3_5RmsNormGated, l2norm};
use crate::harness::arch::qwen3_5::rope::RotaryEmbedding; use crate::harness::arch::qwen3_5::rope::RotaryEmbedding;
use crate::harness::arch::qwen3_5::splice_runs;
use crate::harness::arch::qwen3_5::vision::VisionTower;
pub use crate::harness::arch::qwen3_5::{Config, TextConfig}; pub use crate::harness::arch::qwen3_5::{Config, TextConfig};
// ─── linear-attention (Gated DeltaNet) ────────────────────────────── // ─── linear-attention (Gated DeltaNet) ──────────────────────────────
@@ -990,11 +992,103 @@ impl TpQwen3_5Model {
} }
self.norm.forward(&h) self.norm.forward(&h)
} }
/// Forward with image-embedding splice (TP, replicated tower).
///
/// Mirrors the single-GPU `Qwen3_5Model::forward_inner` splice:
/// embed locally, replace the rows at `image_token_id` positions
/// with the image patch embeddings, then run the sharded decoder
/// stack. The TP invariant is that every rank holds an identical
/// hidden state (only the attention/MLP matmuls shard, with a
/// trailing `AllReduce`). That holds here because every rank
/// encodes the *same* pixels through its *replicated* vision tower
/// and so produces identical `image_embeds` — no broadcast needed.
pub fn forward_with_vision(
&mut self,
input: &Tensor,
offset: usize,
image_embeds: &Tensor,
image_token_id: u32,
) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
// Locate the image-token positions in the (pre-expanded) input
// ids and splice the patch rows in. Same CPU-side scan as the
// single-GPU path; the count must match the patch dimension or
// the prompt expansion is wrong.
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
let mut positions: Vec<u32> = Vec::with_capacity(image_embeds.dim(0)?);
for (idx, id) in ids.iter().enumerate() {
if *id == image_token_id {
positions.push(idx as u32);
}
}
let n_img_tokens = image_embeds.dim(0)?;
if positions.len() != n_img_tokens {
candle_core::bail!(
"TP forward_with_vision: prompt has {} image-token positions but \
image_embeds carries {} tokens — ensure the per-image patch-count \
expansion has been applied",
positions.len(),
n_img_tokens,
);
}
if !positions.is_empty() {
let img = image_embeds.to_dtype(self.dtype)?;
h = splice_runs(&h, &img, &positions)?;
}
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
}
self.norm.forward(&h)
}
} }
pub struct TpQwen3_5ForCausalLM { pub struct TpQwen3_5ForCausalLM {
base: TpQwen3_5Model, base: TpQwen3_5Model,
lm_head: super::tp_linear::MaybeQuantLinear, lm_head: super::tp_linear::MaybeQuantLinear,
/// Replicated vision tower (TP-vision). Loaded on every rank from
/// the full, unsharded `model.visual.*` weights; `None` for
/// text-only checkpoints. Each rank encodes the same image
/// independently — no sharding, no broadcast — which keeps the
/// spliced input embeddings identical across ranks (the
/// replicated-hidden-state invariant the sharded layers rely on).
vision: Option<VisionTower>,
/// `<|image_pad|>` sentinel id (mirrors `Config::image_token_id`);
/// the splice target for `forward_with_vision`.
image_token_id: Option<u32>,
}
/// Load the replicated vision tower from the unsharded `model.visual.*`
/// weights when the config carries a `vision_config` block. Shared by
/// the cuda and non-cuda `load` variants. `vb.pp("model.visual")`
/// resolves against the same full safetensors every rank mmaps; plain
/// `.get()` on a `ShardedVarBuilder` returns the full (replicated)
/// tensor, so this loads identically regardless of `world_size`.
fn load_replicated_vision_tower(
config: &Config,
vb: &ShardedVarBuilder,
) -> Result<Option<VisionTower>> {
match config.vision_config.clone() {
Some(vcfg) => {
tracing::info!(
depth = vcfg.depth,
hidden_size = vcfg.hidden_size,
"loading qwen3_5 vision tower (TP replicated)"
);
let tower = VisionTower::load(vcfg, vb.pp("model.visual"))
.context("load qwen3_5 vision tower (model.visual.*) [TP replicated]")?;
Ok(Some(tower))
}
None => Ok(None),
}
} }
impl TpQwen3_5ForCausalLM { impl TpQwen3_5ForCausalLM {
@@ -1012,7 +1106,14 @@ impl TpQwen3_5ForCausalLM {
let cfg = &config.text_config; let cfg = &config.text_config;
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?; let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?;
let lm_head = build_lm_head(cfg, vb, &base, quant)?; let lm_head = build_lm_head(cfg, vb, &base, quant)?;
let model = Self { base, lm_head }; let vision = load_replicated_vision_tower(&config, vb)?;
let image_token_id = config.image_token_id;
let model = Self {
base,
lm_head,
vision,
image_token_id,
};
log_construction_complete(cfg, rank, world_size, quant, model.device()); log_construction_complete(cfg, rank, world_size, quant, model.device());
Ok(model) Ok(model)
} }
@@ -1029,17 +1130,68 @@ impl TpQwen3_5ForCausalLM {
let cfg = &config.text_config; let cfg = &config.text_config;
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?; let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?;
let lm_head = build_lm_head(cfg, vb, &base, quant)?; let lm_head = build_lm_head(cfg, vb, &base, quant)?;
let model = Self { base, lm_head }; let vision = load_replicated_vision_tower(&config, vb)?;
let image_token_id = config.image_token_id;
let model = Self {
base,
lm_head,
vision,
image_token_id,
};
log_construction_complete(cfg, rank, world_size, quant, model.device()); log_construction_complete(cfg, rank, world_size, quant, model.device());
Ok(model) Ok(model)
} }
/// True when this TP load materialised a replicated vision tower.
/// Drives capability advertising and the Stage 3 vision dispatch.
pub fn has_vision(&self) -> bool {
self.vision.is_some()
}
/// `<|image_pad|>` sentinel id, when known.
pub fn image_token_id(&self) -> Option<u32> {
self.image_token_id
}
/// Encode one preprocessed `(C, H, W)` image into LM-side patch
/// embeddings `(N_lm, hidden)` via this rank's replicated tower.
/// Errors when loaded without a vision tower.
pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
self.vision
.as_ref()
.ok_or_else(|| {
anyhow::anyhow!(
"encode_image: this TP Qwen3.6 load has no vision tower \
(config.json::vision_config absent or weights missing)"
)
})?
.forward(image)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> { pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?; let (_, l) = input.dims2()?;
let hidden = self.base.forward(input, offset)?; let hidden = self.base.forward(input, offset)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head) hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
} }
/// Forward with image-embedding splice (TP). Mirrors `forward` but
/// routes through `TpQwen3_5Model::forward_with_vision` so the
/// per-rank input embeddings get the image patches spliced in at
/// `image_token_id` positions before the sharded decoder stack.
pub fn forward_with_vision(
&mut self,
input: &Tensor,
offset: usize,
image_embeds: &Tensor,
image_token_id: u32,
) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?;
let hidden = self
.base
.forward_with_vision(input, offset, image_embeds, image_token_id)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) { pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache(); self.base.clear_kv_cache();
} }