From fa013505d11c0053ab5f4da837941eb2b1b2d9f5 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Thu, 4 Jun 2026 17:21:36 +0300 Subject: [PATCH] fix(neuron): chunked TP-vision prefill + pre-flight VRAM guard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit agent-0 sent a ~13k-token prompt + image; the TP vision prefill was single-shot, so it tried to materialise activations for all 12,960 positions at once and OOM'd rank 1 mid-forward. Rank 1 died before issuing its row-parallel AllReduce, stranding rank 0 on the collective (it hung holding the pool lock). The text path survives the same size because it chunks the prefill. Chunk the vision prefill the same way: - TpQwen3_5ForCausalLM::prefill_with_images_chunked encodes the image(s) once, then walks the pre-expanded prompt in prefill_chunk_tokens() windows, splicing the patch-embedding rows into whichever chunk(s) carry <|image_pad|> positions (pure-text chunks take the plain forward). Activation is bounded by the chunk, not the prompt. - Every rank runs the identical chunk sequence (chunk_size threaded through GenerateStepWithImages / TpForwardLogitsWithImages / generate_step_with_images), so the per-chunk AllReduces stay paired across ranks with no extra sync — the KV cache accumulates via the growing offset, only the last chunk's logits are kept. Pre-flight guard (validate_vision_prefill): even chunked, a long prompt's KV cache can exhaust VRAM mid-forward, and on TP that hangs the collective. Reject up front with a clean InsufficientVram when the estimated footprint exceeds free VRAM, so a doomed request fails fast instead of hanging the daemon. Heuristic + tunable (NEURON_VISION_PREFILL_MB_PER_1K_TOKENS / _BASE_MB); default permissive so the now-working 12,960-token case still passes. Applied to every vision path (single-GPU + TP); single-GPU vision stays single-shot for now, so the guard is its protection until it's chunked too. Tests: pre-flight guard behaviour; RPC round-trip carries chunk_size. The chunked forward is cuda-gated — CI CUDA type-check validates it. Refs #16 / TP-vision. Operational note: a TP rank OOM still hangs the daemon (needs restart); making a worker failure abort the leader's collective is separate, broader TP hardening. Co-Authored-By: Claude Opus 4.8 (1M context) --- crates/neuron/src/harness/candle.rs | 91 ++++++++++++++++++- .../src/harness/device_worker/dispatch.rs | 10 +- .../neuron/src/harness/device_worker/jobs.rs | 1 + .../neuron/src/harness/device_worker/mod.rs | 3 + crates/neuron/src/harness/tp/mod.rs | 26 ++++-- crates/neuron/src/harness/tp/rpc.rs | 5 + crates/neuron/src/harness/tp/tp_qwen3_5.rs | 79 +++++++++++++--- crates/neuron/src/harness/tp/worker.rs | 46 +++++----- 8 files changed, 209 insertions(+), 52 deletions(-) diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 000e590..613d8d3 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -871,6 +871,45 @@ fn min_free_vram_mb() -> u64 { /// prefill. Called from every chat_completion entry point right after /// the VRAM query. A `prompt_len == 0` is accepted (some clients send /// empty inputs to probe the endpoint); the prefill loop handles it. +/// Rough MiB of VRAM a vision prefill needs per 1000 prompt tokens +/// (accumulating KV cache + per-chunk activation headroom). Tunable; +/// the default is deliberately permissive so the guard rejects only +/// clearly-too-large requests, not ones the chunked prefill handles. +fn vision_prefill_mb_per_1k_tokens() -> u64 { + env_u64("NEURON_VISION_PREFILL_MB_PER_1K_TOKENS", 500) +} + +/// Fixed VRAM overhead (MiB) a vision prefill reserves on top of the +/// per-token estimate — image encode buffers + one chunk's activations. +fn vision_prefill_base_mb() -> u64 { + env_u64("NEURON_VISION_PREFILL_BASE_MB", 2000) +} + +/// Pre-flight check specific to vision prefills. Even with the chunked +/// prefill bounding per-step activation, the accumulating KV cache for +/// a long prompt can exhaust VRAM mid-forward — and on the TP path a +/// mid-forward OOM strands the NCCL collective (one rank dies, the other +/// hangs on the all-reduce, holding the pool lock). Reject up front with +/// a clean `InsufficientVram` when the estimated footprint exceeds free +/// VRAM, so a doomed request fails fast instead of hanging the daemon. +/// +/// Heuristic and tunable (`NEURON_VISION_PREFILL_*`); the default errs +/// permissive. Skipped on the CPU sentinel (`vram_free_mb == 0`). +fn validate_vision_prefill(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> { + if vram_free_mb == 0 { + return Ok(()); + } + let required_mb = vision_prefill_base_mb() + + (prompt_len as u64).saturating_mul(vision_prefill_mb_per_1k_tokens()) / 1000; + if required_mb > vram_free_mb { + return Err(InferenceError::InsufficientVram { + free_mb: vram_free_mb, + required_mb, + }); + } + Ok(()) +} + fn validate_request(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> { let max = max_prompt_tokens(); if prompt_len > max { @@ -1694,6 +1733,12 @@ impl CandleHarness { ); validate_request(prompt_len, vram_free_mb)?; + if vision_route.is_some() { + validate_vision_prefill(prompt_len, vram_free_mb)?; + } + if vision_route.is_some() { + validate_vision_prefill(prompt_len, vram_free_mb)?; + } // Routing: CUDA loads go through the per-device worker // thread (introduced in Phase 1; forward/clear added in @@ -2107,6 +2152,9 @@ impl CandleHarness { } validate_request(prompt_len, vram_free_mb)?; + if vision_route.is_some() { + validate_vision_prefill(prompt_len, vram_free_mb)?; + } // Routing parallel to the non-streaming chat_completion: CUDA // goes through the worker (async task), CPU keeps the @@ -2977,6 +3025,9 @@ impl CandleHarness { ); validate_request(prompt_len, vram_free_mb)?; + if vision_route.is_some() { + validate_vision_prefill(prompt_len, vram_free_mb)?; + } let tp_for_task = Arc::clone(&tp); tokio::spawn( @@ -3023,9 +3074,10 @@ impl CandleHarness { // chunk fans out to every rank with a growing // offset; only the final chunk's logits are kept // for the first sample. - // Vision requests do a single-shot image prefill; - // text requests chunk it. `vision_route` was moved - // into this task from the synchronous setup above. + // Vision requests do a chunked image prefill (encode + // once, splice per chunk); text requests chunk it the + // same way. `vision_route` was moved into this task + // from the synchronous setup above. let prefill_result = match &vision_route { Some((data_uris, image_token_id)) => { pool.generate_step_with_images( @@ -3035,6 +3087,7 @@ impl CandleHarness { 0, *image_token_id, data_uris.clone(), + prefill_chunk_tokens(), ) .await } @@ -3449,6 +3502,9 @@ async fn chat_completion_tp_inner( ); validate_request(prompt_len, vram_free_mb)?; + if vision_route.is_some() { + validate_vision_prefill(prompt_len, vram_free_mb)?; + } // Acquire the pool lock for the duration of the request. After // Phase 3 the leader's TpLeaderModel lives in the device worker @@ -3492,8 +3548,9 @@ async fn chat_completion_tp_inner( // spread across multiple `generate_step` calls with monotonically // growing offsets. let prefill_start = std::time::Instant::now(); - // Vision requests do a single-shot image prefill (every rank encodes - // + splices its replicated tower); text requests chunk the prefill. + // Vision requests do a chunked image prefill (every rank encodes its + // replicated tower once, then splices per chunk); text requests + // chunk the prefill the same way. let logits_vec = match &vision_route { Some((data_uris, image_token_id)) => pool .generate_step_with_images( @@ -3503,6 +3560,7 @@ async fn chat_completion_tp_inner( 0, *image_token_id, data_uris.clone(), + prefill_chunk_tokens(), ) .await .map_err(InferenceError::Other)?, @@ -4982,4 +5040,27 @@ mod tests { .unwrap(); assert!(request_has_images(&with_image)); } + + /// The vision pre-flight guard rejects a prefill whose estimated + /// footprint exceeds free VRAM (so a doomed request fails clean + /// instead of OOM-hanging the TP collective), passes one that fits, + /// and is skipped on the CPU sentinel. + #[test] + fn vision_prefill_guard_behaviour() { + // CPU sentinel (vram_free_mb == 0) is always allowed. + assert!(validate_vision_prefill(10_000_000, 0).is_ok()); + + // A clearly-oversized prompt against tiny free VRAM is rejected + // for any non-degenerate config (default: 2000 base + 500/1k). + assert!(matches!( + validate_vision_prefill(10_000_000, 50), + Err(InferenceError::InsufficientVram { .. }) + )); + + // With defaults, the agent-0-sized 12,960-token prompt that + // OOM'd single-shot fits the estimate at ~12 GB free (2000 + + // 12960*500/1000 = 8480 MiB) — the chunked prefill handles it, + // so the guard must NOT reject it. + assert!(validate_vision_prefill(12_960, 12_445).is_ok()); + } } diff --git a/crates/neuron/src/harness/device_worker/dispatch.rs b/crates/neuron/src/harness/device_worker/dispatch.rs index f1b1586..62c60c2 100644 --- a/crates/neuron/src/harness/device_worker/dispatch.rs +++ b/crates/neuron/src/harness/device_worker/dispatch.rs @@ -269,6 +269,7 @@ pub(crate) fn run(device_index: u32, rx: Receiver, poisoned: Arc { let result = tp_forward_logits_with_images( @@ -278,6 +279,7 @@ pub(crate) fn run(device_index: u32, rx: Receiver, poisoned: Arc anyhow::Result> { use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri}; use candle_core::{DType, Tensor}; @@ -792,8 +795,6 @@ fn tp_forward_logits_with_images( pixels.push(t); } - let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?; - let model = state.tp_models.get_mut(&handle).ok_or_else(|| { anyhow::anyhow!( "TpForwardLogitsWithImages: no model for handle {}", @@ -801,7 +802,10 @@ fn tp_forward_logits_with_images( ) })?; - let logits = model.forward_with_images(&input, offset, &pixels, image_token_id)?; + // Chunked prefill (encode once, splice per chunk) — bounded + // activation, in lockstep with the subprocess ranks. + let logits = + model.prefill_with_images_chunked(tokens, offset, &pixels, image_token_id, chunk_size)?; let logits = logits.squeeze(0)?.squeeze(0)?; let logits = logits.to_dtype(DType::F32)?.flatten_all()?; let values = logits.to_vec1::()?; diff --git a/crates/neuron/src/harness/device_worker/jobs.rs b/crates/neuron/src/harness/device_worker/jobs.rs index 534dabb..fc3587a 100644 --- a/crates/neuron/src/harness/device_worker/jobs.rs +++ b/crates/neuron/src/harness/device_worker/jobs.rs @@ -246,6 +246,7 @@ pub enum Job { offset: usize, image_token_id: u32, image_data_uris: Vec, + chunk_size: usize, reply: oneshot::Sender>>, }, /// Tell the worker to break its dispatch loop and exit. Any jobs diff --git a/crates/neuron/src/harness/device_worker/mod.rs b/crates/neuron/src/harness/device_worker/mod.rs index cab5bb6..7305787 100644 --- a/crates/neuron/src/harness/device_worker/mod.rs +++ b/crates/neuron/src/harness/device_worker/mod.rs @@ -579,6 +579,7 @@ impl DeviceWorkerHandle { /// matching `GenerateStepWithImages` out to subprocess ranks so the /// row-parallel collectives complete. #[cfg(feature = "cuda")] + #[allow(clippy::too_many_arguments)] pub async fn tp_forward_logits_with_images( &self, handle: TpHandle, @@ -586,6 +587,7 @@ impl DeviceWorkerHandle { offset: usize, image_token_id: u32, image_data_uris: Vec, + chunk_size: usize, ) -> Result, WorkerError> { if self.poisoned.load(Ordering::Acquire) { return Err(WorkerError::Poisoned { @@ -600,6 +602,7 @@ impl DeviceWorkerHandle { offset, image_token_id, image_data_uris, + chunk_size, reply: reply_tx, }) .map_err(|_| WorkerError::Gone { diff --git a/crates/neuron/src/harness/tp/mod.rs b/crates/neuron/src/harness/tp/mod.rs index 8d6802d..be83208 100644 --- a/crates/neuron/src/harness/tp/mod.rs +++ b/crates/neuron/src/harness/tp/mod.rs @@ -62,21 +62,26 @@ impl TpLeaderModel { } } - /// Image-bearing forward on rank 0. Only the vision-capable + /// Chunked image prefill on rank 0. Only the vision-capable /// `qwen3_5` arch supports it; the dense `qwen3` arch has no tower. - pub fn forward_with_images( + pub fn prefill_with_images_chunked( &mut self, - input: &candle_core::Tensor, - offset: usize, + tokens: &[u32], + base_offset: usize, image_pixels: &[candle_core::Tensor], image_token_id: u32, + chunk_size: usize, ) -> candle_core::Result { match self { - TpLeaderModel::Qwen3_5(m) => { - m.forward_with_images(input, offset, image_pixels, image_token_id) - } + TpLeaderModel::Qwen3_5(m) => m.prefill_with_images_chunked( + tokens, + base_offset, + image_pixels, + image_token_id, + chunk_size, + ), TpLeaderModel::Qwen3(_) => { - candle_core::bail!("forward_with_images: qwen3 (dense) has no vision tower") + candle_core::bail!("prefill_with_images_chunked: qwen3 (dense) has no vision tower") } } } @@ -722,6 +727,7 @@ impl WorkerPool { /// embedding broadcast. Only used for prefill; decode reuses /// `generate_step`. #[cfg(feature = "cuda")] + #[allow(clippy::too_many_arguments)] pub async fn generate_step_with_images( &mut self, model_id: &str, @@ -730,6 +736,7 @@ impl WorkerPool { offset: usize, image_token_id: u32, image_data_uris: Vec, + chunk_size: usize, ) -> Result> { let step_start = std::time::Instant::now(); let tokens_len = tokens.len(); @@ -738,6 +745,7 @@ impl WorkerPool { tokens = tokens_len, offset, images = image_data_uris.len(), + chunk_size, "WorkerPool::generate_step_with_images: fan-out" ); @@ -749,6 +757,7 @@ impl WorkerPool { offset, image_token_id, image_data_uris: image_data_uris.clone(), + chunk_size, }) .await?; } @@ -766,6 +775,7 @@ impl WorkerPool { offset, image_token_id, image_data_uris, + chunk_size, ) .await; let leader_ok = leader_result.is_ok(); diff --git a/crates/neuron/src/harness/tp/rpc.rs b/crates/neuron/src/harness/tp/rpc.rs index dd2f243..42ff707 100644 --- a/crates/neuron/src/harness/tp/rpc.rs +++ b/crates/neuron/src/harness/tp/rpc.rs @@ -109,6 +109,10 @@ pub enum WorkerRequest { /// image in prompt order. Each rank decodes + preprocesses these /// identically; tens of KB each, so cheap over the stdin pipe. image_data_uris: Vec, + /// Prefill chunk size (tokens). Sent explicitly so every rank + /// walks the prompt in identical windows and the per-chunk + /// row-parallel collectives stay paired across ranks. + chunk_size: usize, }, /// Reset the KV cache for this model on this rank. Sent at the @@ -222,6 +226,7 @@ mod tests { offset: 0, image_token_id: 248056, image_data_uris: vec!["data:image/png;base64,AAA=".into()], + chunk_size: 512, }; let wire = serde_json::to_string(&req).unwrap(); assert!(wire.contains(r#""op":"generate_step_with_images""#)); diff --git a/crates/neuron/src/harness/tp/tp_qwen3_5.rs b/crates/neuron/src/harness/tp/tp_qwen3_5.rs index 19cc1e5..d08e475 100644 --- a/crates/neuron/src/harness/tp/tp_qwen3_5.rs +++ b/crates/neuron/src/harness/tp/tp_qwen3_5.rs @@ -1200,19 +1200,10 @@ impl TpQwen3_5ForCausalLM { /// identical encode → splice → forward and keeps the replicated /// hidden state in lockstep. Returns last-position logits /// `(B, 1, vocab)`, same contract as `forward`. - pub fn forward_with_images( - &mut self, - input: &Tensor, - offset: usize, - image_pixels: &[Tensor], - image_token_id: u32, - ) -> candle_core::Result { - if image_pixels.is_empty() { - candle_core::bail!("forward_with_images: called with zero images"); - } - // Encode each image (immutable borrows of the tower) before the - // mutable forward below; the borrows end as each owned embedding - // is pushed. + /// Encode every preprocessed `(C,H,W)` image once through this + /// rank's replicated tower and concatenate along the patch axis → + /// `(sum_patches, hidden)`. Done once per prefill, not per chunk. + fn encode_images_concat(&self, image_pixels: &[Tensor]) -> candle_core::Result { let mut per_image = Vec::with_capacity(image_pixels.len()); for (idx, img) in image_pixels.iter().enumerate() { let embed = self @@ -1220,8 +1211,66 @@ impl TpQwen3_5ForCausalLM { .map_err(|e| candle_core::Error::Msg(format!("encode image[{idx}]: {e:#}")))?; per_image.push(embed); } - let image_embeds = Tensor::cat(&per_image.iter().collect::>(), 0)?; - self.forward_with_vision(input, offset, &image_embeds, image_token_id) + Tensor::cat(&per_image.iter().collect::>(), 0) + } + + /// Chunked image prefill on one rank. Encodes the image(s) once, + /// then walks the (pre-expanded) prompt in `chunk_size`-token + /// windows — exactly like the text `chunked_prefill_tp` — splicing + /// the patch embeddings into whichever chunk(s) carry `<|image_pad|>` + /// positions. Activation memory is bounded by the chunk, not the + /// full prompt, so a long vision context no longer single-shot-OOMs. + /// + /// Every rank runs the identical chunk sequence (same `tokens.len()` + /// and `chunk_size`), so the row-parallel `AllReduce`s pair up + /// chunk-by-chunk across ranks with no extra synchronisation. The KV + /// cache accumulates across chunks via the growing offset; only the + /// final chunk's last-position logits are returned (intermediate + /// chunks just populate the cache, same as the text path). + pub fn prefill_with_images_chunked( + &mut self, + tokens: &[u32], + base_offset: usize, + image_pixels: &[Tensor], + image_token_id: u32, + chunk_size: usize, + ) -> candle_core::Result { + if image_pixels.is_empty() { + candle_core::bail!("prefill_with_images_chunked: called with zero images"); + } + if tokens.is_empty() { + candle_core::bail!("prefill_with_images_chunked: empty prompt"); + } + let chunk_size = chunk_size.max(1); + let device = self.device().clone(); + let image_embeds = self.encode_images_concat(image_pixels)?; + + let mut last_logits: Option = None; + // Rows of `image_embeds` already spliced by earlier chunks. The + // `<|image_pad|>` run is contiguous, so chunks consume embedding + // rows in order. + let mut img_off = 0usize; + let mut start = 0usize; + while start < tokens.len() { + let end = (start + chunk_size).min(tokens.len()); + let chunk = &tokens[start..end]; + let input = Tensor::new(chunk, &device)?.unsqueeze(0)?; + let n_here = chunk.iter().filter(|&&t| t == image_token_id).count(); + let logits = if n_here == 0 { + // Pure-text chunk — same forward the text prefill runs. + self.forward(&input, base_offset + start)? + } else { + // Splice the next `n_here` patch rows at this chunk's + // local image-pad positions. + let rows = image_embeds.narrow(0, img_off, n_here)?; + img_off += n_here; + self.forward_with_vision(&input, base_offset + start, &rows, image_token_id)? + }; + last_logits = Some(logits); + start = end; + } + last_logits + .ok_or_else(|| candle_core::Error::Msg("prefill_with_images_chunked: no chunks".into())) } pub fn clear_kv_cache(&mut self) { diff --git a/crates/neuron/src/harness/tp/worker.rs b/crates/neuron/src/harness/tp/worker.rs index f2a8077..7dd34a1 100644 --- a/crates/neuron/src/harness/tp/worker.rs +++ b/crates/neuron/src/harness/tp/worker.rs @@ -47,24 +47,30 @@ impl WorkerModel { } } - /// Image-bearing forward on this rank. Only the vision-capable + /// Chunked image prefill on this rank. Only the vision-capable /// `qwen3_5` arch has a replicated tower; the dense `qwen3` arch /// errors. The returned logits are discarded by the caller (the /// leader samples from its own rank-0 copy) — the value is the NCCL - /// collectives the forward issues. - fn forward_with_images( + /// collectives the forward issues, chunk by chunk in lockstep with + /// the leader. + fn prefill_with_images_chunked( &mut self, - input: &candle_core::Tensor, - offset: usize, + tokens: &[u32], + base_offset: usize, image_pixels: &[candle_core::Tensor], image_token_id: u32, + chunk_size: usize, ) -> candle_core::Result { match self { - WorkerModel::Qwen3_5(m) => { - m.forward_with_images(input, offset, image_pixels, image_token_id) - } + WorkerModel::Qwen3_5(m) => m.prefill_with_images_chunked( + tokens, + base_offset, + image_pixels, + image_token_id, + chunk_size, + ), WorkerModel::Qwen3(_) => { - candle_core::bail!("forward_with_images: qwen3 (dense) has no vision tower") + candle_core::bail!("prefill_with_images_chunked: qwen3 (dense) has no vision tower") } } } @@ -195,12 +201,14 @@ impl WorkerState { offset, image_token_id, image_data_uris, + chunk_size, } => self.handle_generate_step_with_images( &model_id, tokens, offset, image_token_id, image_data_uris, + chunk_size, ), WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id), WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id), @@ -466,6 +474,7 @@ impl WorkerState { offset: usize, image_token_id: u32, image_data_uris: Vec, + chunk_size: usize, ) -> WorkerResponse { use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri}; use candle_core::Tensor; @@ -514,16 +523,6 @@ impl WorkerState { } } - let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) { - Ok(t) => t, - Err(e) => { - return WorkerResponse::Error { - kind: "forward_failed".into(), - message: format!("build input tensor: {e}"), - }; - } - }; - let start = std::time::Instant::now(); tracing::debug!( rank = self.config.rank, @@ -531,10 +530,14 @@ impl WorkerState { tokens = tokens.len(), offset, images = pixels.len(), - "worker GenerateStepWithImages: forward starting" + chunk_size, + "worker GenerateStepWithImages: chunked prefill starting" ); // Drop the logits — the leader samples from its own rank-0 copy. - if let Err(e) = model.forward_with_images(&input, offset, &pixels, image_token_id) { + // The chunked prefill builds its own per-chunk input tensors. + if let Err(e) = + model.prefill_with_images_chunked(&tokens, offset, &pixels, image_token_id, chunk_size) + { tracing::warn!( rank = self.config.rank, model = %model_id, @@ -564,6 +567,7 @@ impl WorkerState { _offset: usize, _image_token_id: u32, _image_data_uris: Vec, + _chunk_size: usize, ) -> WorkerResponse { WorkerResponse::Error { kind: "cuda_feature_not_enabled".into(),