feat(neuron): Stage A — vision tower load + preprocessor for Qwen3.6
All checks were successful
CI / CUDA type-check (push) Successful in 32s
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / Format (push) Successful in 28s
CI / Clippy (push) Successful in 2m35s
build-prerelease / Build cortex binary (push) Successful in 5m13s
build-prerelease / Build neuron-blackwell (push) Successful in 6m23s
build-prerelease / Build neuron-ampere (push) Successful in 7m56s
CI / Test (push) Successful in 7m11s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m19s
build-prerelease / Build neuron-ada (push) Successful in 5m30s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m45s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 4m25s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s

Stage A of the vision implementation plan
(doc/vision-qwen3_6-spec.md). Builds the vision tower scaffolding
that today's silent-drop failure mode (issue #3) needs — the
Qwen3.6 ViT loads from `model.visual.*`, runs forward producing
post-merger LM-side image embeddings, and routes through the
device worker via a new `Job::EncodeImage`. No LM splice yet —
that's Stage B.

Refs #3 (umbrella). Deferred sub-stages tracked as #12 (TP-vision),
#13 (27B production deploy), #14 (dynamic resolution), #15
(numerical validation).

What landed:

- **A0 — investigation**: pulled config.json, preprocessor_config.json,
  chat_template.jinja, and safetensors index from beast's local
  Qwen3.6-27B cache. Documented in doc/vision-qwen3_6-spec.md with
  exact tensor shapes for every `model.visual.*` weight. Confirms
  27-block ViT with `hidden_size=1152`, `patch_size=16`,
  `spatial_merge_size=2`, `out_hidden_size=5120`. Vision tower lives
  in 2 of the 15 safetensors shards.

- **A1 — deps + scaffolding**: added `image = "0.25"` (default-
  features off, PNG/JPEG/WebP/BMP/GIF) and `base64 = "0.22"` to
  crates/neuron/Cargo.toml. Created `harness::preprocess` and
  `harness::arch::qwen3_5::vision` modules.

- **A2 — preprocess.rs**: `decode_data_uri` strips
  `data:image/...;base64,...` → image bytes → `image::DynamicImage`
  (rejecting `http(s)://` URLs to avoid SSRF/recursion); `preprocess`
  resizes to a fixed `PreprocessProfile::qwen3_6()` (448×448),
  normalises to `[-1, 1]` per the model's mean/std=0.5, emits
  row-major `(3, H, W)` f32. 9 unit tests covering data URI parse,
  decode failure paths, grayscale-to-RGB promotion, and the
  exact-value normalisation contract.

- **A3 — vision.rs**: `VisionTower` struct with `patch_embed: Conv2d`,
  learned `pos_embed: Embedding`, 27 `VisionBlock`s (pre-LN +
  multi-head self-attention with fused QKV + GELU-tanh MLP +
  residuals), and `VisionMerger` (LayerNorm → 2×2 spatial concat →
  linear_fc1 → GELU-tanh → linear_fc2 to LM hidden_size).
  Includes the Conv3d→Conv2d fold trick documented at the top of
  the file — the published patch_embed.proj.weight is 5D
  `(1152, 3, 2, 16, 16)` but candle 0.10 has no Conv3d; for static
  images we sum-collapse the temporal axis. Video would need real
  Conv3d. 5 unit tests including the exact `gelu_pytorch_tanh`
  reference values from PyTorch.

- **A4 — wire vision into Qwen3_5ForCausalLM**: extended `Config`
  with optional `vision_config: Option<VisionConfig>` and
  `image_token_id`; `Qwen3_5ForCausalLM::new` now loads the vision
  tower when present, exposes `has_vision()` and `vision()` so the
  HTTP layer can advertise capability and so the encode path can
  reach it.

- **A5 — device worker `Job::EncodeImage`**: new job variant carrying
  CPU-side `(C, H, W)` pixels. Dispatch handler reconstructs the
  tensor on the worker's device, calls `arch.encode_image(image)`,
  copies the result back to CPU as flat `Vec<f32>`. Keeps the
  "tensors don't escape the worker" invariant. Poisoned-worker
  drain path handles the new variant.

- **A6 — dispatch round-trip test**: `encode_image_routes_to_dispatch_
  and_errors_on_unknown_handle` proves the channel/dispatch wiring
  works end-to-end via the CPU device worker (errors on unknown
  ArchHandle, which is the expected behaviour without a loaded
  model — real-weights validation happens in Stage B when the LM
  splice path exists).

CI gate: cargo fmt --check, cargo clippy --workspace --all-targets
-- -D warnings, cargo test --workspace (all 28 test groups ok,
zero failures). New test counts: +9 in preprocess, +5 in vision,
+1 in device_worker.

Out of scope (deferred):
- LM-side splice of image embeddings at `<|image_pad|>` positions
  → Stage B.
- Streaming SSE for vision-bearing chat completions → Stage C.
- Reject `image_url` with HTTP 400 for non-vision models /
  advertise `capabilities` in /v1/models → Stage C.
- TP-vision (#12), 27B production deploy (#13), dynamic resolution
  (#14), numerical validation (#15).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-06-02 11:40:47 +03:00
parent 5c520c7e90
commit 7df84fed8f
11 changed files with 1413 additions and 1 deletions

117
Cargo.lock generated
View File

@@ -472,6 +472,12 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "byteorder-lite"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
[[package]]
name = "bytes"
version = "1.11.1"
@@ -668,6 +674,12 @@ dependencies = [
"cc",
]
[[package]]
name = "color_quant"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
[[package]]
name = "colorchoice"
version = "1.0.5"
@@ -1223,6 +1235,15 @@ version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6"
[[package]]
name = "fdeflate"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c"
dependencies = [
"simd-adler32",
]
[[package]]
name = "figment"
version = "0.10.19"
@@ -1731,6 +1752,16 @@ dependencies = [
"wasip3",
]
[[package]]
name = "gif"
version = "0.14.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee8cfcc411d9adbbaba82fb72661cc1bcca13e8bba98b364e62b2dba8f960159"
dependencies = [
"color_quant",
"weezl",
]
[[package]]
name = "glob"
version = "0.3.3"
@@ -2135,6 +2166,34 @@ dependencies = [
"icu_properties",
]
[[package]]
name = "image"
version = "0.25.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85ab80394333c02fe689eaf900ab500fbd0c2213da414687ebf995a65d5a6104"
dependencies = [
"bytemuck",
"byteorder-lite",
"color_quant",
"gif",
"image-webp",
"moxcms",
"num-traits",
"png",
"zune-core",
"zune-jpeg",
]
[[package]]
name = "image-webp"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3"
dependencies = [
"byteorder-lite",
"quick-error",
]
[[package]]
name = "indexmap"
version = "1.9.3"
@@ -2498,6 +2557,16 @@ dependencies = [
"syn",
]
[[package]]
name = "moxcms"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb85c154ba489f01b25c0d36ae69a87e4a1c73a72631fc6c0eb6dde34a73e44b"
dependencies = [
"num-traits",
"pxfm",
]
[[package]]
name = "native-tls"
version = "0.2.18"
@@ -2522,6 +2591,7 @@ dependencies = [
"anyhow",
"async-trait",
"axum",
"base64 0.22.1",
"candle-core",
"candle-nn",
"candle-transformers",
@@ -2533,6 +2603,7 @@ dependencies = [
"futures",
"half",
"hf-hub",
"image",
"minijinja",
"reqwest",
"safetensors 0.7.0",
@@ -2861,6 +2932,19 @@ version = "0.3.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e"
[[package]]
name = "png"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60769b8b31b2a9f263dae2776c37b1b28ae246943cf719eb6946a1db05128a61"
dependencies = [
"bitflags",
"crc32fast",
"fdeflate",
"flate2",
"miniz_oxide",
]
[[package]]
name = "polling"
version = "3.11.0"
@@ -2974,6 +3058,12 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40e24eee682d89fb193496edf918a7f407d30175b2e785fe057e4392dfd182e0"
[[package]]
name = "pxfm"
version = "0.1.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f"
[[package]]
name = "quanta"
version = "0.12.6"
@@ -2989,6 +3079,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "quick-error"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
[[package]]
name = "quinn"
version = "0.11.9"
@@ -4627,6 +4723,12 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "weezl"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88"
[[package]]
name = "which"
version = "7.0.3"
@@ -5164,3 +5266,18 @@ name = "zmij"
version = "1.0.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa"
[[package]]
name = "zune-core"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9"
[[package]]
name = "zune-jpeg"
version = "0.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "27bc9d5b815bc103f142aa054f561d9187d191692ec7c2d1e2b4737f8dbd7296"
dependencies = [
"zune-core",
]

View File

@@ -90,6 +90,13 @@ minijinja = { version = "2", features = ["builtins", "json", "serde"] }
# tp `fused_load` module to read per-rank slices of fused QKV tensors
# without materialising the full tensor on device.
safetensors = "0.7"
# Vision capability for Qwen3.6 (Stage A of the vision plan in
# doc/vision-qwen3_6-spec.md). `image` decodes PNG/JPEG/etc from
# the bytes embedded in `data:image/...;base64,...` content parts;
# `base64` does the URI decode. Default-features off on `image` to
# avoid pulling in audio/video formats we don't need.
image = { version = "0.25", default-features = false, features = ["png", "jpeg", "webp", "bmp", "gif"] }
base64 = "0.22"
[dev-dependencies]
tokio = { workspace = true, features = ["test-util"] }

View File

@@ -78,6 +78,7 @@ pub mod linear_attn;
pub mod mlp;
pub mod rmsnorm;
pub mod rope;
pub mod vision;
use decoder::Qwen3_5DecoderLayer;
use rmsnorm::Qwen3_5RmsNorm;
@@ -99,6 +100,20 @@ pub struct Config {
pub model_type: String,
/// The text-side hyperparameters. Everything we actually need.
pub text_config: TextConfig,
/// Vision tower hyperparameters. Present on multimodal
/// checkpoints (e.g. Qwen/Qwen3.6-27B); absent on text-only
/// variants. When present, `Qwen3_5ForCausalLM::new` loads the
/// vision tower alongside the language model so vision-bearing
/// requests can splice image embeddings at `<|image_pad|>` token
/// positions.
#[serde(default)]
pub vision_config: Option<vision::VisionConfig>,
/// Token id the chat template emits per image patch group.
/// Mirrors the LM tokenizer's `<|image_pad|>` id (248056 for
/// Qwen3.6). The runtime locates these in the prompt and splices
/// in `VisionTower::forward` output. `None` for text-only models.
#[serde(default)]
pub image_token_id: Option<u32>,
}
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
@@ -309,6 +324,15 @@ impl Qwen3_5Model {
pub struct Qwen3_5ForCausalLM {
base: Qwen3_5Model,
lm_head: Linear,
/// Vision tower (Stage A4). `None` for text-only checkpoints or
/// when the operator has opted out. When present, the harness's
/// `Job::EncodeImage` dispatch path runs `vision.forward(image)`
/// and the LM forward (Stage B) splices the result at
/// `image_token_id` positions in the input embedding stream.
vision: Option<vision::VisionTower>,
/// Mirrors `Config::image_token_id`. Cached here so the runtime
/// doesn't have to round-trip through the parsed config struct.
image_token_id: Option<u32>,
}
impl Qwen3_5ForCausalLM {
@@ -324,7 +348,52 @@ impl Qwen3_5ForCausalLM {
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
Linear::new(weight, None)
};
Ok(Self { base, lm_head })
// Stage A4: load the vision tower when the config carries a
// `vision_config` block and the safetensors actually carry
// `model.visual.*` weights. The `Option<VisionConfig>` on the
// config makes this a single-source-of-truth decision —
// text-only checkpoints just leave `vision_config` unset and
// get `None` here without any extra plumbing.
let vision = if let Some(vcfg) = config.vision_config.clone() {
tracing::info!(
depth = vcfg.depth,
hidden_size = vcfg.hidden_size,
"loading qwen3_5 vision tower"
);
Some(
vision::VisionTower::load(vcfg, vb.pp("model.visual"))
.context("load qwen3_5 vision tower (model.visual.*)")?,
)
} else {
None
};
Ok(Self {
base,
lm_head,
vision,
image_token_id: config.image_token_id,
})
}
/// True when this checkpoint loaded a vision tower. Used by the
/// HTTP layer to advertise vision capability in `/v1/models` and
/// to reject image-bearing requests against text-only loads with
/// a clean 400.
pub fn has_vision(&self) -> bool {
self.vision.is_some()
}
/// Vision tower handle, if loaded. The device-worker
/// `EncodeImage` job dispatches to `vision.forward(image)`.
pub fn vision(&self) -> Option<&vision::VisionTower> {
self.vision.as_ref()
}
/// `<|image_pad|>` token id from `config.json`, when known.
/// The Stage B prompt-builder uses this to count expansion targets
/// and the LM forward uses it to locate splice positions.
pub fn image_token_id(&self) -> Option<u32> {
self.image_token_id
}
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at

View File

@@ -0,0 +1,600 @@
//! Qwen3.6 vision tower.
//!
//! 27 pre-norm ViT blocks with **LayerNorm** (with biases — not the
//! `(1+w)·x` RmsNorm the language model uses), fused QKV attention,
//! GELU-tanh MLP. Followed by a `merger` that LayerNorms each
//! 1152-dim vision token, spatially 2×2-merges them into 4608-dim
//! groups, and projects to the LM's 5120-dim hidden via
//! `linear_fc1 → GELU → linear_fc2`.
//!
//! Architecture spec sourced from beast's cached Qwen3.6-27B
//! safetensors header (Stage A0, see
//! `doc/vision-qwen3_6-spec.md`). All weight shapes confirmed
//! from the live `.safetensors` headers, not inferred.
//!
//! **Conv3d wrinkle.** The published `patch_embed.proj.weight` is 5D
//! `[1152, 3, 2, 16, 16]` — a 3D conv with kernel
//! `(t=2, h=16, w=16)`. Candle 0.10 has no Conv3d. For static images
//! we get away with a trick: when the temporal patch size is 2 and we
//! duplicate the still image along the temporal axis (`T = 2`,
//! frame_0 == frame_1), the Conv3d output equals a Conv2d run with
//! the *sum* of the two temporal weight slices:
//!
//! ```text
//! output = W_0 · frame_0 + W_1 · frame_1 + bias
//! = (W_0 + W_1) · frame + bias (static image)
//! ```
//!
//! So at load we sum-collapse the temporal axis and use a 4D
//! `Conv2d` kernel. Video support would have to do the real Conv3d
//! (different frames mean the trick fails) — tracked alongside the
//! dynamic-resolution work in issue #14.
//!
//! Forward signature (Stage A — no LM splice yet):
//!
//! ```text
//! fn forward(&self, image: &Tensor) -> Result<Tensor>
//! ```
//!
//! `image` is `(3, H, W)` f32, normalised by `preprocess::preprocess`.
//! Returns `(N_lm_tokens, out_hidden_size)` post-merger tokens ready
//! to splice into the LM's input embeddings at `<|image_pad|>`
//! positions. For Qwen3.6 at 448×448 → 28×28 patches → 14×14 = 196
//! LM tokens of dim 5120.
use anyhow::{Context, Result};
use candle_core::{D, DType, Device, IndexOp, Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear};
use serde::Deserialize;
/// Qwen3.6 vision tower hyperparameters. Mirrors the `vision_config`
/// block of `config.json`. Only the fields we actually need are
/// captured; serde tolerates the rest.
#[derive(Debug, Clone, Deserialize)]
pub struct VisionConfig {
/// Number of ViT blocks (`depth: 27` for Qwen3.6).
pub depth: usize,
/// Vision-token dimension throughout the tower (1152 for Qwen3.6).
pub hidden_size: usize,
/// MLP intermediate dim (4304).
pub intermediate_size: usize,
/// Attention head count (16). `head_dim = hidden_size / num_heads`.
pub num_heads: usize,
/// Number of slots in the learned position embedding (2304).
/// Caps the maximum image patch count.
pub num_position_embeddings: usize,
/// Spatial patch edge in pixels (16).
pub patch_size: usize,
/// Temporal kernel depth in the patch embed (2 for Qwen3.6 — we
/// collapse this into a single Conv2d for static-image inference;
/// see the module-level Conv3d wrinkle).
pub temporal_patch_size: usize,
/// Patches grouped per LM token by the merger (2 → 2×2 = 4
/// patches per LM token).
pub spatial_merge_size: usize,
/// Vision input channels (3, RGB).
pub in_channels: usize,
/// Merger output dim — matches the LM's `hidden_size` (5120 for
/// Qwen3.6). The merger projects from vision dim → LM dim.
pub out_hidden_size: usize,
}
const LAYER_NORM_EPS: f64 = 1e-6;
/// Number of LM tokens emitted by the merger per vision-token group.
const LM_TOKENS_PER_MERGE_GROUP: usize = 1;
/// One ViT block: pre-LN → attn → residual; pre-LN → MLP → residual.
struct VisionBlock {
norm1: LayerNorm,
qkv: Linear,
proj: Linear,
norm2: LayerNorm,
fc1: Linear,
fc2: Linear,
num_heads: usize,
head_dim: usize,
}
impl VisionBlock {
fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let head_dim = h / cfg.num_heads;
let norm1 = layer_norm(vb.pp("norm1"), h)?;
let qkv = linear(vb.pp("attn.qkv"), h, 3 * h)?;
let proj = linear(vb.pp("attn.proj"), h, h)?;
let norm2 = layer_norm(vb.pp("norm2"), h)?;
let fc1 = linear(vb.pp("mlp.linear_fc1"), h, cfg.intermediate_size)?;
let fc2 = linear(vb.pp("mlp.linear_fc2"), cfg.intermediate_size, h)?;
Ok(Self {
norm1,
qkv,
proj,
norm2,
fc1,
fc2,
num_heads: cfg.num_heads,
head_dim,
})
}
/// `x`: `(N, hidden_size)` un-batched. Returns same shape.
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let attn_in = self.norm1.forward(x)?;
let attn_out = self.attention(&attn_in)?;
let x = x.add(&attn_out)?;
let mlp_in = self.norm2.forward(&x)?;
let mlp_out = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&mlp_in)?)?)?;
x.add(&mlp_out).map_err(Into::into)
}
/// Multi-head self-attention over the patch sequence. No causal
/// mask — every patch attends to every other patch.
fn attention(&self, x: &Tensor) -> Result<Tensor> {
let (n, hidden) = x.dims2()?;
// qkv: (N, 3*hidden). Split into Q, K, V each (N, hidden).
let qkv = self.qkv.forward(x)?;
let qkv = qkv.reshape((n, 3, self.num_heads, self.head_dim))?;
// Transpose to (3, num_heads, N, head_dim) for per-head views.
let qkv = qkv.permute((1, 2, 0, 3))?.contiguous()?;
let q = qkv.i(0)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let scale = 1.0 / (self.head_dim as f64).sqrt();
// (num_heads, N, head_dim) @ (num_heads, head_dim, N) -> (num_heads, N, N)
let scores = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
let scores = (scores * scale)?;
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
// (num_heads, N, N) @ (num_heads, N, head_dim) -> (num_heads, N, head_dim)
let out = probs.matmul(&v)?;
// Merge heads back: (N, num_heads, head_dim) -> (N, hidden).
let out = out.permute((1, 0, 2))?.contiguous()?.reshape((n, hidden))?;
self.proj.forward(&out).map_err(Into::into)
}
}
/// `merger`: LayerNorm per token → spatial 2×2 merge (concat 4
/// adjacent tokens into one 4608-dim vector) → fc1 → GELU-tanh →
/// fc2. Output dim is the LM's hidden_size.
struct VisionMerger {
norm: LayerNorm,
fc1: Linear,
fc2: Linear,
merge_input_dim: usize,
spatial_merge_size: usize,
}
impl VisionMerger {
fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let merge = cfg.spatial_merge_size;
let merge_input_dim = h * merge * merge;
let norm = layer_norm(vb.pp("norm"), h)?;
let fc1 = linear(vb.pp("linear_fc1"), merge_input_dim, merge_input_dim)?;
let fc2 = linear(vb.pp("linear_fc2"), merge_input_dim, cfg.out_hidden_size)?;
Ok(Self {
norm,
fc1,
fc2,
merge_input_dim,
spatial_merge_size: merge,
})
}
/// `tokens`: `(grid_h, grid_w, hidden_size)`. The merger reshapes
/// each `merge×merge` block of adjacent patches into a single
/// concatenated vector, then projects.
///
/// `grid_h` and `grid_w` must both be multiples of
/// `spatial_merge_size`. Returns
/// `(grid_h/merge × grid_w/merge, out_hidden_size)`.
fn forward(&self, tokens: &Tensor) -> Result<Tensor> {
let (gh, gw, h) = tokens.dims3()?;
let m = self.spatial_merge_size;
anyhow::ensure!(
gh.is_multiple_of(m) && gw.is_multiple_of(m),
"merger expects spatial dims divisible by merge_size={m}; got ({gh}, {gw})"
);
let tokens = self.norm.forward(tokens)?;
// (gh, gw, h) -> (gh/m, m, gw/m, m, h) -> (gh/m, gw/m, m, m, h)
// -> flatten last three -> (gh/m, gw/m, m*m*h) -> (N_lm, merge_input_dim)
let out_h = gh / m;
let out_w = gw / m;
let merged = tokens
.reshape((out_h, m, out_w, m, h))?
.permute((0, 2, 1, 3, 4))?
.contiguous()?
.reshape((out_h * out_w, self.merge_input_dim))?;
let hidden = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&merged)?)?)?;
Ok(hidden)
}
}
/// The vision tower itself.
pub struct VisionTower {
/// Sum-collapsed temporal kernel (Conv2d, see module doc).
patch_embed: Conv2d,
pos_embed: Embedding,
blocks: Vec<VisionBlock>,
merger: VisionMerger,
config: VisionConfig,
dtype: DType,
device: Device,
}
impl VisionTower {
/// Load from a `ShardedVarBuilder` rooted at the safetensors
/// `model.visual.` prefix. Caller is responsible for the `pp` —
/// see `Qwen3_5ForCausalLM::new` (Stage A4).
pub fn load(cfg: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
let dtype = vb.dtype();
let device = vb.device().clone();
// patch_embed.proj is published as 5D Conv3d weight; we
// sum-collapse the temporal axis (size = temporal_patch_size)
// to get a 4D Conv2d kernel. This is exact for the static-
// image case where T = temporal_patch_size frames are
// identical (i.e. the input was duplicated along T).
let raw_weight = vb
.pp("patch_embed.proj")
.get(
(
cfg.hidden_size,
cfg.in_channels,
cfg.temporal_patch_size,
cfg.patch_size,
cfg.patch_size,
),
"weight",
)
.context("load model.visual.patch_embed.proj.weight (5D Conv3d kernel)")?;
// Sum along the temporal axis (dim 2) — see module doc-comment.
let folded = raw_weight.sum(2)?; // -> (hidden, in_channels, patch, patch)
let proj_bias = vb
.pp("patch_embed.proj")
.get(cfg.hidden_size, "bias")
.context("load model.visual.patch_embed.proj.bias")?;
let conv_cfg = Conv2dConfig {
stride: cfg.patch_size,
..Default::default()
};
let patch_embed = Conv2d::new(folded, Some(proj_bias), conv_cfg);
let pos_embed_weight = vb
.pp("pos_embed")
.get((cfg.num_position_embeddings, cfg.hidden_size), "weight")
.context("load model.visual.pos_embed.weight")?;
let pos_embed = Embedding::new(pos_embed_weight, cfg.hidden_size);
let blocks_vb = vb.pp("blocks");
let mut blocks = Vec::with_capacity(cfg.depth);
for i in 0..cfg.depth {
blocks.push(
VisionBlock::load(&cfg, &blocks_vb.pp(i))
.with_context(|| format!("load vision block {i}"))?,
);
}
let merger = VisionMerger::load(&cfg, &vb.pp("merger")).context("load vision merger")?;
Ok(Self {
patch_embed,
pos_embed,
blocks,
merger,
config: cfg,
dtype,
device,
})
}
pub fn config(&self) -> &VisionConfig {
&self.config
}
/// Number of LM tokens this tower emits for an `(H, W)` pixel
/// image after the merger. Equal to
/// `(H / patch_size / spatial_merge_size) * (W / patch_size / spatial_merge_size)`.
pub fn lm_tokens_for(&self, h: u32, w: u32) -> usize {
let m = self.config.spatial_merge_size;
let patch = self.config.patch_size;
let gh = (h as usize) / patch / m;
let gw = (w as usize) / patch / m;
gh * gw * LM_TOKENS_PER_MERGE_GROUP
}
/// Encode one image.
///
/// `image`: row-major `(3, H, W)` f32 tensor on `self.device`,
/// already normalised by `preprocess::preprocess`. Both `H` and
/// `W` must be multiples of `patch_size * spatial_merge_size`.
///
/// Returns `(N_lm, out_hidden_size)` — LM-side image tokens
/// ready to splice into the language model's input embeddings.
pub fn forward(&self, image: &Tensor) -> Result<Tensor> {
let (c, h, w) = image.dims3()?;
anyhow::ensure!(
c == self.config.in_channels,
"image must have {} channels, got {c}",
self.config.in_channels
);
let patch = self.config.patch_size;
anyhow::ensure!(
h.is_multiple_of(patch) && w.is_multiple_of(patch),
"image dims must be multiples of patch_size={patch}; got ({h}, {w})"
);
let gh = h / patch;
let gw = w / patch;
let n_patches = gh * gw;
anyhow::ensure!(
n_patches <= self.config.num_position_embeddings,
"patch count {n_patches} exceeds pos_embed budget {}",
self.config.num_position_embeddings
);
// Add batch axis for conv: (1, 3, H, W) → (1, hidden, gh, gw)
// → (hidden, gh, gw) → permute to (gh, gw, hidden) → flatten to (N, hidden)
let x = image.unsqueeze(0)?.to_dtype(self.dtype)?;
let x = self.patch_embed.forward(&x)?;
let x = x.squeeze(0)?;
let x = x.permute((1, 2, 0))?.contiguous()?;
let x = x.reshape((n_patches, self.config.hidden_size))?;
// Add learned positional embeddings (sequential indices for
// Stage A's fixed-resolution path; full 2D positional logic
// lands with variable resolution, issue #14).
let positions = Tensor::arange(0u32, n_patches as u32, &self.device)?;
let pos = self.pos_embed.forward(&positions)?;
let mut x = x.add(&pos)?;
for (i, block) in self.blocks.iter().enumerate() {
x = block
.forward(&x)
.with_context(|| format!("vision block {i}"))?;
}
// (n_patches, hidden) → (gh, gw, hidden) for the merger.
let x = x.reshape((gh, gw, self.config.hidden_size))?;
self.merger.forward(&x)
}
}
/// Manually load a candle_nn LayerNorm from a ShardedVarBuilder.
/// candle_nn's `layer_norm` builder takes `crate::VarBuilder`, not
/// `ShardedVarBuilder`, so the existing arch modules in this crate
/// uniformly do the manual load + struct construction pattern (see
/// `full_attn::load_linear_no_bias`). We follow suit here.
fn layer_norm(vb: ShardedVarBuilder, size: usize) -> Result<LayerNorm> {
let weight = vb
.get(size, "weight")
.with_context(|| format!("load LayerNorm.weight at '{}'", vb.prefix()))?;
let bias = vb
.get(size, "bias")
.with_context(|| format!("load LayerNorm.bias at '{}'", vb.prefix()))?;
Ok(LayerNorm::new(weight, bias, LAYER_NORM_EPS))
}
/// Manually load a candle_nn Linear (with bias) from a
/// ShardedVarBuilder. Same rationale as `layer_norm` above.
fn linear(vb: ShardedVarBuilder, in_dim: usize, out_dim: usize) -> Result<Linear> {
let weight = vb
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load Linear.weight at '{}'", vb.prefix()))?;
let bias = vb
.get(out_dim, "bias")
.with_context(|| format!("load Linear.bias at '{}'", vb.prefix()))?;
Ok(Linear::new(weight, Some(bias)))
}
/// PyTorch's `gelu_pytorch_tanh` approximation — what the Qwen3.6
/// vision tower's `hidden_act` specifies. candle's `Tensor::gelu`
/// uses the exact erf-based GELU, so we compute the tanh
/// approximation explicitly:
///
/// ```text
/// gelu_tanh(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
/// ```
fn gelu_tanh(x: &Tensor) -> Result<Tensor> {
// sqrt(2 / pi) = 0.7978845608028654
const COEFF: f64 = 0.7978845608028654;
const KAPPA: f64 = 0.044715;
let x3 = x.powf(3.0)?;
let inner = (x + (x3 * KAPPA)?)?;
let inner = (inner * COEFF)?;
let t = inner.tanh()?;
let one_plus_t = (t + 1.0)?;
let out = (x * 0.5)?;
let out = out.broadcast_mul(&one_plus_t)?;
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device};
/// Build a tiny VisionConfig usable on CPU with random weights.
/// Match the Qwen3.6 shape relations (depth-N stack, hidden mod
/// num_heads, intermediate_size > hidden_size) but with small
/// dims so tests run in milliseconds.
fn tiny_config() -> VisionConfig {
VisionConfig {
depth: 2,
hidden_size: 32,
intermediate_size: 64,
num_heads: 4,
num_position_embeddings: 64,
patch_size: 4,
temporal_patch_size: 2,
spatial_merge_size: 2,
in_channels: 3,
out_hidden_size: 48,
}
}
/// Hand-construct a VisionTower with random weights. This is the
/// same trick `linear_attn::tests::forward_smoke_with_tiny_dimensions`
/// uses — bypass the safetensors-backed `ShardedVarBuilder` path
/// (which can't be built from in-memory tensors) and assemble the
/// struct fields directly. The real `VisionTower::load` is
/// exercised by the cuda-integration smoke test in Stage A6.
fn tiny_tower(cfg: &VisionConfig) -> VisionTower {
let device = Device::Cpu;
let dtype = DType::F32;
let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &device).unwrap();
let ones = |shape: &[usize]| Tensor::ones(shape, dtype, &device).unwrap();
let randn = |shape: &[usize]| Tensor::randn(0_f32, 0.02, shape, &device).unwrap();
let patch_embed = Conv2d::new(
randn(&[
cfg.hidden_size,
cfg.in_channels,
cfg.patch_size,
cfg.patch_size,
]),
Some(zeros(&[cfg.hidden_size])),
Conv2dConfig {
stride: cfg.patch_size,
..Default::default()
},
);
let pos_embed = Embedding::new(
randn(&[cfg.num_position_embeddings, cfg.hidden_size]),
cfg.hidden_size,
);
let mut blocks = Vec::with_capacity(cfg.depth);
for _ in 0..cfg.depth {
let head_dim = cfg.hidden_size / cfg.num_heads;
blocks.push(VisionBlock {
norm1: LayerNorm::new(
ones(&[cfg.hidden_size]),
zeros(&[cfg.hidden_size]),
LAYER_NORM_EPS,
),
qkv: Linear::new(
randn(&[3 * cfg.hidden_size, cfg.hidden_size]),
Some(zeros(&[3 * cfg.hidden_size])),
),
proj: Linear::new(
randn(&[cfg.hidden_size, cfg.hidden_size]),
Some(zeros(&[cfg.hidden_size])),
),
norm2: LayerNorm::new(
ones(&[cfg.hidden_size]),
zeros(&[cfg.hidden_size]),
LAYER_NORM_EPS,
),
fc1: Linear::new(
randn(&[cfg.intermediate_size, cfg.hidden_size]),
Some(zeros(&[cfg.intermediate_size])),
),
fc2: Linear::new(
randn(&[cfg.hidden_size, cfg.intermediate_size]),
Some(zeros(&[cfg.hidden_size])),
),
num_heads: cfg.num_heads,
head_dim,
});
}
let merge_input_dim = cfg.hidden_size * cfg.spatial_merge_size * cfg.spatial_merge_size;
let merger = VisionMerger {
norm: LayerNorm::new(
ones(&[cfg.hidden_size]),
zeros(&[cfg.hidden_size]),
LAYER_NORM_EPS,
),
fc1: Linear::new(
randn(&[merge_input_dim, merge_input_dim]),
Some(zeros(&[merge_input_dim])),
),
fc2: Linear::new(
randn(&[cfg.out_hidden_size, merge_input_dim]),
Some(zeros(&[cfg.out_hidden_size])),
),
merge_input_dim,
spatial_merge_size: cfg.spatial_merge_size,
};
VisionTower {
patch_embed,
pos_embed,
blocks,
merger,
config: cfg.clone(),
dtype,
device,
}
}
#[test]
fn forward_with_random_weights_produces_finite_output() {
let cfg = tiny_config();
let tower = tiny_tower(&cfg);
// 16×16 image at patch_size=4 → 4×4 patches → after 2×2
// merge → 2×2 = 4 LM tokens of dim out_hidden_size.
let image = Tensor::randn(0_f32, 1.0, (3, 16, 16), &Device::Cpu).unwrap();
let out = tower.forward(&image).expect("forward");
let (n_lm, hidden) = out.dims2().unwrap();
assert_eq!(n_lm, 4);
assert_eq!(hidden, cfg.out_hidden_size);
// No NaN/Inf
let values: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
assert!(
values.iter().all(|v| v.is_finite()),
"forward must produce finite values"
);
}
#[test]
fn lm_token_count_matches_grid() {
let cfg = tiny_config();
let tower = tiny_tower(&cfg);
// 16x16 image → 4x4 patches → 2x2 = 4 LM tokens
assert_eq!(tower.lm_tokens_for(16, 16), 4);
// 32x32 image → 8x8 patches → 4x4 = 16 LM tokens
assert_eq!(tower.lm_tokens_for(32, 32), 16);
}
#[test]
fn rejects_image_with_dims_not_multiple_of_patch() {
let cfg = tiny_config();
let tower = tiny_tower(&cfg);
let image = Tensor::randn(0_f32, 1.0, (3, 17, 17), &Device::Cpu).unwrap();
let err = tower.forward(&image).unwrap_err();
assert!(format!("{err:#}").contains("patch_size"));
}
#[test]
fn rejects_image_with_wrong_channel_count() {
let cfg = tiny_config();
let tower = tiny_tower(&cfg);
let image = Tensor::randn(0_f32, 1.0, (4, 16, 16), &Device::Cpu).unwrap();
let err = tower.forward(&image).unwrap_err();
assert!(format!("{err:#}").contains("channels"));
}
#[test]
fn gelu_tanh_matches_known_values() {
// Reference values for gelu_pytorch_tanh from PyTorch:
// gelu_tanh(0.0) = 0.0
// gelu_tanh(1.0) ≈ 0.8411920071
// gelu_tanh(-1.0) ≈ -0.1588079929
let x = Tensor::new(&[0.0_f32, 1.0, -1.0], &Device::Cpu).unwrap();
let y = gelu_tanh(&x).unwrap();
let v: Vec<f32> = y.to_vec1().unwrap();
assert!((v[0]).abs() < 1e-6, "gelu_tanh(0) ≈ 0, got {}", v[0]);
assert!(
(v[1] - 0.841_192_f32).abs() < 1e-5,
"gelu_tanh(1) ≈ 0.84119, got {}",
v[1]
);
assert!(
(v[2] - -0.158_808_f32).abs() < 1e-5,
"gelu_tanh(-1) ≈ -0.15881, got {}",
v[2]
);
}
}

View File

@@ -332,6 +332,37 @@ impl ModelArch {
}
}
}
/// Encode a preprocessed image into LM-side token embeddings via
/// the loaded vision tower. Stage A5.
///
/// `image`: device-resident `(C, H, W)` f32 tensor — caller has
/// already preprocessed via `harness::preprocess::preprocess` and
/// uploaded to the worker's device. Returns
/// `(N_lm_tokens, hidden_size)`.
///
/// Errors when the loaded architecture has no vision tower
/// (text-only checkpoint, or architecture that doesn't support
/// vision at all). The HTTP layer maps this to a 400 with
/// `vision_unsupported` so clients see a clean rejection rather
/// than a confident text-only hallucination.
pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
match self {
ModelArch::Qwen3_5Dense(m) => m
.vision()
.ok_or_else(|| {
anyhow::anyhow!(
"encode_image: this Qwen3.6 checkpoint was loaded without a vision \
tower (config.json::vision_config absent or weights missing)"
)
})?
.forward(image),
other => anyhow::bail!(
"encode_image: architecture {} has no vision tower",
std::any::type_name_of_val(other)
),
}
}
}
/// Squeeze any leading singleton dims off the logits tensor so the

View File

@@ -158,6 +158,17 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
let result = forward_logits(&mut state, handle, &tokens, offset);
let _ = reply.send(result);
}
Job::EncodeImage {
handle,
pixels,
c,
h,
w,
reply,
} => {
let result = encode_image(&mut state, handle, pixels, c, h, w);
let _ = reply.send(result);
}
Job::NcclInit {
cfg,
comm_id_hex,
@@ -740,6 +751,49 @@ fn forward_logits(
Ok(values)
}
/// Run the vision tower on a single preprocessed image. Stage A5.
///
/// `pixels` is a row-major `(c, h, w)` f32 image that the async-side
/// `harness::preprocess` produced. We reconstruct the tensor on the
/// worker's device (the same device the model was loaded against),
/// call `arch.encode_image`, and copy the resulting
/// `(N_lm_tokens, hidden_size)` embedding back to CPU f32.
///
/// Returns the flattened embedding as a `Vec<f32>` — the caller knows
/// the LM-side token count from `VisionTower::lm_tokens_for(h, w)`
/// and reshapes accordingly. Stage B introduces a device-resident
/// embedding-slab variant that avoids this round-trip when the next
/// forward call needs the result.
fn encode_image(
state: &mut DeviceWorkerState,
handle: ArchHandle,
pixels: Vec<f32>,
c: usize,
h: usize,
w: usize,
) -> anyhow::Result<Vec<f32>> {
use candle_core::{DType, Tensor};
anyhow::ensure!(
pixels.len() == c * h * w,
"EncodeImage: pixels length {} does not match shape ({c}, {h}, {w})",
pixels.len()
);
let image = Tensor::from_vec(pixels, (c, h, w), &state.device)?;
let arch = state
.models
.get(&handle)
.ok_or_else(|| anyhow::anyhow!("EncodeImage: no model for handle {}", handle.0))?;
let embed = arch.encode_image(&image)?;
let values = embed
.to_dtype(DType::F32)?
.flatten_all()?
.to_vec1::<f32>()?;
Ok(values)
}
/// Reply to a job with the poisoned-worker error. Used when the worker
/// has flipped into drain-only mode after a CUDA driver error.
///
@@ -773,6 +827,9 @@ fn drain_poisoned(job: Job, device_index: u32) {
Job::ForwardLogits { reply, .. } => {
let _ = reply.send(Err(err()));
}
Job::EncodeImage { reply, .. } => {
let _ = reply.send(Err(err()));
}
Job::NcclInit { reply, .. } => {
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
kind: "device_worker_poisoned".into(),

View File

@@ -94,6 +94,31 @@ pub enum Job {
offset: usize,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Encode one image through the model's vision tower. Stage A5 of
/// the vision plan (`doc/vision-qwen3_6-spec.md`).
///
/// `pixels` is the CPU-side preprocessed image tensor in row-major
/// `(C, H, W)` f32 layout — what `harness::preprocess::preprocess`
/// produces. `c`, `h`, `w` carry the shape since `Vec<f32>` itself
/// is rank-1. The handler reconstructs the tensor on the worker's
/// device, runs `VisionTower::forward`, copies the resulting
/// `(N_lm_tokens, hidden_size)` embedding back to CPU as a flat
/// `Vec<f32>` (the caller knows the expected shape from
/// `VisionTower::lm_tokens_for(h, w) * hidden_size`).
///
/// Mirrors the `ForwardLogits` "tensors don't escape" invariant —
/// device-side image embeddings are dropped at handler return.
/// Stage B will introduce a follow-up variant that keeps the
/// embeddings device-resident and references them from the next
/// `ForwardLogits` call, avoiding the round-trip copy.
EncodeImage {
handle: ArchHandle,
pixels: Vec<f32>,
c: usize,
h: usize,
w: usize,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Initialize the leader's NCCL communicator. The worker's
/// `NcclState` mints the `Comm` here so its underlying
/// `ncclComm_t` and `CudaContext` live on the same thread as

View File

@@ -313,6 +313,49 @@ impl DeviceWorkerHandle {
}
}
/// Encode a preprocessed image through the model's vision tower
/// and return the resulting LM-side image embeddings as a
/// flattened CPU `Vec<f32>`. Stage A5.
///
/// `pixels` is the row-major `(c, h, w)` f32 image —
/// `harness::preprocess::preprocess` produces this exact shape.
/// The caller knows the expected output length from
/// `VisionTower::lm_tokens_for(h, w) * hidden_size` and reshapes
/// accordingly.
pub async fn encode_image(
&self,
handle: ArchHandle,
pixels: Vec<f32>,
c: usize,
h: usize,
w: usize,
) -> Result<Vec<f32>, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::EncodeImage {
handle,
pixels,
c,
h,
w,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Initialise the leader's NCCL communicator. The reply uses
/// `WorkerResponse` (same shape subprocess workers use over stdio
/// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader +
@@ -569,6 +612,37 @@ mod tests {
handle.shutdown().expect("shutdown ok");
}
/// Stage A5: confirm the EncodeImage job round-trips through the
/// worker channel. We don't have a real loaded model in the slab
/// here, so the dispatch handler returns the
/// "no model for handle" error — which is exactly what we want to
/// see, since it proves the message routed through the channel
/// and reached the handler. Real-weights validation lives in the
/// Stage A7 / Stage B post-deploy smoke on beast.
#[tokio::test]
async fn encode_image_routes_to_dispatch_and_errors_on_unknown_handle() {
use crate::harness::device_worker::jobs::ArchHandle;
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
let fake_arch = ArchHandle(99_999);
// (3, 4, 4) fake image — minimal payload, gets reconstructed
// on the worker before the handler errors out on the unknown
// ArchHandle lookup.
let pixels = vec![0.0_f32; 3 * 4 * 4];
let result = handle.encode_image(fake_arch, pixels, 3, 4, 4).await;
match result {
Err(WorkerError::Job(e)) => {
let msg = format!("{e:#}");
assert!(
msg.contains("EncodeImage: no model for handle"),
"expected unknown-handle error, got: {msg}"
);
}
other => panic!("expected Job(Err), got {other:?}"),
}
handle.shutdown().expect("shutdown ok");
}
#[tokio::test]
async fn shutdown_drains_pending_jobs() {
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");

View File

@@ -5,6 +5,7 @@ pub mod candle;
pub mod chat_template;
pub mod device_worker;
pub mod preflight;
pub mod preprocess;
pub mod tp;
use anyhow::Result;

View File

@@ -0,0 +1,255 @@
//! Image preprocessing for vision-capable models.
//!
//! Decodes `data:image/...;base64,...` URIs from OpenAI-style
//! `image_url` content parts into the patch tensors a candle vision
//! tower expects. Stage A ships **fixed resolution** — every image
//! is resized to the same target dimensions (default 448×448 for
//! Qwen3.6, configurable per-call) so the patch count is constant
//! per image. Variable resolution per [Qwen2VL convention] is tracked
//! as issue #14.
//!
//! Spec reference: `doc/vision-qwen3_6-spec.md` — preprocessor
//! section.
//!
//! Normalisation: pixel value `p ∈ [0, 255]` becomes
//! `(p/255 - mean) / std`. Qwen3.6's preprocessor_config.json
//! specifies `image_mean = image_std = [0.5, 0.5, 0.5]`, which
//! simplifies to `2p/255 - 1` mapping `[0,255]` → `[-1, 1]`. We
//! still parameterise mean/std so the same code generalises to other
//! VL families (Qwen2-VL uses imagenet stats, for instance).
//!
//! Pipeline (per image):
//! 1. data: URI → base64 decode → bytes
//! 2. bytes → image::DynamicImage (PNG/JPEG/WebP/etc)
//! 3. resize_exact to target H×W (pixel space)
//! 4. RGB→f32, normalise per mean/std
//! 5. layout to (C, H, W) tensor
//!
//! Patchification (cutting the HxW tensor into `patch_size` blocks)
//! happens inside the vision tower's `patch_embed` conv, so this
//! module stops at "preprocessed RGB f32 tensor."
use anyhow::{Context, Result, anyhow};
use base64::Engine;
use image::DynamicImage;
use image::imageops::FilterType;
/// Preprocessing target. Captures the resize dimensions and the
/// channel-wise normalisation constants from the model's
/// `preprocessor_config.json`. Stage A ships a single `qwen3_6()`
/// constructor for fixed-resolution Qwen3.6 preprocessing; other
/// models can ship their own profile when added.
#[derive(Debug, Clone)]
pub struct PreprocessProfile {
pub target_height: u32,
pub target_width: u32,
pub image_mean: [f32; 3],
pub image_std: [f32; 3],
}
impl PreprocessProfile {
/// Stage A profile for Qwen3.6. Resize to 448×448, normalise to
/// `[-1, 1]` via mean=std=0.5. Fits within the model's
/// `num_position_embeddings=2304` budget at 28×28 = 784 patches
/// before merging.
pub fn qwen3_6() -> Self {
Self {
target_height: 448,
target_width: 448,
image_mean: [0.5, 0.5, 0.5],
image_std: [0.5, 0.5, 0.5],
}
}
/// Per-channel CHW tensor length: 3 * H * W.
pub fn pixels_chw(&self) -> usize {
3 * (self.target_height as usize) * (self.target_width as usize)
}
}
/// Decode a `data:image/...;base64,...` URI into an in-memory image.
///
/// Accepts the OpenAI Chat Completions `image_url` shape — a string
/// URL with `data:` scheme and base64 payload. The MIME type is read
/// from the URI for diagnostics but `image::load_from_memory` sniffs
/// the format from the bytes themselves, so the MIME is advisory.
///
/// Bare `http(s)://` URLs are explicitly rejected here — fetching
/// them from a vision-model server is a fingerprintable behaviour
/// (server-side request forgery, infinite recursion if the URL
/// points at the gateway itself, etc.). Clients that want remote
/// images can fetch them and pass base64 themselves.
pub fn decode_data_uri(uri: &str) -> Result<DynamicImage> {
let after_scheme = uri
.strip_prefix("data:")
.ok_or_else(|| anyhow!("image_url must use data: scheme; got {uri:.40}…"))?;
let (meta, payload) = after_scheme
.split_once(',')
.ok_or_else(|| anyhow!("malformed data URI: missing ',' separator"))?;
if !meta.contains(";base64") {
anyhow::bail!(
"data URI must use base64 encoding (got '{meta}'); raw URL-encoded payloads not supported"
);
}
let bytes = base64::engine::general_purpose::STANDARD
.decode(payload.trim())
.context("base64-decode image data URI payload")?;
image::load_from_memory(&bytes).context("decode image bytes (PNG/JPEG/WebP/etc)")
}
/// Resize and normalise an image into a `(3, H, W)` row-major
/// `Vec<f32>` ready to hand to the vision tower's `patch_embed`
/// conv.
///
/// Uses bilinear resampling — Qwen2-VL's reference uses bicubic, but
/// bilinear is what the candle ecosystem standardises on and is
/// faster on CPU. Quality difference is marginal for downstream
/// vision-encoder consumption. The numerical-validation issue (#15)
/// will quantify any discrepancy.
pub fn preprocess(img: &DynamicImage, profile: &PreprocessProfile) -> Vec<f32> {
let rgb = img
.resize_exact(
profile.target_width,
profile.target_height,
FilterType::Triangle,
)
.to_rgb8();
let h = profile.target_height as usize;
let w = profile.target_width as usize;
let mut out = vec![0.0_f32; 3 * h * w];
// Row-major (C, H, W). Candle's Conv2d expects NCHW, so this is
// the natural layout — the caller stacks `n` of these along the
// batch axis as needed.
for c in 0..3 {
let mean = profile.image_mean[c];
let std = profile.image_std[c];
for y in 0..h {
for x in 0..w {
let pixel = rgb.get_pixel(x as u32, y as u32);
let raw = pixel[c] as f32 / 255.0;
out[c * h * w + y * w + x] = (raw - mean) / std;
}
}
}
out
}
/// Combined helper: decode + preprocess in one call. Most call
/// sites just want the final tensor; the two-step path exists for
/// callers (tests, future video preprocessing) that need the
/// intermediate `DynamicImage`.
pub fn preprocess_data_uri(uri: &str, profile: &PreprocessProfile) -> Result<Vec<f32>> {
let img = decode_data_uri(uri)?;
Ok(preprocess(&img, profile))
}
#[cfg(test)]
mod tests {
use super::*;
use image::{ImageBuffer, Rgb};
/// A 1×1 red PNG, hand-built. Matches the well-known smallest
/// valid PNG we use in tests/curl examples elsewhere.
const ONE_BY_ONE_RED_PNG_B64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==";
fn red_png_uri() -> String {
format!("data:image/png;base64,{ONE_BY_ONE_RED_PNG_B64}")
}
#[test]
fn decodes_well_formed_png_data_uri() {
let img = decode_data_uri(&red_png_uri()).expect("decode 1x1 png");
assert_eq!(img.width(), 1);
assert_eq!(img.height(), 1);
}
#[test]
fn rejects_non_data_scheme() {
let err = decode_data_uri("https://example.com/cat.jpg")
.expect_err("http(s) URLs must be rejected");
assert!(format!("{err:#}").contains("data:"));
}
#[test]
fn rejects_malformed_uri_without_comma() {
let err = decode_data_uri("data:image/png;base64").unwrap_err();
assert!(format!("{err:#}").contains("','"));
}
#[test]
fn rejects_non_base64_payload() {
let err = decode_data_uri("data:image/png,raw-bytes-here").unwrap_err();
assert!(format!("{err:#}").contains("base64"));
}
#[test]
fn rejects_bad_base64_payload() {
let err = decode_data_uri("data:image/png;base64,not!valid!base64!").unwrap_err();
assert!(format!("{err:#}").contains("base64"));
}
#[test]
fn rejects_garbage_image_bytes() {
// Valid base64 ("Hello World!"), invalid image bytes.
let err = decode_data_uri("data:image/png;base64,SGVsbG8gV29ybGQh").unwrap_err();
assert!(
format!("{err:#}").contains("decode image"),
"should fail at image-decode step"
);
}
#[test]
fn preprocess_red_image_produces_correct_shape_and_values() {
let profile = PreprocessProfile::qwen3_6();
// Build a tiny pure-red image directly, skipping data: URI
// decoding so this test isolates the resize+normalise path.
let img: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_pixel(2, 2, Rgb([255, 0, 0]));
let dyn_img = DynamicImage::ImageRgb8(img);
let out = preprocess(&dyn_img, &profile);
assert_eq!(out.len(), profile.pixels_chw());
// After mean=0.5, std=0.5: red channel (255/255=1.0) → (1.0 - 0.5)/0.5 = 1.0
// green/blue (0.0) → (0.0 - 0.5)/0.5 = -1.0
let h = profile.target_height as usize;
let w = profile.target_width as usize;
assert!(
(out[0] - 1.0).abs() < 1e-5,
"R[0] should be 1.0, got {}",
out[0]
);
assert!((out[h * w] - (-1.0)).abs() < 1e-5, "G[0] should be -1.0");
assert!(
(out[2 * h * w] - (-1.0)).abs() < 1e-5,
"B[0] should be -1.0"
);
// All values are finite
assert!(out.iter().all(|v| v.is_finite()), "no NaN/Inf in output");
}
#[test]
fn preprocess_data_uri_end_to_end() {
let profile = PreprocessProfile::qwen3_6();
let out = preprocess_data_uri(&red_png_uri(), &profile).expect("e2e preprocess");
assert_eq!(out.len(), profile.pixels_chw());
assert!(out.iter().all(|v| v.is_finite()));
}
#[test]
fn preprocess_grayscale_image_promotes_to_rgb() {
let profile = PreprocessProfile::qwen3_6();
// 1x1 grayscale = 200 → after conversion to RGB, all three
// channels equal 200, normalised → (200/255 - 0.5)/0.5 ≈ 0.569
let gray = DynamicImage::ImageLuma8(ImageBuffer::from_pixel(1, 1, image::Luma([200])));
let out = preprocess(&gray, &profile);
let expected = ((200.0 / 255.0) - 0.5) / 0.5;
let h = profile.target_height as usize;
let w = profile.target_width as usize;
for c in 0..3 {
let v = out[c * h * w];
assert!(
(v - expected).abs() < 1e-3,
"channel {c}: expected {expected}, got {v}"
);
}
}
}

176
doc/vision-qwen3_6-spec.md Normal file
View File

@@ -0,0 +1,176 @@
# Qwen3.6-27B vision specification (Stage A0)
Sourced from beast's local cache on 2026-06-01:
`/archive3/llm-cache/models--Qwen--Qwen3.6-27B/snapshots/6a9e13bd6fc8f0983b9b99948120bc37f49c13e9/`.
Single source of truth for Stages AD of the vision plan in
`~/.claude/plans/foamy-twirling-catmull.md`. Umbrella issue:
[#3](https://git.lair.cafe/helexa/cortex/issues/3).
---
## Top-level shape
The model is a unified text+vision architecture (`Qwen3_5ForConditionalGeneration`,
`model_type: qwen3_5`) with three weight sections under a single safetensors
index. Counts from `model.safetensors.index.json`:
| Prefix | Tensors | Role |
|---|---|---|
| `model.language_model.*` | 850 | LM (currently loaded) |
| `model.visual.*` | 333 | Vision tower (currently filtered out at `arch/qwen3_5/mod.rs:228-230`) |
| `mtp.*` | 15 | Multi-token-prediction heads (filtered, out of scope) |
| `lm_head.weight` | 1 | LM head |
Vision tensors live in shards `model-00007-of-00015.safetensors` and
`model-00008-of-00015.safetensors` (2 of the 15 safetensors). Loading just
these two for vision-tower-only smoke tests is feasible.
## Vision tower architecture (`model.visual.*`)
From `config.json::vision_config`:
```
depth: 27 (transformer blocks)
hidden_size: 1152 (vision token dim)
num_heads: 16 (per-block self-attention)
intermediate_size: 4304 (MLP hidden)
patch_size: 16 (16×16 spatial patches)
temporal_patch_size: 2 (video frame pairing; irrelevant for stills)
spatial_merge_size: 2 (2×2 spatial merge in the merger → 4 patches/LM token)
num_position_embeddings: 2304 (learned pos embed slots — max patch sequence length)
in_channels: 3 (RGB)
hidden_act: gelu_pytorch_tanh (GELU with tanh approximation, not exact GELU)
out_hidden_size: 5120 (= LM hidden_size, merger output dim)
deepstack_visual_indexes: [] (no deep-stack visual indexes)
```
### Module inventory (per-block and global)
Global:
- `model.visual.patch_embed.proj.{weight, bias}` — Conv2d (3 → 1152, kernel 16×16, stride 16). Turns image patches into tokens.
- `model.visual.pos_embed.weight` — Learned position embedding, shape `(2304, 1152)`.
- `model.visual.merger.{norm, linear_fc1, linear_fc2}` — The projector that merges 2×2 patches and projects to LM hidden_size (1152 → 5120). All weights have biases.
Per block (×27, named `model.visual.blocks.{0..26}`):
- `norm1.{weight, bias}`**LayerNorm** before attention (with bias — not RmsNorm).
- `attn.qkv.{weight, bias}` — Fused QKV linear (1152 → 3·1152 = 3456).
- `attn.proj.{weight, bias}` — Attention output projection (1152 → 1152).
- `norm2.{weight, bias}` — LayerNorm before MLP.
- `mlp.linear_fc1.{weight, bias}` — MLP up-projection (1152 → 4304).
- `mlp.linear_fc2.{weight, bias}` — MLP down-projection (4304 → 1152).
Pattern matches a standard ViT block with **pre-norm** layout (norm → attn → residual, norm → MLP → residual). Activation between fc1/fc2 is GELU-tanh-approx per `hidden_act`. No attention masking inside the vision tower (all patches attend to each other).
### Forward signature (target)
```
VisionTower::forward(
patches: Tensor [N, in_channels, patch_size, patch_size], # CPU-preprocessed RGB float patches
grid_thw: Option<(usize, usize, usize)>, # (t, h, w) patch grid for position lookup
) -> Tensor [N / (spatial_merge_size²), out_hidden_size] # = (N/4, 5120) for static images
```
Note: the merger consumes 4 spatially-adjacent patches and emits 1 LM token. So an image producing 64×64 = 4096 patches yields 1024 LM-side image tokens.
## Image preprocessor (`preprocessor_config.json`)
```json
{
"size": { "longest_edge": 16777216, "shortest_edge": 65536 },
"patch_size": 16,
"temporal_patch_size": 2,
"merge_size": 2,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5],
"processor_class": "Qwen3VLProcessor",
"image_processor_type": "Qwen2VLImageProcessorFast"
}
```
Reading:
- `image_mean = image_std = 0.5` → normalisation is simply `(x/255 - 0.5) / 0.5 = 2*x/255 - 1`, mapping `[0,255]``[-1, 1]`. No imagenet-style mean/std.
- `size.{shortest_edge, longest_edge}` are **pixel counts**, not edge lengths. The `Qwen2VLImageProcessorFast` recipe picks a resolution within `[65,536 = 256², 16,777,216 = 4096²]` total pixels, snapping `h` and `w` to multiples of `patch_size × spatial_merge_size = 32` pixels.
- Stage A ships **fixed resolution**: pick a target pixel count (e.g. 448×448 = 200,704 px → 28×28 patches → 14×14 LM tokens after merger). Variable resolution deferred to issue [#14](https://git.lair.cafe/helexa/cortex/issues/14).
## Chat template (`chat_template.jinja`)
Image insertion (lines 818 of the template):
```jinja
{%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
...
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
```
Per image, the template emits **one `<|image_pad|>` token** flanked by `<|vision_start|>` and `<|vision_end|>` sentinels. The runtime must:
1. Render the template (preserving the single `<|image_pad|>` per image).
2. For each image, replace its single `<|image_pad|>` with N copies, where N is the number of LM tokens that image produces after the vision tower + merger (= `patches / spatial_merge_size²`).
3. Tokenize the expanded string → `input_ids`.
4. At forward time, locate positions where `input_ids == image_token_id` (248056) and splice in the vision tower's merger output.
Token IDs (top of `config.json`):
- `vision_start_token_id`: 248053
- `vision_end_token_id`: 248054
- `image_token_id`: 248056
- `video_token_id`: 248057 (out of scope)
- `bos_token_id`: 248044
- `eos_token_id`: 248044, 248046 (per `generation_config.json`)
System messages cannot contain images (template raises). Other template-side details:
- `add_vision_id` (jinja arg, default false): emits `'Picture N: '` prefixes when true.
- `preserve_thinking` (jinja arg, default false): keeps `<think>` blocks from prior assistant turns in the rendered prompt.
- `enable_thinking` (jinja arg, default true): emits `<think>\n` (or skips it) at the end of the generation prompt.
The existing chat-template renderer in `crates/neuron/src/harness/chat_template.rs` already passes `MessageContent::Parts` to the Jinja context as a `Value::Array`; the template's `is iterable` branch (line 6 of the template) handles them. **The path is structurally in place** — Stage B just needs to do the `<|image_pad|>` expansion + token-position-aware splice.
## LM-side considerations
The LM's RoPE config uses **multi-axis RoPE (MRoPE)**:
```
rope_parameters: {
mrope_interleaved: true,
mrope_section: [11, 11, 10], # text + height + width components
partial_rotary_factor: 0.25,
rope_theta: 10000000,
rope_type: "default"
}
```
MRoPE encodes spatial position alongside text position so the LM attention layers can reason about image-token spatial structure. The LM's existing forward path *may or may not* already implement this — the qwen3_5 module's doc-comment notes "numerical correctness vs the reference Python is not yet validated." Verifying MRoPE behaviour in the language model is out of Stage A scope (vision tower only) but will be required in Stage B (LM splice) and is tracked under the numerical-validation issue [#15](https://git.lair.cafe/helexa/cortex/issues/15).
`max_position_embeddings = 262144` (256 K context), so context-length limits are not a constraint for vision.
## Iteration target decision
The vision tower has its own self-contained weight tree and is small (~333 tensors in 2 shards, hidden_size 1152 vs LM's 5120). For Stage A specifically (vision-tower-only smoke), we **don't need a smaller iteration model** — we can:
- Build the Rust `VisionTower` struct against the spec above.
- Run unit tests with random tensor weights matching the exact shapes → assert forward produces correct output shape with finite values.
- Optionally: a CUDA-integration test that loads just the 2 vision shards from beast's cache (or on a smaller GPU like quadbrat's Ampere) and runs encode on a real image. Doesn't require loading the 27B LM at all.
This sidesteps the "develop against a smaller VL model" question for Stage A. Stage B (LM splice → end-to-end chat with vision) is where iteration speed becomes pressing; revisit there. The default scope pick 2a (smaller iteration model) is therefore deferred to Stage B planning — issue [#13](https://git.lair.cafe/helexa/cortex/issues/13) covers deployment validation regardless.
## Concrete Stage A1+ inputs
- Add deps to `crates/neuron/Cargo.toml`:
- `image = "0.25"`
- `base64 = "0.22"`
- Stage A2 preprocessor target resolution (fixed): **448×448 → 28×28 patches → 14×14 = 196 image tokens per image**. This balances minimum-patch-count for cheap tests against the model's expected input range.
- Stage A3 module structure: one `VisionTower` struct holding `patch_embed: Conv2d`, `pos_embed: Embedding`, `blocks: Vec<VisionBlock>`, `merger: Merger`. `VisionBlock` carries `norm1`, `norm2`, `attn`, `mlp`. Hand-roll using candle primitives.
- Stage A4 weight loading: extend `Qwen3_5ForCausalLM::new()` to construct `Some(VisionTower::new(vb.pp("model.visual"), config))` when `vision_config` is present in the parsed config.
- Stage A5 worker job: `Job::EncodeImage { handle, patches: Vec<f32>, patch_shape: (usize, usize, usize, usize, usize), reply: oneshot<Result<Vec<f32>>> }`. Patch shape = `(N, C, T, H, W)` where T=1 for static images.
## What this doc does NOT settle (deferred to issues)
- Numerical correctness of `VisionTower` output vs Python transformers
→ issue [#15](https://git.lair.cafe/helexa/cortex/issues/15).
- Variable image resolution
→ issue [#14](https://git.lair.cafe/helexa/cortex/issues/14).
- TP-vision (multi-rank vision tower)
→ issue [#12](https://git.lair.cafe/helexa/cortex/issues/12).
- 27B production deployment
→ issue [#13](https://git.lair.cafe/helexa/cortex/issues/13).