diff --git a/Cargo.lock b/Cargo.lock index d411ace..77ddcc7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -472,6 +472,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "byteorder-lite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" + [[package]] name = "bytes" version = "1.11.1" @@ -668,6 +674,12 @@ dependencies = [ "cc", ] +[[package]] +name = "color_quant" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" + [[package]] name = "colorchoice" version = "1.0.5" @@ -1223,6 +1235,15 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" +[[package]] +name = "fdeflate" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" +dependencies = [ + "simd-adler32", +] + [[package]] name = "figment" version = "0.10.19" @@ -1731,6 +1752,16 @@ dependencies = [ "wasip3", ] +[[package]] +name = "gif" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee8cfcc411d9adbbaba82fb72661cc1bcca13e8bba98b364e62b2dba8f960159" +dependencies = [ + "color_quant", + "weezl", +] + [[package]] name = "glob" version = "0.3.3" @@ -2135,6 +2166,34 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "image" +version = "0.25.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85ab80394333c02fe689eaf900ab500fbd0c2213da414687ebf995a65d5a6104" +dependencies = [ + "bytemuck", + "byteorder-lite", + "color_quant", + "gif", + "image-webp", + "moxcms", + "num-traits", + "png", + "zune-core", + "zune-jpeg", +] + +[[package]] +name = "image-webp" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3" +dependencies = [ + "byteorder-lite", + "quick-error", +] + [[package]] name = "indexmap" version = "1.9.3" @@ -2498,6 +2557,16 @@ dependencies = [ "syn", ] +[[package]] +name = "moxcms" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb85c154ba489f01b25c0d36ae69a87e4a1c73a72631fc6c0eb6dde34a73e44b" +dependencies = [ + "num-traits", + "pxfm", +] + [[package]] name = "native-tls" version = "0.2.18" @@ -2522,6 +2591,7 @@ dependencies = [ "anyhow", "async-trait", "axum", + "base64 0.22.1", "candle-core", "candle-nn", "candle-transformers", @@ -2533,6 +2603,7 @@ dependencies = [ "futures", "half", "hf-hub", + "image", "minijinja", "reqwest", "safetensors 0.7.0", @@ -2861,6 +2932,19 @@ version = "0.3.33" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" +[[package]] +name = "png" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60769b8b31b2a9f263dae2776c37b1b28ae246943cf719eb6946a1db05128a61" +dependencies = [ + "bitflags", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide", +] + [[package]] name = "polling" version = "3.11.0" @@ -2974,6 +3058,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40e24eee682d89fb193496edf918a7f407d30175b2e785fe057e4392dfd182e0" +[[package]] +name = "pxfm" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f" + [[package]] name = "quanta" version = "0.12.6" @@ -2989,6 +3079,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "quick-error" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" + [[package]] name = "quinn" version = "0.11.9" @@ -4627,6 +4723,12 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "weezl" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" + [[package]] name = "which" version = "7.0.3" @@ -5164,3 +5266,18 @@ name = "zmij" version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" + +[[package]] +name = "zune-core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9" + +[[package]] +name = "zune-jpeg" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27bc9d5b815bc103f142aa054f561d9187d191692ec7c2d1e2b4737f8dbd7296" +dependencies = [ + "zune-core", +] diff --git a/crates/neuron/Cargo.toml b/crates/neuron/Cargo.toml index 12e57f6..1aa576d 100644 --- a/crates/neuron/Cargo.toml +++ b/crates/neuron/Cargo.toml @@ -90,6 +90,13 @@ minijinja = { version = "2", features = ["builtins", "json", "serde"] } # tp `fused_load` module to read per-rank slices of fused QKV tensors # without materialising the full tensor on device. safetensors = "0.7" +# Vision capability for Qwen3.6 (Stage A of the vision plan in +# doc/vision-qwen3_6-spec.md). `image` decodes PNG/JPEG/etc from +# the bytes embedded in `data:image/...;base64,...` content parts; +# `base64` does the URI decode. Default-features off on `image` to +# avoid pulling in audio/video formats we don't need. +image = { version = "0.25", default-features = false, features = ["png", "jpeg", "webp", "bmp", "gif"] } +base64 = "0.22" [dev-dependencies] tokio = { workspace = true, features = ["test-util"] } diff --git a/crates/neuron/src/harness/arch/qwen3_5/mod.rs b/crates/neuron/src/harness/arch/qwen3_5/mod.rs index 64cbf76..1be0889 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/mod.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/mod.rs @@ -78,6 +78,7 @@ pub mod linear_attn; pub mod mlp; pub mod rmsnorm; pub mod rope; +pub mod vision; use decoder::Qwen3_5DecoderLayer; use rmsnorm::Qwen3_5RmsNorm; @@ -99,6 +100,20 @@ pub struct Config { pub model_type: String, /// The text-side hyperparameters. Everything we actually need. pub text_config: TextConfig, + /// Vision tower hyperparameters. Present on multimodal + /// checkpoints (e.g. Qwen/Qwen3.6-27B); absent on text-only + /// variants. When present, `Qwen3_5ForCausalLM::new` loads the + /// vision tower alongside the language model so vision-bearing + /// requests can splice image embeddings at `<|image_pad|>` token + /// positions. + #[serde(default)] + pub vision_config: Option, + /// Token id the chat template emits per image patch group. + /// Mirrors the LM tokenizer's `<|image_pad|>` id (248056 for + /// Qwen3.6). The runtime locates these in the prompt and splices + /// in `VisionTower::forward` output. `None` for text-only models. + #[serde(default)] + pub image_token_id: Option, } /// Inner config (the `text_config` block). Mirrors the Qwen3 layout @@ -309,6 +324,15 @@ impl Qwen3_5Model { pub struct Qwen3_5ForCausalLM { base: Qwen3_5Model, lm_head: Linear, + /// Vision tower (Stage A4). `None` for text-only checkpoints or + /// when the operator has opted out. When present, the harness's + /// `Job::EncodeImage` dispatch path runs `vision.forward(image)` + /// and the LM forward (Stage B) splices the result at + /// `image_token_id` positions in the input embedding stream. + vision: Option, + /// Mirrors `Config::image_token_id`. Cached here so the runtime + /// doesn't have to round-trip through the parsed config struct. + image_token_id: Option, } impl Qwen3_5ForCausalLM { @@ -324,7 +348,52 @@ impl Qwen3_5ForCausalLM { .with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?; Linear::new(weight, None) }; - Ok(Self { base, lm_head }) + // Stage A4: load the vision tower when the config carries a + // `vision_config` block and the safetensors actually carry + // `model.visual.*` weights. The `Option` on the + // config makes this a single-source-of-truth decision — + // text-only checkpoints just leave `vision_config` unset and + // get `None` here without any extra plumbing. + let vision = if let Some(vcfg) = config.vision_config.clone() { + tracing::info!( + depth = vcfg.depth, + hidden_size = vcfg.hidden_size, + "loading qwen3_5 vision tower" + ); + Some( + vision::VisionTower::load(vcfg, vb.pp("model.visual")) + .context("load qwen3_5 vision tower (model.visual.*)")?, + ) + } else { + None + }; + Ok(Self { + base, + lm_head, + vision, + image_token_id: config.image_token_id, + }) + } + + /// True when this checkpoint loaded a vision tower. Used by the + /// HTTP layer to advertise vision capability in `/v1/models` and + /// to reject image-bearing requests against text-only loads with + /// a clean 400. + pub fn has_vision(&self) -> bool { + self.vision.is_some() + } + + /// Vision tower handle, if loaded. The device-worker + /// `EncodeImage` job dispatches to `vision.forward(image)`. + pub fn vision(&self) -> Option<&vision::VisionTower> { + self.vision.as_ref() + } + + /// `<|image_pad|>` token id from `config.json`, when known. + /// The Stage B prompt-builder uses this to count expansion targets + /// and the LM forward uses it to locate splice positions. + pub fn image_token_id(&self) -> Option { + self.image_token_id } /// `input`: token-id tensor of shape `(B, L)`. Returns logits at diff --git a/crates/neuron/src/harness/arch/qwen3_5/vision.rs b/crates/neuron/src/harness/arch/qwen3_5/vision.rs new file mode 100644 index 0000000..f644e24 --- /dev/null +++ b/crates/neuron/src/harness/arch/qwen3_5/vision.rs @@ -0,0 +1,600 @@ +//! Qwen3.6 vision tower. +//! +//! 27 pre-norm ViT blocks with **LayerNorm** (with biases — not the +//! `(1+w)·x` RmsNorm the language model uses), fused QKV attention, +//! GELU-tanh MLP. Followed by a `merger` that LayerNorms each +//! 1152-dim vision token, spatially 2×2-merges them into 4608-dim +//! groups, and projects to the LM's 5120-dim hidden via +//! `linear_fc1 → GELU → linear_fc2`. +//! +//! Architecture spec sourced from beast's cached Qwen3.6-27B +//! safetensors header (Stage A0, see +//! `doc/vision-qwen3_6-spec.md`). All weight shapes confirmed +//! from the live `.safetensors` headers, not inferred. +//! +//! **Conv3d wrinkle.** The published `patch_embed.proj.weight` is 5D +//! `[1152, 3, 2, 16, 16]` — a 3D conv with kernel +//! `(t=2, h=16, w=16)`. Candle 0.10 has no Conv3d. For static images +//! we get away with a trick: when the temporal patch size is 2 and we +//! duplicate the still image along the temporal axis (`T = 2`, +//! frame_0 == frame_1), the Conv3d output equals a Conv2d run with +//! the *sum* of the two temporal weight slices: +//! +//! ```text +//! output = W_0 · frame_0 + W_1 · frame_1 + bias +//! = (W_0 + W_1) · frame + bias (static image) +//! ``` +//! +//! So at load we sum-collapse the temporal axis and use a 4D +//! `Conv2d` kernel. Video support would have to do the real Conv3d +//! (different frames mean the trick fails) — tracked alongside the +//! dynamic-resolution work in issue #14. +//! +//! Forward signature (Stage A — no LM splice yet): +//! +//! ```text +//! fn forward(&self, image: &Tensor) -> Result +//! ``` +//! +//! `image` is `(3, H, W)` f32, normalised by `preprocess::preprocess`. +//! Returns `(N_lm_tokens, out_hidden_size)` post-merger tokens ready +//! to splice into the LM's input embeddings at `<|image_pad|>` +//! positions. For Qwen3.6 at 448×448 → 28×28 patches → 14×14 = 196 +//! LM tokens of dim 5120. + +use anyhow::{Context, Result}; +use candle_core::{D, DType, Device, IndexOp, Module, Tensor}; +use candle_nn::var_builder::ShardedVarBuilder; +use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear}; +use serde::Deserialize; + +/// Qwen3.6 vision tower hyperparameters. Mirrors the `vision_config` +/// block of `config.json`. Only the fields we actually need are +/// captured; serde tolerates the rest. +#[derive(Debug, Clone, Deserialize)] +pub struct VisionConfig { + /// Number of ViT blocks (`depth: 27` for Qwen3.6). + pub depth: usize, + /// Vision-token dimension throughout the tower (1152 for Qwen3.6). + pub hidden_size: usize, + /// MLP intermediate dim (4304). + pub intermediate_size: usize, + /// Attention head count (16). `head_dim = hidden_size / num_heads`. + pub num_heads: usize, + /// Number of slots in the learned position embedding (2304). + /// Caps the maximum image patch count. + pub num_position_embeddings: usize, + /// Spatial patch edge in pixels (16). + pub patch_size: usize, + /// Temporal kernel depth in the patch embed (2 for Qwen3.6 — we + /// collapse this into a single Conv2d for static-image inference; + /// see the module-level Conv3d wrinkle). + pub temporal_patch_size: usize, + /// Patches grouped per LM token by the merger (2 → 2×2 = 4 + /// patches per LM token). + pub spatial_merge_size: usize, + /// Vision input channels (3, RGB). + pub in_channels: usize, + /// Merger output dim — matches the LM's `hidden_size` (5120 for + /// Qwen3.6). The merger projects from vision dim → LM dim. + pub out_hidden_size: usize, +} + +const LAYER_NORM_EPS: f64 = 1e-6; +/// Number of LM tokens emitted by the merger per vision-token group. +const LM_TOKENS_PER_MERGE_GROUP: usize = 1; + +/// One ViT block: pre-LN → attn → residual; pre-LN → MLP → residual. +struct VisionBlock { + norm1: LayerNorm, + qkv: Linear, + proj: Linear, + norm2: LayerNorm, + fc1: Linear, + fc2: Linear, + num_heads: usize, + head_dim: usize, +} + +impl VisionBlock { + fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result { + let h = cfg.hidden_size; + let head_dim = h / cfg.num_heads; + let norm1 = layer_norm(vb.pp("norm1"), h)?; + let qkv = linear(vb.pp("attn.qkv"), h, 3 * h)?; + let proj = linear(vb.pp("attn.proj"), h, h)?; + let norm2 = layer_norm(vb.pp("norm2"), h)?; + let fc1 = linear(vb.pp("mlp.linear_fc1"), h, cfg.intermediate_size)?; + let fc2 = linear(vb.pp("mlp.linear_fc2"), cfg.intermediate_size, h)?; + Ok(Self { + norm1, + qkv, + proj, + norm2, + fc1, + fc2, + num_heads: cfg.num_heads, + head_dim, + }) + } + + /// `x`: `(N, hidden_size)` un-batched. Returns same shape. + fn forward(&self, x: &Tensor) -> Result { + let attn_in = self.norm1.forward(x)?; + let attn_out = self.attention(&attn_in)?; + let x = x.add(&attn_out)?; + let mlp_in = self.norm2.forward(&x)?; + let mlp_out = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&mlp_in)?)?)?; + x.add(&mlp_out).map_err(Into::into) + } + + /// Multi-head self-attention over the patch sequence. No causal + /// mask — every patch attends to every other patch. + fn attention(&self, x: &Tensor) -> Result { + let (n, hidden) = x.dims2()?; + // qkv: (N, 3*hidden). Split into Q, K, V each (N, hidden). + let qkv = self.qkv.forward(x)?; + let qkv = qkv.reshape((n, 3, self.num_heads, self.head_dim))?; + // Transpose to (3, num_heads, N, head_dim) for per-head views. + let qkv = qkv.permute((1, 2, 0, 3))?.contiguous()?; + let q = qkv.i(0)?; + let k = qkv.i(1)?; + let v = qkv.i(2)?; + let scale = 1.0 / (self.head_dim as f64).sqrt(); + // (num_heads, N, head_dim) @ (num_heads, head_dim, N) -> (num_heads, N, N) + let scores = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?; + let scores = (scores * scale)?; + let probs = candle_nn::ops::softmax_last_dim(&scores)?; + // (num_heads, N, N) @ (num_heads, N, head_dim) -> (num_heads, N, head_dim) + let out = probs.matmul(&v)?; + // Merge heads back: (N, num_heads, head_dim) -> (N, hidden). + let out = out.permute((1, 0, 2))?.contiguous()?.reshape((n, hidden))?; + self.proj.forward(&out).map_err(Into::into) + } +} + +/// `merger`: LayerNorm per token → spatial 2×2 merge (concat 4 +/// adjacent tokens into one 4608-dim vector) → fc1 → GELU-tanh → +/// fc2. Output dim is the LM's hidden_size. +struct VisionMerger { + norm: LayerNorm, + fc1: Linear, + fc2: Linear, + merge_input_dim: usize, + spatial_merge_size: usize, +} + +impl VisionMerger { + fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result { + let h = cfg.hidden_size; + let merge = cfg.spatial_merge_size; + let merge_input_dim = h * merge * merge; + let norm = layer_norm(vb.pp("norm"), h)?; + let fc1 = linear(vb.pp("linear_fc1"), merge_input_dim, merge_input_dim)?; + let fc2 = linear(vb.pp("linear_fc2"), merge_input_dim, cfg.out_hidden_size)?; + Ok(Self { + norm, + fc1, + fc2, + merge_input_dim, + spatial_merge_size: merge, + }) + } + + /// `tokens`: `(grid_h, grid_w, hidden_size)`. The merger reshapes + /// each `merge×merge` block of adjacent patches into a single + /// concatenated vector, then projects. + /// + /// `grid_h` and `grid_w` must both be multiples of + /// `spatial_merge_size`. Returns + /// `(grid_h/merge × grid_w/merge, out_hidden_size)`. + fn forward(&self, tokens: &Tensor) -> Result { + let (gh, gw, h) = tokens.dims3()?; + let m = self.spatial_merge_size; + anyhow::ensure!( + gh.is_multiple_of(m) && gw.is_multiple_of(m), + "merger expects spatial dims divisible by merge_size={m}; got ({gh}, {gw})" + ); + let tokens = self.norm.forward(tokens)?; + // (gh, gw, h) -> (gh/m, m, gw/m, m, h) -> (gh/m, gw/m, m, m, h) + // -> flatten last three -> (gh/m, gw/m, m*m*h) -> (N_lm, merge_input_dim) + let out_h = gh / m; + let out_w = gw / m; + let merged = tokens + .reshape((out_h, m, out_w, m, h))? + .permute((0, 2, 1, 3, 4))? + .contiguous()? + .reshape((out_h * out_w, self.merge_input_dim))?; + let hidden = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&merged)?)?)?; + Ok(hidden) + } +} + +/// The vision tower itself. +pub struct VisionTower { + /// Sum-collapsed temporal kernel (Conv2d, see module doc). + patch_embed: Conv2d, + pos_embed: Embedding, + blocks: Vec, + merger: VisionMerger, + config: VisionConfig, + dtype: DType, + device: Device, +} + +impl VisionTower { + /// Load from a `ShardedVarBuilder` rooted at the safetensors + /// `model.visual.` prefix. Caller is responsible for the `pp` — + /// see `Qwen3_5ForCausalLM::new` (Stage A4). + pub fn load(cfg: VisionConfig, vb: ShardedVarBuilder) -> Result { + let dtype = vb.dtype(); + let device = vb.device().clone(); + + // patch_embed.proj is published as 5D Conv3d weight; we + // sum-collapse the temporal axis (size = temporal_patch_size) + // to get a 4D Conv2d kernel. This is exact for the static- + // image case where T = temporal_patch_size frames are + // identical (i.e. the input was duplicated along T). + let raw_weight = vb + .pp("patch_embed.proj") + .get( + ( + cfg.hidden_size, + cfg.in_channels, + cfg.temporal_patch_size, + cfg.patch_size, + cfg.patch_size, + ), + "weight", + ) + .context("load model.visual.patch_embed.proj.weight (5D Conv3d kernel)")?; + // Sum along the temporal axis (dim 2) — see module doc-comment. + let folded = raw_weight.sum(2)?; // -> (hidden, in_channels, patch, patch) + let proj_bias = vb + .pp("patch_embed.proj") + .get(cfg.hidden_size, "bias") + .context("load model.visual.patch_embed.proj.bias")?; + let conv_cfg = Conv2dConfig { + stride: cfg.patch_size, + ..Default::default() + }; + let patch_embed = Conv2d::new(folded, Some(proj_bias), conv_cfg); + + let pos_embed_weight = vb + .pp("pos_embed") + .get((cfg.num_position_embeddings, cfg.hidden_size), "weight") + .context("load model.visual.pos_embed.weight")?; + let pos_embed = Embedding::new(pos_embed_weight, cfg.hidden_size); + + let blocks_vb = vb.pp("blocks"); + let mut blocks = Vec::with_capacity(cfg.depth); + for i in 0..cfg.depth { + blocks.push( + VisionBlock::load(&cfg, &blocks_vb.pp(i)) + .with_context(|| format!("load vision block {i}"))?, + ); + } + let merger = VisionMerger::load(&cfg, &vb.pp("merger")).context("load vision merger")?; + + Ok(Self { + patch_embed, + pos_embed, + blocks, + merger, + config: cfg, + dtype, + device, + }) + } + + pub fn config(&self) -> &VisionConfig { + &self.config + } + + /// Number of LM tokens this tower emits for an `(H, W)` pixel + /// image after the merger. Equal to + /// `(H / patch_size / spatial_merge_size) * (W / patch_size / spatial_merge_size)`. + pub fn lm_tokens_for(&self, h: u32, w: u32) -> usize { + let m = self.config.spatial_merge_size; + let patch = self.config.patch_size; + let gh = (h as usize) / patch / m; + let gw = (w as usize) / patch / m; + gh * gw * LM_TOKENS_PER_MERGE_GROUP + } + + /// Encode one image. + /// + /// `image`: row-major `(3, H, W)` f32 tensor on `self.device`, + /// already normalised by `preprocess::preprocess`. Both `H` and + /// `W` must be multiples of `patch_size * spatial_merge_size`. + /// + /// Returns `(N_lm, out_hidden_size)` — LM-side image tokens + /// ready to splice into the language model's input embeddings. + pub fn forward(&self, image: &Tensor) -> Result { + let (c, h, w) = image.dims3()?; + anyhow::ensure!( + c == self.config.in_channels, + "image must have {} channels, got {c}", + self.config.in_channels + ); + let patch = self.config.patch_size; + anyhow::ensure!( + h.is_multiple_of(patch) && w.is_multiple_of(patch), + "image dims must be multiples of patch_size={patch}; got ({h}, {w})" + ); + let gh = h / patch; + let gw = w / patch; + let n_patches = gh * gw; + anyhow::ensure!( + n_patches <= self.config.num_position_embeddings, + "patch count {n_patches} exceeds pos_embed budget {}", + self.config.num_position_embeddings + ); + + // Add batch axis for conv: (1, 3, H, W) → (1, hidden, gh, gw) + // → (hidden, gh, gw) → permute to (gh, gw, hidden) → flatten to (N, hidden) + let x = image.unsqueeze(0)?.to_dtype(self.dtype)?; + let x = self.patch_embed.forward(&x)?; + let x = x.squeeze(0)?; + let x = x.permute((1, 2, 0))?.contiguous()?; + let x = x.reshape((n_patches, self.config.hidden_size))?; + + // Add learned positional embeddings (sequential indices for + // Stage A's fixed-resolution path; full 2D positional logic + // lands with variable resolution, issue #14). + let positions = Tensor::arange(0u32, n_patches as u32, &self.device)?; + let pos = self.pos_embed.forward(&positions)?; + let mut x = x.add(&pos)?; + + for (i, block) in self.blocks.iter().enumerate() { + x = block + .forward(&x) + .with_context(|| format!("vision block {i}"))?; + } + + // (n_patches, hidden) → (gh, gw, hidden) for the merger. + let x = x.reshape((gh, gw, self.config.hidden_size))?; + self.merger.forward(&x) + } +} + +/// Manually load a candle_nn LayerNorm from a ShardedVarBuilder. +/// candle_nn's `layer_norm` builder takes `crate::VarBuilder`, not +/// `ShardedVarBuilder`, so the existing arch modules in this crate +/// uniformly do the manual load + struct construction pattern (see +/// `full_attn::load_linear_no_bias`). We follow suit here. +fn layer_norm(vb: ShardedVarBuilder, size: usize) -> Result { + let weight = vb + .get(size, "weight") + .with_context(|| format!("load LayerNorm.weight at '{}'", vb.prefix()))?; + let bias = vb + .get(size, "bias") + .with_context(|| format!("load LayerNorm.bias at '{}'", vb.prefix()))?; + Ok(LayerNorm::new(weight, bias, LAYER_NORM_EPS)) +} + +/// Manually load a candle_nn Linear (with bias) from a +/// ShardedVarBuilder. Same rationale as `layer_norm` above. +fn linear(vb: ShardedVarBuilder, in_dim: usize, out_dim: usize) -> Result { + let weight = vb + .get((out_dim, in_dim), "weight") + .with_context(|| format!("load Linear.weight at '{}'", vb.prefix()))?; + let bias = vb + .get(out_dim, "bias") + .with_context(|| format!("load Linear.bias at '{}'", vb.prefix()))?; + Ok(Linear::new(weight, Some(bias))) +} + +/// PyTorch's `gelu_pytorch_tanh` approximation — what the Qwen3.6 +/// vision tower's `hidden_act` specifies. candle's `Tensor::gelu` +/// uses the exact erf-based GELU, so we compute the tanh +/// approximation explicitly: +/// +/// ```text +/// gelu_tanh(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) +/// ``` +fn gelu_tanh(x: &Tensor) -> Result { + // sqrt(2 / pi) = 0.7978845608028654 + const COEFF: f64 = 0.7978845608028654; + const KAPPA: f64 = 0.044715; + let x3 = x.powf(3.0)?; + let inner = (x + (x3 * KAPPA)?)?; + let inner = (inner * COEFF)?; + let t = inner.tanh()?; + let one_plus_t = (t + 1.0)?; + let out = (x * 0.5)?; + let out = out.broadcast_mul(&one_plus_t)?; + Ok(out) +} + +#[cfg(test)] +mod tests { + use super::*; + use candle_core::{DType, Device}; + + /// Build a tiny VisionConfig usable on CPU with random weights. + /// Match the Qwen3.6 shape relations (depth-N stack, hidden mod + /// num_heads, intermediate_size > hidden_size) but with small + /// dims so tests run in milliseconds. + fn tiny_config() -> VisionConfig { + VisionConfig { + depth: 2, + hidden_size: 32, + intermediate_size: 64, + num_heads: 4, + num_position_embeddings: 64, + patch_size: 4, + temporal_patch_size: 2, + spatial_merge_size: 2, + in_channels: 3, + out_hidden_size: 48, + } + } + + /// Hand-construct a VisionTower with random weights. This is the + /// same trick `linear_attn::tests::forward_smoke_with_tiny_dimensions` + /// uses — bypass the safetensors-backed `ShardedVarBuilder` path + /// (which can't be built from in-memory tensors) and assemble the + /// struct fields directly. The real `VisionTower::load` is + /// exercised by the cuda-integration smoke test in Stage A6. + fn tiny_tower(cfg: &VisionConfig) -> VisionTower { + let device = Device::Cpu; + let dtype = DType::F32; + let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &device).unwrap(); + let ones = |shape: &[usize]| Tensor::ones(shape, dtype, &device).unwrap(); + let randn = |shape: &[usize]| Tensor::randn(0_f32, 0.02, shape, &device).unwrap(); + + let patch_embed = Conv2d::new( + randn(&[ + cfg.hidden_size, + cfg.in_channels, + cfg.patch_size, + cfg.patch_size, + ]), + Some(zeros(&[cfg.hidden_size])), + Conv2dConfig { + stride: cfg.patch_size, + ..Default::default() + }, + ); + let pos_embed = Embedding::new( + randn(&[cfg.num_position_embeddings, cfg.hidden_size]), + cfg.hidden_size, + ); + + let mut blocks = Vec::with_capacity(cfg.depth); + for _ in 0..cfg.depth { + let head_dim = cfg.hidden_size / cfg.num_heads; + blocks.push(VisionBlock { + norm1: LayerNorm::new( + ones(&[cfg.hidden_size]), + zeros(&[cfg.hidden_size]), + LAYER_NORM_EPS, + ), + qkv: Linear::new( + randn(&[3 * cfg.hidden_size, cfg.hidden_size]), + Some(zeros(&[3 * cfg.hidden_size])), + ), + proj: Linear::new( + randn(&[cfg.hidden_size, cfg.hidden_size]), + Some(zeros(&[cfg.hidden_size])), + ), + norm2: LayerNorm::new( + ones(&[cfg.hidden_size]), + zeros(&[cfg.hidden_size]), + LAYER_NORM_EPS, + ), + fc1: Linear::new( + randn(&[cfg.intermediate_size, cfg.hidden_size]), + Some(zeros(&[cfg.intermediate_size])), + ), + fc2: Linear::new( + randn(&[cfg.hidden_size, cfg.intermediate_size]), + Some(zeros(&[cfg.hidden_size])), + ), + num_heads: cfg.num_heads, + head_dim, + }); + } + + let merge_input_dim = cfg.hidden_size * cfg.spatial_merge_size * cfg.spatial_merge_size; + let merger = VisionMerger { + norm: LayerNorm::new( + ones(&[cfg.hidden_size]), + zeros(&[cfg.hidden_size]), + LAYER_NORM_EPS, + ), + fc1: Linear::new( + randn(&[merge_input_dim, merge_input_dim]), + Some(zeros(&[merge_input_dim])), + ), + fc2: Linear::new( + randn(&[cfg.out_hidden_size, merge_input_dim]), + Some(zeros(&[cfg.out_hidden_size])), + ), + merge_input_dim, + spatial_merge_size: cfg.spatial_merge_size, + }; + + VisionTower { + patch_embed, + pos_embed, + blocks, + merger, + config: cfg.clone(), + dtype, + device, + } + } + + #[test] + fn forward_with_random_weights_produces_finite_output() { + let cfg = tiny_config(); + let tower = tiny_tower(&cfg); + + // 16×16 image at patch_size=4 → 4×4 patches → after 2×2 + // merge → 2×2 = 4 LM tokens of dim out_hidden_size. + let image = Tensor::randn(0_f32, 1.0, (3, 16, 16), &Device::Cpu).unwrap(); + let out = tower.forward(&image).expect("forward"); + let (n_lm, hidden) = out.dims2().unwrap(); + assert_eq!(n_lm, 4); + assert_eq!(hidden, cfg.out_hidden_size); + + // No NaN/Inf + let values: Vec = out.flatten_all().unwrap().to_vec1().unwrap(); + assert!( + values.iter().all(|v| v.is_finite()), + "forward must produce finite values" + ); + } + + #[test] + fn lm_token_count_matches_grid() { + let cfg = tiny_config(); + let tower = tiny_tower(&cfg); + // 16x16 image → 4x4 patches → 2x2 = 4 LM tokens + assert_eq!(tower.lm_tokens_for(16, 16), 4); + // 32x32 image → 8x8 patches → 4x4 = 16 LM tokens + assert_eq!(tower.lm_tokens_for(32, 32), 16); + } + + #[test] + fn rejects_image_with_dims_not_multiple_of_patch() { + let cfg = tiny_config(); + let tower = tiny_tower(&cfg); + let image = Tensor::randn(0_f32, 1.0, (3, 17, 17), &Device::Cpu).unwrap(); + let err = tower.forward(&image).unwrap_err(); + assert!(format!("{err:#}").contains("patch_size")); + } + + #[test] + fn rejects_image_with_wrong_channel_count() { + let cfg = tiny_config(); + let tower = tiny_tower(&cfg); + let image = Tensor::randn(0_f32, 1.0, (4, 16, 16), &Device::Cpu).unwrap(); + let err = tower.forward(&image).unwrap_err(); + assert!(format!("{err:#}").contains("channels")); + } + + #[test] + fn gelu_tanh_matches_known_values() { + // Reference values for gelu_pytorch_tanh from PyTorch: + // gelu_tanh(0.0) = 0.0 + // gelu_tanh(1.0) ≈ 0.8411920071 + // gelu_tanh(-1.0) ≈ -0.1588079929 + let x = Tensor::new(&[0.0_f32, 1.0, -1.0], &Device::Cpu).unwrap(); + let y = gelu_tanh(&x).unwrap(); + let v: Vec = y.to_vec1().unwrap(); + assert!((v[0]).abs() < 1e-6, "gelu_tanh(0) ≈ 0, got {}", v[0]); + assert!( + (v[1] - 0.841_192_f32).abs() < 1e-5, + "gelu_tanh(1) ≈ 0.84119, got {}", + v[1] + ); + assert!( + (v[2] - -0.158_808_f32).abs() < 1e-5, + "gelu_tanh(-1) ≈ -0.15881, got {}", + v[2] + ); + } +} diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index d1ea49e..c1bd0ef 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -332,6 +332,37 @@ impl ModelArch { } } } + + /// Encode a preprocessed image into LM-side token embeddings via + /// the loaded vision tower. Stage A5. + /// + /// `image`: device-resident `(C, H, W)` f32 tensor — caller has + /// already preprocessed via `harness::preprocess::preprocess` and + /// uploaded to the worker's device. Returns + /// `(N_lm_tokens, hidden_size)`. + /// + /// Errors when the loaded architecture has no vision tower + /// (text-only checkpoint, or architecture that doesn't support + /// vision at all). The HTTP layer maps this to a 400 with + /// `vision_unsupported` so clients see a clean rejection rather + /// than a confident text-only hallucination. + pub fn encode_image(&self, image: &Tensor) -> Result { + match self { + ModelArch::Qwen3_5Dense(m) => m + .vision() + .ok_or_else(|| { + anyhow::anyhow!( + "encode_image: this Qwen3.6 checkpoint was loaded without a vision \ + tower (config.json::vision_config absent or weights missing)" + ) + })? + .forward(image), + other => anyhow::bail!( + "encode_image: architecture {} has no vision tower", + std::any::type_name_of_val(other) + ), + } + } } /// Squeeze any leading singleton dims off the logits tensor so the diff --git a/crates/neuron/src/harness/device_worker/dispatch.rs b/crates/neuron/src/harness/device_worker/dispatch.rs index ebfaab9..0c0c500 100644 --- a/crates/neuron/src/harness/device_worker/dispatch.rs +++ b/crates/neuron/src/harness/device_worker/dispatch.rs @@ -158,6 +158,17 @@ pub(crate) fn run(device_index: u32, rx: Receiver, poisoned: Arc { + let result = encode_image(&mut state, handle, pixels, c, h, w); + let _ = reply.send(result); + } Job::NcclInit { cfg, comm_id_hex, @@ -740,6 +751,49 @@ fn forward_logits( Ok(values) } +/// Run the vision tower on a single preprocessed image. Stage A5. +/// +/// `pixels` is a row-major `(c, h, w)` f32 image that the async-side +/// `harness::preprocess` produced. We reconstruct the tensor on the +/// worker's device (the same device the model was loaded against), +/// call `arch.encode_image`, and copy the resulting +/// `(N_lm_tokens, hidden_size)` embedding back to CPU f32. +/// +/// Returns the flattened embedding as a `Vec` — the caller knows +/// the LM-side token count from `VisionTower::lm_tokens_for(h, w)` +/// and reshapes accordingly. Stage B introduces a device-resident +/// embedding-slab variant that avoids this round-trip when the next +/// forward call needs the result. +fn encode_image( + state: &mut DeviceWorkerState, + handle: ArchHandle, + pixels: Vec, + c: usize, + h: usize, + w: usize, +) -> anyhow::Result> { + use candle_core::{DType, Tensor}; + + anyhow::ensure!( + pixels.len() == c * h * w, + "EncodeImage: pixels length {} does not match shape ({c}, {h}, {w})", + pixels.len() + ); + let image = Tensor::from_vec(pixels, (c, h, w), &state.device)?; + + let arch = state + .models + .get(&handle) + .ok_or_else(|| anyhow::anyhow!("EncodeImage: no model for handle {}", handle.0))?; + + let embed = arch.encode_image(&image)?; + let values = embed + .to_dtype(DType::F32)? + .flatten_all()? + .to_vec1::()?; + Ok(values) +} + /// Reply to a job with the poisoned-worker error. Used when the worker /// has flipped into drain-only mode after a CUDA driver error. /// @@ -773,6 +827,9 @@ fn drain_poisoned(job: Job, device_index: u32) { Job::ForwardLogits { reply, .. } => { let _ = reply.send(Err(err())); } + Job::EncodeImage { reply, .. } => { + let _ = reply.send(Err(err())); + } Job::NcclInit { reply, .. } => { let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error { kind: "device_worker_poisoned".into(), diff --git a/crates/neuron/src/harness/device_worker/jobs.rs b/crates/neuron/src/harness/device_worker/jobs.rs index 2c97e1b..820c1c9 100644 --- a/crates/neuron/src/harness/device_worker/jobs.rs +++ b/crates/neuron/src/harness/device_worker/jobs.rs @@ -94,6 +94,31 @@ pub enum Job { offset: usize, reply: oneshot::Sender>>, }, + /// Encode one image through the model's vision tower. Stage A5 of + /// the vision plan (`doc/vision-qwen3_6-spec.md`). + /// + /// `pixels` is the CPU-side preprocessed image tensor in row-major + /// `(C, H, W)` f32 layout — what `harness::preprocess::preprocess` + /// produces. `c`, `h`, `w` carry the shape since `Vec` itself + /// is rank-1. The handler reconstructs the tensor on the worker's + /// device, runs `VisionTower::forward`, copies the resulting + /// `(N_lm_tokens, hidden_size)` embedding back to CPU as a flat + /// `Vec` (the caller knows the expected shape from + /// `VisionTower::lm_tokens_for(h, w) * hidden_size`). + /// + /// Mirrors the `ForwardLogits` "tensors don't escape" invariant — + /// device-side image embeddings are dropped at handler return. + /// Stage B will introduce a follow-up variant that keeps the + /// embeddings device-resident and references them from the next + /// `ForwardLogits` call, avoiding the round-trip copy. + EncodeImage { + handle: ArchHandle, + pixels: Vec, + c: usize, + h: usize, + w: usize, + reply: oneshot::Sender>>, + }, /// Initialize the leader's NCCL communicator. The worker's /// `NcclState` mints the `Comm` here so its underlying /// `ncclComm_t` and `CudaContext` live on the same thread as diff --git a/crates/neuron/src/harness/device_worker/mod.rs b/crates/neuron/src/harness/device_worker/mod.rs index a277976..bc99e62 100644 --- a/crates/neuron/src/harness/device_worker/mod.rs +++ b/crates/neuron/src/harness/device_worker/mod.rs @@ -313,6 +313,49 @@ impl DeviceWorkerHandle { } } + /// Encode a preprocessed image through the model's vision tower + /// and return the resulting LM-side image embeddings as a + /// flattened CPU `Vec`. Stage A5. + /// + /// `pixels` is the row-major `(c, h, w)` f32 image — + /// `harness::preprocess::preprocess` produces this exact shape. + /// The caller knows the expected output length from + /// `VisionTower::lm_tokens_for(h, w) * hidden_size` and reshapes + /// accordingly. + pub async fn encode_image( + &self, + handle: ArchHandle, + pixels: Vec, + c: usize, + h: usize, + w: usize, + ) -> Result, WorkerError> { + if self.poisoned.load(Ordering::Acquire) { + return Err(WorkerError::Poisoned { + device_index: self.device_index, + }); + } + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::EncodeImage { + handle, + pixels, + c, + h, + w, + reply: reply_tx, + }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + match reply_rx.await { + Ok(result) => result.map_err(WorkerError::from), + Err(_) => Err(WorkerError::Gone { + device_index: self.device_index, + }), + } + } + /// Initialise the leader's NCCL communicator. The reply uses /// `WorkerResponse` (same shape subprocess workers use over stdio /// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader + @@ -569,6 +612,37 @@ mod tests { handle.shutdown().expect("shutdown ok"); } + /// Stage A5: confirm the EncodeImage job round-trips through the + /// worker channel. We don't have a real loaded model in the slab + /// here, so the dispatch handler returns the + /// "no model for handle" error — which is exactly what we want to + /// see, since it proves the message routed through the channel + /// and reached the handler. Real-weights validation lives in the + /// Stage A7 / Stage B post-deploy smoke on beast. + #[tokio::test] + async fn encode_image_routes_to_dispatch_and_errors_on_unknown_handle() { + use crate::harness::device_worker::jobs::ArchHandle; + + let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok"); + let fake_arch = ArchHandle(99_999); + // (3, 4, 4) fake image — minimal payload, gets reconstructed + // on the worker before the handler errors out on the unknown + // ArchHandle lookup. + let pixels = vec![0.0_f32; 3 * 4 * 4]; + let result = handle.encode_image(fake_arch, pixels, 3, 4, 4).await; + match result { + Err(WorkerError::Job(e)) => { + let msg = format!("{e:#}"); + assert!( + msg.contains("EncodeImage: no model for handle"), + "expected unknown-handle error, got: {msg}" + ); + } + other => panic!("expected Job(Err), got {other:?}"), + } + handle.shutdown().expect("shutdown ok"); + } + #[tokio::test] async fn shutdown_drains_pending_jobs() { let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok"); diff --git a/crates/neuron/src/harness/mod.rs b/crates/neuron/src/harness/mod.rs index abedc5e..4b3b501 100644 --- a/crates/neuron/src/harness/mod.rs +++ b/crates/neuron/src/harness/mod.rs @@ -5,6 +5,7 @@ pub mod candle; pub mod chat_template; pub mod device_worker; pub mod preflight; +pub mod preprocess; pub mod tp; use anyhow::Result; diff --git a/crates/neuron/src/harness/preprocess.rs b/crates/neuron/src/harness/preprocess.rs new file mode 100644 index 0000000..0356f4d --- /dev/null +++ b/crates/neuron/src/harness/preprocess.rs @@ -0,0 +1,255 @@ +//! Image preprocessing for vision-capable models. +//! +//! Decodes `data:image/...;base64,...` URIs from OpenAI-style +//! `image_url` content parts into the patch tensors a candle vision +//! tower expects. Stage A ships **fixed resolution** — every image +//! is resized to the same target dimensions (default 448×448 for +//! Qwen3.6, configurable per-call) so the patch count is constant +//! per image. Variable resolution per [Qwen2VL convention] is tracked +//! as issue #14. +//! +//! Spec reference: `doc/vision-qwen3_6-spec.md` — preprocessor +//! section. +//! +//! Normalisation: pixel value `p ∈ [0, 255]` becomes +//! `(p/255 - mean) / std`. Qwen3.6's preprocessor_config.json +//! specifies `image_mean = image_std = [0.5, 0.5, 0.5]`, which +//! simplifies to `2p/255 - 1` mapping `[0,255]` → `[-1, 1]`. We +//! still parameterise mean/std so the same code generalises to other +//! VL families (Qwen2-VL uses imagenet stats, for instance). +//! +//! Pipeline (per image): +//! 1. data: URI → base64 decode → bytes +//! 2. bytes → image::DynamicImage (PNG/JPEG/WebP/etc) +//! 3. resize_exact to target H×W (pixel space) +//! 4. RGB→f32, normalise per mean/std +//! 5. layout to (C, H, W) tensor +//! +//! Patchification (cutting the HxW tensor into `patch_size` blocks) +//! happens inside the vision tower's `patch_embed` conv, so this +//! module stops at "preprocessed RGB f32 tensor." + +use anyhow::{Context, Result, anyhow}; +use base64::Engine; +use image::DynamicImage; +use image::imageops::FilterType; + +/// Preprocessing target. Captures the resize dimensions and the +/// channel-wise normalisation constants from the model's +/// `preprocessor_config.json`. Stage A ships a single `qwen3_6()` +/// constructor for fixed-resolution Qwen3.6 preprocessing; other +/// models can ship their own profile when added. +#[derive(Debug, Clone)] +pub struct PreprocessProfile { + pub target_height: u32, + pub target_width: u32, + pub image_mean: [f32; 3], + pub image_std: [f32; 3], +} + +impl PreprocessProfile { + /// Stage A profile for Qwen3.6. Resize to 448×448, normalise to + /// `[-1, 1]` via mean=std=0.5. Fits within the model's + /// `num_position_embeddings=2304` budget at 28×28 = 784 patches + /// before merging. + pub fn qwen3_6() -> Self { + Self { + target_height: 448, + target_width: 448, + image_mean: [0.5, 0.5, 0.5], + image_std: [0.5, 0.5, 0.5], + } + } + + /// Per-channel CHW tensor length: 3 * H * W. + pub fn pixels_chw(&self) -> usize { + 3 * (self.target_height as usize) * (self.target_width as usize) + } +} + +/// Decode a `data:image/...;base64,...` URI into an in-memory image. +/// +/// Accepts the OpenAI Chat Completions `image_url` shape — a string +/// URL with `data:` scheme and base64 payload. The MIME type is read +/// from the URI for diagnostics but `image::load_from_memory` sniffs +/// the format from the bytes themselves, so the MIME is advisory. +/// +/// Bare `http(s)://` URLs are explicitly rejected here — fetching +/// them from a vision-model server is a fingerprintable behaviour +/// (server-side request forgery, infinite recursion if the URL +/// points at the gateway itself, etc.). Clients that want remote +/// images can fetch them and pass base64 themselves. +pub fn decode_data_uri(uri: &str) -> Result { + let after_scheme = uri + .strip_prefix("data:") + .ok_or_else(|| anyhow!("image_url must use data: scheme; got {uri:.40}…"))?; + let (meta, payload) = after_scheme + .split_once(',') + .ok_or_else(|| anyhow!("malformed data URI: missing ',' separator"))?; + if !meta.contains(";base64") { + anyhow::bail!( + "data URI must use base64 encoding (got '{meta}'); raw URL-encoded payloads not supported" + ); + } + let bytes = base64::engine::general_purpose::STANDARD + .decode(payload.trim()) + .context("base64-decode image data URI payload")?; + image::load_from_memory(&bytes).context("decode image bytes (PNG/JPEG/WebP/etc)") +} + +/// Resize and normalise an image into a `(3, H, W)` row-major +/// `Vec` ready to hand to the vision tower's `patch_embed` +/// conv. +/// +/// Uses bilinear resampling — Qwen2-VL's reference uses bicubic, but +/// bilinear is what the candle ecosystem standardises on and is +/// faster on CPU. Quality difference is marginal for downstream +/// vision-encoder consumption. The numerical-validation issue (#15) +/// will quantify any discrepancy. +pub fn preprocess(img: &DynamicImage, profile: &PreprocessProfile) -> Vec { + let rgb = img + .resize_exact( + profile.target_width, + profile.target_height, + FilterType::Triangle, + ) + .to_rgb8(); + let h = profile.target_height as usize; + let w = profile.target_width as usize; + let mut out = vec![0.0_f32; 3 * h * w]; + // Row-major (C, H, W). Candle's Conv2d expects NCHW, so this is + // the natural layout — the caller stacks `n` of these along the + // batch axis as needed. + for c in 0..3 { + let mean = profile.image_mean[c]; + let std = profile.image_std[c]; + for y in 0..h { + for x in 0..w { + let pixel = rgb.get_pixel(x as u32, y as u32); + let raw = pixel[c] as f32 / 255.0; + out[c * h * w + y * w + x] = (raw - mean) / std; + } + } + } + out +} + +/// Combined helper: decode + preprocess in one call. Most call +/// sites just want the final tensor; the two-step path exists for +/// callers (tests, future video preprocessing) that need the +/// intermediate `DynamicImage`. +pub fn preprocess_data_uri(uri: &str, profile: &PreprocessProfile) -> Result> { + let img = decode_data_uri(uri)?; + Ok(preprocess(&img, profile)) +} + +#[cfg(test)] +mod tests { + use super::*; + use image::{ImageBuffer, Rgb}; + + /// A 1×1 red PNG, hand-built. Matches the well-known smallest + /// valid PNG we use in tests/curl examples elsewhere. + const ONE_BY_ONE_RED_PNG_B64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="; + + fn red_png_uri() -> String { + format!("data:image/png;base64,{ONE_BY_ONE_RED_PNG_B64}") + } + + #[test] + fn decodes_well_formed_png_data_uri() { + let img = decode_data_uri(&red_png_uri()).expect("decode 1x1 png"); + assert_eq!(img.width(), 1); + assert_eq!(img.height(), 1); + } + + #[test] + fn rejects_non_data_scheme() { + let err = decode_data_uri("https://example.com/cat.jpg") + .expect_err("http(s) URLs must be rejected"); + assert!(format!("{err:#}").contains("data:")); + } + + #[test] + fn rejects_malformed_uri_without_comma() { + let err = decode_data_uri("data:image/png;base64").unwrap_err(); + assert!(format!("{err:#}").contains("','")); + } + + #[test] + fn rejects_non_base64_payload() { + let err = decode_data_uri("data:image/png,raw-bytes-here").unwrap_err(); + assert!(format!("{err:#}").contains("base64")); + } + + #[test] + fn rejects_bad_base64_payload() { + let err = decode_data_uri("data:image/png;base64,not!valid!base64!").unwrap_err(); + assert!(format!("{err:#}").contains("base64")); + } + + #[test] + fn rejects_garbage_image_bytes() { + // Valid base64 ("Hello World!"), invalid image bytes. + let err = decode_data_uri("data:image/png;base64,SGVsbG8gV29ybGQh").unwrap_err(); + assert!( + format!("{err:#}").contains("decode image"), + "should fail at image-decode step" + ); + } + + #[test] + fn preprocess_red_image_produces_correct_shape_and_values() { + let profile = PreprocessProfile::qwen3_6(); + // Build a tiny pure-red image directly, skipping data: URI + // decoding so this test isolates the resize+normalise path. + let img: ImageBuffer, Vec> = ImageBuffer::from_pixel(2, 2, Rgb([255, 0, 0])); + let dyn_img = DynamicImage::ImageRgb8(img); + let out = preprocess(&dyn_img, &profile); + + assert_eq!(out.len(), profile.pixels_chw()); + // After mean=0.5, std=0.5: red channel (255/255=1.0) → (1.0 - 0.5)/0.5 = 1.0 + // green/blue (0.0) → (0.0 - 0.5)/0.5 = -1.0 + let h = profile.target_height as usize; + let w = profile.target_width as usize; + assert!( + (out[0] - 1.0).abs() < 1e-5, + "R[0] should be 1.0, got {}", + out[0] + ); + assert!((out[h * w] - (-1.0)).abs() < 1e-5, "G[0] should be -1.0"); + assert!( + (out[2 * h * w] - (-1.0)).abs() < 1e-5, + "B[0] should be -1.0" + ); + // All values are finite + assert!(out.iter().all(|v| v.is_finite()), "no NaN/Inf in output"); + } + + #[test] + fn preprocess_data_uri_end_to_end() { + let profile = PreprocessProfile::qwen3_6(); + let out = preprocess_data_uri(&red_png_uri(), &profile).expect("e2e preprocess"); + assert_eq!(out.len(), profile.pixels_chw()); + assert!(out.iter().all(|v| v.is_finite())); + } + + #[test] + fn preprocess_grayscale_image_promotes_to_rgb() { + let profile = PreprocessProfile::qwen3_6(); + // 1x1 grayscale = 200 → after conversion to RGB, all three + // channels equal 200, normalised → (200/255 - 0.5)/0.5 ≈ 0.569 + let gray = DynamicImage::ImageLuma8(ImageBuffer::from_pixel(1, 1, image::Luma([200]))); + let out = preprocess(&gray, &profile); + let expected = ((200.0 / 255.0) - 0.5) / 0.5; + let h = profile.target_height as usize; + let w = profile.target_width as usize; + for c in 0..3 { + let v = out[c * h * w]; + assert!( + (v - expected).abs() < 1e-3, + "channel {c}: expected {expected}, got {v}" + ); + } + } +} diff --git a/doc/vision-qwen3_6-spec.md b/doc/vision-qwen3_6-spec.md new file mode 100644 index 0000000..0557fee --- /dev/null +++ b/doc/vision-qwen3_6-spec.md @@ -0,0 +1,176 @@ +# Qwen3.6-27B vision specification (Stage A0) + +Sourced from beast's local cache on 2026-06-01: +`/archive3/llm-cache/models--Qwen--Qwen3.6-27B/snapshots/6a9e13bd6fc8f0983b9b99948120bc37f49c13e9/`. + +Single source of truth for Stages A–D of the vision plan in +`~/.claude/plans/foamy-twirling-catmull.md`. Umbrella issue: +[#3](https://git.lair.cafe/helexa/cortex/issues/3). + +--- + +## Top-level shape + +The model is a unified text+vision architecture (`Qwen3_5ForConditionalGeneration`, +`model_type: qwen3_5`) with three weight sections under a single safetensors +index. Counts from `model.safetensors.index.json`: + +| Prefix | Tensors | Role | +|---|---|---| +| `model.language_model.*` | 850 | LM (currently loaded) | +| `model.visual.*` | 333 | Vision tower (currently filtered out at `arch/qwen3_5/mod.rs:228-230`) | +| `mtp.*` | 15 | Multi-token-prediction heads (filtered, out of scope) | +| `lm_head.weight` | 1 | LM head | + +Vision tensors live in shards `model-00007-of-00015.safetensors` and +`model-00008-of-00015.safetensors` (2 of the 15 safetensors). Loading just +these two for vision-tower-only smoke tests is feasible. + +## Vision tower architecture (`model.visual.*`) + +From `config.json::vision_config`: + +``` +depth: 27 (transformer blocks) +hidden_size: 1152 (vision token dim) +num_heads: 16 (per-block self-attention) +intermediate_size: 4304 (MLP hidden) +patch_size: 16 (16×16 spatial patches) +temporal_patch_size: 2 (video frame pairing; irrelevant for stills) +spatial_merge_size: 2 (2×2 spatial merge in the merger → 4 patches/LM token) +num_position_embeddings: 2304 (learned pos embed slots — max patch sequence length) +in_channels: 3 (RGB) +hidden_act: gelu_pytorch_tanh (GELU with tanh approximation, not exact GELU) +out_hidden_size: 5120 (= LM hidden_size, merger output dim) +deepstack_visual_indexes: [] (no deep-stack visual indexes) +``` + +### Module inventory (per-block and global) + +Global: +- `model.visual.patch_embed.proj.{weight, bias}` — Conv2d (3 → 1152, kernel 16×16, stride 16). Turns image patches into tokens. +- `model.visual.pos_embed.weight` — Learned position embedding, shape `(2304, 1152)`. +- `model.visual.merger.{norm, linear_fc1, linear_fc2}` — The projector that merges 2×2 patches and projects to LM hidden_size (1152 → 5120). All weights have biases. + +Per block (×27, named `model.visual.blocks.{0..26}`): +- `norm1.{weight, bias}` — **LayerNorm** before attention (with bias — not RmsNorm). +- `attn.qkv.{weight, bias}` — Fused QKV linear (1152 → 3·1152 = 3456). +- `attn.proj.{weight, bias}` — Attention output projection (1152 → 1152). +- `norm2.{weight, bias}` — LayerNorm before MLP. +- `mlp.linear_fc1.{weight, bias}` — MLP up-projection (1152 → 4304). +- `mlp.linear_fc2.{weight, bias}` — MLP down-projection (4304 → 1152). + +Pattern matches a standard ViT block with **pre-norm** layout (norm → attn → residual, norm → MLP → residual). Activation between fc1/fc2 is GELU-tanh-approx per `hidden_act`. No attention masking inside the vision tower (all patches attend to each other). + +### Forward signature (target) + +``` +VisionTower::forward( + patches: Tensor [N, in_channels, patch_size, patch_size], # CPU-preprocessed RGB float patches + grid_thw: Option<(usize, usize, usize)>, # (t, h, w) patch grid for position lookup +) -> Tensor [N / (spatial_merge_size²), out_hidden_size] # = (N/4, 5120) for static images +``` + +Note: the merger consumes 4 spatially-adjacent patches and emits 1 LM token. So an image producing 64×64 = 4096 patches yields 1024 LM-side image tokens. + +## Image preprocessor (`preprocessor_config.json`) + +```json +{ + "size": { "longest_edge": 16777216, "shortest_edge": 65536 }, + "patch_size": 16, + "temporal_patch_size": 2, + "merge_size": 2, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5], + "processor_class": "Qwen3VLProcessor", + "image_processor_type": "Qwen2VLImageProcessorFast" +} +``` + +Reading: + +- `image_mean = image_std = 0.5` → normalisation is simply `(x/255 - 0.5) / 0.5 = 2*x/255 - 1`, mapping `[0,255]` → `[-1, 1]`. No imagenet-style mean/std. +- `size.{shortest_edge, longest_edge}` are **pixel counts**, not edge lengths. The `Qwen2VLImageProcessorFast` recipe picks a resolution within `[65,536 = 256², 16,777,216 = 4096²]` total pixels, snapping `h` and `w` to multiples of `patch_size × spatial_merge_size = 32` pixels. +- Stage A ships **fixed resolution**: pick a target pixel count (e.g. 448×448 = 200,704 px → 28×28 patches → 14×14 LM tokens after merger). Variable resolution deferred to issue [#14](https://git.lair.cafe/helexa/cortex/issues/14). + +## Chat template (`chat_template.jinja`) + +Image insertion (lines 8–18 of the template): + +```jinja +{%- if 'image' in item or 'image_url' in item or item.type == 'image' %} + ... + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} +``` + +Per image, the template emits **one `<|image_pad|>` token** flanked by `<|vision_start|>` and `<|vision_end|>` sentinels. The runtime must: + +1. Render the template (preserving the single `<|image_pad|>` per image). +2. For each image, replace its single `<|image_pad|>` with N copies, where N is the number of LM tokens that image produces after the vision tower + merger (= `patches / spatial_merge_size²`). +3. Tokenize the expanded string → `input_ids`. +4. At forward time, locate positions where `input_ids == image_token_id` (248056) and splice in the vision tower's merger output. + +Token IDs (top of `config.json`): +- `vision_start_token_id`: 248053 +- `vision_end_token_id`: 248054 +- `image_token_id`: 248056 +- `video_token_id`: 248057 (out of scope) +- `bos_token_id`: 248044 +- `eos_token_id`: 248044, 248046 (per `generation_config.json`) + +System messages cannot contain images (template raises). Other template-side details: +- `add_vision_id` (jinja arg, default false): emits `'Picture N: '` prefixes when true. +- `preserve_thinking` (jinja arg, default false): keeps `` blocks from prior assistant turns in the rendered prompt. +- `enable_thinking` (jinja arg, default true): emits `\n` (or skips it) at the end of the generation prompt. + +The existing chat-template renderer in `crates/neuron/src/harness/chat_template.rs` already passes `MessageContent::Parts` to the Jinja context as a `Value::Array`; the template's `is iterable` branch (line 6 of the template) handles them. **The path is structurally in place** — Stage B just needs to do the `<|image_pad|>` expansion + token-position-aware splice. + +## LM-side considerations + +The LM's RoPE config uses **multi-axis RoPE (MRoPE)**: + +``` +rope_parameters: { + mrope_interleaved: true, + mrope_section: [11, 11, 10], # text + height + width components + partial_rotary_factor: 0.25, + rope_theta: 10000000, + rope_type: "default" +} +``` + +MRoPE encodes spatial position alongside text position so the LM attention layers can reason about image-token spatial structure. The LM's existing forward path *may or may not* already implement this — the qwen3_5 module's doc-comment notes "numerical correctness vs the reference Python is not yet validated." Verifying MRoPE behaviour in the language model is out of Stage A scope (vision tower only) but will be required in Stage B (LM splice) and is tracked under the numerical-validation issue [#15](https://git.lair.cafe/helexa/cortex/issues/15). + +`max_position_embeddings = 262144` (256 K context), so context-length limits are not a constraint for vision. + +## Iteration target decision + +The vision tower has its own self-contained weight tree and is small (~333 tensors in 2 shards, hidden_size 1152 vs LM's 5120). For Stage A specifically (vision-tower-only smoke), we **don't need a smaller iteration model** — we can: + +- Build the Rust `VisionTower` struct against the spec above. +- Run unit tests with random tensor weights matching the exact shapes → assert forward produces correct output shape with finite values. +- Optionally: a CUDA-integration test that loads just the 2 vision shards from beast's cache (or on a smaller GPU like quadbrat's Ampere) and runs encode on a real image. Doesn't require loading the 27B LM at all. + +This sidesteps the "develop against a smaller VL model" question for Stage A. Stage B (LM splice → end-to-end chat with vision) is where iteration speed becomes pressing; revisit there. The default scope pick 2a (smaller iteration model) is therefore deferred to Stage B planning — issue [#13](https://git.lair.cafe/helexa/cortex/issues/13) covers deployment validation regardless. + +## Concrete Stage A1+ inputs + +- Add deps to `crates/neuron/Cargo.toml`: + - `image = "0.25"` + - `base64 = "0.22"` +- Stage A2 preprocessor target resolution (fixed): **448×448 → 28×28 patches → 14×14 = 196 image tokens per image**. This balances minimum-patch-count for cheap tests against the model's expected input range. +- Stage A3 module structure: one `VisionTower` struct holding `patch_embed: Conv2d`, `pos_embed: Embedding`, `blocks: Vec`, `merger: Merger`. `VisionBlock` carries `norm1`, `norm2`, `attn`, `mlp`. Hand-roll using candle primitives. +- Stage A4 weight loading: extend `Qwen3_5ForCausalLM::new()` to construct `Some(VisionTower::new(vb.pp("model.visual"), config))` when `vision_config` is present in the parsed config. +- Stage A5 worker job: `Job::EncodeImage { handle, patches: Vec, patch_shape: (usize, usize, usize, usize, usize), reply: oneshot>> }`. Patch shape = `(N, C, T, H, W)` where T=1 for static images. + +## What this doc does NOT settle (deferred to issues) + +- Numerical correctness of `VisionTower` output vs Python transformers + → issue [#15](https://git.lair.cafe/helexa/cortex/issues/15). +- Variable image resolution + → issue [#14](https://git.lair.cafe/helexa/cortex/issues/14). +- TP-vision (multi-rank vision tower) + → issue [#12](https://git.lair.cafe/helexa/cortex/issues/12). +- 27B production deployment + → issue [#13](https://git.lair.cafe/helexa/cortex/issues/13).