feat(neuron): TP-vision Stage 1 — replicated vision tower on the TP model
Load the full, unsharded model.visual.* vision tower on every TP rank (leader + each subprocess worker mmaps the same local safetensors) when config.vision_config is present. VisionTower::load already takes a ShardedVarBuilder whose plain .get() returns the full replicated tensor, so the tower loads identically regardless of world_size — no sharding, no NCCL broadcast. - TpQwen3_5ForCausalLM gains vision: Option<VisionTower> + image_token_id, plus has_vision/image_token_id/encode_image/forward_with_vision, mirroring the single-GPU Qwen3_5ForCausalLM wrapper. - TpQwen3_5Model::forward_with_vision mirrors the single-GPU forward_inner splice: embed locally, replace rows at image_token_id positions, run the sharded decoder stack. Because every rank encodes the same pixels through its replicated tower, the spliced input embeddings are identical across ranks — preserving the TP replicated-hidden-state invariant the row-parallel AllReduce relies on. - splice_runs is now pub(crate) and shared with the TP model. No caller yet — Stage 2 wires the RPC/worker path that invokes encode_image + forward_with_vision per rank. Most of this compiles on the non-cuda build (only the cuda load variant's tower line is gated); CI's CUDA type-check covers the rest. Refs TP-vision plan Stage 1. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -236,7 +236,11 @@ fn default_partial_rotary_factor() -> f32 {
|
|||||||
/// `slice_assign` per run. For typical Qwen3.6 requests this is one
|
/// `slice_assign` per run. For typical Qwen3.6 requests this is one
|
||||||
/// or two runs per image; `slice_assign` does one tensor copy per
|
/// or two runs per image; `slice_assign` does one tensor copy per
|
||||||
/// run, which is cheap relative to the decoder forward pass.
|
/// run, which is cheap relative to the decoder forward pass.
|
||||||
fn splice_runs(h: &Tensor, img: &Tensor, positions: &[u32]) -> candle_core::Result<Tensor> {
|
pub(crate) fn splice_runs(
|
||||||
|
h: &Tensor,
|
||||||
|
img: &Tensor,
|
||||||
|
positions: &[u32],
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
!positions.is_empty(),
|
!positions.is_empty(),
|
||||||
"splice_runs precondition: non-empty positions"
|
"splice_runs precondition: non-empty positions"
|
||||||
|
|||||||
@@ -46,6 +46,8 @@ use super::tp_linear::{ColumnParallelLinear, RowParallelLinear};
|
|||||||
use crate::harness::arch::qwen3_5::linear_attn::repeat_interleave;
|
use crate::harness::arch::qwen3_5::linear_attn::repeat_interleave;
|
||||||
use crate::harness::arch::qwen3_5::rmsnorm::{Qwen3_5RmsNorm, Qwen3_5RmsNormGated, l2norm};
|
use crate::harness::arch::qwen3_5::rmsnorm::{Qwen3_5RmsNorm, Qwen3_5RmsNormGated, l2norm};
|
||||||
use crate::harness::arch::qwen3_5::rope::RotaryEmbedding;
|
use crate::harness::arch::qwen3_5::rope::RotaryEmbedding;
|
||||||
|
use crate::harness::arch::qwen3_5::splice_runs;
|
||||||
|
use crate::harness::arch::qwen3_5::vision::VisionTower;
|
||||||
pub use crate::harness::arch::qwen3_5::{Config, TextConfig};
|
pub use crate::harness::arch::qwen3_5::{Config, TextConfig};
|
||||||
|
|
||||||
// ─── linear-attention (Gated DeltaNet) ──────────────────────────────
|
// ─── linear-attention (Gated DeltaNet) ──────────────────────────────
|
||||||
@@ -990,11 +992,103 @@ impl TpQwen3_5Model {
|
|||||||
}
|
}
|
||||||
self.norm.forward(&h)
|
self.norm.forward(&h)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Forward with image-embedding splice (TP, replicated tower).
|
||||||
|
///
|
||||||
|
/// Mirrors the single-GPU `Qwen3_5Model::forward_inner` splice:
|
||||||
|
/// embed locally, replace the rows at `image_token_id` positions
|
||||||
|
/// with the image patch embeddings, then run the sharded decoder
|
||||||
|
/// stack. The TP invariant is that every rank holds an identical
|
||||||
|
/// hidden state (only the attention/MLP matmuls shard, with a
|
||||||
|
/// trailing `AllReduce`). That holds here because every rank
|
||||||
|
/// encodes the *same* pixels through its *replicated* vision tower
|
||||||
|
/// and so produces identical `image_embeds` — no broadcast needed.
|
||||||
|
pub fn forward_with_vision(
|
||||||
|
&mut self,
|
||||||
|
input: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
image_embeds: &Tensor,
|
||||||
|
image_token_id: u32,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let (b, l) = input.dims2()?;
|
||||||
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
|
|
||||||
|
// Locate the image-token positions in the (pre-expanded) input
|
||||||
|
// ids and splice the patch rows in. Same CPU-side scan as the
|
||||||
|
// single-GPU path; the count must match the patch dimension or
|
||||||
|
// the prompt expansion is wrong.
|
||||||
|
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
|
||||||
|
let mut positions: Vec<u32> = Vec::with_capacity(image_embeds.dim(0)?);
|
||||||
|
for (idx, id) in ids.iter().enumerate() {
|
||||||
|
if *id == image_token_id {
|
||||||
|
positions.push(idx as u32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let n_img_tokens = image_embeds.dim(0)?;
|
||||||
|
if positions.len() != n_img_tokens {
|
||||||
|
candle_core::bail!(
|
||||||
|
"TP forward_with_vision: prompt has {} image-token positions but \
|
||||||
|
image_embeds carries {} tokens — ensure the per-image patch-count \
|
||||||
|
expansion has been applied",
|
||||||
|
positions.len(),
|
||||||
|
n_img_tokens,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if !positions.is_empty() {
|
||||||
|
let img = image_embeds.to_dtype(self.dtype)?;
|
||||||
|
h = splice_runs(&h, &img, &positions)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let causal = if l == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.causal_mask(b, l, offset)?)
|
||||||
|
};
|
||||||
|
for layer in &mut self.layers {
|
||||||
|
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||||
|
}
|
||||||
|
self.norm.forward(&h)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct TpQwen3_5ForCausalLM {
|
pub struct TpQwen3_5ForCausalLM {
|
||||||
base: TpQwen3_5Model,
|
base: TpQwen3_5Model,
|
||||||
lm_head: super::tp_linear::MaybeQuantLinear,
|
lm_head: super::tp_linear::MaybeQuantLinear,
|
||||||
|
/// Replicated vision tower (TP-vision). Loaded on every rank from
|
||||||
|
/// the full, unsharded `model.visual.*` weights; `None` for
|
||||||
|
/// text-only checkpoints. Each rank encodes the same image
|
||||||
|
/// independently — no sharding, no broadcast — which keeps the
|
||||||
|
/// spliced input embeddings identical across ranks (the
|
||||||
|
/// replicated-hidden-state invariant the sharded layers rely on).
|
||||||
|
vision: Option<VisionTower>,
|
||||||
|
/// `<|image_pad|>` sentinel id (mirrors `Config::image_token_id`);
|
||||||
|
/// the splice target for `forward_with_vision`.
|
||||||
|
image_token_id: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load the replicated vision tower from the unsharded `model.visual.*`
|
||||||
|
/// weights when the config carries a `vision_config` block. Shared by
|
||||||
|
/// the cuda and non-cuda `load` variants. `vb.pp("model.visual")`
|
||||||
|
/// resolves against the same full safetensors every rank mmaps; plain
|
||||||
|
/// `.get()` on a `ShardedVarBuilder` returns the full (replicated)
|
||||||
|
/// tensor, so this loads identically regardless of `world_size`.
|
||||||
|
fn load_replicated_vision_tower(
|
||||||
|
config: &Config,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
) -> Result<Option<VisionTower>> {
|
||||||
|
match config.vision_config.clone() {
|
||||||
|
Some(vcfg) => {
|
||||||
|
tracing::info!(
|
||||||
|
depth = vcfg.depth,
|
||||||
|
hidden_size = vcfg.hidden_size,
|
||||||
|
"loading qwen3_5 vision tower (TP replicated)"
|
||||||
|
);
|
||||||
|
let tower = VisionTower::load(vcfg, vb.pp("model.visual"))
|
||||||
|
.context("load qwen3_5 vision tower (model.visual.*) [TP replicated]")?;
|
||||||
|
Ok(Some(tower))
|
||||||
|
}
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TpQwen3_5ForCausalLM {
|
impl TpQwen3_5ForCausalLM {
|
||||||
@@ -1012,7 +1106,14 @@ impl TpQwen3_5ForCausalLM {
|
|||||||
let cfg = &config.text_config;
|
let cfg = &config.text_config;
|
||||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?;
|
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?;
|
||||||
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
||||||
let model = Self { base, lm_head };
|
let vision = load_replicated_vision_tower(&config, vb)?;
|
||||||
|
let image_token_id = config.image_token_id;
|
||||||
|
let model = Self {
|
||||||
|
base,
|
||||||
|
lm_head,
|
||||||
|
vision,
|
||||||
|
image_token_id,
|
||||||
|
};
|
||||||
log_construction_complete(cfg, rank, world_size, quant, model.device());
|
log_construction_complete(cfg, rank, world_size, quant, model.device());
|
||||||
Ok(model)
|
Ok(model)
|
||||||
}
|
}
|
||||||
@@ -1029,17 +1130,68 @@ impl TpQwen3_5ForCausalLM {
|
|||||||
let cfg = &config.text_config;
|
let cfg = &config.text_config;
|
||||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?;
|
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?;
|
||||||
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
||||||
let model = Self { base, lm_head };
|
let vision = load_replicated_vision_tower(&config, vb)?;
|
||||||
|
let image_token_id = config.image_token_id;
|
||||||
|
let model = Self {
|
||||||
|
base,
|
||||||
|
lm_head,
|
||||||
|
vision,
|
||||||
|
image_token_id,
|
||||||
|
};
|
||||||
log_construction_complete(cfg, rank, world_size, quant, model.device());
|
log_construction_complete(cfg, rank, world_size, quant, model.device());
|
||||||
Ok(model)
|
Ok(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// True when this TP load materialised a replicated vision tower.
|
||||||
|
/// Drives capability advertising and the Stage 3 vision dispatch.
|
||||||
|
pub fn has_vision(&self) -> bool {
|
||||||
|
self.vision.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `<|image_pad|>` sentinel id, when known.
|
||||||
|
pub fn image_token_id(&self) -> Option<u32> {
|
||||||
|
self.image_token_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode one preprocessed `(C, H, W)` image into LM-side patch
|
||||||
|
/// embeddings `(N_lm, hidden)` via this rank's replicated tower.
|
||||||
|
/// Errors when loaded without a vision tower.
|
||||||
|
pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
|
||||||
|
self.vision
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"encode_image: this TP Qwen3.6 load has no vision tower \
|
||||||
|
(config.json::vision_config absent or weights missing)"
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
.forward(image)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
let (_, l) = input.dims2()?;
|
let (_, l) = input.dims2()?;
|
||||||
let hidden = self.base.forward(input, offset)?;
|
let hidden = self.base.forward(input, offset)?;
|
||||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Forward with image-embedding splice (TP). Mirrors `forward` but
|
||||||
|
/// routes through `TpQwen3_5Model::forward_with_vision` so the
|
||||||
|
/// per-rank input embeddings get the image patches spliced in at
|
||||||
|
/// `image_token_id` positions before the sharded decoder stack.
|
||||||
|
pub fn forward_with_vision(
|
||||||
|
&mut self,
|
||||||
|
input: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
image_embeds: &Tensor,
|
||||||
|
image_token_id: u32,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let (_, l) = input.dims2()?;
|
||||||
|
let hidden = self
|
||||||
|
.base
|
||||||
|
.forward_with_vision(input, offset, image_embeds, image_token_id)?;
|
||||||
|
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn clear_kv_cache(&mut self) {
|
pub fn clear_kv_cache(&mut self) {
|
||||||
self.base.clear_kv_cache();
|
self.base.clear_kv_cache();
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user