feat(neuron): Stage A — vision tower load + preprocessor for Qwen3.6
All checks were successful
CI / CUDA type-check (push) Successful in 32s
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / Format (push) Successful in 28s
CI / Clippy (push) Successful in 2m35s
build-prerelease / Build cortex binary (push) Successful in 5m13s
build-prerelease / Build neuron-blackwell (push) Successful in 6m23s
build-prerelease / Build neuron-ampere (push) Successful in 7m56s
CI / Test (push) Successful in 7m11s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m19s
build-prerelease / Build neuron-ada (push) Successful in 5m30s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m45s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 4m25s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
All checks were successful
CI / CUDA type-check (push) Successful in 32s
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / Format (push) Successful in 28s
CI / Clippy (push) Successful in 2m35s
build-prerelease / Build cortex binary (push) Successful in 5m13s
build-prerelease / Build neuron-blackwell (push) Successful in 6m23s
build-prerelease / Build neuron-ampere (push) Successful in 7m56s
CI / Test (push) Successful in 7m11s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m19s
build-prerelease / Build neuron-ada (push) Successful in 5m30s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m45s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 4m25s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
Stage A of the vision implementation plan (doc/vision-qwen3_6-spec.md). Builds the vision tower scaffolding that today's silent-drop failure mode (issue #3) needs — the Qwen3.6 ViT loads from `model.visual.*`, runs forward producing post-merger LM-side image embeddings, and routes through the device worker via a new `Job::EncodeImage`. No LM splice yet — that's Stage B. Refs #3 (umbrella). Deferred sub-stages tracked as #12 (TP-vision), #13 (27B production deploy), #14 (dynamic resolution), #15 (numerical validation). What landed: - **A0 — investigation**: pulled config.json, preprocessor_config.json, chat_template.jinja, and safetensors index from beast's local Qwen3.6-27B cache. Documented in doc/vision-qwen3_6-spec.md with exact tensor shapes for every `model.visual.*` weight. Confirms 27-block ViT with `hidden_size=1152`, `patch_size=16`, `spatial_merge_size=2`, `out_hidden_size=5120`. Vision tower lives in 2 of the 15 safetensors shards. - **A1 — deps + scaffolding**: added `image = "0.25"` (default- features off, PNG/JPEG/WebP/BMP/GIF) and `base64 = "0.22"` to crates/neuron/Cargo.toml. Created `harness::preprocess` and `harness::arch::qwen3_5::vision` modules. - **A2 — preprocess.rs**: `decode_data_uri` strips `data:image/...;base64,...` → image bytes → `image::DynamicImage` (rejecting `http(s)://` URLs to avoid SSRF/recursion); `preprocess` resizes to a fixed `PreprocessProfile::qwen3_6()` (448×448), normalises to `[-1, 1]` per the model's mean/std=0.5, emits row-major `(3, H, W)` f32. 9 unit tests covering data URI parse, decode failure paths, grayscale-to-RGB promotion, and the exact-value normalisation contract. - **A3 — vision.rs**: `VisionTower` struct with `patch_embed: Conv2d`, learned `pos_embed: Embedding`, 27 `VisionBlock`s (pre-LN + multi-head self-attention with fused QKV + GELU-tanh MLP + residuals), and `VisionMerger` (LayerNorm → 2×2 spatial concat → linear_fc1 → GELU-tanh → linear_fc2 to LM hidden_size). Includes the Conv3d→Conv2d fold trick documented at the top of the file — the published patch_embed.proj.weight is 5D `(1152, 3, 2, 16, 16)` but candle 0.10 has no Conv3d; for static images we sum-collapse the temporal axis. Video would need real Conv3d. 5 unit tests including the exact `gelu_pytorch_tanh` reference values from PyTorch. - **A4 — wire vision into Qwen3_5ForCausalLM**: extended `Config` with optional `vision_config: Option<VisionConfig>` and `image_token_id`; `Qwen3_5ForCausalLM::new` now loads the vision tower when present, exposes `has_vision()` and `vision()` so the HTTP layer can advertise capability and so the encode path can reach it. - **A5 — device worker `Job::EncodeImage`**: new job variant carrying CPU-side `(C, H, W)` pixels. Dispatch handler reconstructs the tensor on the worker's device, calls `arch.encode_image(image)`, copies the result back to CPU as flat `Vec<f32>`. Keeps the "tensors don't escape the worker" invariant. Poisoned-worker drain path handles the new variant. - **A6 — dispatch round-trip test**: `encode_image_routes_to_dispatch_ and_errors_on_unknown_handle` proves the channel/dispatch wiring works end-to-end via the CPU device worker (errors on unknown ArchHandle, which is the expected behaviour without a loaded model — real-weights validation happens in Stage B when the LM splice path exists). CI gate: cargo fmt --check, cargo clippy --workspace --all-targets -- -D warnings, cargo test --workspace (all 28 test groups ok, zero failures). New test counts: +9 in preprocess, +5 in vision, +1 in device_worker. Out of scope (deferred): - LM-side splice of image embeddings at `<|image_pad|>` positions → Stage B. - Streaming SSE for vision-bearing chat completions → Stage C. - Reject `image_url` with HTTP 400 for non-vision models / advertise `capabilities` in /v1/models → Stage C. - TP-vision (#12), 27B production deploy (#13), dynamic resolution (#14), numerical validation (#15). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -78,6 +78,7 @@ pub mod linear_attn;
|
||||
pub mod mlp;
|
||||
pub mod rmsnorm;
|
||||
pub mod rope;
|
||||
pub mod vision;
|
||||
|
||||
use decoder::Qwen3_5DecoderLayer;
|
||||
use rmsnorm::Qwen3_5RmsNorm;
|
||||
@@ -99,6 +100,20 @@ pub struct Config {
|
||||
pub model_type: String,
|
||||
/// The text-side hyperparameters. Everything we actually need.
|
||||
pub text_config: TextConfig,
|
||||
/// Vision tower hyperparameters. Present on multimodal
|
||||
/// checkpoints (e.g. Qwen/Qwen3.6-27B); absent on text-only
|
||||
/// variants. When present, `Qwen3_5ForCausalLM::new` loads the
|
||||
/// vision tower alongside the language model so vision-bearing
|
||||
/// requests can splice image embeddings at `<|image_pad|>` token
|
||||
/// positions.
|
||||
#[serde(default)]
|
||||
pub vision_config: Option<vision::VisionConfig>,
|
||||
/// Token id the chat template emits per image patch group.
|
||||
/// Mirrors the LM tokenizer's `<|image_pad|>` id (248056 for
|
||||
/// Qwen3.6). The runtime locates these in the prompt and splices
|
||||
/// in `VisionTower::forward` output. `None` for text-only models.
|
||||
#[serde(default)]
|
||||
pub image_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
|
||||
@@ -309,6 +324,15 @@ impl Qwen3_5Model {
|
||||
pub struct Qwen3_5ForCausalLM {
|
||||
base: Qwen3_5Model,
|
||||
lm_head: Linear,
|
||||
/// Vision tower (Stage A4). `None` for text-only checkpoints or
|
||||
/// when the operator has opted out. When present, the harness's
|
||||
/// `Job::EncodeImage` dispatch path runs `vision.forward(image)`
|
||||
/// and the LM forward (Stage B) splices the result at
|
||||
/// `image_token_id` positions in the input embedding stream.
|
||||
vision: Option<vision::VisionTower>,
|
||||
/// Mirrors `Config::image_token_id`. Cached here so the runtime
|
||||
/// doesn't have to round-trip through the parsed config struct.
|
||||
image_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
impl Qwen3_5ForCausalLM {
|
||||
@@ -324,7 +348,52 @@ impl Qwen3_5ForCausalLM {
|
||||
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
|
||||
Linear::new(weight, None)
|
||||
};
|
||||
Ok(Self { base, lm_head })
|
||||
// Stage A4: load the vision tower when the config carries a
|
||||
// `vision_config` block and the safetensors actually carry
|
||||
// `model.visual.*` weights. The `Option<VisionConfig>` on the
|
||||
// config makes this a single-source-of-truth decision —
|
||||
// text-only checkpoints just leave `vision_config` unset and
|
||||
// get `None` here without any extra plumbing.
|
||||
let vision = if let Some(vcfg) = config.vision_config.clone() {
|
||||
tracing::info!(
|
||||
depth = vcfg.depth,
|
||||
hidden_size = vcfg.hidden_size,
|
||||
"loading qwen3_5 vision tower"
|
||||
);
|
||||
Some(
|
||||
vision::VisionTower::load(vcfg, vb.pp("model.visual"))
|
||||
.context("load qwen3_5 vision tower (model.visual.*)")?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
base,
|
||||
lm_head,
|
||||
vision,
|
||||
image_token_id: config.image_token_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// True when this checkpoint loaded a vision tower. Used by the
|
||||
/// HTTP layer to advertise vision capability in `/v1/models` and
|
||||
/// to reject image-bearing requests against text-only loads with
|
||||
/// a clean 400.
|
||||
pub fn has_vision(&self) -> bool {
|
||||
self.vision.is_some()
|
||||
}
|
||||
|
||||
/// Vision tower handle, if loaded. The device-worker
|
||||
/// `EncodeImage` job dispatches to `vision.forward(image)`.
|
||||
pub fn vision(&self) -> Option<&vision::VisionTower> {
|
||||
self.vision.as_ref()
|
||||
}
|
||||
|
||||
/// `<|image_pad|>` token id from `config.json`, when known.
|
||||
/// The Stage B prompt-builder uses this to count expansion targets
|
||||
/// and the LM forward uses it to locate splice positions.
|
||||
pub fn image_token_id(&self) -> Option<u32> {
|
||||
self.image_token_id
|
||||
}
|
||||
|
||||
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
|
||||
|
||||
600
crates/neuron/src/harness/arch/qwen3_5/vision.rs
Normal file
600
crates/neuron/src/harness/arch/qwen3_5/vision.rs
Normal file
@@ -0,0 +1,600 @@
|
||||
//! Qwen3.6 vision tower.
|
||||
//!
|
||||
//! 27 pre-norm ViT blocks with **LayerNorm** (with biases — not the
|
||||
//! `(1+w)·x` RmsNorm the language model uses), fused QKV attention,
|
||||
//! GELU-tanh MLP. Followed by a `merger` that LayerNorms each
|
||||
//! 1152-dim vision token, spatially 2×2-merges them into 4608-dim
|
||||
//! groups, and projects to the LM's 5120-dim hidden via
|
||||
//! `linear_fc1 → GELU → linear_fc2`.
|
||||
//!
|
||||
//! Architecture spec sourced from beast's cached Qwen3.6-27B
|
||||
//! safetensors header (Stage A0, see
|
||||
//! `doc/vision-qwen3_6-spec.md`). All weight shapes confirmed
|
||||
//! from the live `.safetensors` headers, not inferred.
|
||||
//!
|
||||
//! **Conv3d wrinkle.** The published `patch_embed.proj.weight` is 5D
|
||||
//! `[1152, 3, 2, 16, 16]` — a 3D conv with kernel
|
||||
//! `(t=2, h=16, w=16)`. Candle 0.10 has no Conv3d. For static images
|
||||
//! we get away with a trick: when the temporal patch size is 2 and we
|
||||
//! duplicate the still image along the temporal axis (`T = 2`,
|
||||
//! frame_0 == frame_1), the Conv3d output equals a Conv2d run with
|
||||
//! the *sum* of the two temporal weight slices:
|
||||
//!
|
||||
//! ```text
|
||||
//! output = W_0 · frame_0 + W_1 · frame_1 + bias
|
||||
//! = (W_0 + W_1) · frame + bias (static image)
|
||||
//! ```
|
||||
//!
|
||||
//! So at load we sum-collapse the temporal axis and use a 4D
|
||||
//! `Conv2d` kernel. Video support would have to do the real Conv3d
|
||||
//! (different frames mean the trick fails) — tracked alongside the
|
||||
//! dynamic-resolution work in issue #14.
|
||||
//!
|
||||
//! Forward signature (Stage A — no LM splice yet):
|
||||
//!
|
||||
//! ```text
|
||||
//! fn forward(&self, image: &Tensor) -> Result<Tensor>
|
||||
//! ```
|
||||
//!
|
||||
//! `image` is `(3, H, W)` f32, normalised by `preprocess::preprocess`.
|
||||
//! Returns `(N_lm_tokens, out_hidden_size)` post-merger tokens ready
|
||||
//! to splice into the LM's input embeddings at `<|image_pad|>`
|
||||
//! positions. For Qwen3.6 at 448×448 → 28×28 patches → 14×14 = 196
|
||||
//! LM tokens of dim 5120.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{D, DType, Device, IndexOp, Module, Tensor};
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear};
|
||||
use serde::Deserialize;
|
||||
|
||||
/// Qwen3.6 vision tower hyperparameters. Mirrors the `vision_config`
|
||||
/// block of `config.json`. Only the fields we actually need are
|
||||
/// captured; serde tolerates the rest.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct VisionConfig {
|
||||
/// Number of ViT blocks (`depth: 27` for Qwen3.6).
|
||||
pub depth: usize,
|
||||
/// Vision-token dimension throughout the tower (1152 for Qwen3.6).
|
||||
pub hidden_size: usize,
|
||||
/// MLP intermediate dim (4304).
|
||||
pub intermediate_size: usize,
|
||||
/// Attention head count (16). `head_dim = hidden_size / num_heads`.
|
||||
pub num_heads: usize,
|
||||
/// Number of slots in the learned position embedding (2304).
|
||||
/// Caps the maximum image patch count.
|
||||
pub num_position_embeddings: usize,
|
||||
/// Spatial patch edge in pixels (16).
|
||||
pub patch_size: usize,
|
||||
/// Temporal kernel depth in the patch embed (2 for Qwen3.6 — we
|
||||
/// collapse this into a single Conv2d for static-image inference;
|
||||
/// see the module-level Conv3d wrinkle).
|
||||
pub temporal_patch_size: usize,
|
||||
/// Patches grouped per LM token by the merger (2 → 2×2 = 4
|
||||
/// patches per LM token).
|
||||
pub spatial_merge_size: usize,
|
||||
/// Vision input channels (3, RGB).
|
||||
pub in_channels: usize,
|
||||
/// Merger output dim — matches the LM's `hidden_size` (5120 for
|
||||
/// Qwen3.6). The merger projects from vision dim → LM dim.
|
||||
pub out_hidden_size: usize,
|
||||
}
|
||||
|
||||
const LAYER_NORM_EPS: f64 = 1e-6;
|
||||
/// Number of LM tokens emitted by the merger per vision-token group.
|
||||
const LM_TOKENS_PER_MERGE_GROUP: usize = 1;
|
||||
|
||||
/// One ViT block: pre-LN → attn → residual; pre-LN → MLP → residual.
|
||||
struct VisionBlock {
|
||||
norm1: LayerNorm,
|
||||
qkv: Linear,
|
||||
proj: Linear,
|
||||
norm2: LayerNorm,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
}
|
||||
|
||||
impl VisionBlock {
|
||||
fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let head_dim = h / cfg.num_heads;
|
||||
let norm1 = layer_norm(vb.pp("norm1"), h)?;
|
||||
let qkv = linear(vb.pp("attn.qkv"), h, 3 * h)?;
|
||||
let proj = linear(vb.pp("attn.proj"), h, h)?;
|
||||
let norm2 = layer_norm(vb.pp("norm2"), h)?;
|
||||
let fc1 = linear(vb.pp("mlp.linear_fc1"), h, cfg.intermediate_size)?;
|
||||
let fc2 = linear(vb.pp("mlp.linear_fc2"), cfg.intermediate_size, h)?;
|
||||
Ok(Self {
|
||||
norm1,
|
||||
qkv,
|
||||
proj,
|
||||
norm2,
|
||||
fc1,
|
||||
fc2,
|
||||
num_heads: cfg.num_heads,
|
||||
head_dim,
|
||||
})
|
||||
}
|
||||
|
||||
/// `x`: `(N, hidden_size)` un-batched. Returns same shape.
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let attn_in = self.norm1.forward(x)?;
|
||||
let attn_out = self.attention(&attn_in)?;
|
||||
let x = x.add(&attn_out)?;
|
||||
let mlp_in = self.norm2.forward(&x)?;
|
||||
let mlp_out = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&mlp_in)?)?)?;
|
||||
x.add(&mlp_out).map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Multi-head self-attention over the patch sequence. No causal
|
||||
/// mask — every patch attends to every other patch.
|
||||
fn attention(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (n, hidden) = x.dims2()?;
|
||||
// qkv: (N, 3*hidden). Split into Q, K, V each (N, hidden).
|
||||
let qkv = self.qkv.forward(x)?;
|
||||
let qkv = qkv.reshape((n, 3, self.num_heads, self.head_dim))?;
|
||||
// Transpose to (3, num_heads, N, head_dim) for per-head views.
|
||||
let qkv = qkv.permute((1, 2, 0, 3))?.contiguous()?;
|
||||
let q = qkv.i(0)?;
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||
// (num_heads, N, head_dim) @ (num_heads, head_dim, N) -> (num_heads, N, N)
|
||||
let scores = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
|
||||
let scores = (scores * scale)?;
|
||||
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||
// (num_heads, N, N) @ (num_heads, N, head_dim) -> (num_heads, N, head_dim)
|
||||
let out = probs.matmul(&v)?;
|
||||
// Merge heads back: (N, num_heads, head_dim) -> (N, hidden).
|
||||
let out = out.permute((1, 0, 2))?.contiguous()?.reshape((n, hidden))?;
|
||||
self.proj.forward(&out).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
/// `merger`: LayerNorm per token → spatial 2×2 merge (concat 4
|
||||
/// adjacent tokens into one 4608-dim vector) → fc1 → GELU-tanh →
|
||||
/// fc2. Output dim is the LM's hidden_size.
|
||||
struct VisionMerger {
|
||||
norm: LayerNorm,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
merge_input_dim: usize,
|
||||
spatial_merge_size: usize,
|
||||
}
|
||||
|
||||
impl VisionMerger {
|
||||
fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let merge = cfg.spatial_merge_size;
|
||||
let merge_input_dim = h * merge * merge;
|
||||
let norm = layer_norm(vb.pp("norm"), h)?;
|
||||
let fc1 = linear(vb.pp("linear_fc1"), merge_input_dim, merge_input_dim)?;
|
||||
let fc2 = linear(vb.pp("linear_fc2"), merge_input_dim, cfg.out_hidden_size)?;
|
||||
Ok(Self {
|
||||
norm,
|
||||
fc1,
|
||||
fc2,
|
||||
merge_input_dim,
|
||||
spatial_merge_size: merge,
|
||||
})
|
||||
}
|
||||
|
||||
/// `tokens`: `(grid_h, grid_w, hidden_size)`. The merger reshapes
|
||||
/// each `merge×merge` block of adjacent patches into a single
|
||||
/// concatenated vector, then projects.
|
||||
///
|
||||
/// `grid_h` and `grid_w` must both be multiples of
|
||||
/// `spatial_merge_size`. Returns
|
||||
/// `(grid_h/merge × grid_w/merge, out_hidden_size)`.
|
||||
fn forward(&self, tokens: &Tensor) -> Result<Tensor> {
|
||||
let (gh, gw, h) = tokens.dims3()?;
|
||||
let m = self.spatial_merge_size;
|
||||
anyhow::ensure!(
|
||||
gh.is_multiple_of(m) && gw.is_multiple_of(m),
|
||||
"merger expects spatial dims divisible by merge_size={m}; got ({gh}, {gw})"
|
||||
);
|
||||
let tokens = self.norm.forward(tokens)?;
|
||||
// (gh, gw, h) -> (gh/m, m, gw/m, m, h) -> (gh/m, gw/m, m, m, h)
|
||||
// -> flatten last three -> (gh/m, gw/m, m*m*h) -> (N_lm, merge_input_dim)
|
||||
let out_h = gh / m;
|
||||
let out_w = gw / m;
|
||||
let merged = tokens
|
||||
.reshape((out_h, m, out_w, m, h))?
|
||||
.permute((0, 2, 1, 3, 4))?
|
||||
.contiguous()?
|
||||
.reshape((out_h * out_w, self.merge_input_dim))?;
|
||||
let hidden = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&merged)?)?)?;
|
||||
Ok(hidden)
|
||||
}
|
||||
}
|
||||
|
||||
/// The vision tower itself.
|
||||
pub struct VisionTower {
|
||||
/// Sum-collapsed temporal kernel (Conv2d, see module doc).
|
||||
patch_embed: Conv2d,
|
||||
pos_embed: Embedding,
|
||||
blocks: Vec<VisionBlock>,
|
||||
merger: VisionMerger,
|
||||
config: VisionConfig,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl VisionTower {
|
||||
/// Load from a `ShardedVarBuilder` rooted at the safetensors
|
||||
/// `model.visual.` prefix. Caller is responsible for the `pp` —
|
||||
/// see `Qwen3_5ForCausalLM::new` (Stage A4).
|
||||
pub fn load(cfg: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
|
||||
let dtype = vb.dtype();
|
||||
let device = vb.device().clone();
|
||||
|
||||
// patch_embed.proj is published as 5D Conv3d weight; we
|
||||
// sum-collapse the temporal axis (size = temporal_patch_size)
|
||||
// to get a 4D Conv2d kernel. This is exact for the static-
|
||||
// image case where T = temporal_patch_size frames are
|
||||
// identical (i.e. the input was duplicated along T).
|
||||
let raw_weight = vb
|
||||
.pp("patch_embed.proj")
|
||||
.get(
|
||||
(
|
||||
cfg.hidden_size,
|
||||
cfg.in_channels,
|
||||
cfg.temporal_patch_size,
|
||||
cfg.patch_size,
|
||||
cfg.patch_size,
|
||||
),
|
||||
"weight",
|
||||
)
|
||||
.context("load model.visual.patch_embed.proj.weight (5D Conv3d kernel)")?;
|
||||
// Sum along the temporal axis (dim 2) — see module doc-comment.
|
||||
let folded = raw_weight.sum(2)?; // -> (hidden, in_channels, patch, patch)
|
||||
let proj_bias = vb
|
||||
.pp("patch_embed.proj")
|
||||
.get(cfg.hidden_size, "bias")
|
||||
.context("load model.visual.patch_embed.proj.bias")?;
|
||||
let conv_cfg = Conv2dConfig {
|
||||
stride: cfg.patch_size,
|
||||
..Default::default()
|
||||
};
|
||||
let patch_embed = Conv2d::new(folded, Some(proj_bias), conv_cfg);
|
||||
|
||||
let pos_embed_weight = vb
|
||||
.pp("pos_embed")
|
||||
.get((cfg.num_position_embeddings, cfg.hidden_size), "weight")
|
||||
.context("load model.visual.pos_embed.weight")?;
|
||||
let pos_embed = Embedding::new(pos_embed_weight, cfg.hidden_size);
|
||||
|
||||
let blocks_vb = vb.pp("blocks");
|
||||
let mut blocks = Vec::with_capacity(cfg.depth);
|
||||
for i in 0..cfg.depth {
|
||||
blocks.push(
|
||||
VisionBlock::load(&cfg, &blocks_vb.pp(i))
|
||||
.with_context(|| format!("load vision block {i}"))?,
|
||||
);
|
||||
}
|
||||
let merger = VisionMerger::load(&cfg, &vb.pp("merger")).context("load vision merger")?;
|
||||
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
pos_embed,
|
||||
blocks,
|
||||
merger,
|
||||
config: cfg,
|
||||
dtype,
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &VisionConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Number of LM tokens this tower emits for an `(H, W)` pixel
|
||||
/// image after the merger. Equal to
|
||||
/// `(H / patch_size / spatial_merge_size) * (W / patch_size / spatial_merge_size)`.
|
||||
pub fn lm_tokens_for(&self, h: u32, w: u32) -> usize {
|
||||
let m = self.config.spatial_merge_size;
|
||||
let patch = self.config.patch_size;
|
||||
let gh = (h as usize) / patch / m;
|
||||
let gw = (w as usize) / patch / m;
|
||||
gh * gw * LM_TOKENS_PER_MERGE_GROUP
|
||||
}
|
||||
|
||||
/// Encode one image.
|
||||
///
|
||||
/// `image`: row-major `(3, H, W)` f32 tensor on `self.device`,
|
||||
/// already normalised by `preprocess::preprocess`. Both `H` and
|
||||
/// `W` must be multiples of `patch_size * spatial_merge_size`.
|
||||
///
|
||||
/// Returns `(N_lm, out_hidden_size)` — LM-side image tokens
|
||||
/// ready to splice into the language model's input embeddings.
|
||||
pub fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let (c, h, w) = image.dims3()?;
|
||||
anyhow::ensure!(
|
||||
c == self.config.in_channels,
|
||||
"image must have {} channels, got {c}",
|
||||
self.config.in_channels
|
||||
);
|
||||
let patch = self.config.patch_size;
|
||||
anyhow::ensure!(
|
||||
h.is_multiple_of(patch) && w.is_multiple_of(patch),
|
||||
"image dims must be multiples of patch_size={patch}; got ({h}, {w})"
|
||||
);
|
||||
let gh = h / patch;
|
||||
let gw = w / patch;
|
||||
let n_patches = gh * gw;
|
||||
anyhow::ensure!(
|
||||
n_patches <= self.config.num_position_embeddings,
|
||||
"patch count {n_patches} exceeds pos_embed budget {}",
|
||||
self.config.num_position_embeddings
|
||||
);
|
||||
|
||||
// Add batch axis for conv: (1, 3, H, W) → (1, hidden, gh, gw)
|
||||
// → (hidden, gh, gw) → permute to (gh, gw, hidden) → flatten to (N, hidden)
|
||||
let x = image.unsqueeze(0)?.to_dtype(self.dtype)?;
|
||||
let x = self.patch_embed.forward(&x)?;
|
||||
let x = x.squeeze(0)?;
|
||||
let x = x.permute((1, 2, 0))?.contiguous()?;
|
||||
let x = x.reshape((n_patches, self.config.hidden_size))?;
|
||||
|
||||
// Add learned positional embeddings (sequential indices for
|
||||
// Stage A's fixed-resolution path; full 2D positional logic
|
||||
// lands with variable resolution, issue #14).
|
||||
let positions = Tensor::arange(0u32, n_patches as u32, &self.device)?;
|
||||
let pos = self.pos_embed.forward(&positions)?;
|
||||
let mut x = x.add(&pos)?;
|
||||
|
||||
for (i, block) in self.blocks.iter().enumerate() {
|
||||
x = block
|
||||
.forward(&x)
|
||||
.with_context(|| format!("vision block {i}"))?;
|
||||
}
|
||||
|
||||
// (n_patches, hidden) → (gh, gw, hidden) for the merger.
|
||||
let x = x.reshape((gh, gw, self.config.hidden_size))?;
|
||||
self.merger.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
/// Manually load a candle_nn LayerNorm from a ShardedVarBuilder.
|
||||
/// candle_nn's `layer_norm` builder takes `crate::VarBuilder`, not
|
||||
/// `ShardedVarBuilder`, so the existing arch modules in this crate
|
||||
/// uniformly do the manual load + struct construction pattern (see
|
||||
/// `full_attn::load_linear_no_bias`). We follow suit here.
|
||||
fn layer_norm(vb: ShardedVarBuilder, size: usize) -> Result<LayerNorm> {
|
||||
let weight = vb
|
||||
.get(size, "weight")
|
||||
.with_context(|| format!("load LayerNorm.weight at '{}'", vb.prefix()))?;
|
||||
let bias = vb
|
||||
.get(size, "bias")
|
||||
.with_context(|| format!("load LayerNorm.bias at '{}'", vb.prefix()))?;
|
||||
Ok(LayerNorm::new(weight, bias, LAYER_NORM_EPS))
|
||||
}
|
||||
|
||||
/// Manually load a candle_nn Linear (with bias) from a
|
||||
/// ShardedVarBuilder. Same rationale as `layer_norm` above.
|
||||
fn linear(vb: ShardedVarBuilder, in_dim: usize, out_dim: usize) -> Result<Linear> {
|
||||
let weight = vb
|
||||
.get((out_dim, in_dim), "weight")
|
||||
.with_context(|| format!("load Linear.weight at '{}'", vb.prefix()))?;
|
||||
let bias = vb
|
||||
.get(out_dim, "bias")
|
||||
.with_context(|| format!("load Linear.bias at '{}'", vb.prefix()))?;
|
||||
Ok(Linear::new(weight, Some(bias)))
|
||||
}
|
||||
|
||||
/// PyTorch's `gelu_pytorch_tanh` approximation — what the Qwen3.6
|
||||
/// vision tower's `hidden_act` specifies. candle's `Tensor::gelu`
|
||||
/// uses the exact erf-based GELU, so we compute the tanh
|
||||
/// approximation explicitly:
|
||||
///
|
||||
/// ```text
|
||||
/// gelu_tanh(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
/// ```
|
||||
fn gelu_tanh(x: &Tensor) -> Result<Tensor> {
|
||||
// sqrt(2 / pi) = 0.7978845608028654
|
||||
const COEFF: f64 = 0.7978845608028654;
|
||||
const KAPPA: f64 = 0.044715;
|
||||
let x3 = x.powf(3.0)?;
|
||||
let inner = (x + (x3 * KAPPA)?)?;
|
||||
let inner = (inner * COEFF)?;
|
||||
let t = inner.tanh()?;
|
||||
let one_plus_t = (t + 1.0)?;
|
||||
let out = (x * 0.5)?;
|
||||
let out = out.broadcast_mul(&one_plus_t)?;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle_core::{DType, Device};
|
||||
|
||||
/// Build a tiny VisionConfig usable on CPU with random weights.
|
||||
/// Match the Qwen3.6 shape relations (depth-N stack, hidden mod
|
||||
/// num_heads, intermediate_size > hidden_size) but with small
|
||||
/// dims so tests run in milliseconds.
|
||||
fn tiny_config() -> VisionConfig {
|
||||
VisionConfig {
|
||||
depth: 2,
|
||||
hidden_size: 32,
|
||||
intermediate_size: 64,
|
||||
num_heads: 4,
|
||||
num_position_embeddings: 64,
|
||||
patch_size: 4,
|
||||
temporal_patch_size: 2,
|
||||
spatial_merge_size: 2,
|
||||
in_channels: 3,
|
||||
out_hidden_size: 48,
|
||||
}
|
||||
}
|
||||
|
||||
/// Hand-construct a VisionTower with random weights. This is the
|
||||
/// same trick `linear_attn::tests::forward_smoke_with_tiny_dimensions`
|
||||
/// uses — bypass the safetensors-backed `ShardedVarBuilder` path
|
||||
/// (which can't be built from in-memory tensors) and assemble the
|
||||
/// struct fields directly. The real `VisionTower::load` is
|
||||
/// exercised by the cuda-integration smoke test in Stage A6.
|
||||
fn tiny_tower(cfg: &VisionConfig) -> VisionTower {
|
||||
let device = Device::Cpu;
|
||||
let dtype = DType::F32;
|
||||
let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &device).unwrap();
|
||||
let ones = |shape: &[usize]| Tensor::ones(shape, dtype, &device).unwrap();
|
||||
let randn = |shape: &[usize]| Tensor::randn(0_f32, 0.02, shape, &device).unwrap();
|
||||
|
||||
let patch_embed = Conv2d::new(
|
||||
randn(&[
|
||||
cfg.hidden_size,
|
||||
cfg.in_channels,
|
||||
cfg.patch_size,
|
||||
cfg.patch_size,
|
||||
]),
|
||||
Some(zeros(&[cfg.hidden_size])),
|
||||
Conv2dConfig {
|
||||
stride: cfg.patch_size,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
let pos_embed = Embedding::new(
|
||||
randn(&[cfg.num_position_embeddings, cfg.hidden_size]),
|
||||
cfg.hidden_size,
|
||||
);
|
||||
|
||||
let mut blocks = Vec::with_capacity(cfg.depth);
|
||||
for _ in 0..cfg.depth {
|
||||
let head_dim = cfg.hidden_size / cfg.num_heads;
|
||||
blocks.push(VisionBlock {
|
||||
norm1: LayerNorm::new(
|
||||
ones(&[cfg.hidden_size]),
|
||||
zeros(&[cfg.hidden_size]),
|
||||
LAYER_NORM_EPS,
|
||||
),
|
||||
qkv: Linear::new(
|
||||
randn(&[3 * cfg.hidden_size, cfg.hidden_size]),
|
||||
Some(zeros(&[3 * cfg.hidden_size])),
|
||||
),
|
||||
proj: Linear::new(
|
||||
randn(&[cfg.hidden_size, cfg.hidden_size]),
|
||||
Some(zeros(&[cfg.hidden_size])),
|
||||
),
|
||||
norm2: LayerNorm::new(
|
||||
ones(&[cfg.hidden_size]),
|
||||
zeros(&[cfg.hidden_size]),
|
||||
LAYER_NORM_EPS,
|
||||
),
|
||||
fc1: Linear::new(
|
||||
randn(&[cfg.intermediate_size, cfg.hidden_size]),
|
||||
Some(zeros(&[cfg.intermediate_size])),
|
||||
),
|
||||
fc2: Linear::new(
|
||||
randn(&[cfg.hidden_size, cfg.intermediate_size]),
|
||||
Some(zeros(&[cfg.hidden_size])),
|
||||
),
|
||||
num_heads: cfg.num_heads,
|
||||
head_dim,
|
||||
});
|
||||
}
|
||||
|
||||
let merge_input_dim = cfg.hidden_size * cfg.spatial_merge_size * cfg.spatial_merge_size;
|
||||
let merger = VisionMerger {
|
||||
norm: LayerNorm::new(
|
||||
ones(&[cfg.hidden_size]),
|
||||
zeros(&[cfg.hidden_size]),
|
||||
LAYER_NORM_EPS,
|
||||
),
|
||||
fc1: Linear::new(
|
||||
randn(&[merge_input_dim, merge_input_dim]),
|
||||
Some(zeros(&[merge_input_dim])),
|
||||
),
|
||||
fc2: Linear::new(
|
||||
randn(&[cfg.out_hidden_size, merge_input_dim]),
|
||||
Some(zeros(&[cfg.out_hidden_size])),
|
||||
),
|
||||
merge_input_dim,
|
||||
spatial_merge_size: cfg.spatial_merge_size,
|
||||
};
|
||||
|
||||
VisionTower {
|
||||
patch_embed,
|
||||
pos_embed,
|
||||
blocks,
|
||||
merger,
|
||||
config: cfg.clone(),
|
||||
dtype,
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_with_random_weights_produces_finite_output() {
|
||||
let cfg = tiny_config();
|
||||
let tower = tiny_tower(&cfg);
|
||||
|
||||
// 16×16 image at patch_size=4 → 4×4 patches → after 2×2
|
||||
// merge → 2×2 = 4 LM tokens of dim out_hidden_size.
|
||||
let image = Tensor::randn(0_f32, 1.0, (3, 16, 16), &Device::Cpu).unwrap();
|
||||
let out = tower.forward(&image).expect("forward");
|
||||
let (n_lm, hidden) = out.dims2().unwrap();
|
||||
assert_eq!(n_lm, 4);
|
||||
assert_eq!(hidden, cfg.out_hidden_size);
|
||||
|
||||
// No NaN/Inf
|
||||
let values: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
|
||||
assert!(
|
||||
values.iter().all(|v| v.is_finite()),
|
||||
"forward must produce finite values"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lm_token_count_matches_grid() {
|
||||
let cfg = tiny_config();
|
||||
let tower = tiny_tower(&cfg);
|
||||
// 16x16 image → 4x4 patches → 2x2 = 4 LM tokens
|
||||
assert_eq!(tower.lm_tokens_for(16, 16), 4);
|
||||
// 32x32 image → 8x8 patches → 4x4 = 16 LM tokens
|
||||
assert_eq!(tower.lm_tokens_for(32, 32), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_image_with_dims_not_multiple_of_patch() {
|
||||
let cfg = tiny_config();
|
||||
let tower = tiny_tower(&cfg);
|
||||
let image = Tensor::randn(0_f32, 1.0, (3, 17, 17), &Device::Cpu).unwrap();
|
||||
let err = tower.forward(&image).unwrap_err();
|
||||
assert!(format!("{err:#}").contains("patch_size"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_image_with_wrong_channel_count() {
|
||||
let cfg = tiny_config();
|
||||
let tower = tiny_tower(&cfg);
|
||||
let image = Tensor::randn(0_f32, 1.0, (4, 16, 16), &Device::Cpu).unwrap();
|
||||
let err = tower.forward(&image).unwrap_err();
|
||||
assert!(format!("{err:#}").contains("channels"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gelu_tanh_matches_known_values() {
|
||||
// Reference values for gelu_pytorch_tanh from PyTorch:
|
||||
// gelu_tanh(0.0) = 0.0
|
||||
// gelu_tanh(1.0) ≈ 0.8411920071
|
||||
// gelu_tanh(-1.0) ≈ -0.1588079929
|
||||
let x = Tensor::new(&[0.0_f32, 1.0, -1.0], &Device::Cpu).unwrap();
|
||||
let y = gelu_tanh(&x).unwrap();
|
||||
let v: Vec<f32> = y.to_vec1().unwrap();
|
||||
assert!((v[0]).abs() < 1e-6, "gelu_tanh(0) ≈ 0, got {}", v[0]);
|
||||
assert!(
|
||||
(v[1] - 0.841_192_f32).abs() < 1e-5,
|
||||
"gelu_tanh(1) ≈ 0.84119, got {}",
|
||||
v[1]
|
||||
);
|
||||
assert!(
|
||||
(v[2] - -0.158_808_f32).abs() < 1e-5,
|
||||
"gelu_tanh(-1) ≈ -0.15881, got {}",
|
||||
v[2]
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -332,6 +332,37 @@ impl ModelArch {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode a preprocessed image into LM-side token embeddings via
|
||||
/// the loaded vision tower. Stage A5.
|
||||
///
|
||||
/// `image`: device-resident `(C, H, W)` f32 tensor — caller has
|
||||
/// already preprocessed via `harness::preprocess::preprocess` and
|
||||
/// uploaded to the worker's device. Returns
|
||||
/// `(N_lm_tokens, hidden_size)`.
|
||||
///
|
||||
/// Errors when the loaded architecture has no vision tower
|
||||
/// (text-only checkpoint, or architecture that doesn't support
|
||||
/// vision at all). The HTTP layer maps this to a 400 with
|
||||
/// `vision_unsupported` so clients see a clean rejection rather
|
||||
/// than a confident text-only hallucination.
|
||||
pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
ModelArch::Qwen3_5Dense(m) => m
|
||||
.vision()
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"encode_image: this Qwen3.6 checkpoint was loaded without a vision \
|
||||
tower (config.json::vision_config absent or weights missing)"
|
||||
)
|
||||
})?
|
||||
.forward(image),
|
||||
other => anyhow::bail!(
|
||||
"encode_image: architecture {} has no vision tower",
|
||||
std::any::type_name_of_val(other)
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Squeeze any leading singleton dims off the logits tensor so the
|
||||
|
||||
@@ -158,6 +158,17 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
||||
let result = forward_logits(&mut state, handle, &tokens, offset);
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
Job::EncodeImage {
|
||||
handle,
|
||||
pixels,
|
||||
c,
|
||||
h,
|
||||
w,
|
||||
reply,
|
||||
} => {
|
||||
let result = encode_image(&mut state, handle, pixels, c, h, w);
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
Job::NcclInit {
|
||||
cfg,
|
||||
comm_id_hex,
|
||||
@@ -740,6 +751,49 @@ fn forward_logits(
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
/// Run the vision tower on a single preprocessed image. Stage A5.
|
||||
///
|
||||
/// `pixels` is a row-major `(c, h, w)` f32 image that the async-side
|
||||
/// `harness::preprocess` produced. We reconstruct the tensor on the
|
||||
/// worker's device (the same device the model was loaded against),
|
||||
/// call `arch.encode_image`, and copy the resulting
|
||||
/// `(N_lm_tokens, hidden_size)` embedding back to CPU f32.
|
||||
///
|
||||
/// Returns the flattened embedding as a `Vec<f32>` — the caller knows
|
||||
/// the LM-side token count from `VisionTower::lm_tokens_for(h, w)`
|
||||
/// and reshapes accordingly. Stage B introduces a device-resident
|
||||
/// embedding-slab variant that avoids this round-trip when the next
|
||||
/// forward call needs the result.
|
||||
fn encode_image(
|
||||
state: &mut DeviceWorkerState,
|
||||
handle: ArchHandle,
|
||||
pixels: Vec<f32>,
|
||||
c: usize,
|
||||
h: usize,
|
||||
w: usize,
|
||||
) -> anyhow::Result<Vec<f32>> {
|
||||
use candle_core::{DType, Tensor};
|
||||
|
||||
anyhow::ensure!(
|
||||
pixels.len() == c * h * w,
|
||||
"EncodeImage: pixels length {} does not match shape ({c}, {h}, {w})",
|
||||
pixels.len()
|
||||
);
|
||||
let image = Tensor::from_vec(pixels, (c, h, w), &state.device)?;
|
||||
|
||||
let arch = state
|
||||
.models
|
||||
.get(&handle)
|
||||
.ok_or_else(|| anyhow::anyhow!("EncodeImage: no model for handle {}", handle.0))?;
|
||||
|
||||
let embed = arch.encode_image(&image)?;
|
||||
let values = embed
|
||||
.to_dtype(DType::F32)?
|
||||
.flatten_all()?
|
||||
.to_vec1::<f32>()?;
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
/// Reply to a job with the poisoned-worker error. Used when the worker
|
||||
/// has flipped into drain-only mode after a CUDA driver error.
|
||||
///
|
||||
@@ -773,6 +827,9 @@ fn drain_poisoned(job: Job, device_index: u32) {
|
||||
Job::ForwardLogits { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::EncodeImage { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::NcclInit { reply, .. } => {
|
||||
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
||||
kind: "device_worker_poisoned".into(),
|
||||
|
||||
@@ -94,6 +94,31 @@ pub enum Job {
|
||||
offset: usize,
|
||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||
},
|
||||
/// Encode one image through the model's vision tower. Stage A5 of
|
||||
/// the vision plan (`doc/vision-qwen3_6-spec.md`).
|
||||
///
|
||||
/// `pixels` is the CPU-side preprocessed image tensor in row-major
|
||||
/// `(C, H, W)` f32 layout — what `harness::preprocess::preprocess`
|
||||
/// produces. `c`, `h`, `w` carry the shape since `Vec<f32>` itself
|
||||
/// is rank-1. The handler reconstructs the tensor on the worker's
|
||||
/// device, runs `VisionTower::forward`, copies the resulting
|
||||
/// `(N_lm_tokens, hidden_size)` embedding back to CPU as a flat
|
||||
/// `Vec<f32>` (the caller knows the expected shape from
|
||||
/// `VisionTower::lm_tokens_for(h, w) * hidden_size`).
|
||||
///
|
||||
/// Mirrors the `ForwardLogits` "tensors don't escape" invariant —
|
||||
/// device-side image embeddings are dropped at handler return.
|
||||
/// Stage B will introduce a follow-up variant that keeps the
|
||||
/// embeddings device-resident and references them from the next
|
||||
/// `ForwardLogits` call, avoiding the round-trip copy.
|
||||
EncodeImage {
|
||||
handle: ArchHandle,
|
||||
pixels: Vec<f32>,
|
||||
c: usize,
|
||||
h: usize,
|
||||
w: usize,
|
||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||
},
|
||||
/// Initialize the leader's NCCL communicator. The worker's
|
||||
/// `NcclState` mints the `Comm` here so its underlying
|
||||
/// `ncclComm_t` and `CudaContext` live on the same thread as
|
||||
|
||||
@@ -313,6 +313,49 @@ impl DeviceWorkerHandle {
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode a preprocessed image through the model's vision tower
|
||||
/// and return the resulting LM-side image embeddings as a
|
||||
/// flattened CPU `Vec<f32>`. Stage A5.
|
||||
///
|
||||
/// `pixels` is the row-major `(c, h, w)` f32 image —
|
||||
/// `harness::preprocess::preprocess` produces this exact shape.
|
||||
/// The caller knows the expected output length from
|
||||
/// `VisionTower::lm_tokens_for(h, w) * hidden_size` and reshapes
|
||||
/// accordingly.
|
||||
pub async fn encode_image(
|
||||
&self,
|
||||
handle: ArchHandle,
|
||||
pixels: Vec<f32>,
|
||||
c: usize,
|
||||
h: usize,
|
||||
w: usize,
|
||||
) -> Result<Vec<f32>, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::EncodeImage {
|
||||
handle,
|
||||
pixels,
|
||||
c,
|
||||
h,
|
||||
w,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialise the leader's NCCL communicator. The reply uses
|
||||
/// `WorkerResponse` (same shape subprocess workers use over stdio
|
||||
/// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader +
|
||||
@@ -569,6 +612,37 @@ mod tests {
|
||||
handle.shutdown().expect("shutdown ok");
|
||||
}
|
||||
|
||||
/// Stage A5: confirm the EncodeImage job round-trips through the
|
||||
/// worker channel. We don't have a real loaded model in the slab
|
||||
/// here, so the dispatch handler returns the
|
||||
/// "no model for handle" error — which is exactly what we want to
|
||||
/// see, since it proves the message routed through the channel
|
||||
/// and reached the handler. Real-weights validation lives in the
|
||||
/// Stage A7 / Stage B post-deploy smoke on beast.
|
||||
#[tokio::test]
|
||||
async fn encode_image_routes_to_dispatch_and_errors_on_unknown_handle() {
|
||||
use crate::harness::device_worker::jobs::ArchHandle;
|
||||
|
||||
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||
let fake_arch = ArchHandle(99_999);
|
||||
// (3, 4, 4) fake image — minimal payload, gets reconstructed
|
||||
// on the worker before the handler errors out on the unknown
|
||||
// ArchHandle lookup.
|
||||
let pixels = vec![0.0_f32; 3 * 4 * 4];
|
||||
let result = handle.encode_image(fake_arch, pixels, 3, 4, 4).await;
|
||||
match result {
|
||||
Err(WorkerError::Job(e)) => {
|
||||
let msg = format!("{e:#}");
|
||||
assert!(
|
||||
msg.contains("EncodeImage: no model for handle"),
|
||||
"expected unknown-handle error, got: {msg}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected Job(Err), got {other:?}"),
|
||||
}
|
||||
handle.shutdown().expect("shutdown ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_drains_pending_jobs() {
|
||||
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||
|
||||
@@ -5,6 +5,7 @@ pub mod candle;
|
||||
pub mod chat_template;
|
||||
pub mod device_worker;
|
||||
pub mod preflight;
|
||||
pub mod preprocess;
|
||||
pub mod tp;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
255
crates/neuron/src/harness/preprocess.rs
Normal file
255
crates/neuron/src/harness/preprocess.rs
Normal file
@@ -0,0 +1,255 @@
|
||||
//! Image preprocessing for vision-capable models.
|
||||
//!
|
||||
//! Decodes `data:image/...;base64,...` URIs from OpenAI-style
|
||||
//! `image_url` content parts into the patch tensors a candle vision
|
||||
//! tower expects. Stage A ships **fixed resolution** — every image
|
||||
//! is resized to the same target dimensions (default 448×448 for
|
||||
//! Qwen3.6, configurable per-call) so the patch count is constant
|
||||
//! per image. Variable resolution per [Qwen2VL convention] is tracked
|
||||
//! as issue #14.
|
||||
//!
|
||||
//! Spec reference: `doc/vision-qwen3_6-spec.md` — preprocessor
|
||||
//! section.
|
||||
//!
|
||||
//! Normalisation: pixel value `p ∈ [0, 255]` becomes
|
||||
//! `(p/255 - mean) / std`. Qwen3.6's preprocessor_config.json
|
||||
//! specifies `image_mean = image_std = [0.5, 0.5, 0.5]`, which
|
||||
//! simplifies to `2p/255 - 1` mapping `[0,255]` → `[-1, 1]`. We
|
||||
//! still parameterise mean/std so the same code generalises to other
|
||||
//! VL families (Qwen2-VL uses imagenet stats, for instance).
|
||||
//!
|
||||
//! Pipeline (per image):
|
||||
//! 1. data: URI → base64 decode → bytes
|
||||
//! 2. bytes → image::DynamicImage (PNG/JPEG/WebP/etc)
|
||||
//! 3. resize_exact to target H×W (pixel space)
|
||||
//! 4. RGB→f32, normalise per mean/std
|
||||
//! 5. layout to (C, H, W) tensor
|
||||
//!
|
||||
//! Patchification (cutting the HxW tensor into `patch_size` blocks)
|
||||
//! happens inside the vision tower's `patch_embed` conv, so this
|
||||
//! module stops at "preprocessed RGB f32 tensor."
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use base64::Engine;
|
||||
use image::DynamicImage;
|
||||
use image::imageops::FilterType;
|
||||
|
||||
/// Preprocessing target. Captures the resize dimensions and the
|
||||
/// channel-wise normalisation constants from the model's
|
||||
/// `preprocessor_config.json`. Stage A ships a single `qwen3_6()`
|
||||
/// constructor for fixed-resolution Qwen3.6 preprocessing; other
|
||||
/// models can ship their own profile when added.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PreprocessProfile {
|
||||
pub target_height: u32,
|
||||
pub target_width: u32,
|
||||
pub image_mean: [f32; 3],
|
||||
pub image_std: [f32; 3],
|
||||
}
|
||||
|
||||
impl PreprocessProfile {
|
||||
/// Stage A profile for Qwen3.6. Resize to 448×448, normalise to
|
||||
/// `[-1, 1]` via mean=std=0.5. Fits within the model's
|
||||
/// `num_position_embeddings=2304` budget at 28×28 = 784 patches
|
||||
/// before merging.
|
||||
pub fn qwen3_6() -> Self {
|
||||
Self {
|
||||
target_height: 448,
|
||||
target_width: 448,
|
||||
image_mean: [0.5, 0.5, 0.5],
|
||||
image_std: [0.5, 0.5, 0.5],
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-channel CHW tensor length: 3 * H * W.
|
||||
pub fn pixels_chw(&self) -> usize {
|
||||
3 * (self.target_height as usize) * (self.target_width as usize)
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode a `data:image/...;base64,...` URI into an in-memory image.
|
||||
///
|
||||
/// Accepts the OpenAI Chat Completions `image_url` shape — a string
|
||||
/// URL with `data:` scheme and base64 payload. The MIME type is read
|
||||
/// from the URI for diagnostics but `image::load_from_memory` sniffs
|
||||
/// the format from the bytes themselves, so the MIME is advisory.
|
||||
///
|
||||
/// Bare `http(s)://` URLs are explicitly rejected here — fetching
|
||||
/// them from a vision-model server is a fingerprintable behaviour
|
||||
/// (server-side request forgery, infinite recursion if the URL
|
||||
/// points at the gateway itself, etc.). Clients that want remote
|
||||
/// images can fetch them and pass base64 themselves.
|
||||
pub fn decode_data_uri(uri: &str) -> Result<DynamicImage> {
|
||||
let after_scheme = uri
|
||||
.strip_prefix("data:")
|
||||
.ok_or_else(|| anyhow!("image_url must use data: scheme; got {uri:.40}…"))?;
|
||||
let (meta, payload) = after_scheme
|
||||
.split_once(',')
|
||||
.ok_or_else(|| anyhow!("malformed data URI: missing ',' separator"))?;
|
||||
if !meta.contains(";base64") {
|
||||
anyhow::bail!(
|
||||
"data URI must use base64 encoding (got '{meta}'); raw URL-encoded payloads not supported"
|
||||
);
|
||||
}
|
||||
let bytes = base64::engine::general_purpose::STANDARD
|
||||
.decode(payload.trim())
|
||||
.context("base64-decode image data URI payload")?;
|
||||
image::load_from_memory(&bytes).context("decode image bytes (PNG/JPEG/WebP/etc)")
|
||||
}
|
||||
|
||||
/// Resize and normalise an image into a `(3, H, W)` row-major
|
||||
/// `Vec<f32>` ready to hand to the vision tower's `patch_embed`
|
||||
/// conv.
|
||||
///
|
||||
/// Uses bilinear resampling — Qwen2-VL's reference uses bicubic, but
|
||||
/// bilinear is what the candle ecosystem standardises on and is
|
||||
/// faster on CPU. Quality difference is marginal for downstream
|
||||
/// vision-encoder consumption. The numerical-validation issue (#15)
|
||||
/// will quantify any discrepancy.
|
||||
pub fn preprocess(img: &DynamicImage, profile: &PreprocessProfile) -> Vec<f32> {
|
||||
let rgb = img
|
||||
.resize_exact(
|
||||
profile.target_width,
|
||||
profile.target_height,
|
||||
FilterType::Triangle,
|
||||
)
|
||||
.to_rgb8();
|
||||
let h = profile.target_height as usize;
|
||||
let w = profile.target_width as usize;
|
||||
let mut out = vec![0.0_f32; 3 * h * w];
|
||||
// Row-major (C, H, W). Candle's Conv2d expects NCHW, so this is
|
||||
// the natural layout — the caller stacks `n` of these along the
|
||||
// batch axis as needed.
|
||||
for c in 0..3 {
|
||||
let mean = profile.image_mean[c];
|
||||
let std = profile.image_std[c];
|
||||
for y in 0..h {
|
||||
for x in 0..w {
|
||||
let pixel = rgb.get_pixel(x as u32, y as u32);
|
||||
let raw = pixel[c] as f32 / 255.0;
|
||||
out[c * h * w + y * w + x] = (raw - mean) / std;
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Combined helper: decode + preprocess in one call. Most call
|
||||
/// sites just want the final tensor; the two-step path exists for
|
||||
/// callers (tests, future video preprocessing) that need the
|
||||
/// intermediate `DynamicImage`.
|
||||
pub fn preprocess_data_uri(uri: &str, profile: &PreprocessProfile) -> Result<Vec<f32>> {
|
||||
let img = decode_data_uri(uri)?;
|
||||
Ok(preprocess(&img, profile))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use image::{ImageBuffer, Rgb};
|
||||
|
||||
/// A 1×1 red PNG, hand-built. Matches the well-known smallest
|
||||
/// valid PNG we use in tests/curl examples elsewhere.
|
||||
const ONE_BY_ONE_RED_PNG_B64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==";
|
||||
|
||||
fn red_png_uri() -> String {
|
||||
format!("data:image/png;base64,{ONE_BY_ONE_RED_PNG_B64}")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decodes_well_formed_png_data_uri() {
|
||||
let img = decode_data_uri(&red_png_uri()).expect("decode 1x1 png");
|
||||
assert_eq!(img.width(), 1);
|
||||
assert_eq!(img.height(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_non_data_scheme() {
|
||||
let err = decode_data_uri("https://example.com/cat.jpg")
|
||||
.expect_err("http(s) URLs must be rejected");
|
||||
assert!(format!("{err:#}").contains("data:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_malformed_uri_without_comma() {
|
||||
let err = decode_data_uri("data:image/png;base64").unwrap_err();
|
||||
assert!(format!("{err:#}").contains("','"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_non_base64_payload() {
|
||||
let err = decode_data_uri("data:image/png,raw-bytes-here").unwrap_err();
|
||||
assert!(format!("{err:#}").contains("base64"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_bad_base64_payload() {
|
||||
let err = decode_data_uri("data:image/png;base64,not!valid!base64!").unwrap_err();
|
||||
assert!(format!("{err:#}").contains("base64"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_garbage_image_bytes() {
|
||||
// Valid base64 ("Hello World!"), invalid image bytes.
|
||||
let err = decode_data_uri("data:image/png;base64,SGVsbG8gV29ybGQh").unwrap_err();
|
||||
assert!(
|
||||
format!("{err:#}").contains("decode image"),
|
||||
"should fail at image-decode step"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preprocess_red_image_produces_correct_shape_and_values() {
|
||||
let profile = PreprocessProfile::qwen3_6();
|
||||
// Build a tiny pure-red image directly, skipping data: URI
|
||||
// decoding so this test isolates the resize+normalise path.
|
||||
let img: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_pixel(2, 2, Rgb([255, 0, 0]));
|
||||
let dyn_img = DynamicImage::ImageRgb8(img);
|
||||
let out = preprocess(&dyn_img, &profile);
|
||||
|
||||
assert_eq!(out.len(), profile.pixels_chw());
|
||||
// After mean=0.5, std=0.5: red channel (255/255=1.0) → (1.0 - 0.5)/0.5 = 1.0
|
||||
// green/blue (0.0) → (0.0 - 0.5)/0.5 = -1.0
|
||||
let h = profile.target_height as usize;
|
||||
let w = profile.target_width as usize;
|
||||
assert!(
|
||||
(out[0] - 1.0).abs() < 1e-5,
|
||||
"R[0] should be 1.0, got {}",
|
||||
out[0]
|
||||
);
|
||||
assert!((out[h * w] - (-1.0)).abs() < 1e-5, "G[0] should be -1.0");
|
||||
assert!(
|
||||
(out[2 * h * w] - (-1.0)).abs() < 1e-5,
|
||||
"B[0] should be -1.0"
|
||||
);
|
||||
// All values are finite
|
||||
assert!(out.iter().all(|v| v.is_finite()), "no NaN/Inf in output");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preprocess_data_uri_end_to_end() {
|
||||
let profile = PreprocessProfile::qwen3_6();
|
||||
let out = preprocess_data_uri(&red_png_uri(), &profile).expect("e2e preprocess");
|
||||
assert_eq!(out.len(), profile.pixels_chw());
|
||||
assert!(out.iter().all(|v| v.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preprocess_grayscale_image_promotes_to_rgb() {
|
||||
let profile = PreprocessProfile::qwen3_6();
|
||||
// 1x1 grayscale = 200 → after conversion to RGB, all three
|
||||
// channels equal 200, normalised → (200/255 - 0.5)/0.5 ≈ 0.569
|
||||
let gray = DynamicImage::ImageLuma8(ImageBuffer::from_pixel(1, 1, image::Luma([200])));
|
||||
let out = preprocess(&gray, &profile);
|
||||
let expected = ((200.0 / 255.0) - 0.5) / 0.5;
|
||||
let h = profile.target_height as usize;
|
||||
let w = profile.target_width as usize;
|
||||
for c in 0..3 {
|
||||
let v = out[c * h * w];
|
||||
assert!(
|
||||
(v - expected).abs() < 1e-3,
|
||||
"channel {c}: expected {expected}, got {v}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user