feat(neuron): Stage A — vision tower load + preprocessor for Qwen3.6
All checks were successful
CI / CUDA type-check (push) Successful in 32s
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / Format (push) Successful in 28s
CI / Clippy (push) Successful in 2m35s
build-prerelease / Build cortex binary (push) Successful in 5m13s
build-prerelease / Build neuron-blackwell (push) Successful in 6m23s
build-prerelease / Build neuron-ampere (push) Successful in 7m56s
CI / Test (push) Successful in 7m11s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m19s
build-prerelease / Build neuron-ada (push) Successful in 5m30s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m45s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 4m25s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
All checks were successful
CI / CUDA type-check (push) Successful in 32s
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / Format (push) Successful in 28s
CI / Clippy (push) Successful in 2m35s
build-prerelease / Build cortex binary (push) Successful in 5m13s
build-prerelease / Build neuron-blackwell (push) Successful in 6m23s
build-prerelease / Build neuron-ampere (push) Successful in 7m56s
CI / Test (push) Successful in 7m11s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m19s
build-prerelease / Build neuron-ada (push) Successful in 5m30s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m45s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 4m25s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
Stage A of the vision implementation plan (doc/vision-qwen3_6-spec.md). Builds the vision tower scaffolding that today's silent-drop failure mode (issue #3) needs — the Qwen3.6 ViT loads from `model.visual.*`, runs forward producing post-merger LM-side image embeddings, and routes through the device worker via a new `Job::EncodeImage`. No LM splice yet — that's Stage B. Refs #3 (umbrella). Deferred sub-stages tracked as #12 (TP-vision), #13 (27B production deploy), #14 (dynamic resolution), #15 (numerical validation). What landed: - **A0 — investigation**: pulled config.json, preprocessor_config.json, chat_template.jinja, and safetensors index from beast's local Qwen3.6-27B cache. Documented in doc/vision-qwen3_6-spec.md with exact tensor shapes for every `model.visual.*` weight. Confirms 27-block ViT with `hidden_size=1152`, `patch_size=16`, `spatial_merge_size=2`, `out_hidden_size=5120`. Vision tower lives in 2 of the 15 safetensors shards. - **A1 — deps + scaffolding**: added `image = "0.25"` (default- features off, PNG/JPEG/WebP/BMP/GIF) and `base64 = "0.22"` to crates/neuron/Cargo.toml. Created `harness::preprocess` and `harness::arch::qwen3_5::vision` modules. - **A2 — preprocess.rs**: `decode_data_uri` strips `data:image/...;base64,...` → image bytes → `image::DynamicImage` (rejecting `http(s)://` URLs to avoid SSRF/recursion); `preprocess` resizes to a fixed `PreprocessProfile::qwen3_6()` (448×448), normalises to `[-1, 1]` per the model's mean/std=0.5, emits row-major `(3, H, W)` f32. 9 unit tests covering data URI parse, decode failure paths, grayscale-to-RGB promotion, and the exact-value normalisation contract. - **A3 — vision.rs**: `VisionTower` struct with `patch_embed: Conv2d`, learned `pos_embed: Embedding`, 27 `VisionBlock`s (pre-LN + multi-head self-attention with fused QKV + GELU-tanh MLP + residuals), and `VisionMerger` (LayerNorm → 2×2 spatial concat → linear_fc1 → GELU-tanh → linear_fc2 to LM hidden_size). Includes the Conv3d→Conv2d fold trick documented at the top of the file — the published patch_embed.proj.weight is 5D `(1152, 3, 2, 16, 16)` but candle 0.10 has no Conv3d; for static images we sum-collapse the temporal axis. Video would need real Conv3d. 5 unit tests including the exact `gelu_pytorch_tanh` reference values from PyTorch. - **A4 — wire vision into Qwen3_5ForCausalLM**: extended `Config` with optional `vision_config: Option<VisionConfig>` and `image_token_id`; `Qwen3_5ForCausalLM::new` now loads the vision tower when present, exposes `has_vision()` and `vision()` so the HTTP layer can advertise capability and so the encode path can reach it. - **A5 — device worker `Job::EncodeImage`**: new job variant carrying CPU-side `(C, H, W)` pixels. Dispatch handler reconstructs the tensor on the worker's device, calls `arch.encode_image(image)`, copies the result back to CPU as flat `Vec<f32>`. Keeps the "tensors don't escape the worker" invariant. Poisoned-worker drain path handles the new variant. - **A6 — dispatch round-trip test**: `encode_image_routes_to_dispatch_ and_errors_on_unknown_handle` proves the channel/dispatch wiring works end-to-end via the CPU device worker (errors on unknown ArchHandle, which is the expected behaviour without a loaded model — real-weights validation happens in Stage B when the LM splice path exists). CI gate: cargo fmt --check, cargo clippy --workspace --all-targets -- -D warnings, cargo test --workspace (all 28 test groups ok, zero failures). New test counts: +9 in preprocess, +5 in vision, +1 in device_worker. Out of scope (deferred): - LM-side splice of image embeddings at `<|image_pad|>` positions → Stage B. - Streaming SSE for vision-bearing chat completions → Stage C. - Reject `image_url` with HTTP 400 for non-vision models / advertise `capabilities` in /v1/models → Stage C. - TP-vision (#12), 27B production deploy (#13), dynamic resolution (#14), numerical validation (#15). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
117
Cargo.lock
generated
117
Cargo.lock
generated
@@ -472,6 +472,12 @@ version = "1.5.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "byteorder-lite"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bytes"
|
name = "bytes"
|
||||||
version = "1.11.1"
|
version = "1.11.1"
|
||||||
@@ -668,6 +674,12 @@ dependencies = [
|
|||||||
"cc",
|
"cc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "color_quant"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "colorchoice"
|
name = "colorchoice"
|
||||||
version = "1.0.5"
|
version = "1.0.5"
|
||||||
@@ -1223,6 +1235,15 @@ version = "2.4.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6"
|
checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fdeflate"
|
||||||
|
version = "0.3.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c"
|
||||||
|
dependencies = [
|
||||||
|
"simd-adler32",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "figment"
|
name = "figment"
|
||||||
version = "0.10.19"
|
version = "0.10.19"
|
||||||
@@ -1731,6 +1752,16 @@ dependencies = [
|
|||||||
"wasip3",
|
"wasip3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gif"
|
||||||
|
version = "0.14.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ee8cfcc411d9adbbaba82fb72661cc1bcca13e8bba98b364e62b2dba8f960159"
|
||||||
|
dependencies = [
|
||||||
|
"color_quant",
|
||||||
|
"weezl",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "glob"
|
name = "glob"
|
||||||
version = "0.3.3"
|
version = "0.3.3"
|
||||||
@@ -2135,6 +2166,34 @@ dependencies = [
|
|||||||
"icu_properties",
|
"icu_properties",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "image"
|
||||||
|
version = "0.25.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "85ab80394333c02fe689eaf900ab500fbd0c2213da414687ebf995a65d5a6104"
|
||||||
|
dependencies = [
|
||||||
|
"bytemuck",
|
||||||
|
"byteorder-lite",
|
||||||
|
"color_quant",
|
||||||
|
"gif",
|
||||||
|
"image-webp",
|
||||||
|
"moxcms",
|
||||||
|
"num-traits",
|
||||||
|
"png",
|
||||||
|
"zune-core",
|
||||||
|
"zune-jpeg",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "image-webp"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder-lite",
|
||||||
|
"quick-error",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "indexmap"
|
name = "indexmap"
|
||||||
version = "1.9.3"
|
version = "1.9.3"
|
||||||
@@ -2498,6 +2557,16 @@ dependencies = [
|
|||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "moxcms"
|
||||||
|
version = "0.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bb85c154ba489f01b25c0d36ae69a87e4a1c73a72631fc6c0eb6dde34a73e44b"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
"pxfm",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "native-tls"
|
name = "native-tls"
|
||||||
version = "0.2.18"
|
version = "0.2.18"
|
||||||
@@ -2522,6 +2591,7 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum",
|
"axum",
|
||||||
|
"base64 0.22.1",
|
||||||
"candle-core",
|
"candle-core",
|
||||||
"candle-nn",
|
"candle-nn",
|
||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
@@ -2533,6 +2603,7 @@ dependencies = [
|
|||||||
"futures",
|
"futures",
|
||||||
"half",
|
"half",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
|
"image",
|
||||||
"minijinja",
|
"minijinja",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"safetensors 0.7.0",
|
"safetensors 0.7.0",
|
||||||
@@ -2861,6 +2932,19 @@ version = "0.3.33"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e"
|
checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "png"
|
||||||
|
version = "0.18.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "60769b8b31b2a9f263dae2776c37b1b28ae246943cf719eb6946a1db05128a61"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"crc32fast",
|
||||||
|
"fdeflate",
|
||||||
|
"flate2",
|
||||||
|
"miniz_oxide",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "polling"
|
name = "polling"
|
||||||
version = "3.11.0"
|
version = "3.11.0"
|
||||||
@@ -2974,6 +3058,12 @@ version = "0.1.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "40e24eee682d89fb193496edf918a7f407d30175b2e785fe057e4392dfd182e0"
|
checksum = "40e24eee682d89fb193496edf918a7f407d30175b2e785fe057e4392dfd182e0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pxfm"
|
||||||
|
version = "0.1.29"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quanta"
|
name = "quanta"
|
||||||
version = "0.12.6"
|
version = "0.12.6"
|
||||||
@@ -2989,6 +3079,12 @@ dependencies = [
|
|||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quick-error"
|
||||||
|
version = "2.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quinn"
|
name = "quinn"
|
||||||
version = "0.11.9"
|
version = "0.11.9"
|
||||||
@@ -4627,6 +4723,12 @@ dependencies = [
|
|||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "weezl"
|
||||||
|
version = "0.1.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "which"
|
name = "which"
|
||||||
version = "7.0.3"
|
version = "7.0.3"
|
||||||
@@ -5164,3 +5266,18 @@ name = "zmij"
|
|||||||
version = "1.0.21"
|
version = "1.0.21"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa"
|
checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zune-core"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zune-jpeg"
|
||||||
|
version = "0.5.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "27bc9d5b815bc103f142aa054f561d9187d191692ec7c2d1e2b4737f8dbd7296"
|
||||||
|
dependencies = [
|
||||||
|
"zune-core",
|
||||||
|
]
|
||||||
|
|||||||
@@ -90,6 +90,13 @@ minijinja = { version = "2", features = ["builtins", "json", "serde"] }
|
|||||||
# tp `fused_load` module to read per-rank slices of fused QKV tensors
|
# tp `fused_load` module to read per-rank slices of fused QKV tensors
|
||||||
# without materialising the full tensor on device.
|
# without materialising the full tensor on device.
|
||||||
safetensors = "0.7"
|
safetensors = "0.7"
|
||||||
|
# Vision capability for Qwen3.6 (Stage A of the vision plan in
|
||||||
|
# doc/vision-qwen3_6-spec.md). `image` decodes PNG/JPEG/etc from
|
||||||
|
# the bytes embedded in `data:image/...;base64,...` content parts;
|
||||||
|
# `base64` does the URI decode. Default-features off on `image` to
|
||||||
|
# avoid pulling in audio/video formats we don't need.
|
||||||
|
image = { version = "0.25", default-features = false, features = ["png", "jpeg", "webp", "bmp", "gif"] }
|
||||||
|
base64 = "0.22"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { workspace = true, features = ["test-util"] }
|
tokio = { workspace = true, features = ["test-util"] }
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ pub mod linear_attn;
|
|||||||
pub mod mlp;
|
pub mod mlp;
|
||||||
pub mod rmsnorm;
|
pub mod rmsnorm;
|
||||||
pub mod rope;
|
pub mod rope;
|
||||||
|
pub mod vision;
|
||||||
|
|
||||||
use decoder::Qwen3_5DecoderLayer;
|
use decoder::Qwen3_5DecoderLayer;
|
||||||
use rmsnorm::Qwen3_5RmsNorm;
|
use rmsnorm::Qwen3_5RmsNorm;
|
||||||
@@ -99,6 +100,20 @@ pub struct Config {
|
|||||||
pub model_type: String,
|
pub model_type: String,
|
||||||
/// The text-side hyperparameters. Everything we actually need.
|
/// The text-side hyperparameters. Everything we actually need.
|
||||||
pub text_config: TextConfig,
|
pub text_config: TextConfig,
|
||||||
|
/// Vision tower hyperparameters. Present on multimodal
|
||||||
|
/// checkpoints (e.g. Qwen/Qwen3.6-27B); absent on text-only
|
||||||
|
/// variants. When present, `Qwen3_5ForCausalLM::new` loads the
|
||||||
|
/// vision tower alongside the language model so vision-bearing
|
||||||
|
/// requests can splice image embeddings at `<|image_pad|>` token
|
||||||
|
/// positions.
|
||||||
|
#[serde(default)]
|
||||||
|
pub vision_config: Option<vision::VisionConfig>,
|
||||||
|
/// Token id the chat template emits per image patch group.
|
||||||
|
/// Mirrors the LM tokenizer's `<|image_pad|>` id (248056 for
|
||||||
|
/// Qwen3.6). The runtime locates these in the prompt and splices
|
||||||
|
/// in `VisionTower::forward` output. `None` for text-only models.
|
||||||
|
#[serde(default)]
|
||||||
|
pub image_token_id: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
|
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
|
||||||
@@ -309,6 +324,15 @@ impl Qwen3_5Model {
|
|||||||
pub struct Qwen3_5ForCausalLM {
|
pub struct Qwen3_5ForCausalLM {
|
||||||
base: Qwen3_5Model,
|
base: Qwen3_5Model,
|
||||||
lm_head: Linear,
|
lm_head: Linear,
|
||||||
|
/// Vision tower (Stage A4). `None` for text-only checkpoints or
|
||||||
|
/// when the operator has opted out. When present, the harness's
|
||||||
|
/// `Job::EncodeImage` dispatch path runs `vision.forward(image)`
|
||||||
|
/// and the LM forward (Stage B) splices the result at
|
||||||
|
/// `image_token_id` positions in the input embedding stream.
|
||||||
|
vision: Option<vision::VisionTower>,
|
||||||
|
/// Mirrors `Config::image_token_id`. Cached here so the runtime
|
||||||
|
/// doesn't have to round-trip through the parsed config struct.
|
||||||
|
image_token_id: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Qwen3_5ForCausalLM {
|
impl Qwen3_5ForCausalLM {
|
||||||
@@ -324,7 +348,52 @@ impl Qwen3_5ForCausalLM {
|
|||||||
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
|
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
|
||||||
Linear::new(weight, None)
|
Linear::new(weight, None)
|
||||||
};
|
};
|
||||||
Ok(Self { base, lm_head })
|
// Stage A4: load the vision tower when the config carries a
|
||||||
|
// `vision_config` block and the safetensors actually carry
|
||||||
|
// `model.visual.*` weights. The `Option<VisionConfig>` on the
|
||||||
|
// config makes this a single-source-of-truth decision —
|
||||||
|
// text-only checkpoints just leave `vision_config` unset and
|
||||||
|
// get `None` here without any extra plumbing.
|
||||||
|
let vision = if let Some(vcfg) = config.vision_config.clone() {
|
||||||
|
tracing::info!(
|
||||||
|
depth = vcfg.depth,
|
||||||
|
hidden_size = vcfg.hidden_size,
|
||||||
|
"loading qwen3_5 vision tower"
|
||||||
|
);
|
||||||
|
Some(
|
||||||
|
vision::VisionTower::load(vcfg, vb.pp("model.visual"))
|
||||||
|
.context("load qwen3_5 vision tower (model.visual.*)")?,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
base,
|
||||||
|
lm_head,
|
||||||
|
vision,
|
||||||
|
image_token_id: config.image_token_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// True when this checkpoint loaded a vision tower. Used by the
|
||||||
|
/// HTTP layer to advertise vision capability in `/v1/models` and
|
||||||
|
/// to reject image-bearing requests against text-only loads with
|
||||||
|
/// a clean 400.
|
||||||
|
pub fn has_vision(&self) -> bool {
|
||||||
|
self.vision.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Vision tower handle, if loaded. The device-worker
|
||||||
|
/// `EncodeImage` job dispatches to `vision.forward(image)`.
|
||||||
|
pub fn vision(&self) -> Option<&vision::VisionTower> {
|
||||||
|
self.vision.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `<|image_pad|>` token id from `config.json`, when known.
|
||||||
|
/// The Stage B prompt-builder uses this to count expansion targets
|
||||||
|
/// and the LM forward uses it to locate splice positions.
|
||||||
|
pub fn image_token_id(&self) -> Option<u32> {
|
||||||
|
self.image_token_id
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
|
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
|
||||||
|
|||||||
600
crates/neuron/src/harness/arch/qwen3_5/vision.rs
Normal file
600
crates/neuron/src/harness/arch/qwen3_5/vision.rs
Normal file
@@ -0,0 +1,600 @@
|
|||||||
|
//! Qwen3.6 vision tower.
|
||||||
|
//!
|
||||||
|
//! 27 pre-norm ViT blocks with **LayerNorm** (with biases — not the
|
||||||
|
//! `(1+w)·x` RmsNorm the language model uses), fused QKV attention,
|
||||||
|
//! GELU-tanh MLP. Followed by a `merger` that LayerNorms each
|
||||||
|
//! 1152-dim vision token, spatially 2×2-merges them into 4608-dim
|
||||||
|
//! groups, and projects to the LM's 5120-dim hidden via
|
||||||
|
//! `linear_fc1 → GELU → linear_fc2`.
|
||||||
|
//!
|
||||||
|
//! Architecture spec sourced from beast's cached Qwen3.6-27B
|
||||||
|
//! safetensors header (Stage A0, see
|
||||||
|
//! `doc/vision-qwen3_6-spec.md`). All weight shapes confirmed
|
||||||
|
//! from the live `.safetensors` headers, not inferred.
|
||||||
|
//!
|
||||||
|
//! **Conv3d wrinkle.** The published `patch_embed.proj.weight` is 5D
|
||||||
|
//! `[1152, 3, 2, 16, 16]` — a 3D conv with kernel
|
||||||
|
//! `(t=2, h=16, w=16)`. Candle 0.10 has no Conv3d. For static images
|
||||||
|
//! we get away with a trick: when the temporal patch size is 2 and we
|
||||||
|
//! duplicate the still image along the temporal axis (`T = 2`,
|
||||||
|
//! frame_0 == frame_1), the Conv3d output equals a Conv2d run with
|
||||||
|
//! the *sum* of the two temporal weight slices:
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! output = W_0 · frame_0 + W_1 · frame_1 + bias
|
||||||
|
//! = (W_0 + W_1) · frame + bias (static image)
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! So at load we sum-collapse the temporal axis and use a 4D
|
||||||
|
//! `Conv2d` kernel. Video support would have to do the real Conv3d
|
||||||
|
//! (different frames mean the trick fails) — tracked alongside the
|
||||||
|
//! dynamic-resolution work in issue #14.
|
||||||
|
//!
|
||||||
|
//! Forward signature (Stage A — no LM splice yet):
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! fn forward(&self, image: &Tensor) -> Result<Tensor>
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! `image` is `(3, H, W)` f32, normalised by `preprocess::preprocess`.
|
||||||
|
//! Returns `(N_lm_tokens, out_hidden_size)` post-merger tokens ready
|
||||||
|
//! to splice into the LM's input embeddings at `<|image_pad|>`
|
||||||
|
//! positions. For Qwen3.6 at 448×448 → 28×28 patches → 14×14 = 196
|
||||||
|
//! LM tokens of dim 5120.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{D, DType, Device, IndexOp, Module, Tensor};
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear};
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
/// Qwen3.6 vision tower hyperparameters. Mirrors the `vision_config`
|
||||||
|
/// block of `config.json`. Only the fields we actually need are
|
||||||
|
/// captured; serde tolerates the rest.
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct VisionConfig {
|
||||||
|
/// Number of ViT blocks (`depth: 27` for Qwen3.6).
|
||||||
|
pub depth: usize,
|
||||||
|
/// Vision-token dimension throughout the tower (1152 for Qwen3.6).
|
||||||
|
pub hidden_size: usize,
|
||||||
|
/// MLP intermediate dim (4304).
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
/// Attention head count (16). `head_dim = hidden_size / num_heads`.
|
||||||
|
pub num_heads: usize,
|
||||||
|
/// Number of slots in the learned position embedding (2304).
|
||||||
|
/// Caps the maximum image patch count.
|
||||||
|
pub num_position_embeddings: usize,
|
||||||
|
/// Spatial patch edge in pixels (16).
|
||||||
|
pub patch_size: usize,
|
||||||
|
/// Temporal kernel depth in the patch embed (2 for Qwen3.6 — we
|
||||||
|
/// collapse this into a single Conv2d for static-image inference;
|
||||||
|
/// see the module-level Conv3d wrinkle).
|
||||||
|
pub temporal_patch_size: usize,
|
||||||
|
/// Patches grouped per LM token by the merger (2 → 2×2 = 4
|
||||||
|
/// patches per LM token).
|
||||||
|
pub spatial_merge_size: usize,
|
||||||
|
/// Vision input channels (3, RGB).
|
||||||
|
pub in_channels: usize,
|
||||||
|
/// Merger output dim — matches the LM's `hidden_size` (5120 for
|
||||||
|
/// Qwen3.6). The merger projects from vision dim → LM dim.
|
||||||
|
pub out_hidden_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
const LAYER_NORM_EPS: f64 = 1e-6;
|
||||||
|
/// Number of LM tokens emitted by the merger per vision-token group.
|
||||||
|
const LM_TOKENS_PER_MERGE_GROUP: usize = 1;
|
||||||
|
|
||||||
|
/// One ViT block: pre-LN → attn → residual; pre-LN → MLP → residual.
|
||||||
|
struct VisionBlock {
|
||||||
|
norm1: LayerNorm,
|
||||||
|
qkv: Linear,
|
||||||
|
proj: Linear,
|
||||||
|
norm2: LayerNorm,
|
||||||
|
fc1: Linear,
|
||||||
|
fc2: Linear,
|
||||||
|
num_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisionBlock {
|
||||||
|
fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let h = cfg.hidden_size;
|
||||||
|
let head_dim = h / cfg.num_heads;
|
||||||
|
let norm1 = layer_norm(vb.pp("norm1"), h)?;
|
||||||
|
let qkv = linear(vb.pp("attn.qkv"), h, 3 * h)?;
|
||||||
|
let proj = linear(vb.pp("attn.proj"), h, h)?;
|
||||||
|
let norm2 = layer_norm(vb.pp("norm2"), h)?;
|
||||||
|
let fc1 = linear(vb.pp("mlp.linear_fc1"), h, cfg.intermediate_size)?;
|
||||||
|
let fc2 = linear(vb.pp("mlp.linear_fc2"), cfg.intermediate_size, h)?;
|
||||||
|
Ok(Self {
|
||||||
|
norm1,
|
||||||
|
qkv,
|
||||||
|
proj,
|
||||||
|
norm2,
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
num_heads: cfg.num_heads,
|
||||||
|
head_dim,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `x`: `(N, hidden_size)` un-batched. Returns same shape.
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let attn_in = self.norm1.forward(x)?;
|
||||||
|
let attn_out = self.attention(&attn_in)?;
|
||||||
|
let x = x.add(&attn_out)?;
|
||||||
|
let mlp_in = self.norm2.forward(&x)?;
|
||||||
|
let mlp_out = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&mlp_in)?)?)?;
|
||||||
|
x.add(&mlp_out).map_err(Into::into)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Multi-head self-attention over the patch sequence. No causal
|
||||||
|
/// mask — every patch attends to every other patch.
|
||||||
|
fn attention(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let (n, hidden) = x.dims2()?;
|
||||||
|
// qkv: (N, 3*hidden). Split into Q, K, V each (N, hidden).
|
||||||
|
let qkv = self.qkv.forward(x)?;
|
||||||
|
let qkv = qkv.reshape((n, 3, self.num_heads, self.head_dim))?;
|
||||||
|
// Transpose to (3, num_heads, N, head_dim) for per-head views.
|
||||||
|
let qkv = qkv.permute((1, 2, 0, 3))?.contiguous()?;
|
||||||
|
let q = qkv.i(0)?;
|
||||||
|
let k = qkv.i(1)?;
|
||||||
|
let v = qkv.i(2)?;
|
||||||
|
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||||
|
// (num_heads, N, head_dim) @ (num_heads, head_dim, N) -> (num_heads, N, N)
|
||||||
|
let scores = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
|
||||||
|
let scores = (scores * scale)?;
|
||||||
|
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||||
|
// (num_heads, N, N) @ (num_heads, N, head_dim) -> (num_heads, N, head_dim)
|
||||||
|
let out = probs.matmul(&v)?;
|
||||||
|
// Merge heads back: (N, num_heads, head_dim) -> (N, hidden).
|
||||||
|
let out = out.permute((1, 0, 2))?.contiguous()?.reshape((n, hidden))?;
|
||||||
|
self.proj.forward(&out).map_err(Into::into)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `merger`: LayerNorm per token → spatial 2×2 merge (concat 4
|
||||||
|
/// adjacent tokens into one 4608-dim vector) → fc1 → GELU-tanh →
|
||||||
|
/// fc2. Output dim is the LM's hidden_size.
|
||||||
|
struct VisionMerger {
|
||||||
|
norm: LayerNorm,
|
||||||
|
fc1: Linear,
|
||||||
|
fc2: Linear,
|
||||||
|
merge_input_dim: usize,
|
||||||
|
spatial_merge_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisionMerger {
|
||||||
|
fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let h = cfg.hidden_size;
|
||||||
|
let merge = cfg.spatial_merge_size;
|
||||||
|
let merge_input_dim = h * merge * merge;
|
||||||
|
let norm = layer_norm(vb.pp("norm"), h)?;
|
||||||
|
let fc1 = linear(vb.pp("linear_fc1"), merge_input_dim, merge_input_dim)?;
|
||||||
|
let fc2 = linear(vb.pp("linear_fc2"), merge_input_dim, cfg.out_hidden_size)?;
|
||||||
|
Ok(Self {
|
||||||
|
norm,
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
merge_input_dim,
|
||||||
|
spatial_merge_size: merge,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `tokens`: `(grid_h, grid_w, hidden_size)`. The merger reshapes
|
||||||
|
/// each `merge×merge` block of adjacent patches into a single
|
||||||
|
/// concatenated vector, then projects.
|
||||||
|
///
|
||||||
|
/// `grid_h` and `grid_w` must both be multiples of
|
||||||
|
/// `spatial_merge_size`. Returns
|
||||||
|
/// `(grid_h/merge × grid_w/merge, out_hidden_size)`.
|
||||||
|
fn forward(&self, tokens: &Tensor) -> Result<Tensor> {
|
||||||
|
let (gh, gw, h) = tokens.dims3()?;
|
||||||
|
let m = self.spatial_merge_size;
|
||||||
|
anyhow::ensure!(
|
||||||
|
gh.is_multiple_of(m) && gw.is_multiple_of(m),
|
||||||
|
"merger expects spatial dims divisible by merge_size={m}; got ({gh}, {gw})"
|
||||||
|
);
|
||||||
|
let tokens = self.norm.forward(tokens)?;
|
||||||
|
// (gh, gw, h) -> (gh/m, m, gw/m, m, h) -> (gh/m, gw/m, m, m, h)
|
||||||
|
// -> flatten last three -> (gh/m, gw/m, m*m*h) -> (N_lm, merge_input_dim)
|
||||||
|
let out_h = gh / m;
|
||||||
|
let out_w = gw / m;
|
||||||
|
let merged = tokens
|
||||||
|
.reshape((out_h, m, out_w, m, h))?
|
||||||
|
.permute((0, 2, 1, 3, 4))?
|
||||||
|
.contiguous()?
|
||||||
|
.reshape((out_h * out_w, self.merge_input_dim))?;
|
||||||
|
let hidden = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&merged)?)?)?;
|
||||||
|
Ok(hidden)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The vision tower itself.
|
||||||
|
pub struct VisionTower {
|
||||||
|
/// Sum-collapsed temporal kernel (Conv2d, see module doc).
|
||||||
|
patch_embed: Conv2d,
|
||||||
|
pos_embed: Embedding,
|
||||||
|
blocks: Vec<VisionBlock>,
|
||||||
|
merger: VisionMerger,
|
||||||
|
config: VisionConfig,
|
||||||
|
dtype: DType,
|
||||||
|
device: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisionTower {
|
||||||
|
/// Load from a `ShardedVarBuilder` rooted at the safetensors
|
||||||
|
/// `model.visual.` prefix. Caller is responsible for the `pp` —
|
||||||
|
/// see `Qwen3_5ForCausalLM::new` (Stage A4).
|
||||||
|
pub fn load(cfg: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let dtype = vb.dtype();
|
||||||
|
let device = vb.device().clone();
|
||||||
|
|
||||||
|
// patch_embed.proj is published as 5D Conv3d weight; we
|
||||||
|
// sum-collapse the temporal axis (size = temporal_patch_size)
|
||||||
|
// to get a 4D Conv2d kernel. This is exact for the static-
|
||||||
|
// image case where T = temporal_patch_size frames are
|
||||||
|
// identical (i.e. the input was duplicated along T).
|
||||||
|
let raw_weight = vb
|
||||||
|
.pp("patch_embed.proj")
|
||||||
|
.get(
|
||||||
|
(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.in_channels,
|
||||||
|
cfg.temporal_patch_size,
|
||||||
|
cfg.patch_size,
|
||||||
|
cfg.patch_size,
|
||||||
|
),
|
||||||
|
"weight",
|
||||||
|
)
|
||||||
|
.context("load model.visual.patch_embed.proj.weight (5D Conv3d kernel)")?;
|
||||||
|
// Sum along the temporal axis (dim 2) — see module doc-comment.
|
||||||
|
let folded = raw_weight.sum(2)?; // -> (hidden, in_channels, patch, patch)
|
||||||
|
let proj_bias = vb
|
||||||
|
.pp("patch_embed.proj")
|
||||||
|
.get(cfg.hidden_size, "bias")
|
||||||
|
.context("load model.visual.patch_embed.proj.bias")?;
|
||||||
|
let conv_cfg = Conv2dConfig {
|
||||||
|
stride: cfg.patch_size,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let patch_embed = Conv2d::new(folded, Some(proj_bias), conv_cfg);
|
||||||
|
|
||||||
|
let pos_embed_weight = vb
|
||||||
|
.pp("pos_embed")
|
||||||
|
.get((cfg.num_position_embeddings, cfg.hidden_size), "weight")
|
||||||
|
.context("load model.visual.pos_embed.weight")?;
|
||||||
|
let pos_embed = Embedding::new(pos_embed_weight, cfg.hidden_size);
|
||||||
|
|
||||||
|
let blocks_vb = vb.pp("blocks");
|
||||||
|
let mut blocks = Vec::with_capacity(cfg.depth);
|
||||||
|
for i in 0..cfg.depth {
|
||||||
|
blocks.push(
|
||||||
|
VisionBlock::load(&cfg, &blocks_vb.pp(i))
|
||||||
|
.with_context(|| format!("load vision block {i}"))?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let merger = VisionMerger::load(&cfg, &vb.pp("merger")).context("load vision merger")?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
patch_embed,
|
||||||
|
pos_embed,
|
||||||
|
blocks,
|
||||||
|
merger,
|
||||||
|
config: cfg,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn config(&self) -> &VisionConfig {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of LM tokens this tower emits for an `(H, W)` pixel
|
||||||
|
/// image after the merger. Equal to
|
||||||
|
/// `(H / patch_size / spatial_merge_size) * (W / patch_size / spatial_merge_size)`.
|
||||||
|
pub fn lm_tokens_for(&self, h: u32, w: u32) -> usize {
|
||||||
|
let m = self.config.spatial_merge_size;
|
||||||
|
let patch = self.config.patch_size;
|
||||||
|
let gh = (h as usize) / patch / m;
|
||||||
|
let gw = (w as usize) / patch / m;
|
||||||
|
gh * gw * LM_TOKENS_PER_MERGE_GROUP
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode one image.
|
||||||
|
///
|
||||||
|
/// `image`: row-major `(3, H, W)` f32 tensor on `self.device`,
|
||||||
|
/// already normalised by `preprocess::preprocess`. Both `H` and
|
||||||
|
/// `W` must be multiples of `patch_size * spatial_merge_size`.
|
||||||
|
///
|
||||||
|
/// Returns `(N_lm, out_hidden_size)` — LM-side image tokens
|
||||||
|
/// ready to splice into the language model's input embeddings.
|
||||||
|
pub fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||||
|
let (c, h, w) = image.dims3()?;
|
||||||
|
anyhow::ensure!(
|
||||||
|
c == self.config.in_channels,
|
||||||
|
"image must have {} channels, got {c}",
|
||||||
|
self.config.in_channels
|
||||||
|
);
|
||||||
|
let patch = self.config.patch_size;
|
||||||
|
anyhow::ensure!(
|
||||||
|
h.is_multiple_of(patch) && w.is_multiple_of(patch),
|
||||||
|
"image dims must be multiples of patch_size={patch}; got ({h}, {w})"
|
||||||
|
);
|
||||||
|
let gh = h / patch;
|
||||||
|
let gw = w / patch;
|
||||||
|
let n_patches = gh * gw;
|
||||||
|
anyhow::ensure!(
|
||||||
|
n_patches <= self.config.num_position_embeddings,
|
||||||
|
"patch count {n_patches} exceeds pos_embed budget {}",
|
||||||
|
self.config.num_position_embeddings
|
||||||
|
);
|
||||||
|
|
||||||
|
// Add batch axis for conv: (1, 3, H, W) → (1, hidden, gh, gw)
|
||||||
|
// → (hidden, gh, gw) → permute to (gh, gw, hidden) → flatten to (N, hidden)
|
||||||
|
let x = image.unsqueeze(0)?.to_dtype(self.dtype)?;
|
||||||
|
let x = self.patch_embed.forward(&x)?;
|
||||||
|
let x = x.squeeze(0)?;
|
||||||
|
let x = x.permute((1, 2, 0))?.contiguous()?;
|
||||||
|
let x = x.reshape((n_patches, self.config.hidden_size))?;
|
||||||
|
|
||||||
|
// Add learned positional embeddings (sequential indices for
|
||||||
|
// Stage A's fixed-resolution path; full 2D positional logic
|
||||||
|
// lands with variable resolution, issue #14).
|
||||||
|
let positions = Tensor::arange(0u32, n_patches as u32, &self.device)?;
|
||||||
|
let pos = self.pos_embed.forward(&positions)?;
|
||||||
|
let mut x = x.add(&pos)?;
|
||||||
|
|
||||||
|
for (i, block) in self.blocks.iter().enumerate() {
|
||||||
|
x = block
|
||||||
|
.forward(&x)
|
||||||
|
.with_context(|| format!("vision block {i}"))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// (n_patches, hidden) → (gh, gw, hidden) for the merger.
|
||||||
|
let x = x.reshape((gh, gw, self.config.hidden_size))?;
|
||||||
|
self.merger.forward(&x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Manually load a candle_nn LayerNorm from a ShardedVarBuilder.
|
||||||
|
/// candle_nn's `layer_norm` builder takes `crate::VarBuilder`, not
|
||||||
|
/// `ShardedVarBuilder`, so the existing arch modules in this crate
|
||||||
|
/// uniformly do the manual load + struct construction pattern (see
|
||||||
|
/// `full_attn::load_linear_no_bias`). We follow suit here.
|
||||||
|
fn layer_norm(vb: ShardedVarBuilder, size: usize) -> Result<LayerNorm> {
|
||||||
|
let weight = vb
|
||||||
|
.get(size, "weight")
|
||||||
|
.with_context(|| format!("load LayerNorm.weight at '{}'", vb.prefix()))?;
|
||||||
|
let bias = vb
|
||||||
|
.get(size, "bias")
|
||||||
|
.with_context(|| format!("load LayerNorm.bias at '{}'", vb.prefix()))?;
|
||||||
|
Ok(LayerNorm::new(weight, bias, LAYER_NORM_EPS))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Manually load a candle_nn Linear (with bias) from a
|
||||||
|
/// ShardedVarBuilder. Same rationale as `layer_norm` above.
|
||||||
|
fn linear(vb: ShardedVarBuilder, in_dim: usize, out_dim: usize) -> Result<Linear> {
|
||||||
|
let weight = vb
|
||||||
|
.get((out_dim, in_dim), "weight")
|
||||||
|
.with_context(|| format!("load Linear.weight at '{}'", vb.prefix()))?;
|
||||||
|
let bias = vb
|
||||||
|
.get(out_dim, "bias")
|
||||||
|
.with_context(|| format!("load Linear.bias at '{}'", vb.prefix()))?;
|
||||||
|
Ok(Linear::new(weight, Some(bias)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// PyTorch's `gelu_pytorch_tanh` approximation — what the Qwen3.6
|
||||||
|
/// vision tower's `hidden_act` specifies. candle's `Tensor::gelu`
|
||||||
|
/// uses the exact erf-based GELU, so we compute the tanh
|
||||||
|
/// approximation explicitly:
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// gelu_tanh(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||||
|
/// ```
|
||||||
|
fn gelu_tanh(x: &Tensor) -> Result<Tensor> {
|
||||||
|
// sqrt(2 / pi) = 0.7978845608028654
|
||||||
|
const COEFF: f64 = 0.7978845608028654;
|
||||||
|
const KAPPA: f64 = 0.044715;
|
||||||
|
let x3 = x.powf(3.0)?;
|
||||||
|
let inner = (x + (x3 * KAPPA)?)?;
|
||||||
|
let inner = (inner * COEFF)?;
|
||||||
|
let t = inner.tanh()?;
|
||||||
|
let one_plus_t = (t + 1.0)?;
|
||||||
|
let out = (x * 0.5)?;
|
||||||
|
let out = out.broadcast_mul(&one_plus_t)?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use candle_core::{DType, Device};
|
||||||
|
|
||||||
|
/// Build a tiny VisionConfig usable on CPU with random weights.
|
||||||
|
/// Match the Qwen3.6 shape relations (depth-N stack, hidden mod
|
||||||
|
/// num_heads, intermediate_size > hidden_size) but with small
|
||||||
|
/// dims so tests run in milliseconds.
|
||||||
|
fn tiny_config() -> VisionConfig {
|
||||||
|
VisionConfig {
|
||||||
|
depth: 2,
|
||||||
|
hidden_size: 32,
|
||||||
|
intermediate_size: 64,
|
||||||
|
num_heads: 4,
|
||||||
|
num_position_embeddings: 64,
|
||||||
|
patch_size: 4,
|
||||||
|
temporal_patch_size: 2,
|
||||||
|
spatial_merge_size: 2,
|
||||||
|
in_channels: 3,
|
||||||
|
out_hidden_size: 48,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hand-construct a VisionTower with random weights. This is the
|
||||||
|
/// same trick `linear_attn::tests::forward_smoke_with_tiny_dimensions`
|
||||||
|
/// uses — bypass the safetensors-backed `ShardedVarBuilder` path
|
||||||
|
/// (which can't be built from in-memory tensors) and assemble the
|
||||||
|
/// struct fields directly. The real `VisionTower::load` is
|
||||||
|
/// exercised by the cuda-integration smoke test in Stage A6.
|
||||||
|
fn tiny_tower(cfg: &VisionConfig) -> VisionTower {
|
||||||
|
let device = Device::Cpu;
|
||||||
|
let dtype = DType::F32;
|
||||||
|
let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &device).unwrap();
|
||||||
|
let ones = |shape: &[usize]| Tensor::ones(shape, dtype, &device).unwrap();
|
||||||
|
let randn = |shape: &[usize]| Tensor::randn(0_f32, 0.02, shape, &device).unwrap();
|
||||||
|
|
||||||
|
let patch_embed = Conv2d::new(
|
||||||
|
randn(&[
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.in_channels,
|
||||||
|
cfg.patch_size,
|
||||||
|
cfg.patch_size,
|
||||||
|
]),
|
||||||
|
Some(zeros(&[cfg.hidden_size])),
|
||||||
|
Conv2dConfig {
|
||||||
|
stride: cfg.patch_size,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let pos_embed = Embedding::new(
|
||||||
|
randn(&[cfg.num_position_embeddings, cfg.hidden_size]),
|
||||||
|
cfg.hidden_size,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut blocks = Vec::with_capacity(cfg.depth);
|
||||||
|
for _ in 0..cfg.depth {
|
||||||
|
let head_dim = cfg.hidden_size / cfg.num_heads;
|
||||||
|
blocks.push(VisionBlock {
|
||||||
|
norm1: LayerNorm::new(
|
||||||
|
ones(&[cfg.hidden_size]),
|
||||||
|
zeros(&[cfg.hidden_size]),
|
||||||
|
LAYER_NORM_EPS,
|
||||||
|
),
|
||||||
|
qkv: Linear::new(
|
||||||
|
randn(&[3 * cfg.hidden_size, cfg.hidden_size]),
|
||||||
|
Some(zeros(&[3 * cfg.hidden_size])),
|
||||||
|
),
|
||||||
|
proj: Linear::new(
|
||||||
|
randn(&[cfg.hidden_size, cfg.hidden_size]),
|
||||||
|
Some(zeros(&[cfg.hidden_size])),
|
||||||
|
),
|
||||||
|
norm2: LayerNorm::new(
|
||||||
|
ones(&[cfg.hidden_size]),
|
||||||
|
zeros(&[cfg.hidden_size]),
|
||||||
|
LAYER_NORM_EPS,
|
||||||
|
),
|
||||||
|
fc1: Linear::new(
|
||||||
|
randn(&[cfg.intermediate_size, cfg.hidden_size]),
|
||||||
|
Some(zeros(&[cfg.intermediate_size])),
|
||||||
|
),
|
||||||
|
fc2: Linear::new(
|
||||||
|
randn(&[cfg.hidden_size, cfg.intermediate_size]),
|
||||||
|
Some(zeros(&[cfg.hidden_size])),
|
||||||
|
),
|
||||||
|
num_heads: cfg.num_heads,
|
||||||
|
head_dim,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let merge_input_dim = cfg.hidden_size * cfg.spatial_merge_size * cfg.spatial_merge_size;
|
||||||
|
let merger = VisionMerger {
|
||||||
|
norm: LayerNorm::new(
|
||||||
|
ones(&[cfg.hidden_size]),
|
||||||
|
zeros(&[cfg.hidden_size]),
|
||||||
|
LAYER_NORM_EPS,
|
||||||
|
),
|
||||||
|
fc1: Linear::new(
|
||||||
|
randn(&[merge_input_dim, merge_input_dim]),
|
||||||
|
Some(zeros(&[merge_input_dim])),
|
||||||
|
),
|
||||||
|
fc2: Linear::new(
|
||||||
|
randn(&[cfg.out_hidden_size, merge_input_dim]),
|
||||||
|
Some(zeros(&[cfg.out_hidden_size])),
|
||||||
|
),
|
||||||
|
merge_input_dim,
|
||||||
|
spatial_merge_size: cfg.spatial_merge_size,
|
||||||
|
};
|
||||||
|
|
||||||
|
VisionTower {
|
||||||
|
patch_embed,
|
||||||
|
pos_embed,
|
||||||
|
blocks,
|
||||||
|
merger,
|
||||||
|
config: cfg.clone(),
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forward_with_random_weights_produces_finite_output() {
|
||||||
|
let cfg = tiny_config();
|
||||||
|
let tower = tiny_tower(&cfg);
|
||||||
|
|
||||||
|
// 16×16 image at patch_size=4 → 4×4 patches → after 2×2
|
||||||
|
// merge → 2×2 = 4 LM tokens of dim out_hidden_size.
|
||||||
|
let image = Tensor::randn(0_f32, 1.0, (3, 16, 16), &Device::Cpu).unwrap();
|
||||||
|
let out = tower.forward(&image).expect("forward");
|
||||||
|
let (n_lm, hidden) = out.dims2().unwrap();
|
||||||
|
assert_eq!(n_lm, 4);
|
||||||
|
assert_eq!(hidden, cfg.out_hidden_size);
|
||||||
|
|
||||||
|
// No NaN/Inf
|
||||||
|
let values: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
|
||||||
|
assert!(
|
||||||
|
values.iter().all(|v| v.is_finite()),
|
||||||
|
"forward must produce finite values"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lm_token_count_matches_grid() {
|
||||||
|
let cfg = tiny_config();
|
||||||
|
let tower = tiny_tower(&cfg);
|
||||||
|
// 16x16 image → 4x4 patches → 2x2 = 4 LM tokens
|
||||||
|
assert_eq!(tower.lm_tokens_for(16, 16), 4);
|
||||||
|
// 32x32 image → 8x8 patches → 4x4 = 16 LM tokens
|
||||||
|
assert_eq!(tower.lm_tokens_for(32, 32), 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_image_with_dims_not_multiple_of_patch() {
|
||||||
|
let cfg = tiny_config();
|
||||||
|
let tower = tiny_tower(&cfg);
|
||||||
|
let image = Tensor::randn(0_f32, 1.0, (3, 17, 17), &Device::Cpu).unwrap();
|
||||||
|
let err = tower.forward(&image).unwrap_err();
|
||||||
|
assert!(format!("{err:#}").contains("patch_size"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_image_with_wrong_channel_count() {
|
||||||
|
let cfg = tiny_config();
|
||||||
|
let tower = tiny_tower(&cfg);
|
||||||
|
let image = Tensor::randn(0_f32, 1.0, (4, 16, 16), &Device::Cpu).unwrap();
|
||||||
|
let err = tower.forward(&image).unwrap_err();
|
||||||
|
assert!(format!("{err:#}").contains("channels"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gelu_tanh_matches_known_values() {
|
||||||
|
// Reference values for gelu_pytorch_tanh from PyTorch:
|
||||||
|
// gelu_tanh(0.0) = 0.0
|
||||||
|
// gelu_tanh(1.0) ≈ 0.8411920071
|
||||||
|
// gelu_tanh(-1.0) ≈ -0.1588079929
|
||||||
|
let x = Tensor::new(&[0.0_f32, 1.0, -1.0], &Device::Cpu).unwrap();
|
||||||
|
let y = gelu_tanh(&x).unwrap();
|
||||||
|
let v: Vec<f32> = y.to_vec1().unwrap();
|
||||||
|
assert!((v[0]).abs() < 1e-6, "gelu_tanh(0) ≈ 0, got {}", v[0]);
|
||||||
|
assert!(
|
||||||
|
(v[1] - 0.841_192_f32).abs() < 1e-5,
|
||||||
|
"gelu_tanh(1) ≈ 0.84119, got {}",
|
||||||
|
v[1]
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
(v[2] - -0.158_808_f32).abs() < 1e-5,
|
||||||
|
"gelu_tanh(-1) ≈ -0.15881, got {}",
|
||||||
|
v[2]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -332,6 +332,37 @@ impl ModelArch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Encode a preprocessed image into LM-side token embeddings via
|
||||||
|
/// the loaded vision tower. Stage A5.
|
||||||
|
///
|
||||||
|
/// `image`: device-resident `(C, H, W)` f32 tensor — caller has
|
||||||
|
/// already preprocessed via `harness::preprocess::preprocess` and
|
||||||
|
/// uploaded to the worker's device. Returns
|
||||||
|
/// `(N_lm_tokens, hidden_size)`.
|
||||||
|
///
|
||||||
|
/// Errors when the loaded architecture has no vision tower
|
||||||
|
/// (text-only checkpoint, or architecture that doesn't support
|
||||||
|
/// vision at all). The HTTP layer maps this to a 400 with
|
||||||
|
/// `vision_unsupported` so clients see a clean rejection rather
|
||||||
|
/// than a confident text-only hallucination.
|
||||||
|
pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
ModelArch::Qwen3_5Dense(m) => m
|
||||||
|
.vision()
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"encode_image: this Qwen3.6 checkpoint was loaded without a vision \
|
||||||
|
tower (config.json::vision_config absent or weights missing)"
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
.forward(image),
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"encode_image: architecture {} has no vision tower",
|
||||||
|
std::any::type_name_of_val(other)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Squeeze any leading singleton dims off the logits tensor so the
|
/// Squeeze any leading singleton dims off the logits tensor so the
|
||||||
|
|||||||
@@ -158,6 +158,17 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
|||||||
let result = forward_logits(&mut state, handle, &tokens, offset);
|
let result = forward_logits(&mut state, handle, &tokens, offset);
|
||||||
let _ = reply.send(result);
|
let _ = reply.send(result);
|
||||||
}
|
}
|
||||||
|
Job::EncodeImage {
|
||||||
|
handle,
|
||||||
|
pixels,
|
||||||
|
c,
|
||||||
|
h,
|
||||||
|
w,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let result = encode_image(&mut state, handle, pixels, c, h, w);
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
Job::NcclInit {
|
Job::NcclInit {
|
||||||
cfg,
|
cfg,
|
||||||
comm_id_hex,
|
comm_id_hex,
|
||||||
@@ -740,6 +751,49 @@ fn forward_logits(
|
|||||||
Ok(values)
|
Ok(values)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Run the vision tower on a single preprocessed image. Stage A5.
|
||||||
|
///
|
||||||
|
/// `pixels` is a row-major `(c, h, w)` f32 image that the async-side
|
||||||
|
/// `harness::preprocess` produced. We reconstruct the tensor on the
|
||||||
|
/// worker's device (the same device the model was loaded against),
|
||||||
|
/// call `arch.encode_image`, and copy the resulting
|
||||||
|
/// `(N_lm_tokens, hidden_size)` embedding back to CPU f32.
|
||||||
|
///
|
||||||
|
/// Returns the flattened embedding as a `Vec<f32>` — the caller knows
|
||||||
|
/// the LM-side token count from `VisionTower::lm_tokens_for(h, w)`
|
||||||
|
/// and reshapes accordingly. Stage B introduces a device-resident
|
||||||
|
/// embedding-slab variant that avoids this round-trip when the next
|
||||||
|
/// forward call needs the result.
|
||||||
|
fn encode_image(
|
||||||
|
state: &mut DeviceWorkerState,
|
||||||
|
handle: ArchHandle,
|
||||||
|
pixels: Vec<f32>,
|
||||||
|
c: usize,
|
||||||
|
h: usize,
|
||||||
|
w: usize,
|
||||||
|
) -> anyhow::Result<Vec<f32>> {
|
||||||
|
use candle_core::{DType, Tensor};
|
||||||
|
|
||||||
|
anyhow::ensure!(
|
||||||
|
pixels.len() == c * h * w,
|
||||||
|
"EncodeImage: pixels length {} does not match shape ({c}, {h}, {w})",
|
||||||
|
pixels.len()
|
||||||
|
);
|
||||||
|
let image = Tensor::from_vec(pixels, (c, h, w), &state.device)?;
|
||||||
|
|
||||||
|
let arch = state
|
||||||
|
.models
|
||||||
|
.get(&handle)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("EncodeImage: no model for handle {}", handle.0))?;
|
||||||
|
|
||||||
|
let embed = arch.encode_image(&image)?;
|
||||||
|
let values = embed
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.flatten_all()?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
Ok(values)
|
||||||
|
}
|
||||||
|
|
||||||
/// Reply to a job with the poisoned-worker error. Used when the worker
|
/// Reply to a job with the poisoned-worker error. Used when the worker
|
||||||
/// has flipped into drain-only mode after a CUDA driver error.
|
/// has flipped into drain-only mode after a CUDA driver error.
|
||||||
///
|
///
|
||||||
@@ -773,6 +827,9 @@ fn drain_poisoned(job: Job, device_index: u32) {
|
|||||||
Job::ForwardLogits { reply, .. } => {
|
Job::ForwardLogits { reply, .. } => {
|
||||||
let _ = reply.send(Err(err()));
|
let _ = reply.send(Err(err()));
|
||||||
}
|
}
|
||||||
|
Job::EncodeImage { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
Job::NcclInit { reply, .. } => {
|
Job::NcclInit { reply, .. } => {
|
||||||
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
||||||
kind: "device_worker_poisoned".into(),
|
kind: "device_worker_poisoned".into(),
|
||||||
|
|||||||
@@ -94,6 +94,31 @@ pub enum Job {
|
|||||||
offset: usize,
|
offset: usize,
|
||||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
},
|
},
|
||||||
|
/// Encode one image through the model's vision tower. Stage A5 of
|
||||||
|
/// the vision plan (`doc/vision-qwen3_6-spec.md`).
|
||||||
|
///
|
||||||
|
/// `pixels` is the CPU-side preprocessed image tensor in row-major
|
||||||
|
/// `(C, H, W)` f32 layout — what `harness::preprocess::preprocess`
|
||||||
|
/// produces. `c`, `h`, `w` carry the shape since `Vec<f32>` itself
|
||||||
|
/// is rank-1. The handler reconstructs the tensor on the worker's
|
||||||
|
/// device, runs `VisionTower::forward`, copies the resulting
|
||||||
|
/// `(N_lm_tokens, hidden_size)` embedding back to CPU as a flat
|
||||||
|
/// `Vec<f32>` (the caller knows the expected shape from
|
||||||
|
/// `VisionTower::lm_tokens_for(h, w) * hidden_size`).
|
||||||
|
///
|
||||||
|
/// Mirrors the `ForwardLogits` "tensors don't escape" invariant —
|
||||||
|
/// device-side image embeddings are dropped at handler return.
|
||||||
|
/// Stage B will introduce a follow-up variant that keeps the
|
||||||
|
/// embeddings device-resident and references them from the next
|
||||||
|
/// `ForwardLogits` call, avoiding the round-trip copy.
|
||||||
|
EncodeImage {
|
||||||
|
handle: ArchHandle,
|
||||||
|
pixels: Vec<f32>,
|
||||||
|
c: usize,
|
||||||
|
h: usize,
|
||||||
|
w: usize,
|
||||||
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
|
},
|
||||||
/// Initialize the leader's NCCL communicator. The worker's
|
/// Initialize the leader's NCCL communicator. The worker's
|
||||||
/// `NcclState` mints the `Comm` here so its underlying
|
/// `NcclState` mints the `Comm` here so its underlying
|
||||||
/// `ncclComm_t` and `CudaContext` live on the same thread as
|
/// `ncclComm_t` and `CudaContext` live on the same thread as
|
||||||
|
|||||||
@@ -313,6 +313,49 @@ impl DeviceWorkerHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Encode a preprocessed image through the model's vision tower
|
||||||
|
/// and return the resulting LM-side image embeddings as a
|
||||||
|
/// flattened CPU `Vec<f32>`. Stage A5.
|
||||||
|
///
|
||||||
|
/// `pixels` is the row-major `(c, h, w)` f32 image —
|
||||||
|
/// `harness::preprocess::preprocess` produces this exact shape.
|
||||||
|
/// The caller knows the expected output length from
|
||||||
|
/// `VisionTower::lm_tokens_for(h, w) * hidden_size` and reshapes
|
||||||
|
/// accordingly.
|
||||||
|
pub async fn encode_image(
|
||||||
|
&self,
|
||||||
|
handle: ArchHandle,
|
||||||
|
pixels: Vec<f32>,
|
||||||
|
c: usize,
|
||||||
|
h: usize,
|
||||||
|
w: usize,
|
||||||
|
) -> Result<Vec<f32>, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::EncodeImage {
|
||||||
|
handle,
|
||||||
|
pixels,
|
||||||
|
c,
|
||||||
|
h,
|
||||||
|
w,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Initialise the leader's NCCL communicator. The reply uses
|
/// Initialise the leader's NCCL communicator. The reply uses
|
||||||
/// `WorkerResponse` (same shape subprocess workers use over stdio
|
/// `WorkerResponse` (same shape subprocess workers use over stdio
|
||||||
/// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader +
|
/// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader +
|
||||||
@@ -569,6 +612,37 @@ mod tests {
|
|||||||
handle.shutdown().expect("shutdown ok");
|
handle.shutdown().expect("shutdown ok");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Stage A5: confirm the EncodeImage job round-trips through the
|
||||||
|
/// worker channel. We don't have a real loaded model in the slab
|
||||||
|
/// here, so the dispatch handler returns the
|
||||||
|
/// "no model for handle" error — which is exactly what we want to
|
||||||
|
/// see, since it proves the message routed through the channel
|
||||||
|
/// and reached the handler. Real-weights validation lives in the
|
||||||
|
/// Stage A7 / Stage B post-deploy smoke on beast.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn encode_image_routes_to_dispatch_and_errors_on_unknown_handle() {
|
||||||
|
use crate::harness::device_worker::jobs::ArchHandle;
|
||||||
|
|
||||||
|
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||||
|
let fake_arch = ArchHandle(99_999);
|
||||||
|
// (3, 4, 4) fake image — minimal payload, gets reconstructed
|
||||||
|
// on the worker before the handler errors out on the unknown
|
||||||
|
// ArchHandle lookup.
|
||||||
|
let pixels = vec![0.0_f32; 3 * 4 * 4];
|
||||||
|
let result = handle.encode_image(fake_arch, pixels, 3, 4, 4).await;
|
||||||
|
match result {
|
||||||
|
Err(WorkerError::Job(e)) => {
|
||||||
|
let msg = format!("{e:#}");
|
||||||
|
assert!(
|
||||||
|
msg.contains("EncodeImage: no model for handle"),
|
||||||
|
"expected unknown-handle error, got: {msg}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
other => panic!("expected Job(Err), got {other:?}"),
|
||||||
|
}
|
||||||
|
handle.shutdown().expect("shutdown ok");
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn shutdown_drains_pending_jobs() {
|
async fn shutdown_drains_pending_jobs() {
|
||||||
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ pub mod candle;
|
|||||||
pub mod chat_template;
|
pub mod chat_template;
|
||||||
pub mod device_worker;
|
pub mod device_worker;
|
||||||
pub mod preflight;
|
pub mod preflight;
|
||||||
|
pub mod preprocess;
|
||||||
pub mod tp;
|
pub mod tp;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
|||||||
255
crates/neuron/src/harness/preprocess.rs
Normal file
255
crates/neuron/src/harness/preprocess.rs
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
//! Image preprocessing for vision-capable models.
|
||||||
|
//!
|
||||||
|
//! Decodes `data:image/...;base64,...` URIs from OpenAI-style
|
||||||
|
//! `image_url` content parts into the patch tensors a candle vision
|
||||||
|
//! tower expects. Stage A ships **fixed resolution** — every image
|
||||||
|
//! is resized to the same target dimensions (default 448×448 for
|
||||||
|
//! Qwen3.6, configurable per-call) so the patch count is constant
|
||||||
|
//! per image. Variable resolution per [Qwen2VL convention] is tracked
|
||||||
|
//! as issue #14.
|
||||||
|
//!
|
||||||
|
//! Spec reference: `doc/vision-qwen3_6-spec.md` — preprocessor
|
||||||
|
//! section.
|
||||||
|
//!
|
||||||
|
//! Normalisation: pixel value `p ∈ [0, 255]` becomes
|
||||||
|
//! `(p/255 - mean) / std`. Qwen3.6's preprocessor_config.json
|
||||||
|
//! specifies `image_mean = image_std = [0.5, 0.5, 0.5]`, which
|
||||||
|
//! simplifies to `2p/255 - 1` mapping `[0,255]` → `[-1, 1]`. We
|
||||||
|
//! still parameterise mean/std so the same code generalises to other
|
||||||
|
//! VL families (Qwen2-VL uses imagenet stats, for instance).
|
||||||
|
//!
|
||||||
|
//! Pipeline (per image):
|
||||||
|
//! 1. data: URI → base64 decode → bytes
|
||||||
|
//! 2. bytes → image::DynamicImage (PNG/JPEG/WebP/etc)
|
||||||
|
//! 3. resize_exact to target H×W (pixel space)
|
||||||
|
//! 4. RGB→f32, normalise per mean/std
|
||||||
|
//! 5. layout to (C, H, W) tensor
|
||||||
|
//!
|
||||||
|
//! Patchification (cutting the HxW tensor into `patch_size` blocks)
|
||||||
|
//! happens inside the vision tower's `patch_embed` conv, so this
|
||||||
|
//! module stops at "preprocessed RGB f32 tensor."
|
||||||
|
|
||||||
|
use anyhow::{Context, Result, anyhow};
|
||||||
|
use base64::Engine;
|
||||||
|
use image::DynamicImage;
|
||||||
|
use image::imageops::FilterType;
|
||||||
|
|
||||||
|
/// Preprocessing target. Captures the resize dimensions and the
|
||||||
|
/// channel-wise normalisation constants from the model's
|
||||||
|
/// `preprocessor_config.json`. Stage A ships a single `qwen3_6()`
|
||||||
|
/// constructor for fixed-resolution Qwen3.6 preprocessing; other
|
||||||
|
/// models can ship their own profile when added.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PreprocessProfile {
|
||||||
|
pub target_height: u32,
|
||||||
|
pub target_width: u32,
|
||||||
|
pub image_mean: [f32; 3],
|
||||||
|
pub image_std: [f32; 3],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PreprocessProfile {
|
||||||
|
/// Stage A profile for Qwen3.6. Resize to 448×448, normalise to
|
||||||
|
/// `[-1, 1]` via mean=std=0.5. Fits within the model's
|
||||||
|
/// `num_position_embeddings=2304` budget at 28×28 = 784 patches
|
||||||
|
/// before merging.
|
||||||
|
pub fn qwen3_6() -> Self {
|
||||||
|
Self {
|
||||||
|
target_height: 448,
|
||||||
|
target_width: 448,
|
||||||
|
image_mean: [0.5, 0.5, 0.5],
|
||||||
|
image_std: [0.5, 0.5, 0.5],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-channel CHW tensor length: 3 * H * W.
|
||||||
|
pub fn pixels_chw(&self) -> usize {
|
||||||
|
3 * (self.target_height as usize) * (self.target_width as usize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode a `data:image/...;base64,...` URI into an in-memory image.
|
||||||
|
///
|
||||||
|
/// Accepts the OpenAI Chat Completions `image_url` shape — a string
|
||||||
|
/// URL with `data:` scheme and base64 payload. The MIME type is read
|
||||||
|
/// from the URI for diagnostics but `image::load_from_memory` sniffs
|
||||||
|
/// the format from the bytes themselves, so the MIME is advisory.
|
||||||
|
///
|
||||||
|
/// Bare `http(s)://` URLs are explicitly rejected here — fetching
|
||||||
|
/// them from a vision-model server is a fingerprintable behaviour
|
||||||
|
/// (server-side request forgery, infinite recursion if the URL
|
||||||
|
/// points at the gateway itself, etc.). Clients that want remote
|
||||||
|
/// images can fetch them and pass base64 themselves.
|
||||||
|
pub fn decode_data_uri(uri: &str) -> Result<DynamicImage> {
|
||||||
|
let after_scheme = uri
|
||||||
|
.strip_prefix("data:")
|
||||||
|
.ok_or_else(|| anyhow!("image_url must use data: scheme; got {uri:.40}…"))?;
|
||||||
|
let (meta, payload) = after_scheme
|
||||||
|
.split_once(',')
|
||||||
|
.ok_or_else(|| anyhow!("malformed data URI: missing ',' separator"))?;
|
||||||
|
if !meta.contains(";base64") {
|
||||||
|
anyhow::bail!(
|
||||||
|
"data URI must use base64 encoding (got '{meta}'); raw URL-encoded payloads not supported"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let bytes = base64::engine::general_purpose::STANDARD
|
||||||
|
.decode(payload.trim())
|
||||||
|
.context("base64-decode image data URI payload")?;
|
||||||
|
image::load_from_memory(&bytes).context("decode image bytes (PNG/JPEG/WebP/etc)")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resize and normalise an image into a `(3, H, W)` row-major
|
||||||
|
/// `Vec<f32>` ready to hand to the vision tower's `patch_embed`
|
||||||
|
/// conv.
|
||||||
|
///
|
||||||
|
/// Uses bilinear resampling — Qwen2-VL's reference uses bicubic, but
|
||||||
|
/// bilinear is what the candle ecosystem standardises on and is
|
||||||
|
/// faster on CPU. Quality difference is marginal for downstream
|
||||||
|
/// vision-encoder consumption. The numerical-validation issue (#15)
|
||||||
|
/// will quantify any discrepancy.
|
||||||
|
pub fn preprocess(img: &DynamicImage, profile: &PreprocessProfile) -> Vec<f32> {
|
||||||
|
let rgb = img
|
||||||
|
.resize_exact(
|
||||||
|
profile.target_width,
|
||||||
|
profile.target_height,
|
||||||
|
FilterType::Triangle,
|
||||||
|
)
|
||||||
|
.to_rgb8();
|
||||||
|
let h = profile.target_height as usize;
|
||||||
|
let w = profile.target_width as usize;
|
||||||
|
let mut out = vec![0.0_f32; 3 * h * w];
|
||||||
|
// Row-major (C, H, W). Candle's Conv2d expects NCHW, so this is
|
||||||
|
// the natural layout — the caller stacks `n` of these along the
|
||||||
|
// batch axis as needed.
|
||||||
|
for c in 0..3 {
|
||||||
|
let mean = profile.image_mean[c];
|
||||||
|
let std = profile.image_std[c];
|
||||||
|
for y in 0..h {
|
||||||
|
for x in 0..w {
|
||||||
|
let pixel = rgb.get_pixel(x as u32, y as u32);
|
||||||
|
let raw = pixel[c] as f32 / 255.0;
|
||||||
|
out[c * h * w + y * w + x] = (raw - mean) / std;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Combined helper: decode + preprocess in one call. Most call
|
||||||
|
/// sites just want the final tensor; the two-step path exists for
|
||||||
|
/// callers (tests, future video preprocessing) that need the
|
||||||
|
/// intermediate `DynamicImage`.
|
||||||
|
pub fn preprocess_data_uri(uri: &str, profile: &PreprocessProfile) -> Result<Vec<f32>> {
|
||||||
|
let img = decode_data_uri(uri)?;
|
||||||
|
Ok(preprocess(&img, profile))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use image::{ImageBuffer, Rgb};
|
||||||
|
|
||||||
|
/// A 1×1 red PNG, hand-built. Matches the well-known smallest
|
||||||
|
/// valid PNG we use in tests/curl examples elsewhere.
|
||||||
|
const ONE_BY_ONE_RED_PNG_B64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==";
|
||||||
|
|
||||||
|
fn red_png_uri() -> String {
|
||||||
|
format!("data:image/png;base64,{ONE_BY_ONE_RED_PNG_B64}")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decodes_well_formed_png_data_uri() {
|
||||||
|
let img = decode_data_uri(&red_png_uri()).expect("decode 1x1 png");
|
||||||
|
assert_eq!(img.width(), 1);
|
||||||
|
assert_eq!(img.height(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_non_data_scheme() {
|
||||||
|
let err = decode_data_uri("https://example.com/cat.jpg")
|
||||||
|
.expect_err("http(s) URLs must be rejected");
|
||||||
|
assert!(format!("{err:#}").contains("data:"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_malformed_uri_without_comma() {
|
||||||
|
let err = decode_data_uri("data:image/png;base64").unwrap_err();
|
||||||
|
assert!(format!("{err:#}").contains("','"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_non_base64_payload() {
|
||||||
|
let err = decode_data_uri("data:image/png,raw-bytes-here").unwrap_err();
|
||||||
|
assert!(format!("{err:#}").contains("base64"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_bad_base64_payload() {
|
||||||
|
let err = decode_data_uri("data:image/png;base64,not!valid!base64!").unwrap_err();
|
||||||
|
assert!(format!("{err:#}").contains("base64"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_garbage_image_bytes() {
|
||||||
|
// Valid base64 ("Hello World!"), invalid image bytes.
|
||||||
|
let err = decode_data_uri("data:image/png;base64,SGVsbG8gV29ybGQh").unwrap_err();
|
||||||
|
assert!(
|
||||||
|
format!("{err:#}").contains("decode image"),
|
||||||
|
"should fail at image-decode step"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preprocess_red_image_produces_correct_shape_and_values() {
|
||||||
|
let profile = PreprocessProfile::qwen3_6();
|
||||||
|
// Build a tiny pure-red image directly, skipping data: URI
|
||||||
|
// decoding so this test isolates the resize+normalise path.
|
||||||
|
let img: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_pixel(2, 2, Rgb([255, 0, 0]));
|
||||||
|
let dyn_img = DynamicImage::ImageRgb8(img);
|
||||||
|
let out = preprocess(&dyn_img, &profile);
|
||||||
|
|
||||||
|
assert_eq!(out.len(), profile.pixels_chw());
|
||||||
|
// After mean=0.5, std=0.5: red channel (255/255=1.0) → (1.0 - 0.5)/0.5 = 1.0
|
||||||
|
// green/blue (0.0) → (0.0 - 0.5)/0.5 = -1.0
|
||||||
|
let h = profile.target_height as usize;
|
||||||
|
let w = profile.target_width as usize;
|
||||||
|
assert!(
|
||||||
|
(out[0] - 1.0).abs() < 1e-5,
|
||||||
|
"R[0] should be 1.0, got {}",
|
||||||
|
out[0]
|
||||||
|
);
|
||||||
|
assert!((out[h * w] - (-1.0)).abs() < 1e-5, "G[0] should be -1.0");
|
||||||
|
assert!(
|
||||||
|
(out[2 * h * w] - (-1.0)).abs() < 1e-5,
|
||||||
|
"B[0] should be -1.0"
|
||||||
|
);
|
||||||
|
// All values are finite
|
||||||
|
assert!(out.iter().all(|v| v.is_finite()), "no NaN/Inf in output");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preprocess_data_uri_end_to_end() {
|
||||||
|
let profile = PreprocessProfile::qwen3_6();
|
||||||
|
let out = preprocess_data_uri(&red_png_uri(), &profile).expect("e2e preprocess");
|
||||||
|
assert_eq!(out.len(), profile.pixels_chw());
|
||||||
|
assert!(out.iter().all(|v| v.is_finite()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preprocess_grayscale_image_promotes_to_rgb() {
|
||||||
|
let profile = PreprocessProfile::qwen3_6();
|
||||||
|
// 1x1 grayscale = 200 → after conversion to RGB, all three
|
||||||
|
// channels equal 200, normalised → (200/255 - 0.5)/0.5 ≈ 0.569
|
||||||
|
let gray = DynamicImage::ImageLuma8(ImageBuffer::from_pixel(1, 1, image::Luma([200])));
|
||||||
|
let out = preprocess(&gray, &profile);
|
||||||
|
let expected = ((200.0 / 255.0) - 0.5) / 0.5;
|
||||||
|
let h = profile.target_height as usize;
|
||||||
|
let w = profile.target_width as usize;
|
||||||
|
for c in 0..3 {
|
||||||
|
let v = out[c * h * w];
|
||||||
|
assert!(
|
||||||
|
(v - expected).abs() < 1e-3,
|
||||||
|
"channel {c}: expected {expected}, got {v}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
176
doc/vision-qwen3_6-spec.md
Normal file
176
doc/vision-qwen3_6-spec.md
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
# Qwen3.6-27B vision specification (Stage A0)
|
||||||
|
|
||||||
|
Sourced from beast's local cache on 2026-06-01:
|
||||||
|
`/archive3/llm-cache/models--Qwen--Qwen3.6-27B/snapshots/6a9e13bd6fc8f0983b9b99948120bc37f49c13e9/`.
|
||||||
|
|
||||||
|
Single source of truth for Stages A–D of the vision plan in
|
||||||
|
`~/.claude/plans/foamy-twirling-catmull.md`. Umbrella issue:
|
||||||
|
[#3](https://git.lair.cafe/helexa/cortex/issues/3).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Top-level shape
|
||||||
|
|
||||||
|
The model is a unified text+vision architecture (`Qwen3_5ForConditionalGeneration`,
|
||||||
|
`model_type: qwen3_5`) with three weight sections under a single safetensors
|
||||||
|
index. Counts from `model.safetensors.index.json`:
|
||||||
|
|
||||||
|
| Prefix | Tensors | Role |
|
||||||
|
|---|---|---|
|
||||||
|
| `model.language_model.*` | 850 | LM (currently loaded) |
|
||||||
|
| `model.visual.*` | 333 | Vision tower (currently filtered out at `arch/qwen3_5/mod.rs:228-230`) |
|
||||||
|
| `mtp.*` | 15 | Multi-token-prediction heads (filtered, out of scope) |
|
||||||
|
| `lm_head.weight` | 1 | LM head |
|
||||||
|
|
||||||
|
Vision tensors live in shards `model-00007-of-00015.safetensors` and
|
||||||
|
`model-00008-of-00015.safetensors` (2 of the 15 safetensors). Loading just
|
||||||
|
these two for vision-tower-only smoke tests is feasible.
|
||||||
|
|
||||||
|
## Vision tower architecture (`model.visual.*`)
|
||||||
|
|
||||||
|
From `config.json::vision_config`:
|
||||||
|
|
||||||
|
```
|
||||||
|
depth: 27 (transformer blocks)
|
||||||
|
hidden_size: 1152 (vision token dim)
|
||||||
|
num_heads: 16 (per-block self-attention)
|
||||||
|
intermediate_size: 4304 (MLP hidden)
|
||||||
|
patch_size: 16 (16×16 spatial patches)
|
||||||
|
temporal_patch_size: 2 (video frame pairing; irrelevant for stills)
|
||||||
|
spatial_merge_size: 2 (2×2 spatial merge in the merger → 4 patches/LM token)
|
||||||
|
num_position_embeddings: 2304 (learned pos embed slots — max patch sequence length)
|
||||||
|
in_channels: 3 (RGB)
|
||||||
|
hidden_act: gelu_pytorch_tanh (GELU with tanh approximation, not exact GELU)
|
||||||
|
out_hidden_size: 5120 (= LM hidden_size, merger output dim)
|
||||||
|
deepstack_visual_indexes: [] (no deep-stack visual indexes)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Module inventory (per-block and global)
|
||||||
|
|
||||||
|
Global:
|
||||||
|
- `model.visual.patch_embed.proj.{weight, bias}` — Conv2d (3 → 1152, kernel 16×16, stride 16). Turns image patches into tokens.
|
||||||
|
- `model.visual.pos_embed.weight` — Learned position embedding, shape `(2304, 1152)`.
|
||||||
|
- `model.visual.merger.{norm, linear_fc1, linear_fc2}` — The projector that merges 2×2 patches and projects to LM hidden_size (1152 → 5120). All weights have biases.
|
||||||
|
|
||||||
|
Per block (×27, named `model.visual.blocks.{0..26}`):
|
||||||
|
- `norm1.{weight, bias}` — **LayerNorm** before attention (with bias — not RmsNorm).
|
||||||
|
- `attn.qkv.{weight, bias}` — Fused QKV linear (1152 → 3·1152 = 3456).
|
||||||
|
- `attn.proj.{weight, bias}` — Attention output projection (1152 → 1152).
|
||||||
|
- `norm2.{weight, bias}` — LayerNorm before MLP.
|
||||||
|
- `mlp.linear_fc1.{weight, bias}` — MLP up-projection (1152 → 4304).
|
||||||
|
- `mlp.linear_fc2.{weight, bias}` — MLP down-projection (4304 → 1152).
|
||||||
|
|
||||||
|
Pattern matches a standard ViT block with **pre-norm** layout (norm → attn → residual, norm → MLP → residual). Activation between fc1/fc2 is GELU-tanh-approx per `hidden_act`. No attention masking inside the vision tower (all patches attend to each other).
|
||||||
|
|
||||||
|
### Forward signature (target)
|
||||||
|
|
||||||
|
```
|
||||||
|
VisionTower::forward(
|
||||||
|
patches: Tensor [N, in_channels, patch_size, patch_size], # CPU-preprocessed RGB float patches
|
||||||
|
grid_thw: Option<(usize, usize, usize)>, # (t, h, w) patch grid for position lookup
|
||||||
|
) -> Tensor [N / (spatial_merge_size²), out_hidden_size] # = (N/4, 5120) for static images
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: the merger consumes 4 spatially-adjacent patches and emits 1 LM token. So an image producing 64×64 = 4096 patches yields 1024 LM-side image tokens.
|
||||||
|
|
||||||
|
## Image preprocessor (`preprocessor_config.json`)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"size": { "longest_edge": 16777216, "shortest_edge": 65536 },
|
||||||
|
"patch_size": 16,
|
||||||
|
"temporal_patch_size": 2,
|
||||||
|
"merge_size": 2,
|
||||||
|
"image_mean": [0.5, 0.5, 0.5],
|
||||||
|
"image_std": [0.5, 0.5, 0.5],
|
||||||
|
"processor_class": "Qwen3VLProcessor",
|
||||||
|
"image_processor_type": "Qwen2VLImageProcessorFast"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Reading:
|
||||||
|
|
||||||
|
- `image_mean = image_std = 0.5` → normalisation is simply `(x/255 - 0.5) / 0.5 = 2*x/255 - 1`, mapping `[0,255]` → `[-1, 1]`. No imagenet-style mean/std.
|
||||||
|
- `size.{shortest_edge, longest_edge}` are **pixel counts**, not edge lengths. The `Qwen2VLImageProcessorFast` recipe picks a resolution within `[65,536 = 256², 16,777,216 = 4096²]` total pixels, snapping `h` and `w` to multiples of `patch_size × spatial_merge_size = 32` pixels.
|
||||||
|
- Stage A ships **fixed resolution**: pick a target pixel count (e.g. 448×448 = 200,704 px → 28×28 patches → 14×14 LM tokens after merger). Variable resolution deferred to issue [#14](https://git.lair.cafe/helexa/cortex/issues/14).
|
||||||
|
|
||||||
|
## Chat template (`chat_template.jinja`)
|
||||||
|
|
||||||
|
Image insertion (lines 8–18 of the template):
|
||||||
|
|
||||||
|
```jinja
|
||||||
|
{%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
|
||||||
|
...
|
||||||
|
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
|
||||||
|
```
|
||||||
|
|
||||||
|
Per image, the template emits **one `<|image_pad|>` token** flanked by `<|vision_start|>` and `<|vision_end|>` sentinels. The runtime must:
|
||||||
|
|
||||||
|
1. Render the template (preserving the single `<|image_pad|>` per image).
|
||||||
|
2. For each image, replace its single `<|image_pad|>` with N copies, where N is the number of LM tokens that image produces after the vision tower + merger (= `patches / spatial_merge_size²`).
|
||||||
|
3. Tokenize the expanded string → `input_ids`.
|
||||||
|
4. At forward time, locate positions where `input_ids == image_token_id` (248056) and splice in the vision tower's merger output.
|
||||||
|
|
||||||
|
Token IDs (top of `config.json`):
|
||||||
|
- `vision_start_token_id`: 248053
|
||||||
|
- `vision_end_token_id`: 248054
|
||||||
|
- `image_token_id`: 248056
|
||||||
|
- `video_token_id`: 248057 (out of scope)
|
||||||
|
- `bos_token_id`: 248044
|
||||||
|
- `eos_token_id`: 248044, 248046 (per `generation_config.json`)
|
||||||
|
|
||||||
|
System messages cannot contain images (template raises). Other template-side details:
|
||||||
|
- `add_vision_id` (jinja arg, default false): emits `'Picture N: '` prefixes when true.
|
||||||
|
- `preserve_thinking` (jinja arg, default false): keeps `<think>` blocks from prior assistant turns in the rendered prompt.
|
||||||
|
- `enable_thinking` (jinja arg, default true): emits `<think>\n` (or skips it) at the end of the generation prompt.
|
||||||
|
|
||||||
|
The existing chat-template renderer in `crates/neuron/src/harness/chat_template.rs` already passes `MessageContent::Parts` to the Jinja context as a `Value::Array`; the template's `is iterable` branch (line 6 of the template) handles them. **The path is structurally in place** — Stage B just needs to do the `<|image_pad|>` expansion + token-position-aware splice.
|
||||||
|
|
||||||
|
## LM-side considerations
|
||||||
|
|
||||||
|
The LM's RoPE config uses **multi-axis RoPE (MRoPE)**:
|
||||||
|
|
||||||
|
```
|
||||||
|
rope_parameters: {
|
||||||
|
mrope_interleaved: true,
|
||||||
|
mrope_section: [11, 11, 10], # text + height + width components
|
||||||
|
partial_rotary_factor: 0.25,
|
||||||
|
rope_theta: 10000000,
|
||||||
|
rope_type: "default"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
MRoPE encodes spatial position alongside text position so the LM attention layers can reason about image-token spatial structure. The LM's existing forward path *may or may not* already implement this — the qwen3_5 module's doc-comment notes "numerical correctness vs the reference Python is not yet validated." Verifying MRoPE behaviour in the language model is out of Stage A scope (vision tower only) but will be required in Stage B (LM splice) and is tracked under the numerical-validation issue [#15](https://git.lair.cafe/helexa/cortex/issues/15).
|
||||||
|
|
||||||
|
`max_position_embeddings = 262144` (256 K context), so context-length limits are not a constraint for vision.
|
||||||
|
|
||||||
|
## Iteration target decision
|
||||||
|
|
||||||
|
The vision tower has its own self-contained weight tree and is small (~333 tensors in 2 shards, hidden_size 1152 vs LM's 5120). For Stage A specifically (vision-tower-only smoke), we **don't need a smaller iteration model** — we can:
|
||||||
|
|
||||||
|
- Build the Rust `VisionTower` struct against the spec above.
|
||||||
|
- Run unit tests with random tensor weights matching the exact shapes → assert forward produces correct output shape with finite values.
|
||||||
|
- Optionally: a CUDA-integration test that loads just the 2 vision shards from beast's cache (or on a smaller GPU like quadbrat's Ampere) and runs encode on a real image. Doesn't require loading the 27B LM at all.
|
||||||
|
|
||||||
|
This sidesteps the "develop against a smaller VL model" question for Stage A. Stage B (LM splice → end-to-end chat with vision) is where iteration speed becomes pressing; revisit there. The default scope pick 2a (smaller iteration model) is therefore deferred to Stage B planning — issue [#13](https://git.lair.cafe/helexa/cortex/issues/13) covers deployment validation regardless.
|
||||||
|
|
||||||
|
## Concrete Stage A1+ inputs
|
||||||
|
|
||||||
|
- Add deps to `crates/neuron/Cargo.toml`:
|
||||||
|
- `image = "0.25"`
|
||||||
|
- `base64 = "0.22"`
|
||||||
|
- Stage A2 preprocessor target resolution (fixed): **448×448 → 28×28 patches → 14×14 = 196 image tokens per image**. This balances minimum-patch-count for cheap tests against the model's expected input range.
|
||||||
|
- Stage A3 module structure: one `VisionTower` struct holding `patch_embed: Conv2d`, `pos_embed: Embedding`, `blocks: Vec<VisionBlock>`, `merger: Merger`. `VisionBlock` carries `norm1`, `norm2`, `attn`, `mlp`. Hand-roll using candle primitives.
|
||||||
|
- Stage A4 weight loading: extend `Qwen3_5ForCausalLM::new()` to construct `Some(VisionTower::new(vb.pp("model.visual"), config))` when `vision_config` is present in the parsed config.
|
||||||
|
- Stage A5 worker job: `Job::EncodeImage { handle, patches: Vec<f32>, patch_shape: (usize, usize, usize, usize, usize), reply: oneshot<Result<Vec<f32>>> }`. Patch shape = `(N, C, T, H, W)` where T=1 for static images.
|
||||||
|
|
||||||
|
## What this doc does NOT settle (deferred to issues)
|
||||||
|
|
||||||
|
- Numerical correctness of `VisionTower` output vs Python transformers
|
||||||
|
→ issue [#15](https://git.lair.cafe/helexa/cortex/issues/15).
|
||||||
|
- Variable image resolution
|
||||||
|
→ issue [#14](https://git.lair.cafe/helexa/cortex/issues/14).
|
||||||
|
- TP-vision (multi-rank vision tower)
|
||||||
|
→ issue [#12](https://git.lair.cafe/helexa/cortex/issues/12).
|
||||||
|
- 27B production deployment
|
||||||
|
→ issue [#13](https://git.lair.cafe/helexa/cortex/issues/13).
|
||||||
Reference in New Issue
Block a user