feat(neuron): Stage B — end-to-end text+image chat for Qwen3.6
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 33s
CI / CUDA type-check (push) Failing after 46s
CI / Clippy (push) Successful in 2m37s
build-prerelease / Build cortex binary (push) Successful in 4m32s
build-prerelease / Build neuron-blackwell (push) Failing after 5m35s
CI / Test (push) Successful in 6m40s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-ampere (push) Failing after 7m46s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
build-prerelease / Build neuron-ada (push) Failing after 4m51s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 33s
CI / CUDA type-check (push) Failing after 46s
CI / Clippy (push) Successful in 2m37s
build-prerelease / Build cortex binary (push) Successful in 4m32s
build-prerelease / Build neuron-blackwell (push) Failing after 5m35s
CI / Test (push) Successful in 6m40s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-ampere (push) Failing after 7m46s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
build-prerelease / Build neuron-ada (push) Failing after 4m51s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Stage B of the vision plan (doc/vision-qwen3_6-spec.md). Wires the vision tower from Stage A through to a complete non-streaming chat completion: extract images from the request, preprocess, encode on the worker thread, splice embeddings into the LM input at `<|image_pad|>` positions, return coherent text response with `prompt_tokens` reflecting patch tokens. Closes the silent-drop class of failures from issue #3 — vision requests against Qwen3.6 now condition the model on the image instead of producing confident text-only hallucinations. Streaming for vision is Stage C. Deferred items tracked under #12 (TP-vision), #13 (27B production), #14 (dynamic resolution), #15 (numerical validation). What landed: - **B1 — `Qwen3_5Model::forward_with_vision`**: text-only `forward` unchanged; new method takes `(input_ids, offset, image_embeds, image_token_id)`, embeds tokens, locates `image_token_id` positions, splices via the new `splice_runs` helper. MRoPE applies text-positions to image tokens for Stage B (spatial MRoPE is the issue #15 numerical-validation follow-up). 2 unit tests for `splice_runs` covering contiguous + non-contiguous runs. - **B2 — `ModelArch::forward_with_vision` dispatch**: routes Qwen3_5Dense to the new method; other arches return an error. Defence-in-depth — the HTTP layer (B6) already rejects image content for non-vision models. - **B3 — `Job::ForwardLogitsWithImages`**: new worker variant carrying tokens + per-image `(pixels, c, h, w)` payloads. The dispatcher encodes each image (device-resident), concatenates the resulting embeddings, calls `arch.forward_with_vision`, and returns CPU logits. Image embeddings never copy back to CPU — the "tensors don't escape the worker" invariant from the per-device worker refactor still holds. Poisoned-worker drain path handles the new variant. - **B4 — Prompt builder**: - `request_has_images` detects image content cheaply. - `extract_images_from_request(request, profile)` walks `MessageContent::Parts`, decodes data URIs, runs `harness::preprocess::preprocess` per image, returns `Vec<ImageInput>` in request order. - `expand_image_pad_tokens(input_ids, image_token_id, patches_per_image)` walks the tokenized prompt and replaces each `<|image_pad|>` (id 248056 for Qwen3.6) with N copies matching the per-image patch count. 4 unit tests. - `VisionMeta::from_config_path` peeks `config.json` at load time for `image_token_id`, vision_config patch/merge sizes, and derives `lm_tokens_per_image` for the Stage B fixed resolution. - **B5 — `chat_completion` vision routing**: detects image content, validates the loaded model has vision, expands the prompt, and calls a new `run_inference_with_images_via_worker` helper that does single-shot prefill + standard decode loop (KV cache holds the post-splice hidden states from prefill, so decode steps don't re-splice). Stage B skips chunked prefill for vision — at 448×448 fixed resolution the budget stays well under the activation-memory threshold. Long-vision chunking is Stage D follow-up. - **B6 — `InferenceError::VisionUnsupported`**: structured 400 with `code=vision_unsupported, model_id, suggestion` when an image request hits a non-vision model. Closes the agent0 failure mode where vision requests degraded silently. - **B7 — `ModelInfo.capabilities`**: per-model array (`["text"]` vs `["text", "vision"]`) in `/v1/models` and forwarded verbatim by cortex-gateway. Lets clients (litellm, agent0) gate image_url submission on the declared capability set. Optional in the wire format; defaults to empty for older clients. CI gate: cargo fmt --check, cargo clippy --workspace --all-targets -- -D warnings, cargo test --workspace (all 28 test groups ok, 124 lib tests). New unit-test counts: +2 splice_runs, +4 expand_image_pad. Manual verification (after RPMs deploy on beast): curl http://hanzalova.internal:31313/v1/chat/completions \ -H 'Content-Type: application/json' \ -d "{\"model\":\"Qwen/Qwen3.6-27B\", \"messages\":[{\"role\":\"user\",\"content\":[ {\"type\":\"text\",\"text\":\"What's in this image?\"}, {\"type\":\"image_url\",\"image_url\":{\"url\":\"data:image/jpeg;base64,...\"}} ]}], \"max_tokens\":120}" | jq Expect prompt_tokens > 196 (text + 196 patch tokens) and a response that references actual image content. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -250,6 +250,18 @@ async fn chat_completions(
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::VisionUnsupported { model_id }) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"model '{model_id}' does not support image input"
|
||||
),
|
||||
"code": "vision_unsupported",
|
||||
"model_id": model_id,
|
||||
"suggestion": "load a vision-capable model or remove image_url content parts",
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::Other(e)) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
@@ -289,6 +301,18 @@ async fn chat_completions(
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::VisionUnsupported { model_id }) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"model '{model_id}' does not support image input"
|
||||
),
|
||||
"code": "vision_unsupported",
|
||||
"model_id": model_id,
|
||||
"suggestion": "load a vision-capable model or remove image_url content parts",
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::Other(e)) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
@@ -452,6 +476,18 @@ fn inference_error_response(err: InferenceError) -> axum::response::Response {
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
InferenceError::VisionUnsupported { model_id } => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"model '{model_id}' does not support image input"
|
||||
),
|
||||
"code": "vision_unsupported",
|
||||
"model_id": model_id,
|
||||
"suggestion": "load a vision-capable model or remove image_url content parts",
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
InferenceError::Other(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
|
||||
@@ -221,6 +221,76 @@ fn default_partial_rotary_factor() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
/// Splice rows from `img` into `h` at `positions`. Stage B helper.
|
||||
///
|
||||
/// `h`: `(1, L, hidden)` — the LM's input embedding tensor after
|
||||
/// `embed_tokens.forward`.
|
||||
/// `img`: `(N_img, hidden)` — image embeddings, one row per
|
||||
/// `<|image_pad|>` token in the prompt. Must already be in `h.dtype()`.
|
||||
/// `positions`: indices into the `L` axis where image rows go;
|
||||
/// `positions.len() == N_img`.
|
||||
///
|
||||
/// Approach: group `positions` into contiguous runs (because the chat
|
||||
/// template emits `<|vision_start|><|image_pad|>×N<|vision_end|>` —
|
||||
/// the pad tokens for each image land in one contiguous span), then
|
||||
/// `slice_assign` per run. For typical Qwen3.6 requests this is one
|
||||
/// or two runs per image; `slice_assign` does one tensor copy per
|
||||
/// run, which is cheap relative to the decoder forward pass.
|
||||
fn splice_runs(h: &Tensor, img: &Tensor, positions: &[u32]) -> candle_core::Result<Tensor> {
|
||||
debug_assert!(
|
||||
!positions.is_empty(),
|
||||
"splice_runs precondition: non-empty positions"
|
||||
);
|
||||
let hidden = h.dim(2)?;
|
||||
let mut out = h.clone();
|
||||
let mut img_offset = 0_usize;
|
||||
let mut run_start = positions[0] as usize;
|
||||
let mut run_end_exclusive = run_start + 1;
|
||||
for &p in &positions[1..] {
|
||||
let p = p as usize;
|
||||
if p == run_end_exclusive {
|
||||
run_end_exclusive = p + 1;
|
||||
} else {
|
||||
apply_run(
|
||||
&mut out,
|
||||
img,
|
||||
&mut img_offset,
|
||||
run_start,
|
||||
run_end_exclusive,
|
||||
hidden,
|
||||
)?;
|
||||
run_start = p;
|
||||
run_end_exclusive = p + 1;
|
||||
}
|
||||
}
|
||||
apply_run(
|
||||
&mut out,
|
||||
img,
|
||||
&mut img_offset,
|
||||
run_start,
|
||||
run_end_exclusive,
|
||||
hidden,
|
||||
)?;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn apply_run(
|
||||
out: &mut Tensor,
|
||||
img: &Tensor,
|
||||
img_offset: &mut usize,
|
||||
run_start: usize,
|
||||
run_end_exclusive: usize,
|
||||
hidden: usize,
|
||||
) -> candle_core::Result<()> {
|
||||
let run_len = run_end_exclusive - run_start;
|
||||
let slice = img
|
||||
.narrow(0, *img_offset, run_len)?
|
||||
.reshape((1, run_len, hidden))?;
|
||||
*out = out.slice_assign(&[0..1, run_start..run_end_exclusive, 0..hidden], &slice)?;
|
||||
*img_offset += run_len;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Qwen3-Next base transformer (embedding + decoder stack + final
|
||||
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
|
||||
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the
|
||||
@@ -304,8 +374,95 @@ impl Qwen3_5Model {
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||
self.forward_inner(input, offset, None, None)
|
||||
}
|
||||
|
||||
/// Forward with image-embedding splice. Stage B of the vision plan.
|
||||
///
|
||||
/// `input_ids`: `(1, L)` token ids — same shape the text-only
|
||||
/// `forward` accepts (single-batch; multi-batch vision is not in
|
||||
/// scope today).
|
||||
/// `image_embeds`: `(N_image_tokens, hidden_size)` — concatenation
|
||||
/// of every image's post-merger embedding (`VisionTower::forward`
|
||||
/// output), in the same order images appear in the input. The
|
||||
/// caller has already done the per-image patch-count expansion of
|
||||
/// `<|image_pad|>` tokens in `input_ids`, so `N_image_tokens`
|
||||
/// equals the number of `image_token_id` positions in `input_ids`.
|
||||
/// `image_token_id`: the sentinel token (e.g. 248056 for Qwen3.6).
|
||||
///
|
||||
/// The splice replaces the LM's text-side embedding at each
|
||||
/// `image_token_id` position with the corresponding row from
|
||||
/// `image_embeds`. After the splice the decoder runs unchanged.
|
||||
///
|
||||
/// **MRoPE gap.** Qwen3.6's `rope_parameters` declares MRoPE
|
||||
/// (interleaved text/height/width axes); Stage B applies plain
|
||||
/// text-position RoPE to image tokens. The model still attends
|
||||
/// to image content but loses spatial structure that MRoPE-aware
|
||||
/// position encoding would preserve. Tracked under issue #15
|
||||
/// (numerical validation) — quality benchmark from Stage D should
|
||||
/// surface the impact, and the fix lives in `rope::RotaryEmbedding`.
|
||||
pub fn forward_with_vision(
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
offset: usize,
|
||||
image_embeds: &Tensor,
|
||||
image_token_id: u32,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
self.forward_inner(input_ids, offset, Some(image_embeds), Some(image_token_id))
|
||||
}
|
||||
|
||||
fn forward_inner(
|
||||
&mut self,
|
||||
input: &Tensor,
|
||||
offset: usize,
|
||||
image_embeds: Option<&Tensor>,
|
||||
image_token_id: Option<u32>,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (b, l) = input.dims2()?;
|
||||
let mut h = self.embed_tokens.forward(input)?;
|
||||
// Splice image embeddings at `image_token_id` positions. The
|
||||
// caller pre-expanded the prompt so every patch token in the
|
||||
// image_embeds tensor has a matching position in `input`. We
|
||||
// index_put the rows in place.
|
||||
if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
|
||||
// Locate image-token positions in input_ids. Operate on
|
||||
// CPU since the input ids are tiny (max ~10k entries
|
||||
// including the patch expansion) and the comparison is
|
||||
// not in the per-step hot path.
|
||||
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
|
||||
let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?);
|
||||
for (idx, id) in ids.iter().enumerate() {
|
||||
if *id == tok_id {
|
||||
positions.push(idx as u32);
|
||||
}
|
||||
}
|
||||
let n_img_tokens = img.dim(0)?;
|
||||
if positions.len() != n_img_tokens {
|
||||
candle_core::bail!(
|
||||
"forward_with_vision: prompt has {} image-token positions but \
|
||||
image_embeds carries {} tokens — call build_prompt_for_request to \
|
||||
ensure the per-image patch-count expansion has been applied",
|
||||
positions.len(),
|
||||
n_img_tokens,
|
||||
);
|
||||
}
|
||||
if !positions.is_empty() {
|
||||
// Cast image_embeds to the LM's dtype so the splice
|
||||
// produces a uniform tensor for the decoder stack.
|
||||
let img = img.to_dtype(self.dtype)?;
|
||||
// index_select would return the rows; we want to put.
|
||||
// candle's slice_assign with explicit positions ranges
|
||||
// doesn't exist; use scatter via index_select + an
|
||||
// accumulator: build a `(B, L, hidden)` zero tensor,
|
||||
// scatter the image rows in, then add to a masked
|
||||
// version of `h`. Simpler approach: walk positions
|
||||
// and use `slice_assign` for contiguous runs. Since
|
||||
// image_pad runs are contiguous (template emits
|
||||
// `<|vision_start|><|image_pad|>×N<|vision_end|>`),
|
||||
// we group positions and assign per run.
|
||||
h = splice_runs(&h, &img, &positions)?;
|
||||
}
|
||||
}
|
||||
// Causal mask only needed for L > 1 prefill; full-attention
|
||||
// layers consume it via broadcast_add. Linear-attention layers
|
||||
// ignore the mask.
|
||||
@@ -406,6 +563,24 @@ impl Qwen3_5ForCausalLM {
|
||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
/// Stage B: forward with image-embedding splice. Mirrors `forward`
|
||||
/// but routes through `Qwen3_5Model::forward_with_vision` so the
|
||||
/// LM's input embeddings get the image patches spliced in at
|
||||
/// `image_token_id` positions before the decoder stack runs.
|
||||
pub fn forward_with_vision(
|
||||
&mut self,
|
||||
input: &Tensor,
|
||||
offset: usize,
|
||||
image_embeds: &Tensor,
|
||||
image_token_id: u32,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (_, l) = input.dims2()?;
|
||||
let hidden = self
|
||||
.base
|
||||
.forward_with_vision(input, offset, image_embeds, image_token_id)?;
|
||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.base.clear_kv_cache();
|
||||
}
|
||||
@@ -463,4 +638,50 @@ mod tests {
|
||||
assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0);
|
||||
assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6);
|
||||
}
|
||||
|
||||
/// `splice_runs` replaces (1, L, H) embedding rows at the given
|
||||
/// positions with rows from a (N_img, H) image-embedding tensor,
|
||||
/// in the order positions are supplied.
|
||||
#[test]
|
||||
fn splice_runs_replaces_at_contiguous_positions() {
|
||||
use candle_core::{DType, Device};
|
||||
|
||||
let dev = Device::Cpu;
|
||||
// (1, L=5, H=2) text embeddings — encoded as floats so the
|
||||
// assertion can spot the change without dtype conversion.
|
||||
let h_vals: Vec<f32> = vec![
|
||||
10., 11., // pos 0
|
||||
20., 21., // pos 1
|
||||
30., 31., // pos 2
|
||||
40., 41., // pos 3
|
||||
50., 51., // pos 4
|
||||
];
|
||||
let h = Tensor::from_vec(h_vals, (1, 5, 2), &dev).unwrap();
|
||||
|
||||
// Two image embeddings to splice at positions 1 and 2 (a
|
||||
// contiguous run — single image emitting two patch tokens).
|
||||
let img_vals: Vec<f32> = vec![-1., -2., -3., -4.];
|
||||
let img = Tensor::from_vec(img_vals, (2, 2), &dev).unwrap();
|
||||
|
||||
let out = splice_runs(&h, &img, &[1, 2]).unwrap();
|
||||
let flat: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
|
||||
assert_eq!(flat, vec![10., 11., -1., -2., -3., -4., 40., 41., 50., 51.]);
|
||||
let _ = DType::F32;
|
||||
}
|
||||
|
||||
/// Non-contiguous positions: two images at positions [1] and [3]
|
||||
/// each contributing one patch. `splice_runs` should iterate
|
||||
/// runs and place the corresponding image rows.
|
||||
#[test]
|
||||
fn splice_runs_handles_non_contiguous_runs() {
|
||||
use candle_core::Device;
|
||||
let dev = Device::Cpu;
|
||||
let h_vals: Vec<f32> = vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.];
|
||||
let h = Tensor::from_vec(h_vals, (1, 5, 2), &dev).unwrap();
|
||||
let img_vals: Vec<f32> = vec![-1., -2., -3., -4.];
|
||||
let img = Tensor::from_vec(img_vals, (2, 2), &dev).unwrap();
|
||||
let out = splice_runs(&h, &img, &[1, 3]).unwrap();
|
||||
let flat: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
|
||||
assert_eq!(flat, vec![1., 1., -1., -2., 3., 3., -3., -4., 5., 5.]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,6 +105,22 @@ impl LoadedHandle {
|
||||
LoadedHandle::Tp(m) => m.poisoned.load(Ordering::Acquire),
|
||||
}
|
||||
}
|
||||
|
||||
/// Modalities the loaded model supports. Stage B7. TP models are
|
||||
/// always text-only today — TP-vision is tracked under issue #12.
|
||||
pub fn capabilities(&self) -> Vec<String> {
|
||||
let mut caps = vec!["text".to_string()];
|
||||
match self {
|
||||
LoadedHandle::Single(m) => {
|
||||
if m.has_vision {
|
||||
caps.push("vision".to_string());
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(_) => {}
|
||||
}
|
||||
caps
|
||||
}
|
||||
}
|
||||
|
||||
/// A loaded model with its tokenizer, device placement, and architecture-
|
||||
@@ -182,6 +198,25 @@ pub struct LoadedModel {
|
||||
/// is used. The `NEURON_USE_CHAT_TEMPLATE=false` env var
|
||||
/// forces the fallback path even when `Some`.
|
||||
pub chat_template: Option<String>,
|
||||
/// Vision capability flag derived at load time. `true` iff the
|
||||
/// loaded `ModelArch` exposes a vision tower (Stage A4 wires this
|
||||
/// from `Qwen3_5ForCausalLM::has_vision`). Used by the chat
|
||||
/// completion handler to reject image content on non-vision
|
||||
/// models with a structured 400 (Stage B6) and by `/v1/models`
|
||||
/// to advertise `capabilities: ["text", "vision"]` (Stage B7).
|
||||
pub has_vision: bool,
|
||||
/// `<|image_pad|>` token id from `config.json::image_token_id`.
|
||||
/// The Stage B prompt-builder uses this to compute expansion
|
||||
/// targets and the worker forward uses it to locate splice
|
||||
/// positions in the LM input embeddings.
|
||||
pub image_token_id: Option<u32>,
|
||||
/// LM-side tokens this model's vision tower emits per image at
|
||||
/// the Stage B fixed resolution (448×448 → 196 for Qwen3.6).
|
||||
/// `None` for text-only models. Set at load time so the
|
||||
/// hot path doesn't recompute it per request. Stage B fixed
|
||||
/// resolution → constant; dynamic resolution per #14 makes it
|
||||
/// per-image.
|
||||
pub lm_tokens_per_image: Option<usize>,
|
||||
}
|
||||
|
||||
impl LoadedModel {
|
||||
@@ -333,6 +368,35 @@ impl ModelArch {
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward step that splices vision-tower output at
|
||||
/// `<|image_pad|>` token positions. Stage B2.
|
||||
///
|
||||
/// Only `Qwen3_5Dense` supports this — other architectures error
|
||||
/// because they don't have a vision tower. The HTTP layer is
|
||||
/// expected to have rejected image content for non-vision models
|
||||
/// already (Stage B6); this is a defence-in-depth error path.
|
||||
///
|
||||
/// Returns rank-1 `[vocab_size]` logits, same shape contract as
|
||||
/// `forward`.
|
||||
pub fn forward_with_vision(
|
||||
&mut self,
|
||||
input: &Tensor,
|
||||
offset: usize,
|
||||
image_embeds: &Tensor,
|
||||
image_token_id: u32,
|
||||
) -> Result<Tensor> {
|
||||
let raw = match self {
|
||||
ModelArch::Qwen3_5Dense(m) => {
|
||||
m.forward_with_vision(input, offset, image_embeds, image_token_id)?
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"forward_with_vision: architecture {} has no vision tower",
|
||||
std::any::type_name_of_val(other)
|
||||
),
|
||||
};
|
||||
squeeze_to_vocab(&raw)
|
||||
}
|
||||
|
||||
/// Encode a preprocessed image into LM-side token embeddings via
|
||||
/// the loaded vision tower. Stage A5.
|
||||
///
|
||||
@@ -1548,9 +1612,54 @@ impl CandleHarness {
|
||||
.tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||||
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||||
|
||||
// Stage B: when the request carries images, preprocess
|
||||
// them, expand each `<|image_pad|>` sentinel to N copies
|
||||
// matching the per-image patch count, and route to the
|
||||
// vision-aware worker path. Non-image requests skip all
|
||||
// of this and follow the existing text-only flow.
|
||||
let vision_route = if request_has_images(&request) {
|
||||
// Stage B6: surface a structured `vision_unsupported`
|
||||
// rejection when the request asks for vision against a
|
||||
// text-only model. Cheap and stops the issue-#3 silent-
|
||||
// drop pattern.
|
||||
if !loaded.has_vision {
|
||||
return Err(InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
});
|
||||
}
|
||||
let image_token_id = loaded
|
||||
.image_token_id
|
||||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
})?;
|
||||
let patches_per_image = loaded
|
||||
.lm_tokens_per_image
|
||||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
})?;
|
||||
let profile = super::preprocess::PreprocessProfile::qwen3_6();
|
||||
let images = extract_images_from_request(&request, &profile).map_err(|e| {
|
||||
InferenceError::Other(anyhow::anyhow!("extract_images: {e}"))
|
||||
})?;
|
||||
if images.is_empty() {
|
||||
// request_has_images said true but extract returned
|
||||
// empty — defensive bail rather than silently dropping.
|
||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||
"request has image content but extractor produced zero images"
|
||||
)));
|
||||
}
|
||||
let per_image_counts: Vec<usize> = vec![patches_per_image; images.len()];
|
||||
prompt_tokens =
|
||||
expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
|
||||
.map_err(InferenceError::Other)?;
|
||||
Some((images, image_token_id))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let temperature = request.temperature.unwrap_or(0.7);
|
||||
let top_p = request.top_p;
|
||||
let max_new = request.max_tokens.unwrap_or(8192) as usize;
|
||||
@@ -1570,6 +1679,7 @@ impl CandleHarness {
|
||||
?eos_id,
|
||||
vram_free_mb,
|
||||
vram_total_mb,
|
||||
vision = vision_route.is_some(),
|
||||
"chat_completion: starting"
|
||||
);
|
||||
|
||||
@@ -1588,18 +1698,37 @@ impl CandleHarness {
|
||||
// Worker path (CUDA).
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
match run_inference_via_worker(
|
||||
worker,
|
||||
handle,
|
||||
&prompt_tokens,
|
||||
max_new,
|
||||
temperature,
|
||||
top_p,
|
||||
seed,
|
||||
eos_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let result = match &vision_route {
|
||||
Some((images, image_token_id)) => {
|
||||
run_inference_with_images_via_worker(
|
||||
worker,
|
||||
handle,
|
||||
&prompt_tokens,
|
||||
images.clone(),
|
||||
*image_token_id,
|
||||
max_new,
|
||||
temperature,
|
||||
top_p,
|
||||
seed,
|
||||
eos_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
None => {
|
||||
run_inference_via_worker(
|
||||
worker,
|
||||
handle,
|
||||
&prompt_tokens,
|
||||
max_new,
|
||||
temperature,
|
||||
top_p,
|
||||
seed,
|
||||
eos_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
};
|
||||
match result {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
let chain = format!("{e:#}");
|
||||
@@ -2120,6 +2249,7 @@ impl Harness for CandleHarness {
|
||||
},
|
||||
devices: h.devices(),
|
||||
vram_used_mb: None,
|
||||
capabilities: h.capabilities(),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
@@ -2189,7 +2319,7 @@ impl Harness for CandleHarness {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let (tokenizer_path, arch_local, arch_handle) = if let Some(w) = &worker {
|
||||
let (tokenizer_path, arch_local, arch_handle, vision_meta) = if let Some(w) = &worker {
|
||||
// CUDA path: resolve, then load in the worker.
|
||||
if spec.quant.is_some() {
|
||||
let (gguf_path, tokenizer_path) = self.resolve_files(spec, &source_id).await?;
|
||||
@@ -2197,15 +2327,19 @@ impl Harness for CandleHarness {
|
||||
.load_gguf(gguf_path, spec.model_id.clone())
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("worker load_gguf: {e}"))?;
|
||||
(tokenizer_path, None, Some(handle))
|
||||
// GGUF Qwen3.6 releases don't ship the vision tower
|
||||
// (Qwen-VL weights are in the dense safetensors only),
|
||||
// so a GGUF load is text-only by construction.
|
||||
(tokenizer_path, None, Some(handle), VisionMeta::default())
|
||||
} else {
|
||||
let (config_path, tokenizer_path, safetensors_paths) =
|
||||
self.resolve_dense_files(spec, &source_id).await?;
|
||||
let meta = VisionMeta::from_config_path(&config_path);
|
||||
let handle = w
|
||||
.load_dense(config_path, safetensors_paths, spec.model_id.clone())
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("worker load_dense: {e}"))?;
|
||||
(tokenizer_path, None, Some(handle))
|
||||
(tokenizer_path, None, Some(handle), meta)
|
||||
}
|
||||
} else {
|
||||
// CPU path: legacy spawn_blocking + Arc<Mutex<ModelArch>>.
|
||||
@@ -2214,7 +2348,16 @@ impl Harness for CandleHarness {
|
||||
} else {
|
||||
self.load_arch_dense(spec, &source_id, &device).await?
|
||||
};
|
||||
(tokenizer_path, Some(Arc::new(Mutex::new(arch))), None)
|
||||
// CPU Qwen3.6 isn't a supported deployment target — the
|
||||
// 27B doesn't fit any reasonable CPU memory budget — so
|
||||
// we don't attempt to reach into the arch for vision
|
||||
// metadata. Stays text-only.
|
||||
(
|
||||
tokenizer_path,
|
||||
Some(Arc::new(Mutex::new(arch))),
|
||||
None,
|
||||
VisionMeta::default(),
|
||||
)
|
||||
};
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
||||
@@ -2278,6 +2421,9 @@ impl Harness for CandleHarness {
|
||||
reasoning_tokens,
|
||||
tool_call_tokens,
|
||||
chat_template,
|
||||
has_vision: vision_meta.has_vision,
|
||||
image_token_id: vision_meta.image_token_id,
|
||||
lm_tokens_per_image: vision_meta.lm_tokens_per_image,
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
@@ -3434,6 +3580,16 @@ pub enum InferenceError {
|
||||
"insufficient free VRAM for prefill: {free_mb} MiB free, need at least {required_mb} MiB"
|
||||
)]
|
||||
InsufficientVram { free_mb: u64, required_mb: u64 },
|
||||
/// Request carried `image_url` content but the loaded model has
|
||||
/// no vision tower. Stage B6 — replaces the silent-drop pattern
|
||||
/// from issue #3 with an explicit 400 + `vision_unsupported`
|
||||
/// error body that clients (litellm, agent0, …) can act on.
|
||||
#[error(
|
||||
"model '{model_id}' does not support image input; \
|
||||
load a vision-capable model (e.g. Qwen/Qwen3.6-27B) or \
|
||||
remove the image_url content parts from the request"
|
||||
)]
|
||||
VisionUnsupported { model_id: String },
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
@@ -3498,6 +3654,169 @@ fn build_prompt_for_request(
|
||||
}
|
||||
}
|
||||
|
||||
/// Vision metadata derived at model-load time. Stashed on
|
||||
/// `LoadedModel` so the chat-completion hot path doesn't have to
|
||||
/// re-parse `config.json` or reach across the worker thread to peek
|
||||
/// at the loaded `ModelArch`.
|
||||
#[derive(Debug, Default, Clone, Copy)]
|
||||
struct VisionMeta {
|
||||
has_vision: bool,
|
||||
image_token_id: Option<u32>,
|
||||
/// LM-side tokens this model's vision tower emits per image at
|
||||
/// the Stage B fixed `PreprocessProfile::qwen3_6()` resolution
|
||||
/// (448×448). Equal to `(H/patch_size/spatial_merge_size)²`.
|
||||
lm_tokens_per_image: Option<usize>,
|
||||
}
|
||||
|
||||
impl VisionMeta {
|
||||
/// Peek at `config.json` for vision-related fields. Returns the
|
||||
/// default (no-vision) struct on any read/parse error — vision is
|
||||
/// best-effort metadata; load can still succeed for text usage.
|
||||
fn from_config_path(config_path: &std::path::Path) -> Self {
|
||||
let Ok(text) = std::fs::read_to_string(config_path) else {
|
||||
return Self::default();
|
||||
};
|
||||
let Ok(v) = serde_json::from_str::<serde_json::Value>(&text) else {
|
||||
return Self::default();
|
||||
};
|
||||
let Some(vision_config) = v.get("vision_config") else {
|
||||
return Self::default();
|
||||
};
|
||||
let patch_size = vision_config
|
||||
.get("patch_size")
|
||||
.and_then(|x| x.as_u64())
|
||||
.unwrap_or(16) as usize;
|
||||
let spatial_merge_size = vision_config
|
||||
.get("spatial_merge_size")
|
||||
.and_then(|x| x.as_u64())
|
||||
.unwrap_or(2) as usize;
|
||||
let image_token_id = v
|
||||
.get("image_token_id")
|
||||
.and_then(|x| x.as_u64())
|
||||
.map(|n| n as u32);
|
||||
// Compute LM tokens per image at the Stage B fixed resolution
|
||||
// (PreprocessProfile::qwen3_6() → 448×448). One LM token per
|
||||
// spatial-merge group of patches.
|
||||
let target_h = super::preprocess::PreprocessProfile::qwen3_6().target_height as usize;
|
||||
let target_w = super::preprocess::PreprocessProfile::qwen3_6().target_width as usize;
|
||||
let lm_tokens_per_image = if patch_size > 0 && spatial_merge_size > 0 {
|
||||
let gh = target_h / patch_size / spatial_merge_size;
|
||||
let gw = target_w / patch_size / spatial_merge_size;
|
||||
Some(gh * gw)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self {
|
||||
has_vision: true,
|
||||
image_token_id,
|
||||
lm_tokens_per_image,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// True iff any message in the request carries an `image_url`
|
||||
/// content part. The Stage B routing decision in `chat_completion`
|
||||
/// dispatches to the vision-aware worker job when this is true.
|
||||
fn request_has_images(request: &ChatCompletionRequest) -> bool {
|
||||
request.messages.iter().any(|m| {
|
||||
matches!(&m.content, MessageContent::Parts(parts)
|
||||
if parts.iter().any(|p|
|
||||
p.get("type").and_then(|v| v.as_str()) == Some("image_url")))
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract `image_url` content parts from a chat request and turn
|
||||
/// each one into a preprocessed `ImageInput` ready for the device
|
||||
/// worker. Stage B4.
|
||||
///
|
||||
/// Walks `request.messages`, looking for `MessageContent::Parts` and
|
||||
/// pulling out entries whose `type == "image_url"`. Each is run
|
||||
/// through `harness::preprocess::decode_data_uri` + `preprocess` with
|
||||
/// the supplied `profile` (Stage B always uses
|
||||
/// `PreprocessProfile::qwen3_6()` — fixed 448×448 — so every image
|
||||
/// produces the same patch count; dynamic resolution per issue #14
|
||||
/// would parameterise this).
|
||||
///
|
||||
/// Returns images in the order they appear in the request, which
|
||||
/// matches the order the chat template emits `<|image_pad|>` tokens.
|
||||
fn extract_images_from_request(
|
||||
request: &ChatCompletionRequest,
|
||||
profile: &super::preprocess::PreprocessProfile,
|
||||
) -> anyhow::Result<Vec<super::device_worker::jobs::ImageInput>> {
|
||||
let mut out = Vec::new();
|
||||
for msg in &request.messages {
|
||||
if let MessageContent::Parts(parts) = &msg.content {
|
||||
for part in parts {
|
||||
let kind = part.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if kind != "image_url" {
|
||||
continue;
|
||||
}
|
||||
let url = part
|
||||
.get("image_url")
|
||||
.and_then(|v| v.get("url"))
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("image_url part missing url field"))?;
|
||||
let pixels = super::preprocess::preprocess_data_uri(url, profile)
|
||||
.with_context(|| format!("preprocess image #{}", out.len()))?;
|
||||
out.push(super::device_worker::jobs::ImageInput {
|
||||
pixels,
|
||||
c: 3,
|
||||
h: profile.target_height as usize,
|
||||
w: profile.target_width as usize,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Expand each occurrence of `image_token_id` in `input_ids` into
|
||||
/// `patches_per_image[i]` copies (one expansion per image, in order).
|
||||
/// Stage B4 helper.
|
||||
///
|
||||
/// The chat template emits a single `<|image_pad|>` per image; this
|
||||
/// is what fits Qwen3-VL's template-then-runtime-expansion convention.
|
||||
/// The runtime (us) is responsible for replacing each one with N
|
||||
/// copies based on the corresponding image's patch count.
|
||||
///
|
||||
/// For Stage B fixed resolution every entry of `patches_per_image`
|
||||
/// is the same constant (196 at 448×448). For dynamic resolution
|
||||
/// (issue #14) each image gets its own value.
|
||||
///
|
||||
/// Errors if the number of `image_token_id` occurrences in `input_ids`
|
||||
/// doesn't equal `patches_per_image.len()` — a mismatch means the
|
||||
/// template emitted the wrong number of pad tokens (operator-visible
|
||||
/// template bug, not a runtime error).
|
||||
fn expand_image_pad_tokens(
|
||||
input_ids: &[u32],
|
||||
image_token_id: u32,
|
||||
patches_per_image: &[usize],
|
||||
) -> anyhow::Result<Vec<u32>> {
|
||||
let occurrences = input_ids.iter().filter(|&&t| t == image_token_id).count();
|
||||
if occurrences != patches_per_image.len() {
|
||||
anyhow::bail!(
|
||||
"expand_image_pad_tokens: prompt has {occurrences} image_token_id occurrences but \
|
||||
{} images were preprocessed — chat template emitted the wrong number of pad tokens",
|
||||
patches_per_image.len()
|
||||
);
|
||||
}
|
||||
let total_extra: usize = patches_per_image.iter().map(|n| n.saturating_sub(1)).sum();
|
||||
let mut out = Vec::with_capacity(input_ids.len() + total_extra);
|
||||
let mut img_idx = 0;
|
||||
for &t in input_ids {
|
||||
if t == image_token_id {
|
||||
let n = patches_per_image[img_idx];
|
||||
for _ in 0..n {
|
||||
out.push(image_token_id);
|
||||
}
|
||||
img_idx += 1;
|
||||
} else {
|
||||
out.push(t);
|
||||
}
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Apply the Qwen3 chat template:
|
||||
///
|
||||
/// ```text
|
||||
@@ -3544,6 +3863,103 @@ fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
|
||||
/// would only add channel overhead with no diagnostic benefit.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Vision-aware analogue of `run_inference_via_worker`. Stage B5.
|
||||
///
|
||||
/// Single-shot prefill carrying the pre-expanded prompt + the image
|
||||
/// payloads. The worker encodes each image through the vision tower,
|
||||
/// splices the resulting embeddings at `image_token_id` positions,
|
||||
/// and returns the last-position logits. Decode steps thereafter
|
||||
/// follow the existing text-only `forward_logits` path — the KV
|
||||
/// cache holds the image-conditioned hidden states from prefill, so
|
||||
/// no further splicing is needed.
|
||||
///
|
||||
/// Stage B skips chunked prefill for vision (the fixed-resolution
|
||||
/// budget — 196 image tokens at 448×448 + typical text — stays well
|
||||
/// under the activation-memory threshold). Long-prompt-with-images
|
||||
/// chunking is a Stage D follow-up.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn run_inference_with_images_via_worker(
|
||||
worker: &super::device_worker::DeviceWorkerHandle,
|
||||
handle: super::device_worker::ArchHandle,
|
||||
prompt_tokens: &[u32],
|
||||
images: Vec<super::device_worker::jobs::ImageInput>,
|
||||
image_token_id: u32,
|
||||
max_new: usize,
|
||||
temperature: f64,
|
||||
top_p: Option<f64>,
|
||||
seed: u64,
|
||||
eos_id: Option<u32>,
|
||||
) -> Result<(Vec<u32>, String)> {
|
||||
let mut logits_processor = {
|
||||
let sampling = if temperature <= 0.0 {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match top_p {
|
||||
Some(p) => Sampling::TopP { p, temperature },
|
||||
None => Sampling::All { temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(seed, sampling)
|
||||
};
|
||||
|
||||
let mut generated: Vec<u32> = Vec::new();
|
||||
let prompt_len = prompt_tokens.len();
|
||||
|
||||
worker
|
||||
.clear_kv_cache(handle)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
|
||||
|
||||
// Single-shot prefill with image splicing.
|
||||
let logits_vec = worker
|
||||
.forward_logits_with_images(handle, prompt_tokens.to_vec(), 0, images, image_token_id)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("forward_logits_with_images: {e}"))?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
?health,
|
||||
"chat_completion (worker, vision): prefill sample failed; logits unhealthy"
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
if Some(next_token) == eos_id {
|
||||
return Ok((generated, "stop".into()));
|
||||
}
|
||||
generated.push(next_token);
|
||||
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
let logits_vec = worker
|
||||
.forward_logits(handle, vec![next_token], prompt_len + index)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
step = index,
|
||||
?health,
|
||||
"chat_completion (worker, vision): decode sample failed; logits unhealthy"
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
if Some(next_token) == eos_id {
|
||||
return Ok((generated, "stop".into()));
|
||||
}
|
||||
generated.push(next_token);
|
||||
}
|
||||
Ok((generated, "length".into()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
async fn run_inference_via_worker(
|
||||
worker: &super::device_worker::DeviceWorkerHandle,
|
||||
handle: super::device_worker::ArchHandle,
|
||||
@@ -4243,4 +4659,44 @@ mod tests {
|
||||
.expect("synth huggingface source should build");
|
||||
assert_eq!(harness.default_source_scheme(), "huggingface");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_image_pad_replaces_single_token_with_n_copies() {
|
||||
// Mimics the chat template's output: each image emits
|
||||
// [vision_start, image_pad, vision_end]. After expansion
|
||||
// with 3 patches/image we want
|
||||
// [vision_start, pad×3, vision_end].
|
||||
let pad = 248056_u32;
|
||||
let vstart = 248053_u32;
|
||||
let vend = 248054_u32;
|
||||
let input = vec![1, vstart, pad, vend, 2];
|
||||
let out = expand_image_pad_tokens(&input, pad, &[3]).unwrap();
|
||||
assert_eq!(out, vec![1, vstart, pad, pad, pad, vend, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_image_pad_handles_multiple_images() {
|
||||
let pad = 248056_u32;
|
||||
// Two images in one prompt; first gets 2 patches, second 3.
|
||||
let input = vec![pad, 99, pad];
|
||||
let out = expand_image_pad_tokens(&input, pad, &[2, 3]).unwrap();
|
||||
assert_eq!(out, vec![pad, pad, 99, pad, pad, pad]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_image_pad_errors_on_count_mismatch() {
|
||||
let pad = 248056_u32;
|
||||
// Prompt has 2 pad tokens but caller supplied 3 images.
|
||||
let input = vec![pad, 99, pad];
|
||||
let err = expand_image_pad_tokens(&input, pad, &[2, 3, 4]).unwrap_err();
|
||||
assert!(format!("{err:#}").contains("emitted the wrong number"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_image_pad_passes_through_when_no_images() {
|
||||
let pad = 248056_u32;
|
||||
let input = vec![1, 2, 3];
|
||||
let out = expand_image_pad_tokens(&input, pad, &[]).unwrap();
|
||||
assert_eq!(out, input);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,10 +16,11 @@
|
||||
use crate::harness::candle::ModelArch;
|
||||
#[cfg(feature = "cuda")]
|
||||
use crate::harness::device_worker::jobs::TpHandle;
|
||||
use crate::harness::device_worker::jobs::{ArchHandle, Job};
|
||||
use crate::harness::device_worker::jobs::{ArchHandle, ImageInput, Job};
|
||||
#[cfg(feature = "cuda")]
|
||||
use crate::harness::tp::TpLeaderModel;
|
||||
use crate::harness::tp::nccl_state::NcclState;
|
||||
use anyhow::Context as _;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
@@ -169,6 +170,24 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
||||
let result = encode_image(&mut state, handle, pixels, c, h, w);
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
Job::ForwardLogitsWithImages {
|
||||
handle,
|
||||
tokens,
|
||||
offset,
|
||||
images,
|
||||
image_token_id,
|
||||
reply,
|
||||
} => {
|
||||
let result = forward_logits_with_images(
|
||||
&mut state,
|
||||
handle,
|
||||
&tokens,
|
||||
offset,
|
||||
images,
|
||||
image_token_id,
|
||||
);
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
Job::NcclInit {
|
||||
cfg,
|
||||
comm_id_hex,
|
||||
@@ -751,6 +770,67 @@ fn forward_logits(
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
/// Run the LM forward with vision-tower image splicing. Stage B3.
|
||||
///
|
||||
/// Encodes each image through the vision tower (`VisionTower::forward`,
|
||||
/// dispatched via `ModelArch::encode_image`), concatenates the
|
||||
/// resulting embeddings into a single `(N_total, hidden)` tensor, and
|
||||
/// passes it to `ModelArch::forward_with_vision` along with the
|
||||
/// prompt-expanded `tokens`. Image embeddings never leave the device.
|
||||
///
|
||||
/// Returns CPU `[vocab]` logits — same shape contract as
|
||||
/// `ForwardLogits` so the async sampler doesn't have to branch on the
|
||||
/// presence of images.
|
||||
fn forward_logits_with_images(
|
||||
state: &mut DeviceWorkerState,
|
||||
handle: ArchHandle,
|
||||
tokens: &[u32],
|
||||
offset: usize,
|
||||
images: Vec<ImageInput>,
|
||||
image_token_id: u32,
|
||||
) -> anyhow::Result<Vec<f32>> {
|
||||
use candle_core::{DType, Tensor};
|
||||
|
||||
if images.is_empty() {
|
||||
anyhow::bail!("ForwardLogitsWithImages dispatched with zero images");
|
||||
}
|
||||
|
||||
let arch = state.models.get_mut(&handle).ok_or_else(|| {
|
||||
anyhow::anyhow!("ForwardLogitsWithImages: no model for handle {}", handle.0)
|
||||
})?;
|
||||
|
||||
// Encode every image on the worker's device, collecting per-image
|
||||
// post-merger embeddings as device-resident tensors.
|
||||
let mut per_image: Vec<Tensor> = Vec::with_capacity(images.len());
|
||||
for (idx, img) in images.into_iter().enumerate() {
|
||||
anyhow::ensure!(
|
||||
img.pixels.len() == img.c * img.h * img.w,
|
||||
"ForwardLogitsWithImages: image[{idx}] pixels length {} does not match shape ({}, {}, {})",
|
||||
img.pixels.len(),
|
||||
img.c,
|
||||
img.h,
|
||||
img.w,
|
||||
);
|
||||
let image = Tensor::from_vec(img.pixels, (img.c, img.h, img.w), &state.device)?;
|
||||
let embed = arch
|
||||
.encode_image(&image)
|
||||
.with_context(|| format!("encode image[{idx}]"))?;
|
||||
per_image.push(embed);
|
||||
}
|
||||
// Concatenate per-image embeddings along the patch axis →
|
||||
// (sum_of_patches, hidden). `Tensor::cat` keeps the result
|
||||
// device-resident.
|
||||
let image_embeds = Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)?;
|
||||
|
||||
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
|
||||
let logits = arch.forward_with_vision(&input, offset, &image_embeds, image_token_id)?;
|
||||
let values = logits
|
||||
.to_dtype(DType::F32)?
|
||||
.flatten_all()?
|
||||
.to_vec1::<f32>()?;
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
/// Run the vision tower on a single preprocessed image. Stage A5.
|
||||
///
|
||||
/// `pixels` is a row-major `(c, h, w)` f32 image that the async-side
|
||||
@@ -830,6 +910,9 @@ fn drain_poisoned(job: Job, device_index: u32) {
|
||||
Job::EncodeImage { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::ForwardLogitsWithImages { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::NcclInit { reply, .. } => {
|
||||
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
||||
kind: "device_worker_poisoned".into(),
|
||||
|
||||
@@ -28,6 +28,17 @@ pub struct ArchHandle(pub u64);
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct TpHandle(pub u64);
|
||||
|
||||
/// One image payload for `Job::ForwardLogitsWithImages` /
|
||||
/// `Job::EncodeImage`. Pixels are row-major `(c, h, w)` f32 — the
|
||||
/// shape `harness::preprocess::preprocess` produces. Carries the
|
||||
/// shape inline since `Vec<f32>` is rank-1.
|
||||
pub struct ImageInput {
|
||||
pub pixels: Vec<f32>,
|
||||
pub c: usize,
|
||||
pub h: usize,
|
||||
pub w: usize,
|
||||
}
|
||||
|
||||
/// One unit of work for the device worker.
|
||||
///
|
||||
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
|
||||
@@ -94,6 +105,33 @@ pub enum Job {
|
||||
offset: usize,
|
||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||
},
|
||||
/// Run the LM forward with vision splicing in one round-trip.
|
||||
/// Stage B3 of the vision plan.
|
||||
///
|
||||
/// Inputs:
|
||||
/// - `tokens`: prompt-expanded token ids (the caller has already
|
||||
/// replaced each `<|image_pad|>` with N copies per the
|
||||
/// per-image patch count, so `tokens` already contains exactly
|
||||
/// `sum(n_i)` `image_token_id` entries across all images).
|
||||
/// - `offset`: KV-cache position (same contract as `ForwardLogits`).
|
||||
/// - `images`: one entry per image — preprocessed pixels plus the
|
||||
/// `(c, h, w)` shape. Images are encoded on the worker via the
|
||||
/// model's vision tower (`VisionTower::forward`), concatenated
|
||||
/// in order, and spliced into the LM input embeddings at
|
||||
/// `image_token_id` positions.
|
||||
/// - `image_token_id`: the sentinel token (248056 for Qwen3.6).
|
||||
///
|
||||
/// Returns flat CPU `[vocab]` logits, same as `ForwardLogits`.
|
||||
/// Image embeddings stay device-resident — they're never copied
|
||||
/// to CPU. The "tensors don't escape the worker" invariant holds.
|
||||
ForwardLogitsWithImages {
|
||||
handle: ArchHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
images: Vec<ImageInput>,
|
||||
image_token_id: u32,
|
||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||
},
|
||||
/// Encode one image through the model's vision tower. Stage A5 of
|
||||
/// the vision plan (`doc/vision-qwen3_6-spec.md`).
|
||||
///
|
||||
|
||||
@@ -313,6 +313,47 @@ impl DeviceWorkerHandle {
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward with image-aware splicing in one round-trip. Stage B3.
|
||||
///
|
||||
/// Encodes each image on the worker thread (device-resident), then
|
||||
/// runs the LM forward with the embeddings spliced at
|
||||
/// `image_token_id` positions. Returns CPU `[vocab]` logits, same
|
||||
/// shape as `forward_logits`. Image embeddings never copy back to
|
||||
/// CPU.
|
||||
pub async fn forward_logits_with_images(
|
||||
&self,
|
||||
handle: ArchHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
images: Vec<crate::harness::device_worker::jobs::ImageInput>,
|
||||
image_token_id: u32,
|
||||
) -> Result<Vec<f32>, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::ForwardLogitsWithImages {
|
||||
handle,
|
||||
tokens,
|
||||
offset,
|
||||
images,
|
||||
image_token_id,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode a preprocessed image through the model's vision tower
|
||||
/// and return the resulting LM-side image embeddings as a
|
||||
/// flattened CPU `Vec<f32>`. Stage A5.
|
||||
|
||||
Reference in New Issue
Block a user