feat(neuron): dynamic-resolution images via Qwen smart_resize (#14)
Some checks failed
CI / Clippy (push) Waiting to run
CI / Test (push) Waiting to run
CI / CUDA type-check (push) Successful in 32s
CI / Format (push) Successful in 34s
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
Some checks failed
CI / Clippy (push) Waiting to run
CI / Test (push) Waiting to run
CI / CUDA type-check (push) Successful in 32s
CI / Format (push) Successful in 34s
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
Replace the fixed 448×448-square preprocess with native-aspect `smart_resize`, and thread the resulting per-image grid through the LM so spatial structure survives non-square images (documents, screenshots, charts, panoramas, OCR) instead of being squished into a square. - preprocess.rs: port Qwen `smart_resize` (factor = patch×merge = 32; pixel budget [min,max], default 256²–1024² → 64–1024 LM tokens). `PreprocessProfile` drops the fixed target dims for `factor`/`min_pixels`/ `max_pixels`; `preprocess`/`preprocess_data_uri` now return the resized `(h, w)`; add `resized_dims_for_uri` (decode + resize, no normalize) for the TP leader's token count. - rope.rs: `compute_mrope_index`/`get_rope_index` take per-image `grids: &[(lm_gh, lm_gw)]` instead of assuming a square `isqrt(run)`. Walk image runs in order, validate `run == gh*gw`, emit row-major positions, resume the shared counter at `base + max(gh,gw)`. Correct for multiple images of differing grids interleaved with text. - candle.rs: `VisionMeta`/`LoadedModel`/`TpLoadedModel` carry the `image_grid_factor` (patch×merge) instead of the constant 196; all four prompt-build sites compute per-image counts from each image's resized grid (single-GPU from the extracted `ImageInput.h/w`, TP from `resized_dims_for_uri`). `ModelArch` gains `vision_grid_factor`. - single-GPU (`mod.rs`, `dispatch.rs`) and TP (`tp_qwen3_5.rs::prefill_with_images_chunked`, `dispatch.rs`, `tp/worker.rs`) thread the grids into `get_rope_index`. Each TP rank recomputes grids from its own deterministic preprocess — no rpc.rs change, single source of truth. The vision tower itself was already grid-general (recent pos-embed interpolation + 2D rotary fix). No patch-count cap: pos-embed is interpolated to any grid; `max_pixels` bounds cost (O(patches²) ViT attention + prefill) instead. Tests: smart_resize (aspect/cap/floor/reject), `compute_mrope_index` non-square + two-image + mismatch cases, square-grid regression guard. Non-cuda build + clippy + full workspace tests green; TP load/dispatch paths are cuda-gated → Gitea CUDA type-check. Operator pixel-budget config + remaining doc cleanup follow in C5. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -210,13 +210,11 @@ pub struct LoadedModel {
|
||||
/// targets and the worker forward uses it to locate splice
|
||||
/// positions in the LM input embeddings.
|
||||
pub image_token_id: Option<u32>,
|
||||
/// LM-side tokens this model's vision tower emits per image at
|
||||
/// the Stage B fixed resolution (448×448 → 196 for Qwen3.6).
|
||||
/// `None` for text-only models. Set at load time so the
|
||||
/// hot path doesn't recompute it per request. Stage B fixed
|
||||
/// resolution → constant; dynamic resolution per #14 makes it
|
||||
/// per-image.
|
||||
pub lm_tokens_per_image: Option<usize>,
|
||||
/// `patch_size × spatial_merge_size` — divides a resized pixel
|
||||
/// dimension into LM-grid units. Per-image LM token count is
|
||||
/// `(h/factor) × (w/factor)` (#14 dynamic resolution). `None` for
|
||||
/// text-only models. Set at load time.
|
||||
pub image_grid_factor: Option<usize>,
|
||||
}
|
||||
|
||||
impl LoadedModel {
|
||||
@@ -288,9 +286,9 @@ pub struct TpLoadedModel {
|
||||
pub has_vision: bool,
|
||||
/// `<|image_pad|>` token id — same as [`LoadedModel::image_token_id`].
|
||||
pub image_token_id: Option<u32>,
|
||||
/// LM-side tokens per image at the fixed 448×448 resolution — same
|
||||
/// as [`LoadedModel::lm_tokens_per_image`].
|
||||
pub lm_tokens_per_image: Option<usize>,
|
||||
/// Pixel→LM-grid divisor — same as
|
||||
/// [`LoadedModel::image_grid_factor`].
|
||||
pub image_grid_factor: Option<usize>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
@@ -394,10 +392,11 @@ impl ModelArch {
|
||||
offset: usize,
|
||||
image_embeds: &Tensor,
|
||||
image_token_id: u32,
|
||||
grids: &[(usize, usize)],
|
||||
) -> Result<Tensor> {
|
||||
let raw = match self {
|
||||
ModelArch::Qwen3_5Dense(m) => {
|
||||
m.forward_with_vision(input, offset, image_embeds, image_token_id)?
|
||||
m.forward_with_vision(input, offset, image_embeds, image_token_id, grids)?
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"forward_with_vision: architecture {} has no vision tower",
|
||||
@@ -407,6 +406,20 @@ impl ModelArch {
|
||||
squeeze_to_vocab(&raw)
|
||||
}
|
||||
|
||||
/// `patch_size × spatial_merge_size` for the loaded vision tower —
|
||||
/// divides a resized pixel dim into LM-grid units (an image of
|
||||
/// resized `(h, w)` yields the LM grid `(h/factor, w/factor)`).
|
||||
/// `None` for architectures/checkpoints without a vision tower.
|
||||
pub fn vision_grid_factor(&self) -> Option<usize> {
|
||||
match self {
|
||||
ModelArch::Qwen3_5Dense(m) => m.vision().map(|v| {
|
||||
let c = v.config();
|
||||
c.patch_size * c.spatial_merge_size
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode a preprocessed image into LM-side token embeddings via
|
||||
/// the loaded vision tower. Stage A5.
|
||||
///
|
||||
@@ -1683,11 +1696,11 @@ impl CandleHarness {
|
||||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
})?;
|
||||
let patches_per_image = loaded
|
||||
.lm_tokens_per_image
|
||||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||||
let factor = loaded.image_grid_factor.ok_or_else(|| {
|
||||
InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
})?;
|
||||
}
|
||||
})?;
|
||||
let profile = super::preprocess::PreprocessProfile::qwen3_6();
|
||||
let images = extract_images_from_request(&request, &profile).map_err(|e| {
|
||||
InferenceError::Other(anyhow::anyhow!("extract_images: {e}"))
|
||||
@@ -1699,7 +1712,12 @@ impl CandleHarness {
|
||||
"request has image content but extractor produced zero images"
|
||||
)));
|
||||
}
|
||||
let per_image_counts: Vec<usize> = vec![patches_per_image; images.len()];
|
||||
// Per-image LM token count from each image's resized grid
|
||||
// (#14 dynamic resolution; was a constant 196).
|
||||
let per_image_counts: Vec<usize> = images
|
||||
.iter()
|
||||
.map(|im| (im.h / factor) * (im.w / factor))
|
||||
.collect();
|
||||
prompt_tokens =
|
||||
expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
|
||||
.map_err(InferenceError::Other)?;
|
||||
@@ -2059,11 +2077,12 @@ impl CandleHarness {
|
||||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
})?;
|
||||
let patches_per_image = loaded.lm_tokens_per_image.ok_or_else(|| {
|
||||
InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
}
|
||||
})?;
|
||||
let factor =
|
||||
loaded
|
||||
.image_grid_factor
|
||||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
})?;
|
||||
let profile = super::preprocess::PreprocessProfile::qwen3_6();
|
||||
let images = extract_images_from_request(&request, &profile)
|
||||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("extract_images: {e}")))?;
|
||||
@@ -2072,7 +2091,11 @@ impl CandleHarness {
|
||||
"request has image content but extractor produced zero images"
|
||||
)));
|
||||
}
|
||||
let per_image_counts: Vec<usize> = vec![patches_per_image; images.len()];
|
||||
// Per-image LM token count from each image's resized grid (#14).
|
||||
let per_image_counts: Vec<usize> = images
|
||||
.iter()
|
||||
.map(|im| (im.h / factor) * (im.w / factor))
|
||||
.collect();
|
||||
prompt_tokens =
|
||||
expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
|
||||
.map_err(InferenceError::Other)?;
|
||||
@@ -2526,7 +2549,7 @@ impl Harness for CandleHarness {
|
||||
chat_template,
|
||||
has_vision: vision_meta.has_vision,
|
||||
image_token_id: vision_meta.image_token_id,
|
||||
lm_tokens_per_image: vision_meta.lm_tokens_per_image,
|
||||
image_grid_factor: vision_meta.image_grid_factor,
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
@@ -2742,7 +2765,7 @@ impl CandleHarness {
|
||||
tracing::info!(
|
||||
model = %spec.model_id,
|
||||
image_token_id = ?vision_meta.image_token_id,
|
||||
lm_tokens_per_image = ?vision_meta.lm_tokens_per_image,
|
||||
image_grid_factor = ?vision_meta.image_grid_factor,
|
||||
"TP load: vision tower present, advertising vision capability"
|
||||
);
|
||||
}
|
||||
@@ -2764,7 +2787,7 @@ impl CandleHarness {
|
||||
chat_template,
|
||||
has_vision: vision_meta.has_vision,
|
||||
image_token_id: vision_meta.image_token_id,
|
||||
lm_tokens_per_image: vision_meta.lm_tokens_per_image,
|
||||
image_grid_factor: vision_meta.image_grid_factor,
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
@@ -2938,18 +2961,32 @@ impl CandleHarness {
|
||||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
})?;
|
||||
let patches_per_image =
|
||||
tp.lm_tokens_per_image
|
||||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
})?;
|
||||
let factor = tp
|
||||
.image_grid_factor
|
||||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
})?;
|
||||
let data_uris = extract_image_data_uris(&request);
|
||||
if data_uris.is_empty() {
|
||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||
"request has image content but extractor produced zero data URIs"
|
||||
)));
|
||||
}
|
||||
let per_image_counts: Vec<usize> = vec![patches_per_image; data_uris.len()];
|
||||
// Per-image LM token count from each image's resized grid (#14).
|
||||
// Decode header + smart_resize only; the workers re-derive the
|
||||
// same dims when they preprocess for the replicated tower.
|
||||
let profile = super::preprocess::PreprocessProfile::qwen3_6();
|
||||
let per_image_counts: Vec<usize> = data_uris
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, uri)| {
|
||||
let (h, w) =
|
||||
super::preprocess::resized_dims_for_uri(uri, &profile).map_err(|e| {
|
||||
InferenceError::Other(anyhow::anyhow!("resized_dims image #{i}: {e}"))
|
||||
})?;
|
||||
Ok::<usize, InferenceError>((h as usize / factor) * (w as usize / factor))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
prompt_tokens =
|
||||
expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
|
||||
.map_err(InferenceError::Other)?;
|
||||
@@ -3457,18 +3494,30 @@ async fn chat_completion_tp_inner(
|
||||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
})?;
|
||||
let patches_per_image =
|
||||
tp.lm_tokens_per_image
|
||||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
})?;
|
||||
let factor = tp
|
||||
.image_grid_factor
|
||||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||||
model_id: request.model.clone(),
|
||||
})?;
|
||||
let data_uris = extract_image_data_uris(&request);
|
||||
if data_uris.is_empty() {
|
||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||
"request has image content but extractor produced zero data URIs"
|
||||
)));
|
||||
}
|
||||
let per_image_counts: Vec<usize> = vec![patches_per_image; data_uris.len()];
|
||||
// Per-image LM token count from each image's resized grid (#14).
|
||||
let profile = super::preprocess::PreprocessProfile::qwen3_6();
|
||||
let per_image_counts: Vec<usize> = data_uris
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, uri)| {
|
||||
let (h, w) =
|
||||
super::preprocess::resized_dims_for_uri(uri, &profile).map_err(|e| {
|
||||
InferenceError::Other(anyhow::anyhow!("resized_dims image #{i}: {e}"))
|
||||
})?;
|
||||
Ok::<usize, InferenceError>((h as usize / factor) * (w as usize / factor))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
prompt_tokens = expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
|
||||
.map_err(InferenceError::Other)?;
|
||||
Some((data_uris, image_token_id))
|
||||
@@ -3917,10 +3966,12 @@ fn build_prompt_for_request(
|
||||
struct VisionMeta {
|
||||
has_vision: bool,
|
||||
image_token_id: Option<u32>,
|
||||
/// LM-side tokens this model's vision tower emits per image at
|
||||
/// the Stage B fixed `PreprocessProfile::qwen3_6()` resolution
|
||||
/// (448×448). Equal to `(H/patch_size/spatial_merge_size)²`.
|
||||
lm_tokens_per_image: Option<usize>,
|
||||
/// `patch_size × spatial_merge_size` — the divisor that turns a
|
||||
/// resized pixel dimension into an LM-grid dimension. An image of
|
||||
/// resized `(h, w)` emits `(h/factor) × (w/factor)` LM tokens (#14
|
||||
/// dynamic resolution; was a constant 196 at the old fixed 448²).
|
||||
/// `None` for text-only models.
|
||||
image_grid_factor: Option<usize>,
|
||||
}
|
||||
|
||||
impl VisionMeta {
|
||||
@@ -3949,22 +4000,18 @@ impl VisionMeta {
|
||||
.get("image_token_id")
|
||||
.and_then(|x| x.as_u64())
|
||||
.map(|n| n as u32);
|
||||
// Compute LM tokens per image at the Stage B fixed resolution
|
||||
// (PreprocessProfile::qwen3_6() → 448×448). One LM token per
|
||||
// spatial-merge group of patches.
|
||||
let target_h = super::preprocess::PreprocessProfile::qwen3_6().target_height as usize;
|
||||
let target_w = super::preprocess::PreprocessProfile::qwen3_6().target_width as usize;
|
||||
let lm_tokens_per_image = if patch_size > 0 && spatial_merge_size > 0 {
|
||||
let gh = target_h / patch_size / spatial_merge_size;
|
||||
let gw = target_w / patch_size / spatial_merge_size;
|
||||
Some(gh * gw)
|
||||
// The pixel→LM-grid divisor. An image resized to (h, w) emits
|
||||
// (h/factor) × (w/factor) LM tokens — computed per image at
|
||||
// request time now that resolution is dynamic (#14).
|
||||
let image_grid_factor = if patch_size > 0 && spatial_merge_size > 0 {
|
||||
Some(patch_size * spatial_merge_size)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self {
|
||||
has_vision: true,
|
||||
image_token_id,
|
||||
lm_tokens_per_image,
|
||||
image_grid_factor,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4011,13 +4058,13 @@ fn extract_images_from_request(
|
||||
.and_then(|v| v.get("url"))
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("image_url part missing url field"))?;
|
||||
let pixels = super::preprocess::preprocess_data_uri(url, profile)
|
||||
let (pixels, h, w) = super::preprocess::preprocess_data_uri(url, profile)
|
||||
.with_context(|| format!("preprocess image #{}", out.len()))?;
|
||||
out.push(super::device_worker::jobs::ImageInput {
|
||||
pixels,
|
||||
c: 3,
|
||||
h: profile.target_height as usize,
|
||||
w: profile.target_width as usize,
|
||||
h: h as usize,
|
||||
w: w as usize,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user