feat(neuron): Stage B — end-to-end text+image chat for Qwen3.6
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 33s
CI / CUDA type-check (push) Failing after 46s
CI / Clippy (push) Successful in 2m37s
build-prerelease / Build cortex binary (push) Successful in 4m32s
build-prerelease / Build neuron-blackwell (push) Failing after 5m35s
CI / Test (push) Successful in 6m40s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-ampere (push) Failing after 7m46s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
build-prerelease / Build neuron-ada (push) Failing after 4m51s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped

Stage B of the vision plan (doc/vision-qwen3_6-spec.md). Wires
the vision tower from Stage A through to a complete non-streaming
chat completion: extract images from the request, preprocess,
encode on the worker thread, splice embeddings into the LM input
at `<|image_pad|>` positions, return coherent text response with
`prompt_tokens` reflecting patch tokens.

Closes the silent-drop class of failures from issue #3 — vision
requests against Qwen3.6 now condition the model on the image
instead of producing confident text-only hallucinations.

Streaming for vision is Stage C. Deferred items tracked under
#12 (TP-vision), #13 (27B production), #14 (dynamic resolution),
#15 (numerical validation).

What landed:

- **B1 — `Qwen3_5Model::forward_with_vision`**: text-only `forward`
  unchanged; new method takes `(input_ids, offset, image_embeds,
  image_token_id)`, embeds tokens, locates `image_token_id`
  positions, splices via the new `splice_runs` helper. MRoPE
  applies text-positions to image tokens for Stage B (spatial
  MRoPE is the issue #15 numerical-validation follow-up). 2 unit
  tests for `splice_runs` covering contiguous + non-contiguous
  runs.

- **B2 — `ModelArch::forward_with_vision` dispatch**: routes
  Qwen3_5Dense to the new method; other arches return an error.
  Defence-in-depth — the HTTP layer (B6) already rejects image
  content for non-vision models.

- **B3 — `Job::ForwardLogitsWithImages`**: new worker variant
  carrying tokens + per-image `(pixels, c, h, w)` payloads. The
  dispatcher encodes each image (device-resident), concatenates
  the resulting embeddings, calls `arch.forward_with_vision`, and
  returns CPU logits. Image embeddings never copy back to CPU —
  the "tensors don't escape the worker" invariant from the
  per-device worker refactor still holds. Poisoned-worker drain
  path handles the new variant.

- **B4 — Prompt builder**:
  - `request_has_images` detects image content cheaply.
  - `extract_images_from_request(request, profile)` walks
    `MessageContent::Parts`, decodes data URIs, runs
    `harness::preprocess::preprocess` per image, returns
    `Vec<ImageInput>` in request order.
  - `expand_image_pad_tokens(input_ids, image_token_id,
    patches_per_image)` walks the tokenized prompt and replaces
    each `<|image_pad|>` (id 248056 for Qwen3.6) with N copies
    matching the per-image patch count. 4 unit tests.
  - `VisionMeta::from_config_path` peeks `config.json` at load
    time for `image_token_id`, vision_config patch/merge sizes,
    and derives `lm_tokens_per_image` for the Stage B fixed
    resolution.

- **B5 — `chat_completion` vision routing**: detects image
  content, validates the loaded model has vision, expands the
  prompt, and calls a new `run_inference_with_images_via_worker`
  helper that does single-shot prefill + standard decode loop
  (KV cache holds the post-splice hidden states from prefill, so
  decode steps don't re-splice). Stage B skips chunked prefill
  for vision — at 448×448 fixed resolution the budget stays well
  under the activation-memory threshold. Long-vision chunking is
  Stage D follow-up.

- **B6 — `InferenceError::VisionUnsupported`**: structured 400
  with `code=vision_unsupported, model_id, suggestion` when an
  image request hits a non-vision model. Closes the agent0
  failure mode where vision requests degraded silently.

- **B7 — `ModelInfo.capabilities`**: per-model array (`["text"]`
  vs `["text", "vision"]`) in `/v1/models` and forwarded verbatim
  by cortex-gateway. Lets clients (litellm, agent0) gate
  image_url submission on the declared capability set. Optional
  in the wire format; defaults to empty for older clients.

CI gate: cargo fmt --check, cargo clippy --workspace --all-targets
-- -D warnings, cargo test --workspace (all 28 test groups ok,
124 lib tests). New unit-test counts: +2 splice_runs, +4
expand_image_pad.

Manual verification (after RPMs deploy on beast):

  curl http://hanzalova.internal:31313/v1/chat/completions \
    -H 'Content-Type: application/json' \
    -d "{\"model\":\"Qwen/Qwen3.6-27B\", \"messages\":[{\"role\":\"user\",\"content\":[
      {\"type\":\"text\",\"text\":\"What's in this image?\"},
      {\"type\":\"image_url\",\"image_url\":{\"url\":\"data:image/jpeg;base64,...\"}}
    ]}], \"max_tokens\":120}" | jq

  Expect prompt_tokens > 196 (text + 196 patch tokens) and a
  response that references actual image content.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-06-02 15:33:00 +03:00
parent 7df84fed8f
commit 24968e9233
7 changed files with 904 additions and 19 deletions

View File

@@ -105,6 +105,22 @@ impl LoadedHandle {
LoadedHandle::Tp(m) => m.poisoned.load(Ordering::Acquire),
}
}
/// Modalities the loaded model supports. Stage B7. TP models are
/// always text-only today — TP-vision is tracked under issue #12.
pub fn capabilities(&self) -> Vec<String> {
let mut caps = vec!["text".to_string()];
match self {
LoadedHandle::Single(m) => {
if m.has_vision {
caps.push("vision".to_string());
}
}
#[cfg(feature = "cuda")]
LoadedHandle::Tp(_) => {}
}
caps
}
}
/// A loaded model with its tokenizer, device placement, and architecture-
@@ -182,6 +198,25 @@ pub struct LoadedModel {
/// is used. The `NEURON_USE_CHAT_TEMPLATE=false` env var
/// forces the fallback path even when `Some`.
pub chat_template: Option<String>,
/// Vision capability flag derived at load time. `true` iff the
/// loaded `ModelArch` exposes a vision tower (Stage A4 wires this
/// from `Qwen3_5ForCausalLM::has_vision`). Used by the chat
/// completion handler to reject image content on non-vision
/// models with a structured 400 (Stage B6) and by `/v1/models`
/// to advertise `capabilities: ["text", "vision"]` (Stage B7).
pub has_vision: bool,
/// `<|image_pad|>` token id from `config.json::image_token_id`.
/// The Stage B prompt-builder uses this to compute expansion
/// targets and the worker forward uses it to locate splice
/// positions in the LM input embeddings.
pub image_token_id: Option<u32>,
/// LM-side tokens this model's vision tower emits per image at
/// the Stage B fixed resolution (448×448 → 196 for Qwen3.6).
/// `None` for text-only models. Set at load time so the
/// hot path doesn't recompute it per request. Stage B fixed
/// resolution → constant; dynamic resolution per #14 makes it
/// per-image.
pub lm_tokens_per_image: Option<usize>,
}
impl LoadedModel {
@@ -333,6 +368,35 @@ impl ModelArch {
}
}
/// Forward step that splices vision-tower output at
/// `<|image_pad|>` token positions. Stage B2.
///
/// Only `Qwen3_5Dense` supports this — other architectures error
/// because they don't have a vision tower. The HTTP layer is
/// expected to have rejected image content for non-vision models
/// already (Stage B6); this is a defence-in-depth error path.
///
/// Returns rank-1 `[vocab_size]` logits, same shape contract as
/// `forward`.
pub fn forward_with_vision(
&mut self,
input: &Tensor,
offset: usize,
image_embeds: &Tensor,
image_token_id: u32,
) -> Result<Tensor> {
let raw = match self {
ModelArch::Qwen3_5Dense(m) => {
m.forward_with_vision(input, offset, image_embeds, image_token_id)?
}
other => anyhow::bail!(
"forward_with_vision: architecture {} has no vision tower",
std::any::type_name_of_val(other)
),
};
squeeze_to_vocab(&raw)
}
/// Encode a preprocessed image into LM-side token embeddings via
/// the loaded vision tower. Stage A5.
///
@@ -1548,9 +1612,54 @@ impl CandleHarness {
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let prompt_len = prompt_tokens.len();
let mut prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
// Stage B: when the request carries images, preprocess
// them, expand each `<|image_pad|>` sentinel to N copies
// matching the per-image patch count, and route to the
// vision-aware worker path. Non-image requests skip all
// of this and follow the existing text-only flow.
let vision_route = if request_has_images(&request) {
// Stage B6: surface a structured `vision_unsupported`
// rejection when the request asks for vision against a
// text-only model. Cheap and stops the issue-#3 silent-
// drop pattern.
if !loaded.has_vision {
return Err(InferenceError::VisionUnsupported {
model_id: request.model.clone(),
});
}
let image_token_id = loaded
.image_token_id
.ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(),
})?;
let patches_per_image = loaded
.lm_tokens_per_image
.ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(),
})?;
let profile = super::preprocess::PreprocessProfile::qwen3_6();
let images = extract_images_from_request(&request, &profile).map_err(|e| {
InferenceError::Other(anyhow::anyhow!("extract_images: {e}"))
})?;
if images.is_empty() {
// request_has_images said true but extract returned
// empty — defensive bail rather than silently dropping.
return Err(InferenceError::Other(anyhow::anyhow!(
"request has image content but extractor produced zero images"
)));
}
let per_image_counts: Vec<usize> = vec![patches_per_image; images.len()];
prompt_tokens =
expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
.map_err(InferenceError::Other)?;
Some((images, image_token_id))
} else {
None
};
let prompt_len = prompt_tokens.len();
let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(8192) as usize;
@@ -1570,6 +1679,7 @@ impl CandleHarness {
?eos_id,
vram_free_mb,
vram_total_mb,
vision = vision_route.is_some(),
"chat_completion: starting"
);
@@ -1588,18 +1698,37 @@ impl CandleHarness {
// Worker path (CUDA).
#[cfg(feature = "cuda")]
{
match run_inference_via_worker(
worker,
handle,
&prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
)
.await
{
let result = match &vision_route {
Some((images, image_token_id)) => {
run_inference_with_images_via_worker(
worker,
handle,
&prompt_tokens,
images.clone(),
*image_token_id,
max_new,
temperature,
top_p,
seed,
eos_id,
)
.await
}
None => {
run_inference_via_worker(
worker,
handle,
&prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
)
.await
}
};
match result {
Ok(v) => v,
Err(e) => {
let chain = format!("{e:#}");
@@ -2120,6 +2249,7 @@ impl Harness for CandleHarness {
},
devices: h.devices(),
vram_used_mb: None,
capabilities: h.capabilities(),
})
.collect())
}
@@ -2189,7 +2319,7 @@ impl Harness for CandleHarness {
_ => None,
};
let (tokenizer_path, arch_local, arch_handle) = if let Some(w) = &worker {
let (tokenizer_path, arch_local, arch_handle, vision_meta) = if let Some(w) = &worker {
// CUDA path: resolve, then load in the worker.
if spec.quant.is_some() {
let (gguf_path, tokenizer_path) = self.resolve_files(spec, &source_id).await?;
@@ -2197,15 +2327,19 @@ impl Harness for CandleHarness {
.load_gguf(gguf_path, spec.model_id.clone())
.await
.map_err(|e| anyhow::anyhow!("worker load_gguf: {e}"))?;
(tokenizer_path, None, Some(handle))
// GGUF Qwen3.6 releases don't ship the vision tower
// (Qwen-VL weights are in the dense safetensors only),
// so a GGUF load is text-only by construction.
(tokenizer_path, None, Some(handle), VisionMeta::default())
} else {
let (config_path, tokenizer_path, safetensors_paths) =
self.resolve_dense_files(spec, &source_id).await?;
let meta = VisionMeta::from_config_path(&config_path);
let handle = w
.load_dense(config_path, safetensors_paths, spec.model_id.clone())
.await
.map_err(|e| anyhow::anyhow!("worker load_dense: {e}"))?;
(tokenizer_path, None, Some(handle))
(tokenizer_path, None, Some(handle), meta)
}
} else {
// CPU path: legacy spawn_blocking + Arc<Mutex<ModelArch>>.
@@ -2214,7 +2348,16 @@ impl Harness for CandleHarness {
} else {
self.load_arch_dense(spec, &source_id, &device).await?
};
(tokenizer_path, Some(Arc::new(Mutex::new(arch))), None)
// CPU Qwen3.6 isn't a supported deployment target — the
// 27B doesn't fit any reasonable CPU memory budget — so
// we don't attempt to reach into the arch for vision
// metadata. Stays text-only.
(
tokenizer_path,
Some(Arc::new(Mutex::new(arch))),
None,
VisionMeta::default(),
)
};
let tokenizer = Tokenizer::from_file(&tokenizer_path)
@@ -2278,6 +2421,9 @@ impl Harness for CandleHarness {
reasoning_tokens,
tool_call_tokens,
chat_template,
has_vision: vision_meta.has_vision,
image_token_id: vision_meta.image_token_id,
lm_tokens_per_image: vision_meta.lm_tokens_per_image,
});
let mut models = self.models.write().await;
@@ -3434,6 +3580,16 @@ pub enum InferenceError {
"insufficient free VRAM for prefill: {free_mb} MiB free, need at least {required_mb} MiB"
)]
InsufficientVram { free_mb: u64, required_mb: u64 },
/// Request carried `image_url` content but the loaded model has
/// no vision tower. Stage B6 — replaces the silent-drop pattern
/// from issue #3 with an explicit 400 + `vision_unsupported`
/// error body that clients (litellm, agent0, …) can act on.
#[error(
"model '{model_id}' does not support image input; \
load a vision-capable model (e.g. Qwen/Qwen3.6-27B) or \
remove the image_url content parts from the request"
)]
VisionUnsupported { model_id: String },
#[error(transparent)]
Other(#[from] anyhow::Error),
}
@@ -3498,6 +3654,169 @@ fn build_prompt_for_request(
}
}
/// Vision metadata derived at model-load time. Stashed on
/// `LoadedModel` so the chat-completion hot path doesn't have to
/// re-parse `config.json` or reach across the worker thread to peek
/// at the loaded `ModelArch`.
#[derive(Debug, Default, Clone, Copy)]
struct VisionMeta {
has_vision: bool,
image_token_id: Option<u32>,
/// LM-side tokens this model's vision tower emits per image at
/// the Stage B fixed `PreprocessProfile::qwen3_6()` resolution
/// (448×448). Equal to `(H/patch_size/spatial_merge_size)²`.
lm_tokens_per_image: Option<usize>,
}
impl VisionMeta {
/// Peek at `config.json` for vision-related fields. Returns the
/// default (no-vision) struct on any read/parse error — vision is
/// best-effort metadata; load can still succeed for text usage.
fn from_config_path(config_path: &std::path::Path) -> Self {
let Ok(text) = std::fs::read_to_string(config_path) else {
return Self::default();
};
let Ok(v) = serde_json::from_str::<serde_json::Value>(&text) else {
return Self::default();
};
let Some(vision_config) = v.get("vision_config") else {
return Self::default();
};
let patch_size = vision_config
.get("patch_size")
.and_then(|x| x.as_u64())
.unwrap_or(16) as usize;
let spatial_merge_size = vision_config
.get("spatial_merge_size")
.and_then(|x| x.as_u64())
.unwrap_or(2) as usize;
let image_token_id = v
.get("image_token_id")
.and_then(|x| x.as_u64())
.map(|n| n as u32);
// Compute LM tokens per image at the Stage B fixed resolution
// (PreprocessProfile::qwen3_6() → 448×448). One LM token per
// spatial-merge group of patches.
let target_h = super::preprocess::PreprocessProfile::qwen3_6().target_height as usize;
let target_w = super::preprocess::PreprocessProfile::qwen3_6().target_width as usize;
let lm_tokens_per_image = if patch_size > 0 && spatial_merge_size > 0 {
let gh = target_h / patch_size / spatial_merge_size;
let gw = target_w / patch_size / spatial_merge_size;
Some(gh * gw)
} else {
None
};
Self {
has_vision: true,
image_token_id,
lm_tokens_per_image,
}
}
}
/// True iff any message in the request carries an `image_url`
/// content part. The Stage B routing decision in `chat_completion`
/// dispatches to the vision-aware worker job when this is true.
fn request_has_images(request: &ChatCompletionRequest) -> bool {
request.messages.iter().any(|m| {
matches!(&m.content, MessageContent::Parts(parts)
if parts.iter().any(|p|
p.get("type").and_then(|v| v.as_str()) == Some("image_url")))
})
}
/// Extract `image_url` content parts from a chat request and turn
/// each one into a preprocessed `ImageInput` ready for the device
/// worker. Stage B4.
///
/// Walks `request.messages`, looking for `MessageContent::Parts` and
/// pulling out entries whose `type == "image_url"`. Each is run
/// through `harness::preprocess::decode_data_uri` + `preprocess` with
/// the supplied `profile` (Stage B always uses
/// `PreprocessProfile::qwen3_6()` — fixed 448×448 — so every image
/// produces the same patch count; dynamic resolution per issue #14
/// would parameterise this).
///
/// Returns images in the order they appear in the request, which
/// matches the order the chat template emits `<|image_pad|>` tokens.
fn extract_images_from_request(
request: &ChatCompletionRequest,
profile: &super::preprocess::PreprocessProfile,
) -> anyhow::Result<Vec<super::device_worker::jobs::ImageInput>> {
let mut out = Vec::new();
for msg in &request.messages {
if let MessageContent::Parts(parts) = &msg.content {
for part in parts {
let kind = part.get("type").and_then(|v| v.as_str()).unwrap_or("");
if kind != "image_url" {
continue;
}
let url = part
.get("image_url")
.and_then(|v| v.get("url"))
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("image_url part missing url field"))?;
let pixels = super::preprocess::preprocess_data_uri(url, profile)
.with_context(|| format!("preprocess image #{}", out.len()))?;
out.push(super::device_worker::jobs::ImageInput {
pixels,
c: 3,
h: profile.target_height as usize,
w: profile.target_width as usize,
});
}
}
}
Ok(out)
}
/// Expand each occurrence of `image_token_id` in `input_ids` into
/// `patches_per_image[i]` copies (one expansion per image, in order).
/// Stage B4 helper.
///
/// The chat template emits a single `<|image_pad|>` per image; this
/// is what fits Qwen3-VL's template-then-runtime-expansion convention.
/// The runtime (us) is responsible for replacing each one with N
/// copies based on the corresponding image's patch count.
///
/// For Stage B fixed resolution every entry of `patches_per_image`
/// is the same constant (196 at 448×448). For dynamic resolution
/// (issue #14) each image gets its own value.
///
/// Errors if the number of `image_token_id` occurrences in `input_ids`
/// doesn't equal `patches_per_image.len()` — a mismatch means the
/// template emitted the wrong number of pad tokens (operator-visible
/// template bug, not a runtime error).
fn expand_image_pad_tokens(
input_ids: &[u32],
image_token_id: u32,
patches_per_image: &[usize],
) -> anyhow::Result<Vec<u32>> {
let occurrences = input_ids.iter().filter(|&&t| t == image_token_id).count();
if occurrences != patches_per_image.len() {
anyhow::bail!(
"expand_image_pad_tokens: prompt has {occurrences} image_token_id occurrences but \
{} images were preprocessed — chat template emitted the wrong number of pad tokens",
patches_per_image.len()
);
}
let total_extra: usize = patches_per_image.iter().map(|n| n.saturating_sub(1)).sum();
let mut out = Vec::with_capacity(input_ids.len() + total_extra);
let mut img_idx = 0;
for &t in input_ids {
if t == image_token_id {
let n = patches_per_image[img_idx];
for _ in 0..n {
out.push(image_token_id);
}
img_idx += 1;
} else {
out.push(t);
}
}
Ok(out)
}
/// Apply the Qwen3 chat template:
///
/// ```text
@@ -3544,6 +3863,103 @@ fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
/// would only add channel overhead with no diagnostic benefit.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
/// Vision-aware analogue of `run_inference_via_worker`. Stage B5.
///
/// Single-shot prefill carrying the pre-expanded prompt + the image
/// payloads. The worker encodes each image through the vision tower,
/// splices the resulting embeddings at `image_token_id` positions,
/// and returns the last-position logits. Decode steps thereafter
/// follow the existing text-only `forward_logits` path — the KV
/// cache holds the image-conditioned hidden states from prefill, so
/// no further splicing is needed.
///
/// Stage B skips chunked prefill for vision (the fixed-resolution
/// budget — 196 image tokens at 448×448 + typical text — stays well
/// under the activation-memory threshold). Long-prompt-with-images
/// chunking is a Stage D follow-up.
#[allow(clippy::too_many_arguments)]
async fn run_inference_with_images_via_worker(
worker: &super::device_worker::DeviceWorkerHandle,
handle: super::device_worker::ArchHandle,
prompt_tokens: &[u32],
images: Vec<super::device_worker::jobs::ImageInput>,
image_token_id: u32,
max_new: usize,
temperature: f64,
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
) -> Result<(Vec<u32>, String)> {
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut generated: Vec<u32> = Vec::new();
let prompt_len = prompt_tokens.len();
worker
.clear_kv_cache(handle)
.await
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
// Single-shot prefill with image splicing.
let logits_vec = worker
.forward_logits_with_images(handle, prompt_tokens.to_vec(), 0, images, image_token_id)
.await
.map_err(|e| anyhow::anyhow!("forward_logits_with_images: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
?health,
"chat_completion (worker, vision): prefill sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
for index in 0..max_new.saturating_sub(1) {
let logits_vec = worker
.forward_logits(handle, vec![next_token], prompt_len + index)
.await
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
step = index,
?health,
"chat_completion (worker, vision): decode sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
}
Ok((generated, "length".into()))
}
#[cfg(feature = "cuda")]
async fn run_inference_via_worker(
worker: &super::device_worker::DeviceWorkerHandle,
handle: super::device_worker::ArchHandle,
@@ -4243,4 +4659,44 @@ mod tests {
.expect("synth huggingface source should build");
assert_eq!(harness.default_source_scheme(), "huggingface");
}
#[test]
fn expand_image_pad_replaces_single_token_with_n_copies() {
// Mimics the chat template's output: each image emits
// [vision_start, image_pad, vision_end]. After expansion
// with 3 patches/image we want
// [vision_start, pad×3, vision_end].
let pad = 248056_u32;
let vstart = 248053_u32;
let vend = 248054_u32;
let input = vec![1, vstart, pad, vend, 2];
let out = expand_image_pad_tokens(&input, pad, &[3]).unwrap();
assert_eq!(out, vec![1, vstart, pad, pad, pad, vend, 2]);
}
#[test]
fn expand_image_pad_handles_multiple_images() {
let pad = 248056_u32;
// Two images in one prompt; first gets 2 patches, second 3.
let input = vec![pad, 99, pad];
let out = expand_image_pad_tokens(&input, pad, &[2, 3]).unwrap();
assert_eq!(out, vec![pad, pad, 99, pad, pad, pad]);
}
#[test]
fn expand_image_pad_errors_on_count_mismatch() {
let pad = 248056_u32;
// Prompt has 2 pad tokens but caller supplied 3 images.
let input = vec![pad, 99, pad];
let err = expand_image_pad_tokens(&input, pad, &[2, 3, 4]).unwrap_err();
assert!(format!("{err:#}").contains("emitted the wrong number"));
}
#[test]
fn expand_image_pad_passes_through_when_no_images() {
let pad = 248056_u32;
let input = vec![1, 2, 3];
let out = expand_image_pad_tokens(&input, pad, &[]).unwrap();
assert_eq!(out, input);
}
}