feat(neuron): dynamic-resolution images via Qwen smart_resize (#14)
Some checks failed
CI / Clippy (push) Waiting to run
CI / Test (push) Waiting to run
CI / CUDA type-check (push) Successful in 32s
CI / Format (push) Successful in 34s
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled

Replace the fixed 448×448-square preprocess with native-aspect
`smart_resize`, and thread the resulting per-image grid through the LM
so spatial structure survives non-square images (documents, screenshots,
charts, panoramas, OCR) instead of being squished into a square.

- preprocess.rs: port Qwen `smart_resize` (factor = patch×merge = 32;
  pixel budget [min,max], default 256²–1024² → 64–1024 LM tokens).
  `PreprocessProfile` drops the fixed target dims for `factor`/`min_pixels`/
  `max_pixels`; `preprocess`/`preprocess_data_uri` now return the resized
  `(h, w)`; add `resized_dims_for_uri` (decode + resize, no normalize) for
  the TP leader's token count.
- rope.rs: `compute_mrope_index`/`get_rope_index` take per-image
  `grids: &[(lm_gh, lm_gw)]` instead of assuming a square `isqrt(run)`.
  Walk image runs in order, validate `run == gh*gw`, emit row-major
  positions, resume the shared counter at `base + max(gh,gw)`. Correct
  for multiple images of differing grids interleaved with text.
- candle.rs: `VisionMeta`/`LoadedModel`/`TpLoadedModel` carry the
  `image_grid_factor` (patch×merge) instead of the constant 196; all four
  prompt-build sites compute per-image counts from each image's resized
  grid (single-GPU from the extracted `ImageInput.h/w`, TP from
  `resized_dims_for_uri`). `ModelArch` gains `vision_grid_factor`.
- single-GPU (`mod.rs`, `dispatch.rs`) and TP
  (`tp_qwen3_5.rs::prefill_with_images_chunked`, `dispatch.rs`,
  `tp/worker.rs`) thread the grids into `get_rope_index`. Each TP rank
  recomputes grids from its own deterministic preprocess — no rpc.rs
  change, single source of truth.

The vision tower itself was already grid-general (recent pos-embed
interpolation + 2D rotary fix). No patch-count cap: pos-embed is
interpolated to any grid; `max_pixels` bounds cost (O(patches²) ViT
attention + prefill) instead.

Tests: smart_resize (aspect/cap/floor/reject), `compute_mrope_index`
non-square + two-image + mismatch cases, square-grid regression guard.
Non-cuda build + clippy + full workspace tests green; TP load/dispatch
paths are cuda-gated → Gitea CUDA type-check. Operator pixel-budget
config + remaining doc cleanup follow in C5.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-04 22:47:27 +03:00
parent dc048ffcc9
commit c97a8654f5
8 changed files with 425 additions and 169 deletions

View File

@@ -404,7 +404,7 @@ impl Qwen3_5Model {
} }
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> { pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
self.forward_inner(input, offset, None, None) self.forward_inner(input, offset, None, None, &[])
} }
/// Forward with image-embedding splice. Stage B of the vision plan. /// Forward with image-embedding splice. Stage B of the vision plan.
@@ -437,8 +437,15 @@ impl Qwen3_5Model {
offset: usize, offset: usize,
image_embeds: &Tensor, image_embeds: &Tensor,
image_token_id: u32, image_token_id: u32,
grids: &[(usize, usize)],
) -> candle_core::Result<Tensor> { ) -> candle_core::Result<Tensor> {
self.forward_inner(input_ids, offset, Some(image_embeds), Some(image_token_id)) self.forward_inner(
input_ids,
offset,
Some(image_embeds),
Some(image_token_id),
grids,
)
} }
fn forward_inner( fn forward_inner(
@@ -447,6 +454,7 @@ impl Qwen3_5Model {
offset: usize, offset: usize,
image_embeds: Option<&Tensor>, image_embeds: Option<&Tensor>,
image_token_id: Option<u32>, image_token_id: Option<u32>,
grids: &[(usize, usize)],
) -> candle_core::Result<Tensor> { ) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?; let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?; let mut h = self.embed_tokens.forward(input)?;
@@ -483,7 +491,7 @@ impl Qwen3_5Model {
h = splice_runs(&h, &img, &positions)?; h = splice_runs(&h, &img, &positions)?;
} }
let (text, height, width, delta) = rope::get_rope_index(&ids, tok_id) let (text, height, width, delta) = rope::get_rope_index(&ids, tok_id, grids)
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?; .map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
self.rope_delta = delta; self.rope_delta = delta;
let pos = rope::mrope_position_tensor(&text, &height, &width, &self.device)?; let pos = rope::mrope_position_tensor(&text, &height, &width, &self.device)?;
@@ -603,11 +611,12 @@ impl Qwen3_5ForCausalLM {
offset: usize, offset: usize,
image_embeds: &Tensor, image_embeds: &Tensor,
image_token_id: u32, image_token_id: u32,
grids: &[(usize, usize)],
) -> candle_core::Result<Tensor> { ) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?; let (_, l) = input.dims2()?;
let hidden = self let hidden =
.base self.base
.forward_with_vision(input, offset, image_embeds, image_token_id)?; .forward_with_vision(input, offset, image_embeds, image_token_id, grids)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head) hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
} }

View File

@@ -260,28 +260,40 @@ pub(crate) fn mrope_enabled() -> bool {
/// off, returns plain sequential identity positions on all three axes /// off, returns plain sequential identity positions on all three axes
/// (`mrope_cos_sin` then reduces exactly to plain RoPE), restoring the /// (`mrope_cos_sin` then reduces exactly to plain RoPE), restoring the
/// pre-M-RoPE behaviour without touching the rest of the forward. /// pre-M-RoPE behaviour without touching the rest of the forward.
pub(crate) fn get_rope_index(input_ids: &[u32], image_token_id: u32) -> Result<MRopeIndex> { pub(crate) fn get_rope_index(
input_ids: &[u32],
image_token_id: u32,
grids: &[(usize, usize)],
) -> Result<MRopeIndex> {
if !mrope_enabled() { if !mrope_enabled() {
let seq: Vec<i64> = (0..input_ids.len() as i64).collect(); let seq: Vec<i64> = (0..input_ids.len() as i64).collect();
return Ok((seq.clone(), seq.clone(), seq, 0)); return Ok((seq.clone(), seq.clone(), seq, 0));
} }
compute_mrope_index(input_ids, image_token_id) compute_mrope_index(input_ids, image_token_id, grids)
} }
/// The real interleaved-M-RoPE position-id computation (always active in /// The real interleaved-M-RoPE position-id computation (always active in
/// unit tests; gated behind [`get_rope_index`] at runtime). /// unit tests; gated behind [`get_rope_index`] at runtime).
/// ///
/// Fixed-resolution assumption (Stage C): each image run is a perfect /// `grids` carries the post-merge LM grid `(lm_gh, lm_gw)` for each image
/// square with `grid_t = 1` (still image) and `grid_h = grid_w = /// run, in prompt order — a run length alone cannot recover its
/// isqrt(run_len)` — 196 → 14×14. Dynamic resolution (#14) would thread /// factorisation, so the grids must be passed (#14 dynamic resolution).
/// real per-image grids instead. /// Each image is a still frame (`grid_t = 1`); its tokens get
pub(crate) fn compute_mrope_index(input_ids: &[u32], image_token_id: u32) -> Result<MRopeIndex> { /// `[base, base + hh, base + ww]` row-major and the shared counter
/// resumes at `base + max(lm_gh, lm_gw)`. Multi-image is correct because
/// the counter threads across images and interleaved text.
pub(crate) fn compute_mrope_index(
input_ids: &[u32],
image_token_id: u32,
grids: &[(usize, usize)],
) -> Result<MRopeIndex> {
let n = input_ids.len(); let n = input_ids.len();
let mut text = Vec::with_capacity(n); let mut text = Vec::with_capacity(n);
let mut height = Vec::with_capacity(n); let mut height = Vec::with_capacity(n);
let mut width = Vec::with_capacity(n); let mut width = Vec::with_capacity(n);
let mut counter: i64 = 0; let mut counter: i64 = 0;
let mut i = 0; let mut i = 0;
let mut k = 0; // index into `grids`, one per image run
while i < n { while i < n {
if input_ids[i] == image_token_id { if input_ids[i] == image_token_id {
let start = i; let start = i;
@@ -289,25 +301,30 @@ pub(crate) fn compute_mrope_index(input_ids: &[u32], image_token_id: u32) -> Res
i += 1; i += 1;
} }
let run = i - start; let run = i - start;
let g = run.isqrt(); let (grid_h, grid_w) = *grids.get(k).ok_or_else(|| {
if g * g != run { anyhow::anyhow!(
"get_rope_index: image run #{k} (len {run}) has no matching grid \
({} grids supplied)",
grids.len()
)
})?;
k += 1;
if grid_h * grid_w != run {
anyhow::bail!( anyhow::bail!(
"get_rope_index: image run length {run} is not a perfect square \ "get_rope_index: image run #{} length {run} != grid {grid_h}×{grid_w} = {}",
(fixed-resolution Stage C assumes a square grid; dynamic resolution is #14)" k - 1,
grid_h * grid_w
); );
} }
let (grid_t, grid_h, grid_w) = (1usize, g, g);
let base = counter; let base = counter;
for tt in 0..grid_t {
for hh in 0..grid_h { for hh in 0..grid_h {
for ww in 0..grid_w { for ww in 0..grid_w {
text.push(base + tt as i64); text.push(base); // grid_t = 1 → temporal axis const
height.push(base + hh as i64); height.push(base + hh as i64);
width.push(base + ww as i64); width.push(base + ww as i64);
} }
} }
} counter = base + grid_h.max(grid_w) as i64;
counter = base + grid_t.max(grid_h).max(grid_w) as i64;
} else { } else {
text.push(counter); text.push(counter);
height.push(counter); height.push(counter);
@@ -316,6 +333,12 @@ pub(crate) fn compute_mrope_index(input_ids: &[u32], image_token_id: u32) -> Res
i += 1; i += 1;
} }
} }
if k != grids.len() {
anyhow::bail!(
"get_rope_index: prompt has {k} image run(s) but {} grid(s) were supplied",
grids.len()
);
}
let delta = counter - n as i64; let delta = counter - n as i64;
Ok((text, height, width, delta)) Ok((text, height, width, delta))
} }
@@ -447,7 +470,7 @@ mod tests {
#[test] #[test]
fn get_rope_index_text_only_is_sequential() { fn get_rope_index_text_only_is_sequential() {
let (t, h, w, delta) = compute_mrope_index(&[1, 2, 3, 4], 99).unwrap(); let (t, h, w, delta) = compute_mrope_index(&[1, 2, 3, 4], 99, &[]).unwrap();
assert_eq!(t, vec![0, 1, 2, 3]); assert_eq!(t, vec![0, 1, 2, 3]);
assert_eq!(h, vec![0, 1, 2, 3]); assert_eq!(h, vec![0, 1, 2, 3]);
assert_eq!(w, vec![0, 1, 2, 3]); assert_eq!(w, vec![0, 1, 2, 3]);
@@ -456,12 +479,12 @@ mod tests {
#[test] #[test]
fn get_rope_index_text_image_text() { fn get_rope_index_text_image_text() {
// [text, image(2x2 run of 4), text]. image_token = 99. // [text, image(2x2 run of 4), text]. image_token = 99, grid (2,2).
let ids = [1u32, 99, 99, 99, 99, 2]; let ids = [1u32, 99, 99, 99, 99, 2];
let (t, h, w, delta) = compute_mrope_index(&ids, 99).unwrap(); let (t, h, w, delta) = compute_mrope_index(&ids, 99, &[(2, 2)]).unwrap();
// token 0: text → 0. image base=1, grid 1x2x2: // token 0: text → 0. image base=1, grid 2x2:
// t all = 1; h = base+row = [1,1,2,2]; w = base+col = [1,2,1,2]. // t all = 1; h = base+row = [1,1,2,2]; w = base+col = [1,2,1,2].
// resume from base + max(1,2,2) = 3. trailing text → 3. // resume from base + max(2,2) = 3. trailing text → 3.
assert_eq!(t, vec![0, 1, 1, 1, 1, 3]); assert_eq!(t, vec![0, 1, 1, 1, 1, 3]);
assert_eq!(h, vec![0, 1, 1, 2, 2, 3]); assert_eq!(h, vec![0, 1, 1, 2, 2, 3]);
assert_eq!(w, vec![0, 1, 2, 1, 2, 3]); assert_eq!(w, vec![0, 1, 2, 1, 2, 3]);
@@ -472,25 +495,52 @@ mod tests {
assert_eq!(6 + delta, 4); assert_eq!(6 + delta, 4);
} }
#[test]
fn get_rope_index_nonsquare_single_image() {
// text + image(2 rows × 3 cols = 6 tokens). grid (2,3).
let ids = [1u32, 99, 99, 99, 99, 99, 99];
let (t, h, w, delta) = compute_mrope_index(&ids, 99, &[(2, 3)]).unwrap();
// base = 1; row-major h = [0,0,0,1,1,1]+1, w = [0,1,2,0,1,2]+1.
assert_eq!(t, vec![0, 1, 1, 1, 1, 1, 1]);
assert_eq!(h, vec![0, 1, 1, 1, 2, 2, 2]);
assert_eq!(w, vec![0, 1, 2, 3, 1, 2, 3]);
// resume from base + max(2,3) = 4; seq_len 7, counter 4 → delta -3.
assert_eq!(delta, 4 - 7);
}
#[test]
fn get_rope_index_two_images_different_grids() {
// img(2x2)=4, text, img(1x3)=3. grids [(2,2),(1,3)].
let ids = [99, 99, 99, 99, 7, 99, 99, 99];
let (t, h, w, delta) = compute_mrope_index(&ids, 99, &[(2, 2), (1, 3)]).unwrap();
// img1 base=0 → t=0, h=[0,0,1,1], w=[0,1,0,1]; resume max(2,2)=2.
// text at counter 2. img2 base=3 → t=3, h=[3,3,3], w=[3,4,5];
// resume 3+max(1,3)=6.
assert_eq!(t, vec![0, 0, 0, 0, 2, 3, 3, 3]);
assert_eq!(h, vec![0, 0, 1, 1, 2, 3, 3, 3]);
assert_eq!(w, vec![0, 1, 0, 1, 2, 3, 4, 5]);
assert_eq!(delta, 6 - 8);
}
#[test] #[test]
fn get_rope_index_on_by_default() { fn get_rope_index_on_by_default() {
// With NEURON_MROPE unset (default ON), the runtime path returns // With NEURON_MROPE unset (default ON), the runtime path returns
// the real interleaved-M-RoPE positions, so image tokens carry // the real interleaved-M-RoPE positions. (NEURON_MROPE=0 would fall
// their 2D grid coords (height differs from the text counter). // back to identity; not asserted here since it depends on env.)
// (NEURON_MROPE=0 would fall back to identity; not asserted here let (t, h, w, _delta) = get_rope_index(&[1, 99, 99, 99, 99, 2], 99, &[(2, 2)]).unwrap();
// since it depends on env.)
let (t, h, w, _delta) = get_rope_index(&[1, 99, 99, 99, 99, 2], 99).unwrap();
// Same as compute_mrope_index: 2x2 image after one text token.
assert_eq!(t, vec![0, 1, 1, 1, 1, 3]); assert_eq!(t, vec![0, 1, 1, 1, 1, 3]);
assert_eq!(h, vec![0, 1, 1, 2, 2, 3]); assert_eq!(h, vec![0, 1, 1, 2, 2, 3]);
assert_eq!(w, vec![0, 1, 2, 1, 2, 3]); assert_eq!(w, vec![0, 1, 2, 1, 2, 3]);
} }
#[test] #[test]
fn get_rope_index_rejects_non_square_image_run() { fn get_rope_index_grid_mismatches_error() {
// 196 is square (14x14) — ok. 195 is not. // run length != grid product.
assert!(compute_mrope_index(&[99u32; 196], 99).is_ok()); assert!(compute_mrope_index(&[99u32; 6], 99, &[(2, 2)]).is_err());
assert!(compute_mrope_index(&[99u32; 195], 99).is_err()); // too few grids for the number of image runs.
assert!(compute_mrope_index(&[99, 99, 7, 99], 99, &[(1, 2)]).is_err());
// too many grids.
assert!(compute_mrope_index(&[99, 99], 99, &[(1, 2), (1, 1)]).is_err());
} }
#[test] #[test]
@@ -501,7 +551,7 @@ mod tests {
let dev = Device::Cpu; let dev = Device::Cpu;
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap(); let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
let ids = [1u32, 99, 99, 99, 99]; // text + 2x2 image let ids = [1u32, 99, 99, 99, 99]; // text + 2x2 image
let (t, h, w, _d) = compute_mrope_index(&ids, 99).unwrap(); let (t, h, w, _d) = compute_mrope_index(&ids, 99, &[(2, 2)]).unwrap();
let pos = mrope_position_tensor(&t, &h, &w, &dev).unwrap(); let pos = mrope_position_tensor(&t, &h, &w, &dev).unwrap();
assert_eq!(pos.dims(), &[3, 5]); assert_eq!(pos.dims(), &[3, 5]);
let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap(); let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap();
@@ -518,7 +568,7 @@ mod tests {
fn get_rope_index_196_is_14x14() { fn get_rope_index_196_is_14x14() {
let mut ids = vec![1u32]; // one text token let mut ids = vec![1u32]; // one text token
ids.extend(std::iter::repeat_n(99u32, 196)); ids.extend(std::iter::repeat_n(99u32, 196));
let (t, h, w, _delta) = compute_mrope_index(&ids, 99).unwrap(); let (t, h, w, _delta) = compute_mrope_index(&ids, 99, &[(14, 14)]).unwrap();
// image base = 1. Last image token (index 196) is grid (h=13,w=13). // image base = 1. Last image token (index 196) is grid (h=13,w=13).
assert_eq!(*t.last().unwrap(), 1, "grid_t=1 → temporal const at base"); assert_eq!(*t.last().unwrap(), 1, "grid_t=1 → temporal const at base");
assert_eq!(h[1], 1, "first image row at base"); assert_eq!(h[1], 1, "first image row at base");

View File

@@ -210,13 +210,11 @@ pub struct LoadedModel {
/// targets and the worker forward uses it to locate splice /// targets and the worker forward uses it to locate splice
/// positions in the LM input embeddings. /// positions in the LM input embeddings.
pub image_token_id: Option<u32>, pub image_token_id: Option<u32>,
/// LM-side tokens this model's vision tower emits per image at /// `patch_size × spatial_merge_size` — divides a resized pixel
/// the Stage B fixed resolution (448×448 → 196 for Qwen3.6). /// dimension into LM-grid units. Per-image LM token count is
/// `None` for text-only models. Set at load time so the /// `(h/factor) × (w/factor)` (#14 dynamic resolution). `None` for
/// hot path doesn't recompute it per request. Stage B fixed /// text-only models. Set at load time.
/// resolution → constant; dynamic resolution per #14 makes it pub image_grid_factor: Option<usize>,
/// per-image.
pub lm_tokens_per_image: Option<usize>,
} }
impl LoadedModel { impl LoadedModel {
@@ -288,9 +286,9 @@ pub struct TpLoadedModel {
pub has_vision: bool, pub has_vision: bool,
/// `<|image_pad|>` token id — same as [`LoadedModel::image_token_id`]. /// `<|image_pad|>` token id — same as [`LoadedModel::image_token_id`].
pub image_token_id: Option<u32>, pub image_token_id: Option<u32>,
/// LM-side tokens per image at the fixed 448×448 resolution — same /// Pixel→LM-grid divisor — same as
/// as [`LoadedModel::lm_tokens_per_image`]. /// [`LoadedModel::image_grid_factor`].
pub lm_tokens_per_image: Option<usize>, pub image_grid_factor: Option<usize>,
} }
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
@@ -394,10 +392,11 @@ impl ModelArch {
offset: usize, offset: usize,
image_embeds: &Tensor, image_embeds: &Tensor,
image_token_id: u32, image_token_id: u32,
grids: &[(usize, usize)],
) -> Result<Tensor> { ) -> Result<Tensor> {
let raw = match self { let raw = match self {
ModelArch::Qwen3_5Dense(m) => { ModelArch::Qwen3_5Dense(m) => {
m.forward_with_vision(input, offset, image_embeds, image_token_id)? m.forward_with_vision(input, offset, image_embeds, image_token_id, grids)?
} }
other => anyhow::bail!( other => anyhow::bail!(
"forward_with_vision: architecture {} has no vision tower", "forward_with_vision: architecture {} has no vision tower",
@@ -407,6 +406,20 @@ impl ModelArch {
squeeze_to_vocab(&raw) squeeze_to_vocab(&raw)
} }
/// `patch_size × spatial_merge_size` for the loaded vision tower —
/// divides a resized pixel dim into LM-grid units (an image of
/// resized `(h, w)` yields the LM grid `(h/factor, w/factor)`).
/// `None` for architectures/checkpoints without a vision tower.
pub fn vision_grid_factor(&self) -> Option<usize> {
match self {
ModelArch::Qwen3_5Dense(m) => m.vision().map(|v| {
let c = v.config();
c.patch_size * c.spatial_merge_size
}),
_ => None,
}
}
/// Encode a preprocessed image into LM-side token embeddings via /// Encode a preprocessed image into LM-side token embeddings via
/// the loaded vision tower. Stage A5. /// the loaded vision tower. Stage A5.
/// ///
@@ -1683,10 +1696,10 @@ impl CandleHarness {
.ok_or_else(|| InferenceError::VisionUnsupported { .ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(), model_id: request.model.clone(),
})?; })?;
let patches_per_image = loaded let factor = loaded.image_grid_factor.ok_or_else(|| {
.lm_tokens_per_image InferenceError::VisionUnsupported {
.ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(), model_id: request.model.clone(),
}
})?; })?;
let profile = super::preprocess::PreprocessProfile::qwen3_6(); let profile = super::preprocess::PreprocessProfile::qwen3_6();
let images = extract_images_from_request(&request, &profile).map_err(|e| { let images = extract_images_from_request(&request, &profile).map_err(|e| {
@@ -1699,7 +1712,12 @@ impl CandleHarness {
"request has image content but extractor produced zero images" "request has image content but extractor produced zero images"
))); )));
} }
let per_image_counts: Vec<usize> = vec![patches_per_image; images.len()]; // Per-image LM token count from each image's resized grid
// (#14 dynamic resolution; was a constant 196).
let per_image_counts: Vec<usize> = images
.iter()
.map(|im| (im.h / factor) * (im.w / factor))
.collect();
prompt_tokens = prompt_tokens =
expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts) expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
.map_err(InferenceError::Other)?; .map_err(InferenceError::Other)?;
@@ -2059,10 +2077,11 @@ impl CandleHarness {
.ok_or_else(|| InferenceError::VisionUnsupported { .ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(), model_id: request.model.clone(),
})?; })?;
let patches_per_image = loaded.lm_tokens_per_image.ok_or_else(|| { let factor =
InferenceError::VisionUnsupported { loaded
.image_grid_factor
.ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(), model_id: request.model.clone(),
}
})?; })?;
let profile = super::preprocess::PreprocessProfile::qwen3_6(); let profile = super::preprocess::PreprocessProfile::qwen3_6();
let images = extract_images_from_request(&request, &profile) let images = extract_images_from_request(&request, &profile)
@@ -2072,7 +2091,11 @@ impl CandleHarness {
"request has image content but extractor produced zero images" "request has image content but extractor produced zero images"
))); )));
} }
let per_image_counts: Vec<usize> = vec![patches_per_image; images.len()]; // Per-image LM token count from each image's resized grid (#14).
let per_image_counts: Vec<usize> = images
.iter()
.map(|im| (im.h / factor) * (im.w / factor))
.collect();
prompt_tokens = prompt_tokens =
expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts) expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
.map_err(InferenceError::Other)?; .map_err(InferenceError::Other)?;
@@ -2526,7 +2549,7 @@ impl Harness for CandleHarness {
chat_template, chat_template,
has_vision: vision_meta.has_vision, has_vision: vision_meta.has_vision,
image_token_id: vision_meta.image_token_id, image_token_id: vision_meta.image_token_id,
lm_tokens_per_image: vision_meta.lm_tokens_per_image, image_grid_factor: vision_meta.image_grid_factor,
}); });
let mut models = self.models.write().await; let mut models = self.models.write().await;
@@ -2742,7 +2765,7 @@ impl CandleHarness {
tracing::info!( tracing::info!(
model = %spec.model_id, model = %spec.model_id,
image_token_id = ?vision_meta.image_token_id, image_token_id = ?vision_meta.image_token_id,
lm_tokens_per_image = ?vision_meta.lm_tokens_per_image, image_grid_factor = ?vision_meta.image_grid_factor,
"TP load: vision tower present, advertising vision capability" "TP load: vision tower present, advertising vision capability"
); );
} }
@@ -2764,7 +2787,7 @@ impl CandleHarness {
chat_template, chat_template,
has_vision: vision_meta.has_vision, has_vision: vision_meta.has_vision,
image_token_id: vision_meta.image_token_id, image_token_id: vision_meta.image_token_id,
lm_tokens_per_image: vision_meta.lm_tokens_per_image, image_grid_factor: vision_meta.image_grid_factor,
}); });
let mut models = self.models.write().await; let mut models = self.models.write().await;
@@ -2938,8 +2961,8 @@ impl CandleHarness {
.ok_or_else(|| InferenceError::VisionUnsupported { .ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(), model_id: request.model.clone(),
})?; })?;
let patches_per_image = let factor = tp
tp.lm_tokens_per_image .image_grid_factor
.ok_or_else(|| InferenceError::VisionUnsupported { .ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(), model_id: request.model.clone(),
})?; })?;
@@ -2949,7 +2972,21 @@ impl CandleHarness {
"request has image content but extractor produced zero data URIs" "request has image content but extractor produced zero data URIs"
))); )));
} }
let per_image_counts: Vec<usize> = vec![patches_per_image; data_uris.len()]; // Per-image LM token count from each image's resized grid (#14).
// Decode header + smart_resize only; the workers re-derive the
// same dims when they preprocess for the replicated tower.
let profile = super::preprocess::PreprocessProfile::qwen3_6();
let per_image_counts: Vec<usize> = data_uris
.iter()
.enumerate()
.map(|(i, uri)| {
let (h, w) =
super::preprocess::resized_dims_for_uri(uri, &profile).map_err(|e| {
InferenceError::Other(anyhow::anyhow!("resized_dims image #{i}: {e}"))
})?;
Ok::<usize, InferenceError>((h as usize / factor) * (w as usize / factor))
})
.collect::<Result<Vec<_>, _>>()?;
prompt_tokens = prompt_tokens =
expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts) expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
.map_err(InferenceError::Other)?; .map_err(InferenceError::Other)?;
@@ -3457,8 +3494,8 @@ async fn chat_completion_tp_inner(
.ok_or_else(|| InferenceError::VisionUnsupported { .ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(), model_id: request.model.clone(),
})?; })?;
let patches_per_image = let factor = tp
tp.lm_tokens_per_image .image_grid_factor
.ok_or_else(|| InferenceError::VisionUnsupported { .ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(), model_id: request.model.clone(),
})?; })?;
@@ -3468,7 +3505,19 @@ async fn chat_completion_tp_inner(
"request has image content but extractor produced zero data URIs" "request has image content but extractor produced zero data URIs"
))); )));
} }
let per_image_counts: Vec<usize> = vec![patches_per_image; data_uris.len()]; // Per-image LM token count from each image's resized grid (#14).
let profile = super::preprocess::PreprocessProfile::qwen3_6();
let per_image_counts: Vec<usize> = data_uris
.iter()
.enumerate()
.map(|(i, uri)| {
let (h, w) =
super::preprocess::resized_dims_for_uri(uri, &profile).map_err(|e| {
InferenceError::Other(anyhow::anyhow!("resized_dims image #{i}: {e}"))
})?;
Ok::<usize, InferenceError>((h as usize / factor) * (w as usize / factor))
})
.collect::<Result<Vec<_>, _>>()?;
prompt_tokens = expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts) prompt_tokens = expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
.map_err(InferenceError::Other)?; .map_err(InferenceError::Other)?;
Some((data_uris, image_token_id)) Some((data_uris, image_token_id))
@@ -3917,10 +3966,12 @@ fn build_prompt_for_request(
struct VisionMeta { struct VisionMeta {
has_vision: bool, has_vision: bool,
image_token_id: Option<u32>, image_token_id: Option<u32>,
/// LM-side tokens this model's vision tower emits per image at /// `patch_size × spatial_merge_size` — the divisor that turns a
/// the Stage B fixed `PreprocessProfile::qwen3_6()` resolution /// resized pixel dimension into an LM-grid dimension. An image of
/// (448×448). Equal to `(H/patch_size/spatial_merge_size)²`. /// resized `(h, w)` emits `(h/factor) × (w/factor)` LM tokens (#14
lm_tokens_per_image: Option<usize>, /// dynamic resolution; was a constant 196 at the old fixed 448²).
/// `None` for text-only models.
image_grid_factor: Option<usize>,
} }
impl VisionMeta { impl VisionMeta {
@@ -3949,22 +4000,18 @@ impl VisionMeta {
.get("image_token_id") .get("image_token_id")
.and_then(|x| x.as_u64()) .and_then(|x| x.as_u64())
.map(|n| n as u32); .map(|n| n as u32);
// Compute LM tokens per image at the Stage B fixed resolution // The pixel→LM-grid divisor. An image resized to (h, w) emits
// (PreprocessProfile::qwen3_6() → 448×448). One LM token per // (h/factor) × (w/factor) LM tokens — computed per image at
// spatial-merge group of patches. // request time now that resolution is dynamic (#14).
let target_h = super::preprocess::PreprocessProfile::qwen3_6().target_height as usize; let image_grid_factor = if patch_size > 0 && spatial_merge_size > 0 {
let target_w = super::preprocess::PreprocessProfile::qwen3_6().target_width as usize; Some(patch_size * spatial_merge_size)
let lm_tokens_per_image = if patch_size > 0 && spatial_merge_size > 0 {
let gh = target_h / patch_size / spatial_merge_size;
let gw = target_w / patch_size / spatial_merge_size;
Some(gh * gw)
} else { } else {
None None
}; };
Self { Self {
has_vision: true, has_vision: true,
image_token_id, image_token_id,
lm_tokens_per_image, image_grid_factor,
} }
} }
} }
@@ -4011,13 +4058,13 @@ fn extract_images_from_request(
.and_then(|v| v.get("url")) .and_then(|v| v.get("url"))
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("image_url part missing url field"))?; .ok_or_else(|| anyhow::anyhow!("image_url part missing url field"))?;
let pixels = super::preprocess::preprocess_data_uri(url, profile) let (pixels, h, w) = super::preprocess::preprocess_data_uri(url, profile)
.with_context(|| format!("preprocess image #{}", out.len()))?; .with_context(|| format!("preprocess image #{}", out.len()))?;
out.push(super::device_worker::jobs::ImageInput { out.push(super::device_worker::jobs::ImageInput {
pixels, pixels,
c: 3, c: 3,
h: profile.target_height as usize, h: h as usize,
w: profile.target_width as usize, w: w as usize,
}); });
} }
} }

View File

@@ -779,19 +779,17 @@ fn tp_forward_logits_with_images(
anyhow::bail!("TpForwardLogitsWithImages dispatched with zero images"); anyhow::bail!("TpForwardLogitsWithImages dispatched with zero images");
} }
// Preprocess every image into a device-resident (C, H, W) tensor. // Preprocess every image into a device-resident (C, H, W) tensor at
// Same fixed-resolution profile + decode path the subprocess workers // its native-aspect resized dims (#14). Same `smart_resize` + decode
// run, so the encoded embeddings match across ranks bit-for-bit. // path the subprocess workers run, so the encoded embeddings — and
// the per-image grids derived from these dims — match across ranks
// bit-for-bit.
let profile = PreprocessProfile::qwen3_6(); let profile = PreprocessProfile::qwen3_6();
let (h, w) = (
profile.target_height as usize,
profile.target_width as usize,
);
let mut pixels: Vec<Tensor> = Vec::with_capacity(image_data_uris.len()); let mut pixels: Vec<Tensor> = Vec::with_capacity(image_data_uris.len());
for (idx, uri) in image_data_uris.iter().enumerate() { for (idx, uri) in image_data_uris.iter().enumerate() {
let px = preprocess_data_uri(uri, &profile) let (px, h, w) = preprocess_data_uri(uri, &profile)
.with_context(|| format!("preprocess image[{idx}] (TP leader)"))?; .with_context(|| format!("preprocess image[{idx}] (TP leader)"))?;
let t = Tensor::from_vec(px, (3, h, w), &state.device)?; let t = Tensor::from_vec(px, (3, h as usize, w as usize), &state.device)?;
pixels.push(t); pixels.push(t);
} }
@@ -877,9 +875,17 @@ fn forward_logits_with_images(
anyhow::anyhow!("ForwardLogitsWithImages: no model for handle {}", handle.0) anyhow::anyhow!("ForwardLogitsWithImages: no model for handle {}", handle.0)
})?; })?;
// pixel→LM-grid divisor (patch×merge) for this tower; each image's
// LM grid is (h/factor, w/factor) (#14 dynamic resolution).
let factor = arch.vision_grid_factor().ok_or_else(|| {
anyhow::anyhow!("ForwardLogitsWithImages: loaded model has no vision tower")
})?;
// Encode every image on the worker's device, collecting per-image // Encode every image on the worker's device, collecting per-image
// post-merger embeddings as device-resident tensors. // post-merger embeddings as device-resident tensors plus their LM
// grids (for the interleaved-M-RoPE position ids).
let mut per_image: Vec<Tensor> = Vec::with_capacity(images.len()); let mut per_image: Vec<Tensor> = Vec::with_capacity(images.len());
let mut grids: Vec<(usize, usize)> = Vec::with_capacity(images.len());
for (idx, img) in images.into_iter().enumerate() { for (idx, img) in images.into_iter().enumerate() {
anyhow::ensure!( anyhow::ensure!(
img.pixels.len() == img.c * img.h * img.w, img.pixels.len() == img.c * img.h * img.w,
@@ -889,6 +895,7 @@ fn forward_logits_with_images(
img.h, img.h,
img.w, img.w,
); );
grids.push((img.h / factor, img.w / factor));
let image = Tensor::from_vec(img.pixels, (img.c, img.h, img.w), &state.device)?; let image = Tensor::from_vec(img.pixels, (img.c, img.h, img.w), &state.device)?;
let embed = arch let embed = arch
.encode_image(&image) .encode_image(&image)
@@ -901,7 +908,7 @@ fn forward_logits_with_images(
let image_embeds = Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)?; let image_embeds = Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)?;
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?; let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
let logits = arch.forward_with_vision(&input, offset, &image_embeds, image_token_id)?; let logits = arch.forward_with_vision(&input, offset, &image_embeds, image_token_id, &grids)?;
let values = logits let values = logits
.to_dtype(DType::F32)? .to_dtype(DType::F32)?
.flatten_all()? .flatten_all()?

View File

@@ -36,8 +36,13 @@ pub struct TpHandle(pub u64);
/// `Clone` so the vision-aware dispatch in `chat_completion` can /// `Clone` so the vision-aware dispatch in `chat_completion` can
/// match `&vision_route` (carrying borrowed images) and still hand /// match `&vision_route` (carrying borrowed images) and still hand
/// owned `Vec<ImageInput>` to the worker job. The clone cost is one /// owned `Vec<ImageInput>` to the worker job. The clone cost is one
/// pixel-buffer memcpy per image — fine at fixed-resolution sizes /// pixel-buffer memcpy per image — now variable with dynamic resolution
/// (3 × 448 × 448 × 4 bytes = ~2.4 MiB per image). /// (#14): `3 × h × w × 4` bytes, up to ~6.3 MiB at the default 1024²
/// `max_pixels` budget.
///
/// `h`/`w` are the **resized** dims (factor-aligned), so the per-image LM
/// grid is `(h/factor, w/factor)` — derived downstream for the splice
/// and the interleaved-M-RoPE position ids.
#[derive(Clone)] #[derive(Clone)]
pub struct ImageInput { pub struct ImageInput {
pub pixels: Vec<f32>, pub pixels: Vec<f32>,

View File

@@ -2,11 +2,11 @@
//! //!
//! Decodes `data:image/...;base64,...` URIs from OpenAI-style //! Decodes `data:image/...;base64,...` URIs from OpenAI-style
//! `image_url` content parts into the patch tensors a candle vision //! `image_url` content parts into the patch tensors a candle vision
//! tower expects. Stage A ships **fixed resolution** — every image //! tower expects. Resolution is **dynamic** (#14): each image is
//! is resized to the same target dimensions (default 448×448 for //! resized to its native aspect via Qwen `smart_resize` — a
//! Qwen3.6, configurable per-call) so the patch count is constant //! factor-aligned `(h, w)` whose pixel count lands in the profile's
//! per image. Variable resolution per [Qwen2VL convention] is tracked //! `[min_pixels, max_pixels]` budget — so the LM token count varies per
//! as issue #14. //! image (`(h/factor) × (w/factor)`).
//! //!
//! Spec reference: `doc/vision-qwen3_6-spec.md` — preprocessor //! Spec reference: `doc/vision-qwen3_6-spec.md` — preprocessor
//! section. //! section.
@@ -21,7 +21,7 @@
//! Pipeline (per image): //! Pipeline (per image):
//! 1. data: URI → base64 decode → bytes //! 1. data: URI → base64 decode → bytes
//! 2. bytes → image::DynamicImage (PNG/JPEG/WebP/etc) //! 2. bytes → image::DynamicImage (PNG/JPEG/WebP/etc)
//! 3. resize_exact to target H×W (pixel space) //! 3. smart_resize to a native-aspect, factor-aligned H×W (pixel space)
//! 4. RGB→f32, normalise per mean/std //! 4. RGB→f32, normalise per mean/std
//! 5. layout to (C, H, W) tensor //! 5. layout to (C, H, W) tensor
//! //!
@@ -34,39 +34,93 @@ use base64::Engine;
use image::DynamicImage; use image::DynamicImage;
use image::imageops::FilterType; use image::imageops::FilterType;
/// Preprocessing target. Captures the resize dimensions and the /// Preprocessing target. Captures the resize policy (Qwen `smart_resize`
/// channel-wise normalisation constants from the model's /// factor + pixel budget) and the channel-wise normalisation constants
/// `preprocessor_config.json`. Stage A ships a single `qwen3_6()` /// from the model's `preprocessor_config.json`. Images are resized to
/// constructor for fixed-resolution Qwen3.6 preprocessing; other /// their **native aspect** — a factor-aligned `(h, w)` whose pixel count
/// models can ship their own profile when added. /// lands in `[min_pixels, max_pixels]` — not a fixed square (#14).
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PreprocessProfile { pub struct PreprocessProfile {
pub target_height: u32, /// Both output dims are multiples of this. For Qwen3.6 it is
pub target_width: u32, /// `patch_size(16) × spatial_merge_size(2) = 32`, so the post-merge
/// LM grid is exactly `(h/factor, w/factor)`.
pub factor: u32,
/// Lower pixel bound — tiny images are upscaled to at least this.
pub min_pixels: u32,
/// Upper pixel bound — large images are downscaled to at most this.
/// Caps per-image LM tokens (`max_pixels / factor²`) and the
/// O(patches²) ViT attention cost.
pub max_pixels: u32,
pub image_mean: [f32; 3], pub image_mean: [f32; 3],
pub image_std: [f32; 3], pub image_std: [f32; 3],
} }
impl PreprocessProfile { impl PreprocessProfile {
/// Stage A profile for Qwen3.6. Resize to 448×448, normalise to /// Profile for Qwen3.6. Native-aspect `smart_resize` (factor 32),
/// `[-1, 1]` via mean=std=0.5. Fits within the model's /// normalise to `[-1, 1]` via mean=std=0.5. Pixel budget defaults:
/// `num_position_embeddings=2304` budget at 28×28 = 784 patches /// `min = 256² = 65536` (→ 8×8 = 64 LM tokens) and
/// before merging. /// `max = 1024² = 1048576` (→ 32×32 = 1024 LM tokens) — generous for
/// documents/OCR, bounded for serving on 2×RTX5090. (Operator
/// override lands with the `[harness.candle.vision]` config in #14 C5.)
pub fn qwen3_6() -> Self { pub fn qwen3_6() -> Self {
Self { Self {
target_height: 448, factor: 32,
target_width: 448, min_pixels: 65_536,
max_pixels: 1_048_576,
image_mean: [0.5, 0.5, 0.5], image_mean: [0.5, 0.5, 0.5],
image_std: [0.5, 0.5, 0.5], image_std: [0.5, 0.5, 0.5],
} }
} }
/// Per-channel CHW tensor length: 3 * H * W. /// The factor-aligned `(h, w)` this profile would resize a source
pub fn pixels_chw(&self) -> usize { /// `src_h × src_w` image to. Pure integer policy — no pixel work.
3 * (self.target_height as usize) * (self.target_width as usize) pub fn resized_dims(&self, src_h: u32, src_w: u32) -> Result<(u32, u32)> {
smart_resize(src_h, src_w, self.factor, self.min_pixels, self.max_pixels)
} }
} }
/// Qwen `smart_resize`: the smallest `factor`-aligned `(h_bar, w_bar)`
/// that preserves aspect ratio as closely as possible while keeping the
/// pixel count within `[min_pixels, max_pixels]`. Direct port of the
/// canonical Qwen2-VL / Qwen3-VL image-processor function (so neuron's
/// grid matches what the model was trained on).
///
/// Returns `(height, width)`. Errors if the aspect ratio exceeds 200:1
/// (degenerate input — a 1-pixel-tall strip), matching upstream.
pub fn smart_resize(
height: u32,
width: u32,
factor: u32,
min_pixels: u32,
max_pixels: u32,
) -> Result<(u32, u32)> {
let h = height.max(1) as f64;
let w = width.max(1) as f64;
let ratio = h.max(w) / h.min(w);
if ratio > 200.0 {
anyhow::bail!(
"image aspect ratio {ratio:.1}:1 exceeds the 200:1 limit ({height}×{width}); \
refusing to resize"
);
}
let f = factor as f64;
let (minp, maxp) = (min_pixels as f64, max_pixels as f64);
// round-to-nearest-factor (may be 0 for sub-factor inputs; the
// min-pixels branch below grows it back up).
let mut h_bar = (h / f).round() * f;
let mut w_bar = (w / f).round() * f;
if h_bar * w_bar > maxp {
let beta = (h * w / maxp).sqrt();
h_bar = f.max((h / beta / f).floor() * f);
w_bar = f.max((w / beta / f).floor() * f);
} else if h_bar * w_bar < minp {
let beta = (minp / (h * w)).sqrt();
h_bar = (h * beta / f).ceil() * f;
w_bar = (w * beta / f).ceil() * f;
}
Ok((h_bar as u32, w_bar as u32))
}
/// Decode a `data:image/...;base64,...` URI into an in-memory image. /// Decode a `data:image/...;base64,...` URI into an in-memory image.
/// ///
/// Accepts the OpenAI Chat Completions `image_url` shape — a string /// Accepts the OpenAI Chat Completions `image_url` shape — a string
@@ -106,16 +160,13 @@ pub fn decode_data_uri(uri: &str) -> Result<DynamicImage> {
/// faster on CPU. Quality difference is marginal for downstream /// faster on CPU. Quality difference is marginal for downstream
/// vision-encoder consumption. The numerical-validation issue (#15) /// vision-encoder consumption. The numerical-validation issue (#15)
/// will quantify any discrepancy. /// will quantify any discrepancy.
pub fn preprocess(img: &DynamicImage, profile: &PreprocessProfile) -> Vec<f32> { pub fn preprocess(img: &DynamicImage, profile: &PreprocessProfile) -> Result<(Vec<f32>, u32, u32)> {
let (h_bar, w_bar) = profile.resized_dims(img.height(), img.width())?;
let rgb = img let rgb = img
.resize_exact( .resize_exact(w_bar, h_bar, FilterType::Triangle)
profile.target_width,
profile.target_height,
FilterType::Triangle,
)
.to_rgb8(); .to_rgb8();
let h = profile.target_height as usize; let h = h_bar as usize;
let w = profile.target_width as usize; let w = w_bar as usize;
let mut out = vec![0.0_f32; 3 * h * w]; let mut out = vec![0.0_f32; 3 * h * w];
// Row-major (C, H, W). Candle's Conv2d expects NCHW, so this is // Row-major (C, H, W). Candle's Conv2d expects NCHW, so this is
// the natural layout — the caller stacks `n` of these along the // the natural layout — the caller stacks `n` of these along the
@@ -131,16 +182,27 @@ pub fn preprocess(img: &DynamicImage, profile: &PreprocessProfile) -> Vec<f32> {
} }
} }
} }
out Ok((out, h_bar, w_bar))
} }
/// Combined helper: decode + preprocess in one call. Most call /// Combined helper: decode + preprocess in one call. Returns the
/// sites just want the final tensor; the two-step path exists for /// `(3, h, w)` row-major pixels plus the resized `(h, w)` — the caller
/// callers (tests, future video preprocessing) that need the /// needs the dims to build the tensor and to derive the LM token grid
/// `(h/factor, w/factor)`. Most call sites use this; the two-step path
/// exists for callers (tests, future video preprocessing) that need the
/// intermediate `DynamicImage`. /// intermediate `DynamicImage`.
pub fn preprocess_data_uri(uri: &str, profile: &PreprocessProfile) -> Result<Vec<f32>> { pub fn preprocess_data_uri(uri: &str, profile: &PreprocessProfile) -> Result<(Vec<f32>, u32, u32)> {
let img = decode_data_uri(uri)?; let img = decode_data_uri(uri)?;
Ok(preprocess(&img, profile)) preprocess(&img, profile)
}
/// Resized `(h, w)` for a data-URI image **without** running the pixel
/// normalisation — decode header + `smart_resize` only. Lets a caller
/// that just needs the LM token count (e.g. the TP leader expanding the
/// prompt) avoid materialising the full pixel tensor twice.
pub fn resized_dims_for_uri(uri: &str, profile: &PreprocessProfile) -> Result<(u32, u32)> {
let img = decode_data_uri(uri)?;
profile.resized_dims(img.height(), img.width())
} }
#[cfg(test)] #[cfg(test)]
@@ -205,13 +267,17 @@ mod tests {
// decoding so this test isolates the resize+normalise path. // decoding so this test isolates the resize+normalise path.
let img: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_pixel(2, 2, Rgb([255, 0, 0])); let img: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_pixel(2, 2, Rgb([255, 0, 0]));
let dyn_img = DynamicImage::ImageRgb8(img); let dyn_img = DynamicImage::ImageRgb8(img);
let out = preprocess(&dyn_img, &profile); let (out, h_bar, w_bar) = preprocess(&dyn_img, &profile).expect("preprocess");
assert_eq!(out.len(), profile.pixels_chw()); let h = h_bar as usize;
let w = w_bar as usize;
assert_eq!(out.len(), 3 * h * w);
// Dims are factor-aligned and at least the min-pixel floor.
assert_eq!(h_bar % profile.factor, 0);
assert_eq!(w_bar % profile.factor, 0);
assert!(h * w >= profile.min_pixels as usize);
// After mean=0.5, std=0.5: red channel (255/255=1.0) → (1.0 - 0.5)/0.5 = 1.0 // After mean=0.5, std=0.5: red channel (255/255=1.0) → (1.0 - 0.5)/0.5 = 1.0
// green/blue (0.0) → (0.0 - 0.5)/0.5 = -1.0 // green/blue (0.0) → (0.0 - 0.5)/0.5 = -1.0
let h = profile.target_height as usize;
let w = profile.target_width as usize;
assert!( assert!(
(out[0] - 1.0).abs() < 1e-5, (out[0] - 1.0).abs() < 1e-5,
"R[0] should be 1.0, got {}", "R[0] should be 1.0, got {}",
@@ -229,9 +295,12 @@ mod tests {
#[test] #[test]
fn preprocess_data_uri_end_to_end() { fn preprocess_data_uri_end_to_end() {
let profile = PreprocessProfile::qwen3_6(); let profile = PreprocessProfile::qwen3_6();
let out = preprocess_data_uri(&red_png_uri(), &profile).expect("e2e preprocess"); let (out, h, w) = preprocess_data_uri(&red_png_uri(), &profile).expect("e2e preprocess");
assert_eq!(out.len(), profile.pixels_chw()); assert_eq!(out.len(), 3 * h as usize * w as usize);
assert!(out.iter().all(|v| v.is_finite())); assert!(out.iter().all(|v| v.is_finite()));
// resized_dims_for_uri agrees with the full preprocess.
let (h2, w2) = resized_dims_for_uri(&red_png_uri(), &profile).expect("dims");
assert_eq!((h, w), (h2, w2));
} }
#[test] #[test]
@@ -240,10 +309,10 @@ mod tests {
// 1x1 grayscale = 200 → after conversion to RGB, all three // 1x1 grayscale = 200 → after conversion to RGB, all three
// channels equal 200, normalised → (200/255 - 0.5)/0.5 ≈ 0.569 // channels equal 200, normalised → (200/255 - 0.5)/0.5 ≈ 0.569
let gray = DynamicImage::ImageLuma8(ImageBuffer::from_pixel(1, 1, image::Luma([200]))); let gray = DynamicImage::ImageLuma8(ImageBuffer::from_pixel(1, 1, image::Luma([200])));
let out = preprocess(&gray, &profile); let (out, h_bar, w_bar) = preprocess(&gray, &profile).expect("preprocess");
let expected = ((200.0 / 255.0) - 0.5) / 0.5; let expected = ((200.0 / 255.0) - 0.5) / 0.5;
let h = profile.target_height as usize; let h = h_bar as usize;
let w = profile.target_width as usize; let w = w_bar as usize;
for c in 0..3 { for c in 0..3 {
let v = out[c * h * w]; let v = out[c * h * w];
assert!( assert!(
@@ -252,4 +321,52 @@ mod tests {
); );
} }
} }
#[test]
fn smart_resize_keeps_factor_aligned_square_in_budget() {
// 448×448 sits inside [65536, 1048576] and is factor-aligned →
// unchanged. (Regression guard for the old fixed-res sweet spot.)
let (h, w) = smart_resize(448, 448, 32, 65_536, 1_048_576).unwrap();
assert_eq!((h, w), (448, 448));
}
#[test]
fn smart_resize_preserves_aspect_and_caps_at_max() {
// 3000×4000 (landscape) → downscaled under max_pixels, aspect kept.
let (h, w) = smart_resize(3000, 4000, 32, 65_536, 1_048_576).unwrap();
assert_eq!(h % 32, 0);
assert_eq!(w % 32, 0);
assert!(
(h as u64) * (w as u64) <= 1_048_576,
"must respect max_pixels"
);
assert!(w > h, "landscape orientation preserved");
// aspect ≈ 4000/3000 = 1.333; allow a factor-rounding tolerance.
let ar = w as f64 / h as f64;
assert!((ar - 4.0 / 3.0).abs() < 0.15, "aspect ~4:3, got {ar:.3}");
}
#[test]
fn smart_resize_floors_tiny_image_at_min() {
// 16×16 → upscaled to at least min_pixels, factor-aligned.
let (h, w) = smart_resize(16, 16, 32, 65_536, 1_048_576).unwrap();
assert_eq!(h % 32, 0);
assert_eq!(w % 32, 0);
assert!((h as u64) * (w as u64) >= 65_536, "must respect min_pixels");
}
#[test]
fn smart_resize_tall_nonsquare_stays_nonsquare() {
// A tall screenshot keeps portrait orientation.
let (h, w) = smart_resize(2000, 500, 32, 65_536, 1_048_576).unwrap();
assert!(h > w, "portrait orientation preserved");
assert_eq!(h % 32, 0);
assert_eq!(w % 32, 0);
}
#[test]
fn smart_resize_rejects_extreme_aspect() {
let err = smart_resize(1, 500, 32, 65_536, 1_048_576).unwrap_err();
assert!(format!("{err:#}").contains("200:1"));
}
} }

View File

@@ -1288,15 +1288,39 @@ impl TpQwen3_5ForCausalLM {
let device = self.device().clone(); let device = self.device().clone();
let image_embeds = self.encode_images_concat(image_pixels)?; let image_embeds = self.encode_images_concat(image_pixels)?;
// Each image's LM grid (lm_gh, lm_gw) = (h/factor, w/factor),
// factor = patch×merge. Recomputed per rank from this rank's own
// pixel tensors — deterministic, so every rank's grids (and hence
// M-RoPE positions) match without crossing the RPC (#14).
let factor = self
.vision
.as_ref()
.map(|v| {
let c = v.config();
c.patch_size * c.spatial_merge_size
})
.ok_or_else(|| {
candle_core::Error::Msg(
"prefill_with_images_chunked: loaded without a vision tower".into(),
)
})?;
let grids: Vec<(usize, usize)> = image_pixels
.iter()
.map(|t| {
let (_, h, w) = t.dims3()?;
Ok::<(usize, usize), candle_core::Error>((h / factor, w / factor))
})
.collect::<candle_core::Result<Vec<_>>>()?;
// Interleaved-M-RoPE 3D position ids for the whole prompt, // Interleaved-M-RoPE 3D position ids for the whole prompt,
// computed once and sliced per chunk so every rank assigns image // computed once and sliced per chunk so every rank assigns image
// tokens their 14×14 grid coordinates (and text after the image // tokens their grid coordinates (and text after an image resumes
// resumes from the compressed counter). `rope_delta` is stored on // from the compressed counter). `rope_delta` is stored on the base
// the base model for the decode that follows this prefill. Every // model for the decode that follows this prefill. Every chunk —
// chunk — text or image — uses the M-RoPE slice, because the image // text or image — uses the M-RoPE slice, because each image shifts
// shifts the positions of the text around it. // the positions of the text around it.
let (text, height, width, delta) = let (text, height, width, delta) =
crate::harness::arch::qwen3_5::rope::get_rope_index(tokens, image_token_id) crate::harness::arch::qwen3_5::rope::get_rope_index(tokens, image_token_id, &grids)
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?; .map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
self.base.set_rope_delta(delta); self.base.set_rope_delta(delta);
let full_pos = crate::harness::arch::qwen3_5::rope::mrope_position_tensor( let full_pos = crate::harness::arch::qwen3_5::rope::mrope_position_tensor(

View File

@@ -494,16 +494,13 @@ impl WorkerState {
let device = model.device().clone(); let device = model.device().clone();
// Preprocess each image identically to the leader so the encoded // Preprocess each image identically to the leader so the encoded
// embeddings — and thus the spliced hidden state — match across // embeddings — and thus the spliced hidden state and per-image
// ranks. Fixed 448×448 profile. // grids — match across ranks. Native-aspect `smart_resize` (#14);
// deterministic, so each rank derives the same dims.
let profile = PreprocessProfile::qwen3_6(); let profile = PreprocessProfile::qwen3_6();
let (h, w) = (
profile.target_height as usize,
profile.target_width as usize,
);
let mut pixels: Vec<Tensor> = Vec::with_capacity(image_data_uris.len()); let mut pixels: Vec<Tensor> = Vec::with_capacity(image_data_uris.len());
for (idx, uri) in image_data_uris.iter().enumerate() { for (idx, uri) in image_data_uris.iter().enumerate() {
let px = match preprocess_data_uri(uri, &profile) { let (px, h, w) = match preprocess_data_uri(uri, &profile) {
Ok(p) => p, Ok(p) => p,
Err(e) => { Err(e) => {
return WorkerResponse::Error { return WorkerResponse::Error {
@@ -512,7 +509,7 @@ impl WorkerState {
}; };
} }
}; };
match Tensor::from_vec(px, (3, h, w), &device) { match Tensor::from_vec(px, (3, h as usize, w as usize), &device) {
Ok(t) => pixels.push(t), Ok(t) => pixels.push(t),
Err(e) => { Err(e) => {
return WorkerResponse::Error { return WorkerResponse::Error {