diff --git a/crates/neuron/src/harness/arch/qwen3_5/mod.rs b/crates/neuron/src/harness/arch/qwen3_5/mod.rs index 45a4c0d..3431077 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/mod.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/mod.rs @@ -236,7 +236,11 @@ fn default_partial_rotary_factor() -> f32 { /// `slice_assign` per run. For typical Qwen3.6 requests this is one /// or two runs per image; `slice_assign` does one tensor copy per /// run, which is cheap relative to the decoder forward pass. -fn splice_runs(h: &Tensor, img: &Tensor, positions: &[u32]) -> candle_core::Result { +pub(crate) fn splice_runs( + h: &Tensor, + img: &Tensor, + positions: &[u32], +) -> candle_core::Result { debug_assert!( !positions.is_empty(), "splice_runs precondition: non-empty positions" diff --git a/crates/neuron/src/harness/tp/tp_qwen3_5.rs b/crates/neuron/src/harness/tp/tp_qwen3_5.rs index 9fb7d32..8d812ab 100644 --- a/crates/neuron/src/harness/tp/tp_qwen3_5.rs +++ b/crates/neuron/src/harness/tp/tp_qwen3_5.rs @@ -46,6 +46,8 @@ use super::tp_linear::{ColumnParallelLinear, RowParallelLinear}; use crate::harness::arch::qwen3_5::linear_attn::repeat_interleave; use crate::harness::arch::qwen3_5::rmsnorm::{Qwen3_5RmsNorm, Qwen3_5RmsNormGated, l2norm}; use crate::harness::arch::qwen3_5::rope::RotaryEmbedding; +use crate::harness::arch::qwen3_5::splice_runs; +use crate::harness::arch::qwen3_5::vision::VisionTower; pub use crate::harness::arch::qwen3_5::{Config, TextConfig}; // ─── linear-attention (Gated DeltaNet) ────────────────────────────── @@ -990,11 +992,103 @@ impl TpQwen3_5Model { } self.norm.forward(&h) } + + /// Forward with image-embedding splice (TP, replicated tower). + /// + /// Mirrors the single-GPU `Qwen3_5Model::forward_inner` splice: + /// embed locally, replace the rows at `image_token_id` positions + /// with the image patch embeddings, then run the sharded decoder + /// stack. The TP invariant is that every rank holds an identical + /// hidden state (only the attention/MLP matmuls shard, with a + /// trailing `AllReduce`). That holds here because every rank + /// encodes the *same* pixels through its *replicated* vision tower + /// and so produces identical `image_embeds` — no broadcast needed. + pub fn forward_with_vision( + &mut self, + input: &Tensor, + offset: usize, + image_embeds: &Tensor, + image_token_id: u32, + ) -> candle_core::Result { + let (b, l) = input.dims2()?; + let mut h = self.embed_tokens.forward(input)?; + + // Locate the image-token positions in the (pre-expanded) input + // ids and splice the patch rows in. Same CPU-side scan as the + // single-GPU path; the count must match the patch dimension or + // the prompt expansion is wrong. + let ids: Vec = input.flatten_all()?.to_vec1()?; + let mut positions: Vec = Vec::with_capacity(image_embeds.dim(0)?); + for (idx, id) in ids.iter().enumerate() { + if *id == image_token_id { + positions.push(idx as u32); + } + } + let n_img_tokens = image_embeds.dim(0)?; + if positions.len() != n_img_tokens { + candle_core::bail!( + "TP forward_with_vision: prompt has {} image-token positions but \ + image_embeds carries {} tokens — ensure the per-image patch-count \ + expansion has been applied", + positions.len(), + n_img_tokens, + ); + } + if !positions.is_empty() { + let img = image_embeds.to_dtype(self.dtype)?; + h = splice_runs(&h, &img, &positions)?; + } + + let causal = if l == 1 { + None + } else { + Some(self.causal_mask(b, l, offset)?) + }; + for layer in &mut self.layers { + h = layer.forward(&h, causal.as_ref(), offset)?; + } + self.norm.forward(&h) + } } pub struct TpQwen3_5ForCausalLM { base: TpQwen3_5Model, lm_head: super::tp_linear::MaybeQuantLinear, + /// Replicated vision tower (TP-vision). Loaded on every rank from + /// the full, unsharded `model.visual.*` weights; `None` for + /// text-only checkpoints. Each rank encodes the same image + /// independently — no sharding, no broadcast — which keeps the + /// spliced input embeddings identical across ranks (the + /// replicated-hidden-state invariant the sharded layers rely on). + vision: Option, + /// `<|image_pad|>` sentinel id (mirrors `Config::image_token_id`); + /// the splice target for `forward_with_vision`. + image_token_id: Option, +} + +/// Load the replicated vision tower from the unsharded `model.visual.*` +/// weights when the config carries a `vision_config` block. Shared by +/// the cuda and non-cuda `load` variants. `vb.pp("model.visual")` +/// resolves against the same full safetensors every rank mmaps; plain +/// `.get()` on a `ShardedVarBuilder` returns the full (replicated) +/// tensor, so this loads identically regardless of `world_size`. +fn load_replicated_vision_tower( + config: &Config, + vb: &ShardedVarBuilder, +) -> Result> { + match config.vision_config.clone() { + Some(vcfg) => { + tracing::info!( + depth = vcfg.depth, + hidden_size = vcfg.hidden_size, + "loading qwen3_5 vision tower (TP replicated)" + ); + let tower = VisionTower::load(vcfg, vb.pp("model.visual")) + .context("load qwen3_5 vision tower (model.visual.*) [TP replicated]")?; + Ok(Some(tower)) + } + None => Ok(None), + } } impl TpQwen3_5ForCausalLM { @@ -1012,7 +1106,14 @@ impl TpQwen3_5ForCausalLM { let cfg = &config.text_config; let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?; let lm_head = build_lm_head(cfg, vb, &base, quant)?; - let model = Self { base, lm_head }; + let vision = load_replicated_vision_tower(&config, vb)?; + let image_token_id = config.image_token_id; + let model = Self { + base, + lm_head, + vision, + image_token_id, + }; log_construction_complete(cfg, rank, world_size, quant, model.device()); Ok(model) } @@ -1029,17 +1130,68 @@ impl TpQwen3_5ForCausalLM { let cfg = &config.text_config; let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?; let lm_head = build_lm_head(cfg, vb, &base, quant)?; - let model = Self { base, lm_head }; + let vision = load_replicated_vision_tower(&config, vb)?; + let image_token_id = config.image_token_id; + let model = Self { + base, + lm_head, + vision, + image_token_id, + }; log_construction_complete(cfg, rank, world_size, quant, model.device()); Ok(model) } + /// True when this TP load materialised a replicated vision tower. + /// Drives capability advertising and the Stage 3 vision dispatch. + pub fn has_vision(&self) -> bool { + self.vision.is_some() + } + + /// `<|image_pad|>` sentinel id, when known. + pub fn image_token_id(&self) -> Option { + self.image_token_id + } + + /// Encode one preprocessed `(C, H, W)` image into LM-side patch + /// embeddings `(N_lm, hidden)` via this rank's replicated tower. + /// Errors when loaded without a vision tower. + pub fn encode_image(&self, image: &Tensor) -> Result { + self.vision + .as_ref() + .ok_or_else(|| { + anyhow::anyhow!( + "encode_image: this TP Qwen3.6 load has no vision tower \ + (config.json::vision_config absent or weights missing)" + ) + })? + .forward(image) + } + pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result { let (_, l) = input.dims2()?; let hidden = self.base.forward(input, offset)?; hidden.i((.., l - 1.., ..))?.apply(&self.lm_head) } + /// Forward with image-embedding splice (TP). Mirrors `forward` but + /// routes through `TpQwen3_5Model::forward_with_vision` so the + /// per-rank input embeddings get the image patches spliced in at + /// `image_token_id` positions before the sharded decoder stack. + pub fn forward_with_vision( + &mut self, + input: &Tensor, + offset: usize, + image_embeds: &Tensor, + image_token_id: u32, + ) -> candle_core::Result { + let (_, l) = input.dims2()?; + let hidden = self + .base + .forward_with_vision(input, offset, image_embeds, image_token_id)?; + hidden.i((.., l - 1.., ..))?.apply(&self.lm_head) + } + pub fn clear_kv_cache(&mut self) { self.base.clear_kv_cache(); }