fix(neuron): chunked TP-vision prefill + pre-flight VRAM guard
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 29s
build-prerelease / Build cortex binary (push) Successful in 4m26s
build-prerelease / Package cortex RPM (push) Successful in 1m18s
build-prerelease / Build neuron-blackwell (push) Successful in 6m6s
build-prerelease / Build neuron-ampere (push) Successful in 8m30s
CI / Format (push) Successful in 38s
CI / CUDA type-check (push) Successful in 47s
CI / Clippy (push) Successful in 2m36s
build-prerelease / Build neuron-ada (push) Successful in 5m19s
CI / Test (push) Successful in 6m3s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m32s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m47s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 59s

agent-0 sent a ~13k-token prompt + image; the TP vision prefill was
single-shot, so it tried to materialise activations for all 12,960
positions at once and OOM'd rank 1 mid-forward. Rank 1 died before
issuing its row-parallel AllReduce, stranding rank 0 on the collective
(it hung holding the pool lock). The text path survives the same size
because it chunks the prefill.

Chunk the vision prefill the same way:

- TpQwen3_5ForCausalLM::prefill_with_images_chunked encodes the image(s)
  once, then walks the pre-expanded prompt in prefill_chunk_tokens()
  windows, splicing the patch-embedding rows into whichever chunk(s)
  carry <|image_pad|> positions (pure-text chunks take the plain
  forward). Activation is bounded by the chunk, not the prompt.
- Every rank runs the identical chunk sequence (chunk_size threaded
  through GenerateStepWithImages / TpForwardLogitsWithImages /
  generate_step_with_images), so the per-chunk AllReduces stay paired
  across ranks with no extra sync — the KV cache accumulates via the
  growing offset, only the last chunk's logits are kept.

Pre-flight guard (validate_vision_prefill): even chunked, a long
prompt's KV cache can exhaust VRAM mid-forward, and on TP that hangs
the collective. Reject up front with a clean InsufficientVram when the
estimated footprint exceeds free VRAM, so a doomed request fails fast
instead of hanging the daemon. Heuristic + tunable
(NEURON_VISION_PREFILL_MB_PER_1K_TOKENS / _BASE_MB); default permissive
so the now-working 12,960-token case still passes. Applied to every
vision path (single-GPU + TP); single-GPU vision stays single-shot for
now, so the guard is its protection until it's chunked too.

Tests: pre-flight guard behaviour; RPC round-trip carries chunk_size.
The chunked forward is cuda-gated — CI CUDA type-check validates it.

Refs #16 / TP-vision. Operational note: a TP rank OOM still hangs the
daemon (needs restart); making a worker failure abort the leader's
collective is separate, broader TP hardening.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-04 17:21:36 +03:00
parent c8bcaabc38
commit fa013505d1
8 changed files with 209 additions and 52 deletions

View File

@@ -871,6 +871,45 @@ fn min_free_vram_mb() -> u64 {
/// prefill. Called from every chat_completion entry point right after /// prefill. Called from every chat_completion entry point right after
/// the VRAM query. A `prompt_len == 0` is accepted (some clients send /// the VRAM query. A `prompt_len == 0` is accepted (some clients send
/// empty inputs to probe the endpoint); the prefill loop handles it. /// empty inputs to probe the endpoint); the prefill loop handles it.
/// Rough MiB of VRAM a vision prefill needs per 1000 prompt tokens
/// (accumulating KV cache + per-chunk activation headroom). Tunable;
/// the default is deliberately permissive so the guard rejects only
/// clearly-too-large requests, not ones the chunked prefill handles.
fn vision_prefill_mb_per_1k_tokens() -> u64 {
env_u64("NEURON_VISION_PREFILL_MB_PER_1K_TOKENS", 500)
}
/// Fixed VRAM overhead (MiB) a vision prefill reserves on top of the
/// per-token estimate — image encode buffers + one chunk's activations.
fn vision_prefill_base_mb() -> u64 {
env_u64("NEURON_VISION_PREFILL_BASE_MB", 2000)
}
/// Pre-flight check specific to vision prefills. Even with the chunked
/// prefill bounding per-step activation, the accumulating KV cache for
/// a long prompt can exhaust VRAM mid-forward — and on the TP path a
/// mid-forward OOM strands the NCCL collective (one rank dies, the other
/// hangs on the all-reduce, holding the pool lock). Reject up front with
/// a clean `InsufficientVram` when the estimated footprint exceeds free
/// VRAM, so a doomed request fails fast instead of hanging the daemon.
///
/// Heuristic and tunable (`NEURON_VISION_PREFILL_*`); the default errs
/// permissive. Skipped on the CPU sentinel (`vram_free_mb == 0`).
fn validate_vision_prefill(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> {
if vram_free_mb == 0 {
return Ok(());
}
let required_mb = vision_prefill_base_mb()
+ (prompt_len as u64).saturating_mul(vision_prefill_mb_per_1k_tokens()) / 1000;
if required_mb > vram_free_mb {
return Err(InferenceError::InsufficientVram {
free_mb: vram_free_mb,
required_mb,
});
}
Ok(())
}
fn validate_request(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> { fn validate_request(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> {
let max = max_prompt_tokens(); let max = max_prompt_tokens();
if prompt_len > max { if prompt_len > max {
@@ -1694,6 +1733,12 @@ impl CandleHarness {
); );
validate_request(prompt_len, vram_free_mb)?; validate_request(prompt_len, vram_free_mb)?;
if vision_route.is_some() {
validate_vision_prefill(prompt_len, vram_free_mb)?;
}
if vision_route.is_some() {
validate_vision_prefill(prompt_len, vram_free_mb)?;
}
// Routing: CUDA loads go through the per-device worker // Routing: CUDA loads go through the per-device worker
// thread (introduced in Phase 1; forward/clear added in // thread (introduced in Phase 1; forward/clear added in
@@ -2107,6 +2152,9 @@ impl CandleHarness {
} }
validate_request(prompt_len, vram_free_mb)?; validate_request(prompt_len, vram_free_mb)?;
if vision_route.is_some() {
validate_vision_prefill(prompt_len, vram_free_mb)?;
}
// Routing parallel to the non-streaming chat_completion: CUDA // Routing parallel to the non-streaming chat_completion: CUDA
// goes through the worker (async task), CPU keeps the // goes through the worker (async task), CPU keeps the
@@ -2977,6 +3025,9 @@ impl CandleHarness {
); );
validate_request(prompt_len, vram_free_mb)?; validate_request(prompt_len, vram_free_mb)?;
if vision_route.is_some() {
validate_vision_prefill(prompt_len, vram_free_mb)?;
}
let tp_for_task = Arc::clone(&tp); let tp_for_task = Arc::clone(&tp);
tokio::spawn( tokio::spawn(
@@ -3023,9 +3074,10 @@ impl CandleHarness {
// chunk fans out to every rank with a growing // chunk fans out to every rank with a growing
// offset; only the final chunk's logits are kept // offset; only the final chunk's logits are kept
// for the first sample. // for the first sample.
// Vision requests do a single-shot image prefill; // Vision requests do a chunked image prefill (encode
// text requests chunk it. `vision_route` was moved // once, splice per chunk); text requests chunk it the
// into this task from the synchronous setup above. // same way. `vision_route` was moved into this task
// from the synchronous setup above.
let prefill_result = match &vision_route { let prefill_result = match &vision_route {
Some((data_uris, image_token_id)) => { Some((data_uris, image_token_id)) => {
pool.generate_step_with_images( pool.generate_step_with_images(
@@ -3035,6 +3087,7 @@ impl CandleHarness {
0, 0,
*image_token_id, *image_token_id,
data_uris.clone(), data_uris.clone(),
prefill_chunk_tokens(),
) )
.await .await
} }
@@ -3449,6 +3502,9 @@ async fn chat_completion_tp_inner(
); );
validate_request(prompt_len, vram_free_mb)?; validate_request(prompt_len, vram_free_mb)?;
if vision_route.is_some() {
validate_vision_prefill(prompt_len, vram_free_mb)?;
}
// Acquire the pool lock for the duration of the request. After // Acquire the pool lock for the duration of the request. After
// Phase 3 the leader's TpLeaderModel lives in the device worker // Phase 3 the leader's TpLeaderModel lives in the device worker
@@ -3492,8 +3548,9 @@ async fn chat_completion_tp_inner(
// spread across multiple `generate_step` calls with monotonically // spread across multiple `generate_step` calls with monotonically
// growing offsets. // growing offsets.
let prefill_start = std::time::Instant::now(); let prefill_start = std::time::Instant::now();
// Vision requests do a single-shot image prefill (every rank encodes // Vision requests do a chunked image prefill (every rank encodes its
// + splices its replicated tower); text requests chunk the prefill. // replicated tower once, then splices per chunk); text requests
// chunk the prefill the same way.
let logits_vec = match &vision_route { let logits_vec = match &vision_route {
Some((data_uris, image_token_id)) => pool Some((data_uris, image_token_id)) => pool
.generate_step_with_images( .generate_step_with_images(
@@ -3503,6 +3560,7 @@ async fn chat_completion_tp_inner(
0, 0,
*image_token_id, *image_token_id,
data_uris.clone(), data_uris.clone(),
prefill_chunk_tokens(),
) )
.await .await
.map_err(InferenceError::Other)?, .map_err(InferenceError::Other)?,
@@ -4982,4 +5040,27 @@ mod tests {
.unwrap(); .unwrap();
assert!(request_has_images(&with_image)); assert!(request_has_images(&with_image));
} }
/// The vision pre-flight guard rejects a prefill whose estimated
/// footprint exceeds free VRAM (so a doomed request fails clean
/// instead of OOM-hanging the TP collective), passes one that fits,
/// and is skipped on the CPU sentinel.
#[test]
fn vision_prefill_guard_behaviour() {
// CPU sentinel (vram_free_mb == 0) is always allowed.
assert!(validate_vision_prefill(10_000_000, 0).is_ok());
// A clearly-oversized prompt against tiny free VRAM is rejected
// for any non-degenerate config (default: 2000 base + 500/1k).
assert!(matches!(
validate_vision_prefill(10_000_000, 50),
Err(InferenceError::InsufficientVram { .. })
));
// With defaults, the agent-0-sized 12,960-token prompt that
// OOM'd single-shot fits the estimate at ~12 GB free (2000 +
// 12960*500/1000 = 8480 MiB) — the chunked prefill handles it,
// so the guard must NOT reject it.
assert!(validate_vision_prefill(12_960, 12_445).is_ok());
}
} }

View File

@@ -269,6 +269,7 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
offset, offset,
image_token_id, image_token_id,
image_data_uris, image_data_uris,
chunk_size,
reply, reply,
} => { } => {
let result = tp_forward_logits_with_images( let result = tp_forward_logits_with_images(
@@ -278,6 +279,7 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
offset, offset,
image_token_id, image_token_id,
&image_data_uris, &image_data_uris,
chunk_size,
); );
let _ = reply.send(result); let _ = reply.send(result);
} }
@@ -768,6 +770,7 @@ fn tp_forward_logits_with_images(
offset: usize, offset: usize,
image_token_id: u32, image_token_id: u32,
image_data_uris: &[String], image_data_uris: &[String],
chunk_size: usize,
) -> anyhow::Result<Vec<f32>> { ) -> anyhow::Result<Vec<f32>> {
use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri}; use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri};
use candle_core::{DType, Tensor}; use candle_core::{DType, Tensor};
@@ -792,8 +795,6 @@ fn tp_forward_logits_with_images(
pixels.push(t); pixels.push(t);
} }
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
let model = state.tp_models.get_mut(&handle).ok_or_else(|| { let model = state.tp_models.get_mut(&handle).ok_or_else(|| {
anyhow::anyhow!( anyhow::anyhow!(
"TpForwardLogitsWithImages: no model for handle {}", "TpForwardLogitsWithImages: no model for handle {}",
@@ -801,7 +802,10 @@ fn tp_forward_logits_with_images(
) )
})?; })?;
let logits = model.forward_with_images(&input, offset, &pixels, image_token_id)?; // Chunked prefill (encode once, splice per chunk) — bounded
// activation, in lockstep with the subprocess ranks.
let logits =
model.prefill_with_images_chunked(tokens, offset, &pixels, image_token_id, chunk_size)?;
let logits = logits.squeeze(0)?.squeeze(0)?; let logits = logits.squeeze(0)?.squeeze(0)?;
let logits = logits.to_dtype(DType::F32)?.flatten_all()?; let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
let values = logits.to_vec1::<f32>()?; let values = logits.to_vec1::<f32>()?;

View File

@@ -246,6 +246,7 @@ pub enum Job {
offset: usize, offset: usize,
image_token_id: u32, image_token_id: u32,
image_data_uris: Vec<String>, image_data_uris: Vec<String>,
chunk_size: usize,
reply: oneshot::Sender<Result<Vec<f32>>>, reply: oneshot::Sender<Result<Vec<f32>>>,
}, },
/// Tell the worker to break its dispatch loop and exit. Any jobs /// Tell the worker to break its dispatch loop and exit. Any jobs

View File

@@ -579,6 +579,7 @@ impl DeviceWorkerHandle {
/// matching `GenerateStepWithImages` out to subprocess ranks so the /// matching `GenerateStepWithImages` out to subprocess ranks so the
/// row-parallel collectives complete. /// row-parallel collectives complete.
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub async fn tp_forward_logits_with_images( pub async fn tp_forward_logits_with_images(
&self, &self,
handle: TpHandle, handle: TpHandle,
@@ -586,6 +587,7 @@ impl DeviceWorkerHandle {
offset: usize, offset: usize,
image_token_id: u32, image_token_id: u32,
image_data_uris: Vec<String>, image_data_uris: Vec<String>,
chunk_size: usize,
) -> Result<Vec<f32>, WorkerError> { ) -> Result<Vec<f32>, WorkerError> {
if self.poisoned.load(Ordering::Acquire) { if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned { return Err(WorkerError::Poisoned {
@@ -600,6 +602,7 @@ impl DeviceWorkerHandle {
offset, offset,
image_token_id, image_token_id,
image_data_uris, image_data_uris,
chunk_size,
reply: reply_tx, reply: reply_tx,
}) })
.map_err(|_| WorkerError::Gone { .map_err(|_| WorkerError::Gone {

View File

@@ -62,21 +62,26 @@ impl TpLeaderModel {
} }
} }
/// Image-bearing forward on rank 0. Only the vision-capable /// Chunked image prefill on rank 0. Only the vision-capable
/// `qwen3_5` arch supports it; the dense `qwen3` arch has no tower. /// `qwen3_5` arch supports it; the dense `qwen3` arch has no tower.
pub fn forward_with_images( pub fn prefill_with_images_chunked(
&mut self, &mut self,
input: &candle_core::Tensor, tokens: &[u32],
offset: usize, base_offset: usize,
image_pixels: &[candle_core::Tensor], image_pixels: &[candle_core::Tensor],
image_token_id: u32, image_token_id: u32,
chunk_size: usize,
) -> candle_core::Result<candle_core::Tensor> { ) -> candle_core::Result<candle_core::Tensor> {
match self { match self {
TpLeaderModel::Qwen3_5(m) => { TpLeaderModel::Qwen3_5(m) => m.prefill_with_images_chunked(
m.forward_with_images(input, offset, image_pixels, image_token_id) tokens,
} base_offset,
image_pixels,
image_token_id,
chunk_size,
),
TpLeaderModel::Qwen3(_) => { TpLeaderModel::Qwen3(_) => {
candle_core::bail!("forward_with_images: qwen3 (dense) has no vision tower") candle_core::bail!("prefill_with_images_chunked: qwen3 (dense) has no vision tower")
} }
} }
} }
@@ -722,6 +727,7 @@ impl WorkerPool {
/// embedding broadcast. Only used for prefill; decode reuses /// embedding broadcast. Only used for prefill; decode reuses
/// `generate_step`. /// `generate_step`.
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub async fn generate_step_with_images( pub async fn generate_step_with_images(
&mut self, &mut self,
model_id: &str, model_id: &str,
@@ -730,6 +736,7 @@ impl WorkerPool {
offset: usize, offset: usize,
image_token_id: u32, image_token_id: u32,
image_data_uris: Vec<String>, image_data_uris: Vec<String>,
chunk_size: usize,
) -> Result<Vec<f32>> { ) -> Result<Vec<f32>> {
let step_start = std::time::Instant::now(); let step_start = std::time::Instant::now();
let tokens_len = tokens.len(); let tokens_len = tokens.len();
@@ -738,6 +745,7 @@ impl WorkerPool {
tokens = tokens_len, tokens = tokens_len,
offset, offset,
images = image_data_uris.len(), images = image_data_uris.len(),
chunk_size,
"WorkerPool::generate_step_with_images: fan-out" "WorkerPool::generate_step_with_images: fan-out"
); );
@@ -749,6 +757,7 @@ impl WorkerPool {
offset, offset,
image_token_id, image_token_id,
image_data_uris: image_data_uris.clone(), image_data_uris: image_data_uris.clone(),
chunk_size,
}) })
.await?; .await?;
} }
@@ -766,6 +775,7 @@ impl WorkerPool {
offset, offset,
image_token_id, image_token_id,
image_data_uris, image_data_uris,
chunk_size,
) )
.await; .await;
let leader_ok = leader_result.is_ok(); let leader_ok = leader_result.is_ok();

View File

@@ -109,6 +109,10 @@ pub enum WorkerRequest {
/// image in prompt order. Each rank decodes + preprocesses these /// image in prompt order. Each rank decodes + preprocesses these
/// identically; tens of KB each, so cheap over the stdin pipe. /// identically; tens of KB each, so cheap over the stdin pipe.
image_data_uris: Vec<String>, image_data_uris: Vec<String>,
/// Prefill chunk size (tokens). Sent explicitly so every rank
/// walks the prompt in identical windows and the per-chunk
/// row-parallel collectives stay paired across ranks.
chunk_size: usize,
}, },
/// Reset the KV cache for this model on this rank. Sent at the /// Reset the KV cache for this model on this rank. Sent at the
@@ -222,6 +226,7 @@ mod tests {
offset: 0, offset: 0,
image_token_id: 248056, image_token_id: 248056,
image_data_uris: vec!["data:image/png;base64,AAA=".into()], image_data_uris: vec!["data:image/png;base64,AAA=".into()],
chunk_size: 512,
}; };
let wire = serde_json::to_string(&req).unwrap(); let wire = serde_json::to_string(&req).unwrap();
assert!(wire.contains(r#""op":"generate_step_with_images""#)); assert!(wire.contains(r#""op":"generate_step_with_images""#));

View File

@@ -1200,19 +1200,10 @@ impl TpQwen3_5ForCausalLM {
/// identical encode → splice → forward and keeps the replicated /// identical encode → splice → forward and keeps the replicated
/// hidden state in lockstep. Returns last-position logits /// hidden state in lockstep. Returns last-position logits
/// `(B, 1, vocab)`, same contract as `forward`. /// `(B, 1, vocab)`, same contract as `forward`.
pub fn forward_with_images( /// Encode every preprocessed `(C,H,W)` image once through this
&mut self, /// rank's replicated tower and concatenate along the patch axis →
input: &Tensor, /// `(sum_patches, hidden)`. Done once per prefill, not per chunk.
offset: usize, fn encode_images_concat(&self, image_pixels: &[Tensor]) -> candle_core::Result<Tensor> {
image_pixels: &[Tensor],
image_token_id: u32,
) -> candle_core::Result<Tensor> {
if image_pixels.is_empty() {
candle_core::bail!("forward_with_images: called with zero images");
}
// Encode each image (immutable borrows of the tower) before the
// mutable forward below; the borrows end as each owned embedding
// is pushed.
let mut per_image = Vec::with_capacity(image_pixels.len()); let mut per_image = Vec::with_capacity(image_pixels.len());
for (idx, img) in image_pixels.iter().enumerate() { for (idx, img) in image_pixels.iter().enumerate() {
let embed = self let embed = self
@@ -1220,8 +1211,66 @@ impl TpQwen3_5ForCausalLM {
.map_err(|e| candle_core::Error::Msg(format!("encode image[{idx}]: {e:#}")))?; .map_err(|e| candle_core::Error::Msg(format!("encode image[{idx}]: {e:#}")))?;
per_image.push(embed); per_image.push(embed);
} }
let image_embeds = Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)?; Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)
self.forward_with_vision(input, offset, &image_embeds, image_token_id) }
/// Chunked image prefill on one rank. Encodes the image(s) once,
/// then walks the (pre-expanded) prompt in `chunk_size`-token
/// windows — exactly like the text `chunked_prefill_tp` — splicing
/// the patch embeddings into whichever chunk(s) carry `<|image_pad|>`
/// positions. Activation memory is bounded by the chunk, not the
/// full prompt, so a long vision context no longer single-shot-OOMs.
///
/// Every rank runs the identical chunk sequence (same `tokens.len()`
/// and `chunk_size`), so the row-parallel `AllReduce`s pair up
/// chunk-by-chunk across ranks with no extra synchronisation. The KV
/// cache accumulates across chunks via the growing offset; only the
/// final chunk's last-position logits are returned (intermediate
/// chunks just populate the cache, same as the text path).
pub fn prefill_with_images_chunked(
&mut self,
tokens: &[u32],
base_offset: usize,
image_pixels: &[Tensor],
image_token_id: u32,
chunk_size: usize,
) -> candle_core::Result<Tensor> {
if image_pixels.is_empty() {
candle_core::bail!("prefill_with_images_chunked: called with zero images");
}
if tokens.is_empty() {
candle_core::bail!("prefill_with_images_chunked: empty prompt");
}
let chunk_size = chunk_size.max(1);
let device = self.device().clone();
let image_embeds = self.encode_images_concat(image_pixels)?;
let mut last_logits: Option<Tensor> = None;
// Rows of `image_embeds` already spliced by earlier chunks. The
// `<|image_pad|>` run is contiguous, so chunks consume embedding
// rows in order.
let mut img_off = 0usize;
let mut start = 0usize;
while start < tokens.len() {
let end = (start + chunk_size).min(tokens.len());
let chunk = &tokens[start..end];
let input = Tensor::new(chunk, &device)?.unsqueeze(0)?;
let n_here = chunk.iter().filter(|&&t| t == image_token_id).count();
let logits = if n_here == 0 {
// Pure-text chunk — same forward the text prefill runs.
self.forward(&input, base_offset + start)?
} else {
// Splice the next `n_here` patch rows at this chunk's
// local image-pad positions.
let rows = image_embeds.narrow(0, img_off, n_here)?;
img_off += n_here;
self.forward_with_vision(&input, base_offset + start, &rows, image_token_id)?
};
last_logits = Some(logits);
start = end;
}
last_logits
.ok_or_else(|| candle_core::Error::Msg("prefill_with_images_chunked: no chunks".into()))
} }
pub fn clear_kv_cache(&mut self) { pub fn clear_kv_cache(&mut self) {

View File

@@ -47,24 +47,30 @@ impl WorkerModel {
} }
} }
/// Image-bearing forward on this rank. Only the vision-capable /// Chunked image prefill on this rank. Only the vision-capable
/// `qwen3_5` arch has a replicated tower; the dense `qwen3` arch /// `qwen3_5` arch has a replicated tower; the dense `qwen3` arch
/// errors. The returned logits are discarded by the caller (the /// errors. The returned logits are discarded by the caller (the
/// leader samples from its own rank-0 copy) — the value is the NCCL /// leader samples from its own rank-0 copy) — the value is the NCCL
/// collectives the forward issues. /// collectives the forward issues, chunk by chunk in lockstep with
fn forward_with_images( /// the leader.
fn prefill_with_images_chunked(
&mut self, &mut self,
input: &candle_core::Tensor, tokens: &[u32],
offset: usize, base_offset: usize,
image_pixels: &[candle_core::Tensor], image_pixels: &[candle_core::Tensor],
image_token_id: u32, image_token_id: u32,
chunk_size: usize,
) -> candle_core::Result<candle_core::Tensor> { ) -> candle_core::Result<candle_core::Tensor> {
match self { match self {
WorkerModel::Qwen3_5(m) => { WorkerModel::Qwen3_5(m) => m.prefill_with_images_chunked(
m.forward_with_images(input, offset, image_pixels, image_token_id) tokens,
} base_offset,
image_pixels,
image_token_id,
chunk_size,
),
WorkerModel::Qwen3(_) => { WorkerModel::Qwen3(_) => {
candle_core::bail!("forward_with_images: qwen3 (dense) has no vision tower") candle_core::bail!("prefill_with_images_chunked: qwen3 (dense) has no vision tower")
} }
} }
} }
@@ -195,12 +201,14 @@ impl WorkerState {
offset, offset,
image_token_id, image_token_id,
image_data_uris, image_data_uris,
chunk_size,
} => self.handle_generate_step_with_images( } => self.handle_generate_step_with_images(
&model_id, &model_id,
tokens, tokens,
offset, offset,
image_token_id, image_token_id,
image_data_uris, image_data_uris,
chunk_size,
), ),
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id), WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id), WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
@@ -466,6 +474,7 @@ impl WorkerState {
offset: usize, offset: usize,
image_token_id: u32, image_token_id: u32,
image_data_uris: Vec<String>, image_data_uris: Vec<String>,
chunk_size: usize,
) -> WorkerResponse { ) -> WorkerResponse {
use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri}; use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri};
use candle_core::Tensor; use candle_core::Tensor;
@@ -514,16 +523,6 @@ impl WorkerState {
} }
} }
let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => {
return WorkerResponse::Error {
kind: "forward_failed".into(),
message: format!("build input tensor: {e}"),
};
}
};
let start = std::time::Instant::now(); let start = std::time::Instant::now();
tracing::debug!( tracing::debug!(
rank = self.config.rank, rank = self.config.rank,
@@ -531,10 +530,14 @@ impl WorkerState {
tokens = tokens.len(), tokens = tokens.len(),
offset, offset,
images = pixels.len(), images = pixels.len(),
"worker GenerateStepWithImages: forward starting" chunk_size,
"worker GenerateStepWithImages: chunked prefill starting"
); );
// Drop the logits — the leader samples from its own rank-0 copy. // Drop the logits — the leader samples from its own rank-0 copy.
if let Err(e) = model.forward_with_images(&input, offset, &pixels, image_token_id) { // The chunked prefill builds its own per-chunk input tensors.
if let Err(e) =
model.prefill_with_images_chunked(&tokens, offset, &pixels, image_token_id, chunk_size)
{
tracing::warn!( tracing::warn!(
rank = self.config.rank, rank = self.config.rank,
model = %model_id, model = %model_id,
@@ -564,6 +567,7 @@ impl WorkerState {
_offset: usize, _offset: usize,
_image_token_id: u32, _image_token_id: u32,
_image_data_uris: Vec<String>, _image_data_uris: Vec<String>,
_chunk_size: usize,
) -> WorkerResponse { ) -> WorkerResponse {
WorkerResponse::Error { WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(), kind: "cuda_feature_not_enabled".into(),