feat(neuron): TP-vision Stage 1 — replicated vision tower on the TP model
Load the full, unsharded model.visual.* vision tower on every TP rank (leader + each subprocess worker mmaps the same local safetensors) when config.vision_config is present. VisionTower::load already takes a ShardedVarBuilder whose plain .get() returns the full replicated tensor, so the tower loads identically regardless of world_size — no sharding, no NCCL broadcast. - TpQwen3_5ForCausalLM gains vision: Option<VisionTower> + image_token_id, plus has_vision/image_token_id/encode_image/forward_with_vision, mirroring the single-GPU Qwen3_5ForCausalLM wrapper. - TpQwen3_5Model::forward_with_vision mirrors the single-GPU forward_inner splice: embed locally, replace rows at image_token_id positions, run the sharded decoder stack. Because every rank encodes the same pixels through its replicated tower, the spliced input embeddings are identical across ranks — preserving the TP replicated-hidden-state invariant the row-parallel AllReduce relies on. - splice_runs is now pub(crate) and shared with the TP model. No caller yet — Stage 2 wires the RPC/worker path that invokes encode_image + forward_with_vision per rank. Most of this compiles on the non-cuda build (only the cuda load variant's tower line is gated); CI's CUDA type-check covers the rest. Refs TP-vision plan Stage 1. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -236,7 +236,11 @@ fn default_partial_rotary_factor() -> f32 {
|
||||
/// `slice_assign` per run. For typical Qwen3.6 requests this is one
|
||||
/// or two runs per image; `slice_assign` does one tensor copy per
|
||||
/// run, which is cheap relative to the decoder forward pass.
|
||||
fn splice_runs(h: &Tensor, img: &Tensor, positions: &[u32]) -> candle_core::Result<Tensor> {
|
||||
pub(crate) fn splice_runs(
|
||||
h: &Tensor,
|
||||
img: &Tensor,
|
||||
positions: &[u32],
|
||||
) -> candle_core::Result<Tensor> {
|
||||
debug_assert!(
|
||||
!positions.is_empty(),
|
||||
"splice_runs precondition: non-empty positions"
|
||||
|
||||
@@ -46,6 +46,8 @@ use super::tp_linear::{ColumnParallelLinear, RowParallelLinear};
|
||||
use crate::harness::arch::qwen3_5::linear_attn::repeat_interleave;
|
||||
use crate::harness::arch::qwen3_5::rmsnorm::{Qwen3_5RmsNorm, Qwen3_5RmsNormGated, l2norm};
|
||||
use crate::harness::arch::qwen3_5::rope::RotaryEmbedding;
|
||||
use crate::harness::arch::qwen3_5::splice_runs;
|
||||
use crate::harness::arch::qwen3_5::vision::VisionTower;
|
||||
pub use crate::harness::arch::qwen3_5::{Config, TextConfig};
|
||||
|
||||
// ─── linear-attention (Gated DeltaNet) ──────────────────────────────
|
||||
@@ -990,11 +992,103 @@ impl TpQwen3_5Model {
|
||||
}
|
||||
self.norm.forward(&h)
|
||||
}
|
||||
|
||||
/// Forward with image-embedding splice (TP, replicated tower).
|
||||
///
|
||||
/// Mirrors the single-GPU `Qwen3_5Model::forward_inner` splice:
|
||||
/// embed locally, replace the rows at `image_token_id` positions
|
||||
/// with the image patch embeddings, then run the sharded decoder
|
||||
/// stack. The TP invariant is that every rank holds an identical
|
||||
/// hidden state (only the attention/MLP matmuls shard, with a
|
||||
/// trailing `AllReduce`). That holds here because every rank
|
||||
/// encodes the *same* pixels through its *replicated* vision tower
|
||||
/// and so produces identical `image_embeds` — no broadcast needed.
|
||||
pub fn forward_with_vision(
|
||||
&mut self,
|
||||
input: &Tensor,
|
||||
offset: usize,
|
||||
image_embeds: &Tensor,
|
||||
image_token_id: u32,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (b, l) = input.dims2()?;
|
||||
let mut h = self.embed_tokens.forward(input)?;
|
||||
|
||||
// Locate the image-token positions in the (pre-expanded) input
|
||||
// ids and splice the patch rows in. Same CPU-side scan as the
|
||||
// single-GPU path; the count must match the patch dimension or
|
||||
// the prompt expansion is wrong.
|
||||
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
|
||||
let mut positions: Vec<u32> = Vec::with_capacity(image_embeds.dim(0)?);
|
||||
for (idx, id) in ids.iter().enumerate() {
|
||||
if *id == image_token_id {
|
||||
positions.push(idx as u32);
|
||||
}
|
||||
}
|
||||
let n_img_tokens = image_embeds.dim(0)?;
|
||||
if positions.len() != n_img_tokens {
|
||||
candle_core::bail!(
|
||||
"TP forward_with_vision: prompt has {} image-token positions but \
|
||||
image_embeds carries {} tokens — ensure the per-image patch-count \
|
||||
expansion has been applied",
|
||||
positions.len(),
|
||||
n_img_tokens,
|
||||
);
|
||||
}
|
||||
if !positions.is_empty() {
|
||||
let img = image_embeds.to_dtype(self.dtype)?;
|
||||
h = splice_runs(&h, &img, &positions)?;
|
||||
}
|
||||
|
||||
let causal = if l == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.causal_mask(b, l, offset)?)
|
||||
};
|
||||
for layer in &mut self.layers {
|
||||
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||
}
|
||||
self.norm.forward(&h)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TpQwen3_5ForCausalLM {
|
||||
base: TpQwen3_5Model,
|
||||
lm_head: super::tp_linear::MaybeQuantLinear,
|
||||
/// Replicated vision tower (TP-vision). Loaded on every rank from
|
||||
/// the full, unsharded `model.visual.*` weights; `None` for
|
||||
/// text-only checkpoints. Each rank encodes the same image
|
||||
/// independently — no sharding, no broadcast — which keeps the
|
||||
/// spliced input embeddings identical across ranks (the
|
||||
/// replicated-hidden-state invariant the sharded layers rely on).
|
||||
vision: Option<VisionTower>,
|
||||
/// `<|image_pad|>` sentinel id (mirrors `Config::image_token_id`);
|
||||
/// the splice target for `forward_with_vision`.
|
||||
image_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
/// Load the replicated vision tower from the unsharded `model.visual.*`
|
||||
/// weights when the config carries a `vision_config` block. Shared by
|
||||
/// the cuda and non-cuda `load` variants. `vb.pp("model.visual")`
|
||||
/// resolves against the same full safetensors every rank mmaps; plain
|
||||
/// `.get()` on a `ShardedVarBuilder` returns the full (replicated)
|
||||
/// tensor, so this loads identically regardless of `world_size`.
|
||||
fn load_replicated_vision_tower(
|
||||
config: &Config,
|
||||
vb: &ShardedVarBuilder,
|
||||
) -> Result<Option<VisionTower>> {
|
||||
match config.vision_config.clone() {
|
||||
Some(vcfg) => {
|
||||
tracing::info!(
|
||||
depth = vcfg.depth,
|
||||
hidden_size = vcfg.hidden_size,
|
||||
"loading qwen3_5 vision tower (TP replicated)"
|
||||
);
|
||||
let tower = VisionTower::load(vcfg, vb.pp("model.visual"))
|
||||
.context("load qwen3_5 vision tower (model.visual.*) [TP replicated]")?;
|
||||
Ok(Some(tower))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
impl TpQwen3_5ForCausalLM {
|
||||
@@ -1012,7 +1106,14 @@ impl TpQwen3_5ForCausalLM {
|
||||
let cfg = &config.text_config;
|
||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?;
|
||||
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
||||
let model = Self { base, lm_head };
|
||||
let vision = load_replicated_vision_tower(&config, vb)?;
|
||||
let image_token_id = config.image_token_id;
|
||||
let model = Self {
|
||||
base,
|
||||
lm_head,
|
||||
vision,
|
||||
image_token_id,
|
||||
};
|
||||
log_construction_complete(cfg, rank, world_size, quant, model.device());
|
||||
Ok(model)
|
||||
}
|
||||
@@ -1029,17 +1130,68 @@ impl TpQwen3_5ForCausalLM {
|
||||
let cfg = &config.text_config;
|
||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?;
|
||||
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
||||
let model = Self { base, lm_head };
|
||||
let vision = load_replicated_vision_tower(&config, vb)?;
|
||||
let image_token_id = config.image_token_id;
|
||||
let model = Self {
|
||||
base,
|
||||
lm_head,
|
||||
vision,
|
||||
image_token_id,
|
||||
};
|
||||
log_construction_complete(cfg, rank, world_size, quant, model.device());
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
/// True when this TP load materialised a replicated vision tower.
|
||||
/// Drives capability advertising and the Stage 3 vision dispatch.
|
||||
pub fn has_vision(&self) -> bool {
|
||||
self.vision.is_some()
|
||||
}
|
||||
|
||||
/// `<|image_pad|>` sentinel id, when known.
|
||||
pub fn image_token_id(&self) -> Option<u32> {
|
||||
self.image_token_id
|
||||
}
|
||||
|
||||
/// Encode one preprocessed `(C, H, W)` image into LM-side patch
|
||||
/// embeddings `(N_lm, hidden)` via this rank's replicated tower.
|
||||
/// Errors when loaded without a vision tower.
|
||||
pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
|
||||
self.vision
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"encode_image: this TP Qwen3.6 load has no vision tower \
|
||||
(config.json::vision_config absent or weights missing)"
|
||||
)
|
||||
})?
|
||||
.forward(image)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||
let (_, l) = input.dims2()?;
|
||||
let hidden = self.base.forward(input, offset)?;
|
||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
/// Forward with image-embedding splice (TP). Mirrors `forward` but
|
||||
/// routes through `TpQwen3_5Model::forward_with_vision` so the
|
||||
/// per-rank input embeddings get the image patches spliced in at
|
||||
/// `image_token_id` positions before the sharded decoder stack.
|
||||
pub fn forward_with_vision(
|
||||
&mut self,
|
||||
input: &Tensor,
|
||||
offset: usize,
|
||||
image_embeds: &Tensor,
|
||||
image_token_id: u32,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (_, l) = input.dims2()?;
|
||||
let hidden = self
|
||||
.base
|
||||
.forward_with_vision(input, offset, image_embeds, image_token_id)?;
|
||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.base.clear_kv_cache();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user