feat(neuron): Stage B — end-to-end text+image chat for Qwen3.6
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 33s
CI / CUDA type-check (push) Failing after 46s
CI / Clippy (push) Successful in 2m37s
build-prerelease / Build cortex binary (push) Successful in 4m32s
build-prerelease / Build neuron-blackwell (push) Failing after 5m35s
CI / Test (push) Successful in 6m40s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-ampere (push) Failing after 7m46s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
build-prerelease / Build neuron-ada (push) Failing after 4m51s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped

Stage B of the vision plan (doc/vision-qwen3_6-spec.md). Wires
the vision tower from Stage A through to a complete non-streaming
chat completion: extract images from the request, preprocess,
encode on the worker thread, splice embeddings into the LM input
at `<|image_pad|>` positions, return coherent text response with
`prompt_tokens` reflecting patch tokens.

Closes the silent-drop class of failures from issue #3 — vision
requests against Qwen3.6 now condition the model on the image
instead of producing confident text-only hallucinations.

Streaming for vision is Stage C. Deferred items tracked under
#12 (TP-vision), #13 (27B production), #14 (dynamic resolution),
#15 (numerical validation).

What landed:

- **B1 — `Qwen3_5Model::forward_with_vision`**: text-only `forward`
  unchanged; new method takes `(input_ids, offset, image_embeds,
  image_token_id)`, embeds tokens, locates `image_token_id`
  positions, splices via the new `splice_runs` helper. MRoPE
  applies text-positions to image tokens for Stage B (spatial
  MRoPE is the issue #15 numerical-validation follow-up). 2 unit
  tests for `splice_runs` covering contiguous + non-contiguous
  runs.

- **B2 — `ModelArch::forward_with_vision` dispatch**: routes
  Qwen3_5Dense to the new method; other arches return an error.
  Defence-in-depth — the HTTP layer (B6) already rejects image
  content for non-vision models.

- **B3 — `Job::ForwardLogitsWithImages`**: new worker variant
  carrying tokens + per-image `(pixels, c, h, w)` payloads. The
  dispatcher encodes each image (device-resident), concatenates
  the resulting embeddings, calls `arch.forward_with_vision`, and
  returns CPU logits. Image embeddings never copy back to CPU —
  the "tensors don't escape the worker" invariant from the
  per-device worker refactor still holds. Poisoned-worker drain
  path handles the new variant.

- **B4 — Prompt builder**:
  - `request_has_images` detects image content cheaply.
  - `extract_images_from_request(request, profile)` walks
    `MessageContent::Parts`, decodes data URIs, runs
    `harness::preprocess::preprocess` per image, returns
    `Vec<ImageInput>` in request order.
  - `expand_image_pad_tokens(input_ids, image_token_id,
    patches_per_image)` walks the tokenized prompt and replaces
    each `<|image_pad|>` (id 248056 for Qwen3.6) with N copies
    matching the per-image patch count. 4 unit tests.
  - `VisionMeta::from_config_path` peeks `config.json` at load
    time for `image_token_id`, vision_config patch/merge sizes,
    and derives `lm_tokens_per_image` for the Stage B fixed
    resolution.

- **B5 — `chat_completion` vision routing**: detects image
  content, validates the loaded model has vision, expands the
  prompt, and calls a new `run_inference_with_images_via_worker`
  helper that does single-shot prefill + standard decode loop
  (KV cache holds the post-splice hidden states from prefill, so
  decode steps don't re-splice). Stage B skips chunked prefill
  for vision — at 448×448 fixed resolution the budget stays well
  under the activation-memory threshold. Long-vision chunking is
  Stage D follow-up.

- **B6 — `InferenceError::VisionUnsupported`**: structured 400
  with `code=vision_unsupported, model_id, suggestion` when an
  image request hits a non-vision model. Closes the agent0
  failure mode where vision requests degraded silently.

- **B7 — `ModelInfo.capabilities`**: per-model array (`["text"]`
  vs `["text", "vision"]`) in `/v1/models` and forwarded verbatim
  by cortex-gateway. Lets clients (litellm, agent0) gate
  image_url submission on the declared capability set. Optional
  in the wire format; defaults to empty for older clients.

CI gate: cargo fmt --check, cargo clippy --workspace --all-targets
-- -D warnings, cargo test --workspace (all 28 test groups ok,
124 lib tests). New unit-test counts: +2 splice_runs, +4
expand_image_pad.

Manual verification (after RPMs deploy on beast):

  curl http://hanzalova.internal:31313/v1/chat/completions \
    -H 'Content-Type: application/json' \
    -d "{\"model\":\"Qwen/Qwen3.6-27B\", \"messages\":[{\"role\":\"user\",\"content\":[
      {\"type\":\"text\",\"text\":\"What's in this image?\"},
      {\"type\":\"image_url\",\"image_url\":{\"url\":\"data:image/jpeg;base64,...\"}}
    ]}], \"max_tokens\":120}" | jq

  Expect prompt_tokens > 196 (text + 196 patch tokens) and a
  response that references actual image content.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-06-02 15:33:00 +03:00
parent 7df84fed8f
commit 24968e9233
7 changed files with 904 additions and 19 deletions

View File

@@ -44,6 +44,16 @@ pub struct ModelInfo {
pub status: String, pub status: String,
pub devices: Vec<u32>, pub devices: Vec<u32>,
pub vram_used_mb: Option<u64>, pub vram_used_mb: Option<u64>,
/// Modalities this loaded model supports. Today: `["text"]` for
/// text-only checkpoints, `["text", "vision"]` for vision-capable
/// ones (Stage B7 of the vision plan). Clients like litellm /
/// agent0 can gate `image_url` submission on the advertised set.
///
/// Optional in the wire format so older clients that don't read
/// it stay compatible. Default-empty for absent/older data, which
/// callers can interpret as "text".
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub capabilities: Vec<String>,
} }
/// What an inference harness must do, from neuron's perspective. /// What an inference harness must do, from neuron's perspective.

View File

@@ -250,6 +250,18 @@ async fn chat_completions(
})), })),
) )
.into_response(), .into_response(),
Err(InferenceError::VisionUnsupported { model_id }) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!(
"model '{model_id}' does not support image input"
),
"code": "vision_unsupported",
"model_id": model_id,
"suggestion": "load a vision-capable model or remove image_url content parts",
})),
)
.into_response(),
Err(InferenceError::Other(e)) => ( Err(InferenceError::Other(e)) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})), Json(json!({"error": format!("{e:#}")})),
@@ -289,6 +301,18 @@ async fn chat_completions(
})), })),
) )
.into_response(), .into_response(),
Err(InferenceError::VisionUnsupported { model_id }) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!(
"model '{model_id}' does not support image input"
),
"code": "vision_unsupported",
"model_id": model_id,
"suggestion": "load a vision-capable model or remove image_url content parts",
})),
)
.into_response(),
Err(InferenceError::Other(e)) => ( Err(InferenceError::Other(e)) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})), Json(json!({"error": format!("{e:#}")})),
@@ -452,6 +476,18 @@ fn inference_error_response(err: InferenceError) -> axum::response::Response {
})), })),
) )
.into_response(), .into_response(),
InferenceError::VisionUnsupported { model_id } => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!(
"model '{model_id}' does not support image input"
),
"code": "vision_unsupported",
"model_id": model_id,
"suggestion": "load a vision-capable model or remove image_url content parts",
})),
)
.into_response(),
InferenceError::Other(e) => ( InferenceError::Other(e) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})), Json(json!({"error": format!("{e:#}")})),

View File

@@ -221,6 +221,76 @@ fn default_partial_rotary_factor() -> f32 {
1.0 1.0
} }
/// Splice rows from `img` into `h` at `positions`. Stage B helper.
///
/// `h`: `(1, L, hidden)` — the LM's input embedding tensor after
/// `embed_tokens.forward`.
/// `img`: `(N_img, hidden)` — image embeddings, one row per
/// `<|image_pad|>` token in the prompt. Must already be in `h.dtype()`.
/// `positions`: indices into the `L` axis where image rows go;
/// `positions.len() == N_img`.
///
/// Approach: group `positions` into contiguous runs (because the chat
/// template emits `<|vision_start|><|image_pad|>×N<|vision_end|>` —
/// the pad tokens for each image land in one contiguous span), then
/// `slice_assign` per run. For typical Qwen3.6 requests this is one
/// or two runs per image; `slice_assign` does one tensor copy per
/// run, which is cheap relative to the decoder forward pass.
fn splice_runs(h: &Tensor, img: &Tensor, positions: &[u32]) -> candle_core::Result<Tensor> {
debug_assert!(
!positions.is_empty(),
"splice_runs precondition: non-empty positions"
);
let hidden = h.dim(2)?;
let mut out = h.clone();
let mut img_offset = 0_usize;
let mut run_start = positions[0] as usize;
let mut run_end_exclusive = run_start + 1;
for &p in &positions[1..] {
let p = p as usize;
if p == run_end_exclusive {
run_end_exclusive = p + 1;
} else {
apply_run(
&mut out,
img,
&mut img_offset,
run_start,
run_end_exclusive,
hidden,
)?;
run_start = p;
run_end_exclusive = p + 1;
}
}
apply_run(
&mut out,
img,
&mut img_offset,
run_start,
run_end_exclusive,
hidden,
)?;
Ok(out)
}
fn apply_run(
out: &mut Tensor,
img: &Tensor,
img_offset: &mut usize,
run_start: usize,
run_end_exclusive: usize,
hidden: usize,
) -> candle_core::Result<()> {
let run_len = run_end_exclusive - run_start;
let slice = img
.narrow(0, *img_offset, run_len)?
.reshape((1, run_len, hidden))?;
*out = out.slice_assign(&[0..1, run_start..run_end_exclusive, 0..hidden], &slice)?;
*img_offset += run_len;
Ok(())
}
/// Qwen3-Next base transformer (embedding + decoder stack + final /// Qwen3-Next base transformer (embedding + decoder stack + final
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can /// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the /// also build on it later — for now only `Qwen3_5ForCausalLM` is the
@@ -304,8 +374,95 @@ impl Qwen3_5Model {
} }
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> { pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
self.forward_inner(input, offset, None, None)
}
/// Forward with image-embedding splice. Stage B of the vision plan.
///
/// `input_ids`: `(1, L)` token ids — same shape the text-only
/// `forward` accepts (single-batch; multi-batch vision is not in
/// scope today).
/// `image_embeds`: `(N_image_tokens, hidden_size)` — concatenation
/// of every image's post-merger embedding (`VisionTower::forward`
/// output), in the same order images appear in the input. The
/// caller has already done the per-image patch-count expansion of
/// `<|image_pad|>` tokens in `input_ids`, so `N_image_tokens`
/// equals the number of `image_token_id` positions in `input_ids`.
/// `image_token_id`: the sentinel token (e.g. 248056 for Qwen3.6).
///
/// The splice replaces the LM's text-side embedding at each
/// `image_token_id` position with the corresponding row from
/// `image_embeds`. After the splice the decoder runs unchanged.
///
/// **MRoPE gap.** Qwen3.6's `rope_parameters` declares MRoPE
/// (interleaved text/height/width axes); Stage B applies plain
/// text-position RoPE to image tokens. The model still attends
/// to image content but loses spatial structure that MRoPE-aware
/// position encoding would preserve. Tracked under issue #15
/// (numerical validation) — quality benchmark from Stage D should
/// surface the impact, and the fix lives in `rope::RotaryEmbedding`.
pub fn forward_with_vision(
&mut self,
input_ids: &Tensor,
offset: usize,
image_embeds: &Tensor,
image_token_id: u32,
) -> candle_core::Result<Tensor> {
self.forward_inner(input_ids, offset, Some(image_embeds), Some(image_token_id))
}
fn forward_inner(
&mut self,
input: &Tensor,
offset: usize,
image_embeds: Option<&Tensor>,
image_token_id: Option<u32>,
) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?; let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?; let mut h = self.embed_tokens.forward(input)?;
// Splice image embeddings at `image_token_id` positions. The
// caller pre-expanded the prompt so every patch token in the
// image_embeds tensor has a matching position in `input`. We
// index_put the rows in place.
if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
// Locate image-token positions in input_ids. Operate on
// CPU since the input ids are tiny (max ~10k entries
// including the patch expansion) and the comparison is
// not in the per-step hot path.
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?);
for (idx, id) in ids.iter().enumerate() {
if *id == tok_id {
positions.push(idx as u32);
}
}
let n_img_tokens = img.dim(0)?;
if positions.len() != n_img_tokens {
candle_core::bail!(
"forward_with_vision: prompt has {} image-token positions but \
image_embeds carries {} tokens — call build_prompt_for_request to \
ensure the per-image patch-count expansion has been applied",
positions.len(),
n_img_tokens,
);
}
if !positions.is_empty() {
// Cast image_embeds to the LM's dtype so the splice
// produces a uniform tensor for the decoder stack.
let img = img.to_dtype(self.dtype)?;
// index_select would return the rows; we want to put.
// candle's slice_assign with explicit positions ranges
// doesn't exist; use scatter via index_select + an
// accumulator: build a `(B, L, hidden)` zero tensor,
// scatter the image rows in, then add to a masked
// version of `h`. Simpler approach: walk positions
// and use `slice_assign` for contiguous runs. Since
// image_pad runs are contiguous (template emits
// `<|vision_start|><|image_pad|>×N<|vision_end|>`),
// we group positions and assign per run.
h = splice_runs(&h, &img, &positions)?;
}
}
// Causal mask only needed for L > 1 prefill; full-attention // Causal mask only needed for L > 1 prefill; full-attention
// layers consume it via broadcast_add. Linear-attention layers // layers consume it via broadcast_add. Linear-attention layers
// ignore the mask. // ignore the mask.
@@ -406,6 +563,24 @@ impl Qwen3_5ForCausalLM {
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head) hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
} }
/// Stage B: forward with image-embedding splice. Mirrors `forward`
/// but routes through `Qwen3_5Model::forward_with_vision` so the
/// LM's input embeddings get the image patches spliced in at
/// `image_token_id` positions before the decoder stack runs.
pub fn forward_with_vision(
&mut self,
input: &Tensor,
offset: usize,
image_embeds: &Tensor,
image_token_id: u32,
) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?;
let hidden = self
.base
.forward_with_vision(input, offset, image_embeds, image_token_id)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) { pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache(); self.base.clear_kv_cache();
} }
@@ -463,4 +638,50 @@ mod tests {
assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0); assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0);
assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6); assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6);
} }
/// `splice_runs` replaces (1, L, H) embedding rows at the given
/// positions with rows from a (N_img, H) image-embedding tensor,
/// in the order positions are supplied.
#[test]
fn splice_runs_replaces_at_contiguous_positions() {
use candle_core::{DType, Device};
let dev = Device::Cpu;
// (1, L=5, H=2) text embeddings — encoded as floats so the
// assertion can spot the change without dtype conversion.
let h_vals: Vec<f32> = vec![
10., 11., // pos 0
20., 21., // pos 1
30., 31., // pos 2
40., 41., // pos 3
50., 51., // pos 4
];
let h = Tensor::from_vec(h_vals, (1, 5, 2), &dev).unwrap();
// Two image embeddings to splice at positions 1 and 2 (a
// contiguous run — single image emitting two patch tokens).
let img_vals: Vec<f32> = vec![-1., -2., -3., -4.];
let img = Tensor::from_vec(img_vals, (2, 2), &dev).unwrap();
let out = splice_runs(&h, &img, &[1, 2]).unwrap();
let flat: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(flat, vec![10., 11., -1., -2., -3., -4., 40., 41., 50., 51.]);
let _ = DType::F32;
}
/// Non-contiguous positions: two images at positions [1] and [3]
/// each contributing one patch. `splice_runs` should iterate
/// runs and place the corresponding image rows.
#[test]
fn splice_runs_handles_non_contiguous_runs() {
use candle_core::Device;
let dev = Device::Cpu;
let h_vals: Vec<f32> = vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.];
let h = Tensor::from_vec(h_vals, (1, 5, 2), &dev).unwrap();
let img_vals: Vec<f32> = vec![-1., -2., -3., -4.];
let img = Tensor::from_vec(img_vals, (2, 2), &dev).unwrap();
let out = splice_runs(&h, &img, &[1, 3]).unwrap();
let flat: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(flat, vec![1., 1., -1., -2., 3., 3., -3., -4., 5., 5.]);
}
} }

View File

@@ -105,6 +105,22 @@ impl LoadedHandle {
LoadedHandle::Tp(m) => m.poisoned.load(Ordering::Acquire), LoadedHandle::Tp(m) => m.poisoned.load(Ordering::Acquire),
} }
} }
/// Modalities the loaded model supports. Stage B7. TP models are
/// always text-only today — TP-vision is tracked under issue #12.
pub fn capabilities(&self) -> Vec<String> {
let mut caps = vec!["text".to_string()];
match self {
LoadedHandle::Single(m) => {
if m.has_vision {
caps.push("vision".to_string());
}
}
#[cfg(feature = "cuda")]
LoadedHandle::Tp(_) => {}
}
caps
}
} }
/// A loaded model with its tokenizer, device placement, and architecture- /// A loaded model with its tokenizer, device placement, and architecture-
@@ -182,6 +198,25 @@ pub struct LoadedModel {
/// is used. The `NEURON_USE_CHAT_TEMPLATE=false` env var /// is used. The `NEURON_USE_CHAT_TEMPLATE=false` env var
/// forces the fallback path even when `Some`. /// forces the fallback path even when `Some`.
pub chat_template: Option<String>, pub chat_template: Option<String>,
/// Vision capability flag derived at load time. `true` iff the
/// loaded `ModelArch` exposes a vision tower (Stage A4 wires this
/// from `Qwen3_5ForCausalLM::has_vision`). Used by the chat
/// completion handler to reject image content on non-vision
/// models with a structured 400 (Stage B6) and by `/v1/models`
/// to advertise `capabilities: ["text", "vision"]` (Stage B7).
pub has_vision: bool,
/// `<|image_pad|>` token id from `config.json::image_token_id`.
/// The Stage B prompt-builder uses this to compute expansion
/// targets and the worker forward uses it to locate splice
/// positions in the LM input embeddings.
pub image_token_id: Option<u32>,
/// LM-side tokens this model's vision tower emits per image at
/// the Stage B fixed resolution (448×448 → 196 for Qwen3.6).
/// `None` for text-only models. Set at load time so the
/// hot path doesn't recompute it per request. Stage B fixed
/// resolution → constant; dynamic resolution per #14 makes it
/// per-image.
pub lm_tokens_per_image: Option<usize>,
} }
impl LoadedModel { impl LoadedModel {
@@ -333,6 +368,35 @@ impl ModelArch {
} }
} }
/// Forward step that splices vision-tower output at
/// `<|image_pad|>` token positions. Stage B2.
///
/// Only `Qwen3_5Dense` supports this — other architectures error
/// because they don't have a vision tower. The HTTP layer is
/// expected to have rejected image content for non-vision models
/// already (Stage B6); this is a defence-in-depth error path.
///
/// Returns rank-1 `[vocab_size]` logits, same shape contract as
/// `forward`.
pub fn forward_with_vision(
&mut self,
input: &Tensor,
offset: usize,
image_embeds: &Tensor,
image_token_id: u32,
) -> Result<Tensor> {
let raw = match self {
ModelArch::Qwen3_5Dense(m) => {
m.forward_with_vision(input, offset, image_embeds, image_token_id)?
}
other => anyhow::bail!(
"forward_with_vision: architecture {} has no vision tower",
std::any::type_name_of_val(other)
),
};
squeeze_to_vocab(&raw)
}
/// Encode a preprocessed image into LM-side token embeddings via /// Encode a preprocessed image into LM-side token embeddings via
/// the loaded vision tower. Stage A5. /// the loaded vision tower. Stage A5.
/// ///
@@ -1548,9 +1612,54 @@ impl CandleHarness {
.tokenizer .tokenizer
.encode(prompt.as_str(), true) .encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?; .map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec(); let mut prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let prompt_len = prompt_tokens.len();
// Stage B: when the request carries images, preprocess
// them, expand each `<|image_pad|>` sentinel to N copies
// matching the per-image patch count, and route to the
// vision-aware worker path. Non-image requests skip all
// of this and follow the existing text-only flow.
let vision_route = if request_has_images(&request) {
// Stage B6: surface a structured `vision_unsupported`
// rejection when the request asks for vision against a
// text-only model. Cheap and stops the issue-#3 silent-
// drop pattern.
if !loaded.has_vision {
return Err(InferenceError::VisionUnsupported {
model_id: request.model.clone(),
});
}
let image_token_id = loaded
.image_token_id
.ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(),
})?;
let patches_per_image = loaded
.lm_tokens_per_image
.ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(),
})?;
let profile = super::preprocess::PreprocessProfile::qwen3_6();
let images = extract_images_from_request(&request, &profile).map_err(|e| {
InferenceError::Other(anyhow::anyhow!("extract_images: {e}"))
})?;
if images.is_empty() {
// request_has_images said true but extract returned
// empty — defensive bail rather than silently dropping.
return Err(InferenceError::Other(anyhow::anyhow!(
"request has image content but extractor produced zero images"
)));
}
let per_image_counts: Vec<usize> = vec![patches_per_image; images.len()];
prompt_tokens =
expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
.map_err(InferenceError::Other)?;
Some((images, image_token_id))
} else {
None
};
let prompt_len = prompt_tokens.len();
let temperature = request.temperature.unwrap_or(0.7); let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p; let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(8192) as usize; let max_new = request.max_tokens.unwrap_or(8192) as usize;
@@ -1570,6 +1679,7 @@ impl CandleHarness {
?eos_id, ?eos_id,
vram_free_mb, vram_free_mb,
vram_total_mb, vram_total_mb,
vision = vision_route.is_some(),
"chat_completion: starting" "chat_completion: starting"
); );
@@ -1588,18 +1698,37 @@ impl CandleHarness {
// Worker path (CUDA). // Worker path (CUDA).
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
{ {
match run_inference_via_worker( let result = match &vision_route {
worker, Some((images, image_token_id)) => {
handle, run_inference_with_images_via_worker(
&prompt_tokens, worker,
max_new, handle,
temperature, &prompt_tokens,
top_p, images.clone(),
seed, *image_token_id,
eos_id, max_new,
) temperature,
.await top_p,
{ seed,
eos_id,
)
.await
}
None => {
run_inference_via_worker(
worker,
handle,
&prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
)
.await
}
};
match result {
Ok(v) => v, Ok(v) => v,
Err(e) => { Err(e) => {
let chain = format!("{e:#}"); let chain = format!("{e:#}");
@@ -2120,6 +2249,7 @@ impl Harness for CandleHarness {
}, },
devices: h.devices(), devices: h.devices(),
vram_used_mb: None, vram_used_mb: None,
capabilities: h.capabilities(),
}) })
.collect()) .collect())
} }
@@ -2189,7 +2319,7 @@ impl Harness for CandleHarness {
_ => None, _ => None,
}; };
let (tokenizer_path, arch_local, arch_handle) = if let Some(w) = &worker { let (tokenizer_path, arch_local, arch_handle, vision_meta) = if let Some(w) = &worker {
// CUDA path: resolve, then load in the worker. // CUDA path: resolve, then load in the worker.
if spec.quant.is_some() { if spec.quant.is_some() {
let (gguf_path, tokenizer_path) = self.resolve_files(spec, &source_id).await?; let (gguf_path, tokenizer_path) = self.resolve_files(spec, &source_id).await?;
@@ -2197,15 +2327,19 @@ impl Harness for CandleHarness {
.load_gguf(gguf_path, spec.model_id.clone()) .load_gguf(gguf_path, spec.model_id.clone())
.await .await
.map_err(|e| anyhow::anyhow!("worker load_gguf: {e}"))?; .map_err(|e| anyhow::anyhow!("worker load_gguf: {e}"))?;
(tokenizer_path, None, Some(handle)) // GGUF Qwen3.6 releases don't ship the vision tower
// (Qwen-VL weights are in the dense safetensors only),
// so a GGUF load is text-only by construction.
(tokenizer_path, None, Some(handle), VisionMeta::default())
} else { } else {
let (config_path, tokenizer_path, safetensors_paths) = let (config_path, tokenizer_path, safetensors_paths) =
self.resolve_dense_files(spec, &source_id).await?; self.resolve_dense_files(spec, &source_id).await?;
let meta = VisionMeta::from_config_path(&config_path);
let handle = w let handle = w
.load_dense(config_path, safetensors_paths, spec.model_id.clone()) .load_dense(config_path, safetensors_paths, spec.model_id.clone())
.await .await
.map_err(|e| anyhow::anyhow!("worker load_dense: {e}"))?; .map_err(|e| anyhow::anyhow!("worker load_dense: {e}"))?;
(tokenizer_path, None, Some(handle)) (tokenizer_path, None, Some(handle), meta)
} }
} else { } else {
// CPU path: legacy spawn_blocking + Arc<Mutex<ModelArch>>. // CPU path: legacy spawn_blocking + Arc<Mutex<ModelArch>>.
@@ -2214,7 +2348,16 @@ impl Harness for CandleHarness {
} else { } else {
self.load_arch_dense(spec, &source_id, &device).await? self.load_arch_dense(spec, &source_id, &device).await?
}; };
(tokenizer_path, Some(Arc::new(Mutex::new(arch))), None) // CPU Qwen3.6 isn't a supported deployment target — the
// 27B doesn't fit any reasonable CPU memory budget — so
// we don't attempt to reach into the arch for vision
// metadata. Stays text-only.
(
tokenizer_path,
Some(Arc::new(Mutex::new(arch))),
None,
VisionMeta::default(),
)
}; };
let tokenizer = Tokenizer::from_file(&tokenizer_path) let tokenizer = Tokenizer::from_file(&tokenizer_path)
@@ -2278,6 +2421,9 @@ impl Harness for CandleHarness {
reasoning_tokens, reasoning_tokens,
tool_call_tokens, tool_call_tokens,
chat_template, chat_template,
has_vision: vision_meta.has_vision,
image_token_id: vision_meta.image_token_id,
lm_tokens_per_image: vision_meta.lm_tokens_per_image,
}); });
let mut models = self.models.write().await; let mut models = self.models.write().await;
@@ -3434,6 +3580,16 @@ pub enum InferenceError {
"insufficient free VRAM for prefill: {free_mb} MiB free, need at least {required_mb} MiB" "insufficient free VRAM for prefill: {free_mb} MiB free, need at least {required_mb} MiB"
)] )]
InsufficientVram { free_mb: u64, required_mb: u64 }, InsufficientVram { free_mb: u64, required_mb: u64 },
/// Request carried `image_url` content but the loaded model has
/// no vision tower. Stage B6 — replaces the silent-drop pattern
/// from issue #3 with an explicit 400 + `vision_unsupported`
/// error body that clients (litellm, agent0, …) can act on.
#[error(
"model '{model_id}' does not support image input; \
load a vision-capable model (e.g. Qwen/Qwen3.6-27B) or \
remove the image_url content parts from the request"
)]
VisionUnsupported { model_id: String },
#[error(transparent)] #[error(transparent)]
Other(#[from] anyhow::Error), Other(#[from] anyhow::Error),
} }
@@ -3498,6 +3654,169 @@ fn build_prompt_for_request(
} }
} }
/// Vision metadata derived at model-load time. Stashed on
/// `LoadedModel` so the chat-completion hot path doesn't have to
/// re-parse `config.json` or reach across the worker thread to peek
/// at the loaded `ModelArch`.
#[derive(Debug, Default, Clone, Copy)]
struct VisionMeta {
has_vision: bool,
image_token_id: Option<u32>,
/// LM-side tokens this model's vision tower emits per image at
/// the Stage B fixed `PreprocessProfile::qwen3_6()` resolution
/// (448×448). Equal to `(H/patch_size/spatial_merge_size)²`.
lm_tokens_per_image: Option<usize>,
}
impl VisionMeta {
/// Peek at `config.json` for vision-related fields. Returns the
/// default (no-vision) struct on any read/parse error — vision is
/// best-effort metadata; load can still succeed for text usage.
fn from_config_path(config_path: &std::path::Path) -> Self {
let Ok(text) = std::fs::read_to_string(config_path) else {
return Self::default();
};
let Ok(v) = serde_json::from_str::<serde_json::Value>(&text) else {
return Self::default();
};
let Some(vision_config) = v.get("vision_config") else {
return Self::default();
};
let patch_size = vision_config
.get("patch_size")
.and_then(|x| x.as_u64())
.unwrap_or(16) as usize;
let spatial_merge_size = vision_config
.get("spatial_merge_size")
.and_then(|x| x.as_u64())
.unwrap_or(2) as usize;
let image_token_id = v
.get("image_token_id")
.and_then(|x| x.as_u64())
.map(|n| n as u32);
// Compute LM tokens per image at the Stage B fixed resolution
// (PreprocessProfile::qwen3_6() → 448×448). One LM token per
// spatial-merge group of patches.
let target_h = super::preprocess::PreprocessProfile::qwen3_6().target_height as usize;
let target_w = super::preprocess::PreprocessProfile::qwen3_6().target_width as usize;
let lm_tokens_per_image = if patch_size > 0 && spatial_merge_size > 0 {
let gh = target_h / patch_size / spatial_merge_size;
let gw = target_w / patch_size / spatial_merge_size;
Some(gh * gw)
} else {
None
};
Self {
has_vision: true,
image_token_id,
lm_tokens_per_image,
}
}
}
/// True iff any message in the request carries an `image_url`
/// content part. The Stage B routing decision in `chat_completion`
/// dispatches to the vision-aware worker job when this is true.
fn request_has_images(request: &ChatCompletionRequest) -> bool {
request.messages.iter().any(|m| {
matches!(&m.content, MessageContent::Parts(parts)
if parts.iter().any(|p|
p.get("type").and_then(|v| v.as_str()) == Some("image_url")))
})
}
/// Extract `image_url` content parts from a chat request and turn
/// each one into a preprocessed `ImageInput` ready for the device
/// worker. Stage B4.
///
/// Walks `request.messages`, looking for `MessageContent::Parts` and
/// pulling out entries whose `type == "image_url"`. Each is run
/// through `harness::preprocess::decode_data_uri` + `preprocess` with
/// the supplied `profile` (Stage B always uses
/// `PreprocessProfile::qwen3_6()` — fixed 448×448 — so every image
/// produces the same patch count; dynamic resolution per issue #14
/// would parameterise this).
///
/// Returns images in the order they appear in the request, which
/// matches the order the chat template emits `<|image_pad|>` tokens.
fn extract_images_from_request(
request: &ChatCompletionRequest,
profile: &super::preprocess::PreprocessProfile,
) -> anyhow::Result<Vec<super::device_worker::jobs::ImageInput>> {
let mut out = Vec::new();
for msg in &request.messages {
if let MessageContent::Parts(parts) = &msg.content {
for part in parts {
let kind = part.get("type").and_then(|v| v.as_str()).unwrap_or("");
if kind != "image_url" {
continue;
}
let url = part
.get("image_url")
.and_then(|v| v.get("url"))
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("image_url part missing url field"))?;
let pixels = super::preprocess::preprocess_data_uri(url, profile)
.with_context(|| format!("preprocess image #{}", out.len()))?;
out.push(super::device_worker::jobs::ImageInput {
pixels,
c: 3,
h: profile.target_height as usize,
w: profile.target_width as usize,
});
}
}
}
Ok(out)
}
/// Expand each occurrence of `image_token_id` in `input_ids` into
/// `patches_per_image[i]` copies (one expansion per image, in order).
/// Stage B4 helper.
///
/// The chat template emits a single `<|image_pad|>` per image; this
/// is what fits Qwen3-VL's template-then-runtime-expansion convention.
/// The runtime (us) is responsible for replacing each one with N
/// copies based on the corresponding image's patch count.
///
/// For Stage B fixed resolution every entry of `patches_per_image`
/// is the same constant (196 at 448×448). For dynamic resolution
/// (issue #14) each image gets its own value.
///
/// Errors if the number of `image_token_id` occurrences in `input_ids`
/// doesn't equal `patches_per_image.len()` — a mismatch means the
/// template emitted the wrong number of pad tokens (operator-visible
/// template bug, not a runtime error).
fn expand_image_pad_tokens(
input_ids: &[u32],
image_token_id: u32,
patches_per_image: &[usize],
) -> anyhow::Result<Vec<u32>> {
let occurrences = input_ids.iter().filter(|&&t| t == image_token_id).count();
if occurrences != patches_per_image.len() {
anyhow::bail!(
"expand_image_pad_tokens: prompt has {occurrences} image_token_id occurrences but \
{} images were preprocessed — chat template emitted the wrong number of pad tokens",
patches_per_image.len()
);
}
let total_extra: usize = patches_per_image.iter().map(|n| n.saturating_sub(1)).sum();
let mut out = Vec::with_capacity(input_ids.len() + total_extra);
let mut img_idx = 0;
for &t in input_ids {
if t == image_token_id {
let n = patches_per_image[img_idx];
for _ in 0..n {
out.push(image_token_id);
}
img_idx += 1;
} else {
out.push(t);
}
}
Ok(out)
}
/// Apply the Qwen3 chat template: /// Apply the Qwen3 chat template:
/// ///
/// ```text /// ```text
@@ -3544,6 +3863,103 @@ fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
/// would only add channel overhead with no diagnostic benefit. /// would only add channel overhead with no diagnostic benefit.
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
/// Vision-aware analogue of `run_inference_via_worker`. Stage B5.
///
/// Single-shot prefill carrying the pre-expanded prompt + the image
/// payloads. The worker encodes each image through the vision tower,
/// splices the resulting embeddings at `image_token_id` positions,
/// and returns the last-position logits. Decode steps thereafter
/// follow the existing text-only `forward_logits` path — the KV
/// cache holds the image-conditioned hidden states from prefill, so
/// no further splicing is needed.
///
/// Stage B skips chunked prefill for vision (the fixed-resolution
/// budget — 196 image tokens at 448×448 + typical text — stays well
/// under the activation-memory threshold). Long-prompt-with-images
/// chunking is a Stage D follow-up.
#[allow(clippy::too_many_arguments)]
async fn run_inference_with_images_via_worker(
worker: &super::device_worker::DeviceWorkerHandle,
handle: super::device_worker::ArchHandle,
prompt_tokens: &[u32],
images: Vec<super::device_worker::jobs::ImageInput>,
image_token_id: u32,
max_new: usize,
temperature: f64,
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
) -> Result<(Vec<u32>, String)> {
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut generated: Vec<u32> = Vec::new();
let prompt_len = prompt_tokens.len();
worker
.clear_kv_cache(handle)
.await
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
// Single-shot prefill with image splicing.
let logits_vec = worker
.forward_logits_with_images(handle, prompt_tokens.to_vec(), 0, images, image_token_id)
.await
.map_err(|e| anyhow::anyhow!("forward_logits_with_images: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
?health,
"chat_completion (worker, vision): prefill sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
for index in 0..max_new.saturating_sub(1) {
let logits_vec = worker
.forward_logits(handle, vec![next_token], prompt_len + index)
.await
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
step = index,
?health,
"chat_completion (worker, vision): decode sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
}
Ok((generated, "length".into()))
}
#[cfg(feature = "cuda")]
async fn run_inference_via_worker( async fn run_inference_via_worker(
worker: &super::device_worker::DeviceWorkerHandle, worker: &super::device_worker::DeviceWorkerHandle,
handle: super::device_worker::ArchHandle, handle: super::device_worker::ArchHandle,
@@ -4243,4 +4659,44 @@ mod tests {
.expect("synth huggingface source should build"); .expect("synth huggingface source should build");
assert_eq!(harness.default_source_scheme(), "huggingface"); assert_eq!(harness.default_source_scheme(), "huggingface");
} }
#[test]
fn expand_image_pad_replaces_single_token_with_n_copies() {
// Mimics the chat template's output: each image emits
// [vision_start, image_pad, vision_end]. After expansion
// with 3 patches/image we want
// [vision_start, pad×3, vision_end].
let pad = 248056_u32;
let vstart = 248053_u32;
let vend = 248054_u32;
let input = vec![1, vstart, pad, vend, 2];
let out = expand_image_pad_tokens(&input, pad, &[3]).unwrap();
assert_eq!(out, vec![1, vstart, pad, pad, pad, vend, 2]);
}
#[test]
fn expand_image_pad_handles_multiple_images() {
let pad = 248056_u32;
// Two images in one prompt; first gets 2 patches, second 3.
let input = vec![pad, 99, pad];
let out = expand_image_pad_tokens(&input, pad, &[2, 3]).unwrap();
assert_eq!(out, vec![pad, pad, 99, pad, pad, pad]);
}
#[test]
fn expand_image_pad_errors_on_count_mismatch() {
let pad = 248056_u32;
// Prompt has 2 pad tokens but caller supplied 3 images.
let input = vec![pad, 99, pad];
let err = expand_image_pad_tokens(&input, pad, &[2, 3, 4]).unwrap_err();
assert!(format!("{err:#}").contains("emitted the wrong number"));
}
#[test]
fn expand_image_pad_passes_through_when_no_images() {
let pad = 248056_u32;
let input = vec![1, 2, 3];
let out = expand_image_pad_tokens(&input, pad, &[]).unwrap();
assert_eq!(out, input);
}
} }

View File

@@ -16,10 +16,11 @@
use crate::harness::candle::ModelArch; use crate::harness::candle::ModelArch;
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
use crate::harness::device_worker::jobs::TpHandle; use crate::harness::device_worker::jobs::TpHandle;
use crate::harness::device_worker::jobs::{ArchHandle, Job}; use crate::harness::device_worker::jobs::{ArchHandle, ImageInput, Job};
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
use crate::harness::tp::TpLeaderModel; use crate::harness::tp::TpLeaderModel;
use crate::harness::tp::nccl_state::NcclState; use crate::harness::tp::nccl_state::NcclState;
use anyhow::Context as _;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
@@ -169,6 +170,24 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
let result = encode_image(&mut state, handle, pixels, c, h, w); let result = encode_image(&mut state, handle, pixels, c, h, w);
let _ = reply.send(result); let _ = reply.send(result);
} }
Job::ForwardLogitsWithImages {
handle,
tokens,
offset,
images,
image_token_id,
reply,
} => {
let result = forward_logits_with_images(
&mut state,
handle,
&tokens,
offset,
images,
image_token_id,
);
let _ = reply.send(result);
}
Job::NcclInit { Job::NcclInit {
cfg, cfg,
comm_id_hex, comm_id_hex,
@@ -751,6 +770,67 @@ fn forward_logits(
Ok(values) Ok(values)
} }
/// Run the LM forward with vision-tower image splicing. Stage B3.
///
/// Encodes each image through the vision tower (`VisionTower::forward`,
/// dispatched via `ModelArch::encode_image`), concatenates the
/// resulting embeddings into a single `(N_total, hidden)` tensor, and
/// passes it to `ModelArch::forward_with_vision` along with the
/// prompt-expanded `tokens`. Image embeddings never leave the device.
///
/// Returns CPU `[vocab]` logits — same shape contract as
/// `ForwardLogits` so the async sampler doesn't have to branch on the
/// presence of images.
fn forward_logits_with_images(
state: &mut DeviceWorkerState,
handle: ArchHandle,
tokens: &[u32],
offset: usize,
images: Vec<ImageInput>,
image_token_id: u32,
) -> anyhow::Result<Vec<f32>> {
use candle_core::{DType, Tensor};
if images.is_empty() {
anyhow::bail!("ForwardLogitsWithImages dispatched with zero images");
}
let arch = state.models.get_mut(&handle).ok_or_else(|| {
anyhow::anyhow!("ForwardLogitsWithImages: no model for handle {}", handle.0)
})?;
// Encode every image on the worker's device, collecting per-image
// post-merger embeddings as device-resident tensors.
let mut per_image: Vec<Tensor> = Vec::with_capacity(images.len());
for (idx, img) in images.into_iter().enumerate() {
anyhow::ensure!(
img.pixels.len() == img.c * img.h * img.w,
"ForwardLogitsWithImages: image[{idx}] pixels length {} does not match shape ({}, {}, {})",
img.pixels.len(),
img.c,
img.h,
img.w,
);
let image = Tensor::from_vec(img.pixels, (img.c, img.h, img.w), &state.device)?;
let embed = arch
.encode_image(&image)
.with_context(|| format!("encode image[{idx}]"))?;
per_image.push(embed);
}
// Concatenate per-image embeddings along the patch axis →
// (sum_of_patches, hidden). `Tensor::cat` keeps the result
// device-resident.
let image_embeds = Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)?;
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
let logits = arch.forward_with_vision(&input, offset, &image_embeds, image_token_id)?;
let values = logits
.to_dtype(DType::F32)?
.flatten_all()?
.to_vec1::<f32>()?;
Ok(values)
}
/// Run the vision tower on a single preprocessed image. Stage A5. /// Run the vision tower on a single preprocessed image. Stage A5.
/// ///
/// `pixels` is a row-major `(c, h, w)` f32 image that the async-side /// `pixels` is a row-major `(c, h, w)` f32 image that the async-side
@@ -830,6 +910,9 @@ fn drain_poisoned(job: Job, device_index: u32) {
Job::EncodeImage { reply, .. } => { Job::EncodeImage { reply, .. } => {
let _ = reply.send(Err(err())); let _ = reply.send(Err(err()));
} }
Job::ForwardLogitsWithImages { reply, .. } => {
let _ = reply.send(Err(err()));
}
Job::NcclInit { reply, .. } => { Job::NcclInit { reply, .. } => {
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error { let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
kind: "device_worker_poisoned".into(), kind: "device_worker_poisoned".into(),

View File

@@ -28,6 +28,17 @@ pub struct ArchHandle(pub u64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TpHandle(pub u64); pub struct TpHandle(pub u64);
/// One image payload for `Job::ForwardLogitsWithImages` /
/// `Job::EncodeImage`. Pixels are row-major `(c, h, w)` f32 — the
/// shape `harness::preprocess::preprocess` produces. Carries the
/// shape inline since `Vec<f32>` is rank-1.
pub struct ImageInput {
pub pixels: Vec<f32>,
pub c: usize,
pub h: usize,
pub w: usize,
}
/// One unit of work for the device worker. /// One unit of work for the device worker.
/// ///
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the /// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
@@ -94,6 +105,33 @@ pub enum Job {
offset: usize, offset: usize,
reply: oneshot::Sender<Result<Vec<f32>>>, reply: oneshot::Sender<Result<Vec<f32>>>,
}, },
/// Run the LM forward with vision splicing in one round-trip.
/// Stage B3 of the vision plan.
///
/// Inputs:
/// - `tokens`: prompt-expanded token ids (the caller has already
/// replaced each `<|image_pad|>` with N copies per the
/// per-image patch count, so `tokens` already contains exactly
/// `sum(n_i)` `image_token_id` entries across all images).
/// - `offset`: KV-cache position (same contract as `ForwardLogits`).
/// - `images`: one entry per image — preprocessed pixels plus the
/// `(c, h, w)` shape. Images are encoded on the worker via the
/// model's vision tower (`VisionTower::forward`), concatenated
/// in order, and spliced into the LM input embeddings at
/// `image_token_id` positions.
/// - `image_token_id`: the sentinel token (248056 for Qwen3.6).
///
/// Returns flat CPU `[vocab]` logits, same as `ForwardLogits`.
/// Image embeddings stay device-resident — they're never copied
/// to CPU. The "tensors don't escape the worker" invariant holds.
ForwardLogitsWithImages {
handle: ArchHandle,
tokens: Vec<u32>,
offset: usize,
images: Vec<ImageInput>,
image_token_id: u32,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Encode one image through the model's vision tower. Stage A5 of /// Encode one image through the model's vision tower. Stage A5 of
/// the vision plan (`doc/vision-qwen3_6-spec.md`). /// the vision plan (`doc/vision-qwen3_6-spec.md`).
/// ///

View File

@@ -313,6 +313,47 @@ impl DeviceWorkerHandle {
} }
} }
/// Forward with image-aware splicing in one round-trip. Stage B3.
///
/// Encodes each image on the worker thread (device-resident), then
/// runs the LM forward with the embeddings spliced at
/// `image_token_id` positions. Returns CPU `[vocab]` logits, same
/// shape as `forward_logits`. Image embeddings never copy back to
/// CPU.
pub async fn forward_logits_with_images(
&self,
handle: ArchHandle,
tokens: Vec<u32>,
offset: usize,
images: Vec<crate::harness::device_worker::jobs::ImageInput>,
image_token_id: u32,
) -> Result<Vec<f32>, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::ForwardLogitsWithImages {
handle,
tokens,
offset,
images,
image_token_id,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Encode a preprocessed image through the model's vision tower /// Encode a preprocessed image through the model's vision tower
/// and return the resulting LM-side image embeddings as a /// and return the resulting LM-side image embeddings as a
/// flattened CPU `Vec<f32>`. Stage A5. /// flattened CPU `Vec<f32>`. Stage A5.