From 24968e9233adc63362fab2db1f30d8ab083825d0 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Tue, 2 Jun 2026 15:33:00 +0300 Subject: [PATCH] =?UTF-8?q?feat(neuron):=20Stage=20B=20=E2=80=94=20end-to-?= =?UTF-8?q?end=20text+image=20chat=20for=20Qwen3.6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stage B of the vision plan (doc/vision-qwen3_6-spec.md). Wires the vision tower from Stage A through to a complete non-streaming chat completion: extract images from the request, preprocess, encode on the worker thread, splice embeddings into the LM input at `<|image_pad|>` positions, return coherent text response with `prompt_tokens` reflecting patch tokens. Closes the silent-drop class of failures from issue #3 — vision requests against Qwen3.6 now condition the model on the image instead of producing confident text-only hallucinations. Streaming for vision is Stage C. Deferred items tracked under #12 (TP-vision), #13 (27B production), #14 (dynamic resolution), #15 (numerical validation). What landed: - **B1 — `Qwen3_5Model::forward_with_vision`**: text-only `forward` unchanged; new method takes `(input_ids, offset, image_embeds, image_token_id)`, embeds tokens, locates `image_token_id` positions, splices via the new `splice_runs` helper. MRoPE applies text-positions to image tokens for Stage B (spatial MRoPE is the issue #15 numerical-validation follow-up). 2 unit tests for `splice_runs` covering contiguous + non-contiguous runs. - **B2 — `ModelArch::forward_with_vision` dispatch**: routes Qwen3_5Dense to the new method; other arches return an error. Defence-in-depth — the HTTP layer (B6) already rejects image content for non-vision models. - **B3 — `Job::ForwardLogitsWithImages`**: new worker variant carrying tokens + per-image `(pixels, c, h, w)` payloads. The dispatcher encodes each image (device-resident), concatenates the resulting embeddings, calls `arch.forward_with_vision`, and returns CPU logits. Image embeddings never copy back to CPU — the "tensors don't escape the worker" invariant from the per-device worker refactor still holds. Poisoned-worker drain path handles the new variant. - **B4 — Prompt builder**: - `request_has_images` detects image content cheaply. - `extract_images_from_request(request, profile)` walks `MessageContent::Parts`, decodes data URIs, runs `harness::preprocess::preprocess` per image, returns `Vec` in request order. - `expand_image_pad_tokens(input_ids, image_token_id, patches_per_image)` walks the tokenized prompt and replaces each `<|image_pad|>` (id 248056 for Qwen3.6) with N copies matching the per-image patch count. 4 unit tests. - `VisionMeta::from_config_path` peeks `config.json` at load time for `image_token_id`, vision_config patch/merge sizes, and derives `lm_tokens_per_image` for the Stage B fixed resolution. - **B5 — `chat_completion` vision routing**: detects image content, validates the loaded model has vision, expands the prompt, and calls a new `run_inference_with_images_via_worker` helper that does single-shot prefill + standard decode loop (KV cache holds the post-splice hidden states from prefill, so decode steps don't re-splice). Stage B skips chunked prefill for vision — at 448×448 fixed resolution the budget stays well under the activation-memory threshold. Long-vision chunking is Stage D follow-up. - **B6 — `InferenceError::VisionUnsupported`**: structured 400 with `code=vision_unsupported, model_id, suggestion` when an image request hits a non-vision model. Closes the agent0 failure mode where vision requests degraded silently. - **B7 — `ModelInfo.capabilities`**: per-model array (`["text"]` vs `["text", "vision"]`) in `/v1/models` and forwarded verbatim by cortex-gateway. Lets clients (litellm, agent0) gate image_url submission on the declared capability set. Optional in the wire format; defaults to empty for older clients. CI gate: cargo fmt --check, cargo clippy --workspace --all-targets -- -D warnings, cargo test --workspace (all 28 test groups ok, 124 lib tests). New unit-test counts: +2 splice_runs, +4 expand_image_pad. Manual verification (after RPMs deploy on beast): curl http://hanzalova.internal:31313/v1/chat/completions \ -H 'Content-Type: application/json' \ -d "{\"model\":\"Qwen/Qwen3.6-27B\", \"messages\":[{\"role\":\"user\",\"content\":[ {\"type\":\"text\",\"text\":\"What's in this image?\"}, {\"type\":\"image_url\",\"image_url\":{\"url\":\"data:image/jpeg;base64,...\"}} ]}], \"max_tokens\":120}" | jq Expect prompt_tokens > 196 (text + 196 patch tokens) and a response that references actual image content. Co-Authored-By: Claude Opus 4.7 --- crates/cortex-core/src/harness.rs | 10 + crates/neuron/src/api.rs | 36 ++ crates/neuron/src/harness/arch/qwen3_5/mod.rs | 221 ++++++++ crates/neuron/src/harness/candle.rs | 492 +++++++++++++++++- .../src/harness/device_worker/dispatch.rs | 85 ++- .../neuron/src/harness/device_worker/jobs.rs | 38 ++ .../neuron/src/harness/device_worker/mod.rs | 41 ++ 7 files changed, 904 insertions(+), 19 deletions(-) diff --git a/crates/cortex-core/src/harness.rs b/crates/cortex-core/src/harness.rs index 1fcae56..a03b6eb 100644 --- a/crates/cortex-core/src/harness.rs +++ b/crates/cortex-core/src/harness.rs @@ -44,6 +44,16 @@ pub struct ModelInfo { pub status: String, pub devices: Vec, pub vram_used_mb: Option, + /// Modalities this loaded model supports. Today: `["text"]` for + /// text-only checkpoints, `["text", "vision"]` for vision-capable + /// ones (Stage B7 of the vision plan). Clients like litellm / + /// agent0 can gate `image_url` submission on the advertised set. + /// + /// Optional in the wire format so older clients that don't read + /// it stay compatible. Default-empty for absent/older data, which + /// callers can interpret as "text". + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub capabilities: Vec, } /// What an inference harness must do, from neuron's perspective. diff --git a/crates/neuron/src/api.rs b/crates/neuron/src/api.rs index 3106ea8..55027dd 100644 --- a/crates/neuron/src/api.rs +++ b/crates/neuron/src/api.rs @@ -250,6 +250,18 @@ async fn chat_completions( })), ) .into_response(), + Err(InferenceError::VisionUnsupported { model_id }) => ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": format!( + "model '{model_id}' does not support image input" + ), + "code": "vision_unsupported", + "model_id": model_id, + "suggestion": "load a vision-capable model or remove image_url content parts", + })), + ) + .into_response(), Err(InferenceError::Other(e)) => ( StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": format!("{e:#}")})), @@ -289,6 +301,18 @@ async fn chat_completions( })), ) .into_response(), + Err(InferenceError::VisionUnsupported { model_id }) => ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": format!( + "model '{model_id}' does not support image input" + ), + "code": "vision_unsupported", + "model_id": model_id, + "suggestion": "load a vision-capable model or remove image_url content parts", + })), + ) + .into_response(), Err(InferenceError::Other(e)) => ( StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": format!("{e:#}")})), @@ -452,6 +476,18 @@ fn inference_error_response(err: InferenceError) -> axum::response::Response { })), ) .into_response(), + InferenceError::VisionUnsupported { model_id } => ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": format!( + "model '{model_id}' does not support image input" + ), + "code": "vision_unsupported", + "model_id": model_id, + "suggestion": "load a vision-capable model or remove image_url content parts", + })), + ) + .into_response(), InferenceError::Other(e) => ( StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": format!("{e:#}")})), diff --git a/crates/neuron/src/harness/arch/qwen3_5/mod.rs b/crates/neuron/src/harness/arch/qwen3_5/mod.rs index 1be0889..45a4c0d 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/mod.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/mod.rs @@ -221,6 +221,76 @@ fn default_partial_rotary_factor() -> f32 { 1.0 } +/// Splice rows from `img` into `h` at `positions`. Stage B helper. +/// +/// `h`: `(1, L, hidden)` — the LM's input embedding tensor after +/// `embed_tokens.forward`. +/// `img`: `(N_img, hidden)` — image embeddings, one row per +/// `<|image_pad|>` token in the prompt. Must already be in `h.dtype()`. +/// `positions`: indices into the `L` axis where image rows go; +/// `positions.len() == N_img`. +/// +/// Approach: group `positions` into contiguous runs (because the chat +/// template emits `<|vision_start|><|image_pad|>×N<|vision_end|>` — +/// the pad tokens for each image land in one contiguous span), then +/// `slice_assign` per run. For typical Qwen3.6 requests this is one +/// or two runs per image; `slice_assign` does one tensor copy per +/// run, which is cheap relative to the decoder forward pass. +fn splice_runs(h: &Tensor, img: &Tensor, positions: &[u32]) -> candle_core::Result { + debug_assert!( + !positions.is_empty(), + "splice_runs precondition: non-empty positions" + ); + let hidden = h.dim(2)?; + let mut out = h.clone(); + let mut img_offset = 0_usize; + let mut run_start = positions[0] as usize; + let mut run_end_exclusive = run_start + 1; + for &p in &positions[1..] { + let p = p as usize; + if p == run_end_exclusive { + run_end_exclusive = p + 1; + } else { + apply_run( + &mut out, + img, + &mut img_offset, + run_start, + run_end_exclusive, + hidden, + )?; + run_start = p; + run_end_exclusive = p + 1; + } + } + apply_run( + &mut out, + img, + &mut img_offset, + run_start, + run_end_exclusive, + hidden, + )?; + Ok(out) +} + +fn apply_run( + out: &mut Tensor, + img: &Tensor, + img_offset: &mut usize, + run_start: usize, + run_end_exclusive: usize, + hidden: usize, +) -> candle_core::Result<()> { + let run_len = run_end_exclusive - run_start; + let slice = img + .narrow(0, *img_offset, run_len)? + .reshape((1, run_len, hidden))?; + *out = out.slice_assign(&[0..1, run_start..run_end_exclusive, 0..hidden], &slice)?; + *img_offset += run_len; + Ok(()) +} + /// Qwen3-Next base transformer (embedding + decoder stack + final /// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can /// also build on it later — for now only `Qwen3_5ForCausalLM` is the @@ -304,8 +374,95 @@ impl Qwen3_5Model { } pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result { + self.forward_inner(input, offset, None, None) + } + + /// Forward with image-embedding splice. Stage B of the vision plan. + /// + /// `input_ids`: `(1, L)` token ids — same shape the text-only + /// `forward` accepts (single-batch; multi-batch vision is not in + /// scope today). + /// `image_embeds`: `(N_image_tokens, hidden_size)` — concatenation + /// of every image's post-merger embedding (`VisionTower::forward` + /// output), in the same order images appear in the input. The + /// caller has already done the per-image patch-count expansion of + /// `<|image_pad|>` tokens in `input_ids`, so `N_image_tokens` + /// equals the number of `image_token_id` positions in `input_ids`. + /// `image_token_id`: the sentinel token (e.g. 248056 for Qwen3.6). + /// + /// The splice replaces the LM's text-side embedding at each + /// `image_token_id` position with the corresponding row from + /// `image_embeds`. After the splice the decoder runs unchanged. + /// + /// **MRoPE gap.** Qwen3.6's `rope_parameters` declares MRoPE + /// (interleaved text/height/width axes); Stage B applies plain + /// text-position RoPE to image tokens. The model still attends + /// to image content but loses spatial structure that MRoPE-aware + /// position encoding would preserve. Tracked under issue #15 + /// (numerical validation) — quality benchmark from Stage D should + /// surface the impact, and the fix lives in `rope::RotaryEmbedding`. + pub fn forward_with_vision( + &mut self, + input_ids: &Tensor, + offset: usize, + image_embeds: &Tensor, + image_token_id: u32, + ) -> candle_core::Result { + self.forward_inner(input_ids, offset, Some(image_embeds), Some(image_token_id)) + } + + fn forward_inner( + &mut self, + input: &Tensor, + offset: usize, + image_embeds: Option<&Tensor>, + image_token_id: Option, + ) -> candle_core::Result { let (b, l) = input.dims2()?; let mut h = self.embed_tokens.forward(input)?; + // Splice image embeddings at `image_token_id` positions. The + // caller pre-expanded the prompt so every patch token in the + // image_embeds tensor has a matching position in `input`. We + // index_put the rows in place. + if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) { + // Locate image-token positions in input_ids. Operate on + // CPU since the input ids are tiny (max ~10k entries + // including the patch expansion) and the comparison is + // not in the per-step hot path. + let ids: Vec = input.flatten_all()?.to_vec1()?; + let mut positions: Vec = Vec::with_capacity(img.dim(0)?); + for (idx, id) in ids.iter().enumerate() { + if *id == tok_id { + positions.push(idx as u32); + } + } + let n_img_tokens = img.dim(0)?; + if positions.len() != n_img_tokens { + candle_core::bail!( + "forward_with_vision: prompt has {} image-token positions but \ + image_embeds carries {} tokens — call build_prompt_for_request to \ + ensure the per-image patch-count expansion has been applied", + positions.len(), + n_img_tokens, + ); + } + if !positions.is_empty() { + // Cast image_embeds to the LM's dtype so the splice + // produces a uniform tensor for the decoder stack. + let img = img.to_dtype(self.dtype)?; + // index_select would return the rows; we want to put. + // candle's slice_assign with explicit positions ranges + // doesn't exist; use scatter via index_select + an + // accumulator: build a `(B, L, hidden)` zero tensor, + // scatter the image rows in, then add to a masked + // version of `h`. Simpler approach: walk positions + // and use `slice_assign` for contiguous runs. Since + // image_pad runs are contiguous (template emits + // `<|vision_start|><|image_pad|>×N<|vision_end|>`), + // we group positions and assign per run. + h = splice_runs(&h, &img, &positions)?; + } + } // Causal mask only needed for L > 1 prefill; full-attention // layers consume it via broadcast_add. Linear-attention layers // ignore the mask. @@ -406,6 +563,24 @@ impl Qwen3_5ForCausalLM { hidden.i((.., l - 1.., ..))?.apply(&self.lm_head) } + /// Stage B: forward with image-embedding splice. Mirrors `forward` + /// but routes through `Qwen3_5Model::forward_with_vision` so the + /// LM's input embeddings get the image patches spliced in at + /// `image_token_id` positions before the decoder stack runs. + pub fn forward_with_vision( + &mut self, + input: &Tensor, + offset: usize, + image_embeds: &Tensor, + image_token_id: u32, + ) -> candle_core::Result { + let (_, l) = input.dims2()?; + let hidden = self + .base + .forward_with_vision(input, offset, image_embeds, image_token_id)?; + hidden.i((.., l - 1.., ..))?.apply(&self.lm_head) + } + pub fn clear_kv_cache(&mut self) { self.base.clear_kv_cache(); } @@ -463,4 +638,50 @@ mod tests { assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0); assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6); } + + /// `splice_runs` replaces (1, L, H) embedding rows at the given + /// positions with rows from a (N_img, H) image-embedding tensor, + /// in the order positions are supplied. + #[test] + fn splice_runs_replaces_at_contiguous_positions() { + use candle_core::{DType, Device}; + + let dev = Device::Cpu; + // (1, L=5, H=2) text embeddings — encoded as floats so the + // assertion can spot the change without dtype conversion. + let h_vals: Vec = vec![ + 10., 11., // pos 0 + 20., 21., // pos 1 + 30., 31., // pos 2 + 40., 41., // pos 3 + 50., 51., // pos 4 + ]; + let h = Tensor::from_vec(h_vals, (1, 5, 2), &dev).unwrap(); + + // Two image embeddings to splice at positions 1 and 2 (a + // contiguous run — single image emitting two patch tokens). + let img_vals: Vec = vec![-1., -2., -3., -4.]; + let img = Tensor::from_vec(img_vals, (2, 2), &dev).unwrap(); + + let out = splice_runs(&h, &img, &[1, 2]).unwrap(); + let flat: Vec = out.flatten_all().unwrap().to_vec1().unwrap(); + assert_eq!(flat, vec![10., 11., -1., -2., -3., -4., 40., 41., 50., 51.]); + let _ = DType::F32; + } + + /// Non-contiguous positions: two images at positions [1] and [3] + /// each contributing one patch. `splice_runs` should iterate + /// runs and place the corresponding image rows. + #[test] + fn splice_runs_handles_non_contiguous_runs() { + use candle_core::Device; + let dev = Device::Cpu; + let h_vals: Vec = vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]; + let h = Tensor::from_vec(h_vals, (1, 5, 2), &dev).unwrap(); + let img_vals: Vec = vec![-1., -2., -3., -4.]; + let img = Tensor::from_vec(img_vals, (2, 2), &dev).unwrap(); + let out = splice_runs(&h, &img, &[1, 3]).unwrap(); + let flat: Vec = out.flatten_all().unwrap().to_vec1().unwrap(); + assert_eq!(flat, vec![1., 1., -1., -2., 3., 3., -3., -4., 5., 5.]); + } } diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index c1bd0ef..324e860 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -105,6 +105,22 @@ impl LoadedHandle { LoadedHandle::Tp(m) => m.poisoned.load(Ordering::Acquire), } } + + /// Modalities the loaded model supports. Stage B7. TP models are + /// always text-only today — TP-vision is tracked under issue #12. + pub fn capabilities(&self) -> Vec { + let mut caps = vec!["text".to_string()]; + match self { + LoadedHandle::Single(m) => { + if m.has_vision { + caps.push("vision".to_string()); + } + } + #[cfg(feature = "cuda")] + LoadedHandle::Tp(_) => {} + } + caps + } } /// A loaded model with its tokenizer, device placement, and architecture- @@ -182,6 +198,25 @@ pub struct LoadedModel { /// is used. The `NEURON_USE_CHAT_TEMPLATE=false` env var /// forces the fallback path even when `Some`. pub chat_template: Option, + /// Vision capability flag derived at load time. `true` iff the + /// loaded `ModelArch` exposes a vision tower (Stage A4 wires this + /// from `Qwen3_5ForCausalLM::has_vision`). Used by the chat + /// completion handler to reject image content on non-vision + /// models with a structured 400 (Stage B6) and by `/v1/models` + /// to advertise `capabilities: ["text", "vision"]` (Stage B7). + pub has_vision: bool, + /// `<|image_pad|>` token id from `config.json::image_token_id`. + /// The Stage B prompt-builder uses this to compute expansion + /// targets and the worker forward uses it to locate splice + /// positions in the LM input embeddings. + pub image_token_id: Option, + /// LM-side tokens this model's vision tower emits per image at + /// the Stage B fixed resolution (448×448 → 196 for Qwen3.6). + /// `None` for text-only models. Set at load time so the + /// hot path doesn't recompute it per request. Stage B fixed + /// resolution → constant; dynamic resolution per #14 makes it + /// per-image. + pub lm_tokens_per_image: Option, } impl LoadedModel { @@ -333,6 +368,35 @@ impl ModelArch { } } + /// Forward step that splices vision-tower output at + /// `<|image_pad|>` token positions. Stage B2. + /// + /// Only `Qwen3_5Dense` supports this — other architectures error + /// because they don't have a vision tower. The HTTP layer is + /// expected to have rejected image content for non-vision models + /// already (Stage B6); this is a defence-in-depth error path. + /// + /// Returns rank-1 `[vocab_size]` logits, same shape contract as + /// `forward`. + pub fn forward_with_vision( + &mut self, + input: &Tensor, + offset: usize, + image_embeds: &Tensor, + image_token_id: u32, + ) -> Result { + let raw = match self { + ModelArch::Qwen3_5Dense(m) => { + m.forward_with_vision(input, offset, image_embeds, image_token_id)? + } + other => anyhow::bail!( + "forward_with_vision: architecture {} has no vision tower", + std::any::type_name_of_val(other) + ), + }; + squeeze_to_vocab(&raw) + } + /// Encode a preprocessed image into LM-side token embeddings via /// the loaded vision tower. Stage A5. /// @@ -1548,9 +1612,54 @@ impl CandleHarness { .tokenizer .encode(prompt.as_str(), true) .map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?; - let prompt_tokens: Vec = encoding.get_ids().to_vec(); - let prompt_len = prompt_tokens.len(); + let mut prompt_tokens: Vec = encoding.get_ids().to_vec(); + // Stage B: when the request carries images, preprocess + // them, expand each `<|image_pad|>` sentinel to N copies + // matching the per-image patch count, and route to the + // vision-aware worker path. Non-image requests skip all + // of this and follow the existing text-only flow. + let vision_route = if request_has_images(&request) { + // Stage B6: surface a structured `vision_unsupported` + // rejection when the request asks for vision against a + // text-only model. Cheap and stops the issue-#3 silent- + // drop pattern. + if !loaded.has_vision { + return Err(InferenceError::VisionUnsupported { + model_id: request.model.clone(), + }); + } + let image_token_id = loaded + .image_token_id + .ok_or_else(|| InferenceError::VisionUnsupported { + model_id: request.model.clone(), + })?; + let patches_per_image = loaded + .lm_tokens_per_image + .ok_or_else(|| InferenceError::VisionUnsupported { + model_id: request.model.clone(), + })?; + let profile = super::preprocess::PreprocessProfile::qwen3_6(); + let images = extract_images_from_request(&request, &profile).map_err(|e| { + InferenceError::Other(anyhow::anyhow!("extract_images: {e}")) + })?; + if images.is_empty() { + // request_has_images said true but extract returned + // empty — defensive bail rather than silently dropping. + return Err(InferenceError::Other(anyhow::anyhow!( + "request has image content but extractor produced zero images" + ))); + } + let per_image_counts: Vec = vec![patches_per_image; images.len()]; + prompt_tokens = + expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts) + .map_err(InferenceError::Other)?; + Some((images, image_token_id)) + } else { + None + }; + + let prompt_len = prompt_tokens.len(); let temperature = request.temperature.unwrap_or(0.7); let top_p = request.top_p; let max_new = request.max_tokens.unwrap_or(8192) as usize; @@ -1570,6 +1679,7 @@ impl CandleHarness { ?eos_id, vram_free_mb, vram_total_mb, + vision = vision_route.is_some(), "chat_completion: starting" ); @@ -1588,18 +1698,37 @@ impl CandleHarness { // Worker path (CUDA). #[cfg(feature = "cuda")] { - match run_inference_via_worker( - worker, - handle, - &prompt_tokens, - max_new, - temperature, - top_p, - seed, - eos_id, - ) - .await - { + let result = match &vision_route { + Some((images, image_token_id)) => { + run_inference_with_images_via_worker( + worker, + handle, + &prompt_tokens, + images.clone(), + *image_token_id, + max_new, + temperature, + top_p, + seed, + eos_id, + ) + .await + } + None => { + run_inference_via_worker( + worker, + handle, + &prompt_tokens, + max_new, + temperature, + top_p, + seed, + eos_id, + ) + .await + } + }; + match result { Ok(v) => v, Err(e) => { let chain = format!("{e:#}"); @@ -2120,6 +2249,7 @@ impl Harness for CandleHarness { }, devices: h.devices(), vram_used_mb: None, + capabilities: h.capabilities(), }) .collect()) } @@ -2189,7 +2319,7 @@ impl Harness for CandleHarness { _ => None, }; - let (tokenizer_path, arch_local, arch_handle) = if let Some(w) = &worker { + let (tokenizer_path, arch_local, arch_handle, vision_meta) = if let Some(w) = &worker { // CUDA path: resolve, then load in the worker. if spec.quant.is_some() { let (gguf_path, tokenizer_path) = self.resolve_files(spec, &source_id).await?; @@ -2197,15 +2327,19 @@ impl Harness for CandleHarness { .load_gguf(gguf_path, spec.model_id.clone()) .await .map_err(|e| anyhow::anyhow!("worker load_gguf: {e}"))?; - (tokenizer_path, None, Some(handle)) + // GGUF Qwen3.6 releases don't ship the vision tower + // (Qwen-VL weights are in the dense safetensors only), + // so a GGUF load is text-only by construction. + (tokenizer_path, None, Some(handle), VisionMeta::default()) } else { let (config_path, tokenizer_path, safetensors_paths) = self.resolve_dense_files(spec, &source_id).await?; + let meta = VisionMeta::from_config_path(&config_path); let handle = w .load_dense(config_path, safetensors_paths, spec.model_id.clone()) .await .map_err(|e| anyhow::anyhow!("worker load_dense: {e}"))?; - (tokenizer_path, None, Some(handle)) + (tokenizer_path, None, Some(handle), meta) } } else { // CPU path: legacy spawn_blocking + Arc>. @@ -2214,7 +2348,16 @@ impl Harness for CandleHarness { } else { self.load_arch_dense(spec, &source_id, &device).await? }; - (tokenizer_path, Some(Arc::new(Mutex::new(arch))), None) + // CPU Qwen3.6 isn't a supported deployment target — the + // 27B doesn't fit any reasonable CPU memory budget — so + // we don't attempt to reach into the arch for vision + // metadata. Stays text-only. + ( + tokenizer_path, + Some(Arc::new(Mutex::new(arch))), + None, + VisionMeta::default(), + ) }; let tokenizer = Tokenizer::from_file(&tokenizer_path) @@ -2278,6 +2421,9 @@ impl Harness for CandleHarness { reasoning_tokens, tool_call_tokens, chat_template, + has_vision: vision_meta.has_vision, + image_token_id: vision_meta.image_token_id, + lm_tokens_per_image: vision_meta.lm_tokens_per_image, }); let mut models = self.models.write().await; @@ -3434,6 +3580,16 @@ pub enum InferenceError { "insufficient free VRAM for prefill: {free_mb} MiB free, need at least {required_mb} MiB" )] InsufficientVram { free_mb: u64, required_mb: u64 }, + /// Request carried `image_url` content but the loaded model has + /// no vision tower. Stage B6 — replaces the silent-drop pattern + /// from issue #3 with an explicit 400 + `vision_unsupported` + /// error body that clients (litellm, agent0, …) can act on. + #[error( + "model '{model_id}' does not support image input; \ + load a vision-capable model (e.g. Qwen/Qwen3.6-27B) or \ + remove the image_url content parts from the request" + )] + VisionUnsupported { model_id: String }, #[error(transparent)] Other(#[from] anyhow::Error), } @@ -3498,6 +3654,169 @@ fn build_prompt_for_request( } } +/// Vision metadata derived at model-load time. Stashed on +/// `LoadedModel` so the chat-completion hot path doesn't have to +/// re-parse `config.json` or reach across the worker thread to peek +/// at the loaded `ModelArch`. +#[derive(Debug, Default, Clone, Copy)] +struct VisionMeta { + has_vision: bool, + image_token_id: Option, + /// LM-side tokens this model's vision tower emits per image at + /// the Stage B fixed `PreprocessProfile::qwen3_6()` resolution + /// (448×448). Equal to `(H/patch_size/spatial_merge_size)²`. + lm_tokens_per_image: Option, +} + +impl VisionMeta { + /// Peek at `config.json` for vision-related fields. Returns the + /// default (no-vision) struct on any read/parse error — vision is + /// best-effort metadata; load can still succeed for text usage. + fn from_config_path(config_path: &std::path::Path) -> Self { + let Ok(text) = std::fs::read_to_string(config_path) else { + return Self::default(); + }; + let Ok(v) = serde_json::from_str::(&text) else { + return Self::default(); + }; + let Some(vision_config) = v.get("vision_config") else { + return Self::default(); + }; + let patch_size = vision_config + .get("patch_size") + .and_then(|x| x.as_u64()) + .unwrap_or(16) as usize; + let spatial_merge_size = vision_config + .get("spatial_merge_size") + .and_then(|x| x.as_u64()) + .unwrap_or(2) as usize; + let image_token_id = v + .get("image_token_id") + .and_then(|x| x.as_u64()) + .map(|n| n as u32); + // Compute LM tokens per image at the Stage B fixed resolution + // (PreprocessProfile::qwen3_6() → 448×448). One LM token per + // spatial-merge group of patches. + let target_h = super::preprocess::PreprocessProfile::qwen3_6().target_height as usize; + let target_w = super::preprocess::PreprocessProfile::qwen3_6().target_width as usize; + let lm_tokens_per_image = if patch_size > 0 && spatial_merge_size > 0 { + let gh = target_h / patch_size / spatial_merge_size; + let gw = target_w / patch_size / spatial_merge_size; + Some(gh * gw) + } else { + None + }; + Self { + has_vision: true, + image_token_id, + lm_tokens_per_image, + } + } +} + +/// True iff any message in the request carries an `image_url` +/// content part. The Stage B routing decision in `chat_completion` +/// dispatches to the vision-aware worker job when this is true. +fn request_has_images(request: &ChatCompletionRequest) -> bool { + request.messages.iter().any(|m| { + matches!(&m.content, MessageContent::Parts(parts) + if parts.iter().any(|p| + p.get("type").and_then(|v| v.as_str()) == Some("image_url"))) + }) +} + +/// Extract `image_url` content parts from a chat request and turn +/// each one into a preprocessed `ImageInput` ready for the device +/// worker. Stage B4. +/// +/// Walks `request.messages`, looking for `MessageContent::Parts` and +/// pulling out entries whose `type == "image_url"`. Each is run +/// through `harness::preprocess::decode_data_uri` + `preprocess` with +/// the supplied `profile` (Stage B always uses +/// `PreprocessProfile::qwen3_6()` — fixed 448×448 — so every image +/// produces the same patch count; dynamic resolution per issue #14 +/// would parameterise this). +/// +/// Returns images in the order they appear in the request, which +/// matches the order the chat template emits `<|image_pad|>` tokens. +fn extract_images_from_request( + request: &ChatCompletionRequest, + profile: &super::preprocess::PreprocessProfile, +) -> anyhow::Result> { + let mut out = Vec::new(); + for msg in &request.messages { + if let MessageContent::Parts(parts) = &msg.content { + for part in parts { + let kind = part.get("type").and_then(|v| v.as_str()).unwrap_or(""); + if kind != "image_url" { + continue; + } + let url = part + .get("image_url") + .and_then(|v| v.get("url")) + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("image_url part missing url field"))?; + let pixels = super::preprocess::preprocess_data_uri(url, profile) + .with_context(|| format!("preprocess image #{}", out.len()))?; + out.push(super::device_worker::jobs::ImageInput { + pixels, + c: 3, + h: profile.target_height as usize, + w: profile.target_width as usize, + }); + } + } + } + Ok(out) +} + +/// Expand each occurrence of `image_token_id` in `input_ids` into +/// `patches_per_image[i]` copies (one expansion per image, in order). +/// Stage B4 helper. +/// +/// The chat template emits a single `<|image_pad|>` per image; this +/// is what fits Qwen3-VL's template-then-runtime-expansion convention. +/// The runtime (us) is responsible for replacing each one with N +/// copies based on the corresponding image's patch count. +/// +/// For Stage B fixed resolution every entry of `patches_per_image` +/// is the same constant (196 at 448×448). For dynamic resolution +/// (issue #14) each image gets its own value. +/// +/// Errors if the number of `image_token_id` occurrences in `input_ids` +/// doesn't equal `patches_per_image.len()` — a mismatch means the +/// template emitted the wrong number of pad tokens (operator-visible +/// template bug, not a runtime error). +fn expand_image_pad_tokens( + input_ids: &[u32], + image_token_id: u32, + patches_per_image: &[usize], +) -> anyhow::Result> { + let occurrences = input_ids.iter().filter(|&&t| t == image_token_id).count(); + if occurrences != patches_per_image.len() { + anyhow::bail!( + "expand_image_pad_tokens: prompt has {occurrences} image_token_id occurrences but \ + {} images were preprocessed — chat template emitted the wrong number of pad tokens", + patches_per_image.len() + ); + } + let total_extra: usize = patches_per_image.iter().map(|n| n.saturating_sub(1)).sum(); + let mut out = Vec::with_capacity(input_ids.len() + total_extra); + let mut img_idx = 0; + for &t in input_ids { + if t == image_token_id { + let n = patches_per_image[img_idx]; + for _ in 0..n { + out.push(image_token_id); + } + img_idx += 1; + } else { + out.push(t); + } + } + Ok(out) +} + /// Apply the Qwen3 chat template: /// /// ```text @@ -3544,6 +3863,103 @@ fn format_qwen3_prompt(messages: &[ChatMessage]) -> String { /// would only add channel overhead with no diagnostic benefit. #[cfg(feature = "cuda")] #[allow(clippy::too_many_arguments)] +/// Vision-aware analogue of `run_inference_via_worker`. Stage B5. +/// +/// Single-shot prefill carrying the pre-expanded prompt + the image +/// payloads. The worker encodes each image through the vision tower, +/// splices the resulting embeddings at `image_token_id` positions, +/// and returns the last-position logits. Decode steps thereafter +/// follow the existing text-only `forward_logits` path — the KV +/// cache holds the image-conditioned hidden states from prefill, so +/// no further splicing is needed. +/// +/// Stage B skips chunked prefill for vision (the fixed-resolution +/// budget — 196 image tokens at 448×448 + typical text — stays well +/// under the activation-memory threshold). Long-prompt-with-images +/// chunking is a Stage D follow-up. +#[allow(clippy::too_many_arguments)] +async fn run_inference_with_images_via_worker( + worker: &super::device_worker::DeviceWorkerHandle, + handle: super::device_worker::ArchHandle, + prompt_tokens: &[u32], + images: Vec, + image_token_id: u32, + max_new: usize, + temperature: f64, + top_p: Option, + seed: u64, + eos_id: Option, +) -> Result<(Vec, String)> { + let mut logits_processor = { + let sampling = if temperature <= 0.0 { + Sampling::ArgMax + } else { + match top_p { + Some(p) => Sampling::TopP { p, temperature }, + None => Sampling::All { temperature }, + } + }; + LogitsProcessor::from_sampling(seed, sampling) + }; + + let mut generated: Vec = Vec::new(); + let prompt_len = prompt_tokens.len(); + + worker + .clear_kv_cache(handle) + .await + .map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?; + + // Single-shot prefill with image splicing. + let logits_vec = worker + .forward_logits_with_images(handle, prompt_tokens.to_vec(), 0, images, image_token_id) + .await + .map_err(|e| anyhow::anyhow!("forward_logits_with_images: {e}"))?; + let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?; + let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) { + Ok(t) => t, + Err(e) => { + let health = logits_health_slice(&logits_vec); + tracing::warn!( + ?health, + "chat_completion (worker, vision): prefill sample failed; logits unhealthy" + ); + return Err(e); + } + }; + + if Some(next_token) == eos_id { + return Ok((generated, "stop".into())); + } + generated.push(next_token); + + for index in 0..max_new.saturating_sub(1) { + let logits_vec = worker + .forward_logits(handle, vec![next_token], prompt_len + index) + .await + .map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?; + let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?; + next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) { + Ok(t) => t, + Err(e) => { + let health = logits_health_slice(&logits_vec); + tracing::warn!( + step = index, + ?health, + "chat_completion (worker, vision): decode sample failed; logits unhealthy" + ); + return Err(e); + } + }; + if Some(next_token) == eos_id { + return Ok((generated, "stop".into())); + } + generated.push(next_token); + } + Ok((generated, "length".into())) +} + +#[cfg(feature = "cuda")] async fn run_inference_via_worker( worker: &super::device_worker::DeviceWorkerHandle, handle: super::device_worker::ArchHandle, @@ -4243,4 +4659,44 @@ mod tests { .expect("synth huggingface source should build"); assert_eq!(harness.default_source_scheme(), "huggingface"); } + + #[test] + fn expand_image_pad_replaces_single_token_with_n_copies() { + // Mimics the chat template's output: each image emits + // [vision_start, image_pad, vision_end]. After expansion + // with 3 patches/image we want + // [vision_start, pad×3, vision_end]. + let pad = 248056_u32; + let vstart = 248053_u32; + let vend = 248054_u32; + let input = vec![1, vstart, pad, vend, 2]; + let out = expand_image_pad_tokens(&input, pad, &[3]).unwrap(); + assert_eq!(out, vec![1, vstart, pad, pad, pad, vend, 2]); + } + + #[test] + fn expand_image_pad_handles_multiple_images() { + let pad = 248056_u32; + // Two images in one prompt; first gets 2 patches, second 3. + let input = vec![pad, 99, pad]; + let out = expand_image_pad_tokens(&input, pad, &[2, 3]).unwrap(); + assert_eq!(out, vec![pad, pad, 99, pad, pad, pad]); + } + + #[test] + fn expand_image_pad_errors_on_count_mismatch() { + let pad = 248056_u32; + // Prompt has 2 pad tokens but caller supplied 3 images. + let input = vec![pad, 99, pad]; + let err = expand_image_pad_tokens(&input, pad, &[2, 3, 4]).unwrap_err(); + assert!(format!("{err:#}").contains("emitted the wrong number")); + } + + #[test] + fn expand_image_pad_passes_through_when_no_images() { + let pad = 248056_u32; + let input = vec![1, 2, 3]; + let out = expand_image_pad_tokens(&input, pad, &[]).unwrap(); + assert_eq!(out, input); + } } diff --git a/crates/neuron/src/harness/device_worker/dispatch.rs b/crates/neuron/src/harness/device_worker/dispatch.rs index 0c0c500..8113dfd 100644 --- a/crates/neuron/src/harness/device_worker/dispatch.rs +++ b/crates/neuron/src/harness/device_worker/dispatch.rs @@ -16,10 +16,11 @@ use crate::harness::candle::ModelArch; #[cfg(feature = "cuda")] use crate::harness::device_worker::jobs::TpHandle; -use crate::harness::device_worker::jobs::{ArchHandle, Job}; +use crate::harness::device_worker::jobs::{ArchHandle, ImageInput, Job}; #[cfg(feature = "cuda")] use crate::harness::tp::TpLeaderModel; use crate::harness::tp::nccl_state::NcclState; +use anyhow::Context as _; use std::collections::HashMap; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; @@ -169,6 +170,24 @@ pub(crate) fn run(device_index: u32, rx: Receiver, poisoned: Arc { + let result = forward_logits_with_images( + &mut state, + handle, + &tokens, + offset, + images, + image_token_id, + ); + let _ = reply.send(result); + } Job::NcclInit { cfg, comm_id_hex, @@ -751,6 +770,67 @@ fn forward_logits( Ok(values) } +/// Run the LM forward with vision-tower image splicing. Stage B3. +/// +/// Encodes each image through the vision tower (`VisionTower::forward`, +/// dispatched via `ModelArch::encode_image`), concatenates the +/// resulting embeddings into a single `(N_total, hidden)` tensor, and +/// passes it to `ModelArch::forward_with_vision` along with the +/// prompt-expanded `tokens`. Image embeddings never leave the device. +/// +/// Returns CPU `[vocab]` logits — same shape contract as +/// `ForwardLogits` so the async sampler doesn't have to branch on the +/// presence of images. +fn forward_logits_with_images( + state: &mut DeviceWorkerState, + handle: ArchHandle, + tokens: &[u32], + offset: usize, + images: Vec, + image_token_id: u32, +) -> anyhow::Result> { + use candle_core::{DType, Tensor}; + + if images.is_empty() { + anyhow::bail!("ForwardLogitsWithImages dispatched with zero images"); + } + + let arch = state.models.get_mut(&handle).ok_or_else(|| { + anyhow::anyhow!("ForwardLogitsWithImages: no model for handle {}", handle.0) + })?; + + // Encode every image on the worker's device, collecting per-image + // post-merger embeddings as device-resident tensors. + let mut per_image: Vec = Vec::with_capacity(images.len()); + for (idx, img) in images.into_iter().enumerate() { + anyhow::ensure!( + img.pixels.len() == img.c * img.h * img.w, + "ForwardLogitsWithImages: image[{idx}] pixels length {} does not match shape ({}, {}, {})", + img.pixels.len(), + img.c, + img.h, + img.w, + ); + let image = Tensor::from_vec(img.pixels, (img.c, img.h, img.w), &state.device)?; + let embed = arch + .encode_image(&image) + .with_context(|| format!("encode image[{idx}]"))?; + per_image.push(embed); + } + // Concatenate per-image embeddings along the patch axis → + // (sum_of_patches, hidden). `Tensor::cat` keeps the result + // device-resident. + let image_embeds = Tensor::cat(&per_image.iter().collect::>(), 0)?; + + let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?; + let logits = arch.forward_with_vision(&input, offset, &image_embeds, image_token_id)?; + let values = logits + .to_dtype(DType::F32)? + .flatten_all()? + .to_vec1::()?; + Ok(values) +} + /// Run the vision tower on a single preprocessed image. Stage A5. /// /// `pixels` is a row-major `(c, h, w)` f32 image that the async-side @@ -830,6 +910,9 @@ fn drain_poisoned(job: Job, device_index: u32) { Job::EncodeImage { reply, .. } => { let _ = reply.send(Err(err())); } + Job::ForwardLogitsWithImages { reply, .. } => { + let _ = reply.send(Err(err())); + } Job::NcclInit { reply, .. } => { let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error { kind: "device_worker_poisoned".into(), diff --git a/crates/neuron/src/harness/device_worker/jobs.rs b/crates/neuron/src/harness/device_worker/jobs.rs index 820c1c9..38d2d98 100644 --- a/crates/neuron/src/harness/device_worker/jobs.rs +++ b/crates/neuron/src/harness/device_worker/jobs.rs @@ -28,6 +28,17 @@ pub struct ArchHandle(pub u64); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct TpHandle(pub u64); +/// One image payload for `Job::ForwardLogitsWithImages` / +/// `Job::EncodeImage`. Pixels are row-major `(c, h, w)` f32 — the +/// shape `harness::preprocess::preprocess` produces. Carries the +/// shape inline since `Vec` is rank-1. +pub struct ImageInput { + pub pixels: Vec, + pub c: usize, + pub h: usize, + pub w: usize, +} + /// One unit of work for the device worker. /// /// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the @@ -94,6 +105,33 @@ pub enum Job { offset: usize, reply: oneshot::Sender>>, }, + /// Run the LM forward with vision splicing in one round-trip. + /// Stage B3 of the vision plan. + /// + /// Inputs: + /// - `tokens`: prompt-expanded token ids (the caller has already + /// replaced each `<|image_pad|>` with N copies per the + /// per-image patch count, so `tokens` already contains exactly + /// `sum(n_i)` `image_token_id` entries across all images). + /// - `offset`: KV-cache position (same contract as `ForwardLogits`). + /// - `images`: one entry per image — preprocessed pixels plus the + /// `(c, h, w)` shape. Images are encoded on the worker via the + /// model's vision tower (`VisionTower::forward`), concatenated + /// in order, and spliced into the LM input embeddings at + /// `image_token_id` positions. + /// - `image_token_id`: the sentinel token (248056 for Qwen3.6). + /// + /// Returns flat CPU `[vocab]` logits, same as `ForwardLogits`. + /// Image embeddings stay device-resident — they're never copied + /// to CPU. The "tensors don't escape the worker" invariant holds. + ForwardLogitsWithImages { + handle: ArchHandle, + tokens: Vec, + offset: usize, + images: Vec, + image_token_id: u32, + reply: oneshot::Sender>>, + }, /// Encode one image through the model's vision tower. Stage A5 of /// the vision plan (`doc/vision-qwen3_6-spec.md`). /// diff --git a/crates/neuron/src/harness/device_worker/mod.rs b/crates/neuron/src/harness/device_worker/mod.rs index bc99e62..48e41df 100644 --- a/crates/neuron/src/harness/device_worker/mod.rs +++ b/crates/neuron/src/harness/device_worker/mod.rs @@ -313,6 +313,47 @@ impl DeviceWorkerHandle { } } + /// Forward with image-aware splicing in one round-trip. Stage B3. + /// + /// Encodes each image on the worker thread (device-resident), then + /// runs the LM forward with the embeddings spliced at + /// `image_token_id` positions. Returns CPU `[vocab]` logits, same + /// shape as `forward_logits`. Image embeddings never copy back to + /// CPU. + pub async fn forward_logits_with_images( + &self, + handle: ArchHandle, + tokens: Vec, + offset: usize, + images: Vec, + image_token_id: u32, + ) -> Result, WorkerError> { + if self.poisoned.load(Ordering::Acquire) { + return Err(WorkerError::Poisoned { + device_index: self.device_index, + }); + } + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::ForwardLogitsWithImages { + handle, + tokens, + offset, + images, + image_token_id, + reply: reply_tx, + }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + match reply_rx.await { + Ok(result) => result.map_err(WorkerError::from), + Err(_) => Err(WorkerError::Gone { + device_index: self.device_index, + }), + } + } + /// Encode a preprocessed image through the model's vision tower /// and return the resulting LM-side image embeddings as a /// flattened CPU `Vec`. Stage A5.