fix(neuron): chunked TP-vision prefill + pre-flight VRAM guard
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 29s
build-prerelease / Build cortex binary (push) Successful in 4m26s
build-prerelease / Package cortex RPM (push) Successful in 1m18s
build-prerelease / Build neuron-blackwell (push) Successful in 6m6s
build-prerelease / Build neuron-ampere (push) Successful in 8m30s
CI / Format (push) Successful in 38s
CI / CUDA type-check (push) Successful in 47s
CI / Clippy (push) Successful in 2m36s
build-prerelease / Build neuron-ada (push) Successful in 5m19s
CI / Test (push) Successful in 6m3s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m32s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m47s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 59s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 29s
build-prerelease / Build cortex binary (push) Successful in 4m26s
build-prerelease / Package cortex RPM (push) Successful in 1m18s
build-prerelease / Build neuron-blackwell (push) Successful in 6m6s
build-prerelease / Build neuron-ampere (push) Successful in 8m30s
CI / Format (push) Successful in 38s
CI / CUDA type-check (push) Successful in 47s
CI / Clippy (push) Successful in 2m36s
build-prerelease / Build neuron-ada (push) Successful in 5m19s
CI / Test (push) Successful in 6m3s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m32s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m47s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 59s
agent-0 sent a ~13k-token prompt + image; the TP vision prefill was single-shot, so it tried to materialise activations for all 12,960 positions at once and OOM'd rank 1 mid-forward. Rank 1 died before issuing its row-parallel AllReduce, stranding rank 0 on the collective (it hung holding the pool lock). The text path survives the same size because it chunks the prefill. Chunk the vision prefill the same way: - TpQwen3_5ForCausalLM::prefill_with_images_chunked encodes the image(s) once, then walks the pre-expanded prompt in prefill_chunk_tokens() windows, splicing the patch-embedding rows into whichever chunk(s) carry <|image_pad|> positions (pure-text chunks take the plain forward). Activation is bounded by the chunk, not the prompt. - Every rank runs the identical chunk sequence (chunk_size threaded through GenerateStepWithImages / TpForwardLogitsWithImages / generate_step_with_images), so the per-chunk AllReduces stay paired across ranks with no extra sync — the KV cache accumulates via the growing offset, only the last chunk's logits are kept. Pre-flight guard (validate_vision_prefill): even chunked, a long prompt's KV cache can exhaust VRAM mid-forward, and on TP that hangs the collective. Reject up front with a clean InsufficientVram when the estimated footprint exceeds free VRAM, so a doomed request fails fast instead of hanging the daemon. Heuristic + tunable (NEURON_VISION_PREFILL_MB_PER_1K_TOKENS / _BASE_MB); default permissive so the now-working 12,960-token case still passes. Applied to every vision path (single-GPU + TP); single-GPU vision stays single-shot for now, so the guard is its protection until it's chunked too. Tests: pre-flight guard behaviour; RPC round-trip carries chunk_size. The chunked forward is cuda-gated — CI CUDA type-check validates it. Refs #16 / TP-vision. Operational note: a TP rank OOM still hangs the daemon (needs restart); making a worker failure abort the leader's collective is separate, broader TP hardening. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -871,6 +871,45 @@ fn min_free_vram_mb() -> u64 {
|
||||
/// prefill. Called from every chat_completion entry point right after
|
||||
/// the VRAM query. A `prompt_len == 0` is accepted (some clients send
|
||||
/// empty inputs to probe the endpoint); the prefill loop handles it.
|
||||
/// Rough MiB of VRAM a vision prefill needs per 1000 prompt tokens
|
||||
/// (accumulating KV cache + per-chunk activation headroom). Tunable;
|
||||
/// the default is deliberately permissive so the guard rejects only
|
||||
/// clearly-too-large requests, not ones the chunked prefill handles.
|
||||
fn vision_prefill_mb_per_1k_tokens() -> u64 {
|
||||
env_u64("NEURON_VISION_PREFILL_MB_PER_1K_TOKENS", 500)
|
||||
}
|
||||
|
||||
/// Fixed VRAM overhead (MiB) a vision prefill reserves on top of the
|
||||
/// per-token estimate — image encode buffers + one chunk's activations.
|
||||
fn vision_prefill_base_mb() -> u64 {
|
||||
env_u64("NEURON_VISION_PREFILL_BASE_MB", 2000)
|
||||
}
|
||||
|
||||
/// Pre-flight check specific to vision prefills. Even with the chunked
|
||||
/// prefill bounding per-step activation, the accumulating KV cache for
|
||||
/// a long prompt can exhaust VRAM mid-forward — and on the TP path a
|
||||
/// mid-forward OOM strands the NCCL collective (one rank dies, the other
|
||||
/// hangs on the all-reduce, holding the pool lock). Reject up front with
|
||||
/// a clean `InsufficientVram` when the estimated footprint exceeds free
|
||||
/// VRAM, so a doomed request fails fast instead of hanging the daemon.
|
||||
///
|
||||
/// Heuristic and tunable (`NEURON_VISION_PREFILL_*`); the default errs
|
||||
/// permissive. Skipped on the CPU sentinel (`vram_free_mb == 0`).
|
||||
fn validate_vision_prefill(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> {
|
||||
if vram_free_mb == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let required_mb = vision_prefill_base_mb()
|
||||
+ (prompt_len as u64).saturating_mul(vision_prefill_mb_per_1k_tokens()) / 1000;
|
||||
if required_mb > vram_free_mb {
|
||||
return Err(InferenceError::InsufficientVram {
|
||||
free_mb: vram_free_mb,
|
||||
required_mb,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_request(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> {
|
||||
let max = max_prompt_tokens();
|
||||
if prompt_len > max {
|
||||
@@ -1694,6 +1733,12 @@ impl CandleHarness {
|
||||
);
|
||||
|
||||
validate_request(prompt_len, vram_free_mb)?;
|
||||
if vision_route.is_some() {
|
||||
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
||||
}
|
||||
if vision_route.is_some() {
|
||||
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
||||
}
|
||||
|
||||
// Routing: CUDA loads go through the per-device worker
|
||||
// thread (introduced in Phase 1; forward/clear added in
|
||||
@@ -2107,6 +2152,9 @@ impl CandleHarness {
|
||||
}
|
||||
|
||||
validate_request(prompt_len, vram_free_mb)?;
|
||||
if vision_route.is_some() {
|
||||
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
||||
}
|
||||
|
||||
// Routing parallel to the non-streaming chat_completion: CUDA
|
||||
// goes through the worker (async task), CPU keeps the
|
||||
@@ -2977,6 +3025,9 @@ impl CandleHarness {
|
||||
);
|
||||
|
||||
validate_request(prompt_len, vram_free_mb)?;
|
||||
if vision_route.is_some() {
|
||||
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
||||
}
|
||||
|
||||
let tp_for_task = Arc::clone(&tp);
|
||||
tokio::spawn(
|
||||
@@ -3023,9 +3074,10 @@ impl CandleHarness {
|
||||
// chunk fans out to every rank with a growing
|
||||
// offset; only the final chunk's logits are kept
|
||||
// for the first sample.
|
||||
// Vision requests do a single-shot image prefill;
|
||||
// text requests chunk it. `vision_route` was moved
|
||||
// into this task from the synchronous setup above.
|
||||
// Vision requests do a chunked image prefill (encode
|
||||
// once, splice per chunk); text requests chunk it the
|
||||
// same way. `vision_route` was moved into this task
|
||||
// from the synchronous setup above.
|
||||
let prefill_result = match &vision_route {
|
||||
Some((data_uris, image_token_id)) => {
|
||||
pool.generate_step_with_images(
|
||||
@@ -3035,6 +3087,7 @@ impl CandleHarness {
|
||||
0,
|
||||
*image_token_id,
|
||||
data_uris.clone(),
|
||||
prefill_chunk_tokens(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -3449,6 +3502,9 @@ async fn chat_completion_tp_inner(
|
||||
);
|
||||
|
||||
validate_request(prompt_len, vram_free_mb)?;
|
||||
if vision_route.is_some() {
|
||||
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
||||
}
|
||||
|
||||
// Acquire the pool lock for the duration of the request. After
|
||||
// Phase 3 the leader's TpLeaderModel lives in the device worker
|
||||
@@ -3492,8 +3548,9 @@ async fn chat_completion_tp_inner(
|
||||
// spread across multiple `generate_step` calls with monotonically
|
||||
// growing offsets.
|
||||
let prefill_start = std::time::Instant::now();
|
||||
// Vision requests do a single-shot image prefill (every rank encodes
|
||||
// + splices its replicated tower); text requests chunk the prefill.
|
||||
// Vision requests do a chunked image prefill (every rank encodes its
|
||||
// replicated tower once, then splices per chunk); text requests
|
||||
// chunk the prefill the same way.
|
||||
let logits_vec = match &vision_route {
|
||||
Some((data_uris, image_token_id)) => pool
|
||||
.generate_step_with_images(
|
||||
@@ -3503,6 +3560,7 @@ async fn chat_completion_tp_inner(
|
||||
0,
|
||||
*image_token_id,
|
||||
data_uris.clone(),
|
||||
prefill_chunk_tokens(),
|
||||
)
|
||||
.await
|
||||
.map_err(InferenceError::Other)?,
|
||||
@@ -4982,4 +5040,27 @@ mod tests {
|
||||
.unwrap();
|
||||
assert!(request_has_images(&with_image));
|
||||
}
|
||||
|
||||
/// The vision pre-flight guard rejects a prefill whose estimated
|
||||
/// footprint exceeds free VRAM (so a doomed request fails clean
|
||||
/// instead of OOM-hanging the TP collective), passes one that fits,
|
||||
/// and is skipped on the CPU sentinel.
|
||||
#[test]
|
||||
fn vision_prefill_guard_behaviour() {
|
||||
// CPU sentinel (vram_free_mb == 0) is always allowed.
|
||||
assert!(validate_vision_prefill(10_000_000, 0).is_ok());
|
||||
|
||||
// A clearly-oversized prompt against tiny free VRAM is rejected
|
||||
// for any non-degenerate config (default: 2000 base + 500/1k).
|
||||
assert!(matches!(
|
||||
validate_vision_prefill(10_000_000, 50),
|
||||
Err(InferenceError::InsufficientVram { .. })
|
||||
));
|
||||
|
||||
// With defaults, the agent-0-sized 12,960-token prompt that
|
||||
// OOM'd single-shot fits the estimate at ~12 GB free (2000 +
|
||||
// 12960*500/1000 = 8480 MiB) — the chunked prefill handles it,
|
||||
// so the guard must NOT reject it.
|
||||
assert!(validate_vision_prefill(12_960, 12_445).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,6 +269,7 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
||||
offset,
|
||||
image_token_id,
|
||||
image_data_uris,
|
||||
chunk_size,
|
||||
reply,
|
||||
} => {
|
||||
let result = tp_forward_logits_with_images(
|
||||
@@ -278,6 +279,7 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
||||
offset,
|
||||
image_token_id,
|
||||
&image_data_uris,
|
||||
chunk_size,
|
||||
);
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
@@ -768,6 +770,7 @@ fn tp_forward_logits_with_images(
|
||||
offset: usize,
|
||||
image_token_id: u32,
|
||||
image_data_uris: &[String],
|
||||
chunk_size: usize,
|
||||
) -> anyhow::Result<Vec<f32>> {
|
||||
use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri};
|
||||
use candle_core::{DType, Tensor};
|
||||
@@ -792,8 +795,6 @@ fn tp_forward_logits_with_images(
|
||||
pixels.push(t);
|
||||
}
|
||||
|
||||
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
|
||||
|
||||
let model = state.tp_models.get_mut(&handle).ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"TpForwardLogitsWithImages: no model for handle {}",
|
||||
@@ -801,7 +802,10 @@ fn tp_forward_logits_with_images(
|
||||
)
|
||||
})?;
|
||||
|
||||
let logits = model.forward_with_images(&input, offset, &pixels, image_token_id)?;
|
||||
// Chunked prefill (encode once, splice per chunk) — bounded
|
||||
// activation, in lockstep with the subprocess ranks.
|
||||
let logits =
|
||||
model.prefill_with_images_chunked(tokens, offset, &pixels, image_token_id, chunk_size)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?;
|
||||
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
|
||||
let values = logits.to_vec1::<f32>()?;
|
||||
|
||||
@@ -246,6 +246,7 @@ pub enum Job {
|
||||
offset: usize,
|
||||
image_token_id: u32,
|
||||
image_data_uris: Vec<String>,
|
||||
chunk_size: usize,
|
||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||
},
|
||||
/// Tell the worker to break its dispatch loop and exit. Any jobs
|
||||
|
||||
@@ -579,6 +579,7 @@ impl DeviceWorkerHandle {
|
||||
/// matching `GenerateStepWithImages` out to subprocess ranks so the
|
||||
/// row-parallel collectives complete.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn tp_forward_logits_with_images(
|
||||
&self,
|
||||
handle: TpHandle,
|
||||
@@ -586,6 +587,7 @@ impl DeviceWorkerHandle {
|
||||
offset: usize,
|
||||
image_token_id: u32,
|
||||
image_data_uris: Vec<String>,
|
||||
chunk_size: usize,
|
||||
) -> Result<Vec<f32>, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
@@ -600,6 +602,7 @@ impl DeviceWorkerHandle {
|
||||
offset,
|
||||
image_token_id,
|
||||
image_data_uris,
|
||||
chunk_size,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
|
||||
@@ -62,21 +62,26 @@ impl TpLeaderModel {
|
||||
}
|
||||
}
|
||||
|
||||
/// Image-bearing forward on rank 0. Only the vision-capable
|
||||
/// Chunked image prefill on rank 0. Only the vision-capable
|
||||
/// `qwen3_5` arch supports it; the dense `qwen3` arch has no tower.
|
||||
pub fn forward_with_images(
|
||||
pub fn prefill_with_images_chunked(
|
||||
&mut self,
|
||||
input: &candle_core::Tensor,
|
||||
offset: usize,
|
||||
tokens: &[u32],
|
||||
base_offset: usize,
|
||||
image_pixels: &[candle_core::Tensor],
|
||||
image_token_id: u32,
|
||||
chunk_size: usize,
|
||||
) -> candle_core::Result<candle_core::Tensor> {
|
||||
match self {
|
||||
TpLeaderModel::Qwen3_5(m) => {
|
||||
m.forward_with_images(input, offset, image_pixels, image_token_id)
|
||||
}
|
||||
TpLeaderModel::Qwen3_5(m) => m.prefill_with_images_chunked(
|
||||
tokens,
|
||||
base_offset,
|
||||
image_pixels,
|
||||
image_token_id,
|
||||
chunk_size,
|
||||
),
|
||||
TpLeaderModel::Qwen3(_) => {
|
||||
candle_core::bail!("forward_with_images: qwen3 (dense) has no vision tower")
|
||||
candle_core::bail!("prefill_with_images_chunked: qwen3 (dense) has no vision tower")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -722,6 +727,7 @@ impl WorkerPool {
|
||||
/// embedding broadcast. Only used for prefill; decode reuses
|
||||
/// `generate_step`.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn generate_step_with_images(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
@@ -730,6 +736,7 @@ impl WorkerPool {
|
||||
offset: usize,
|
||||
image_token_id: u32,
|
||||
image_data_uris: Vec<String>,
|
||||
chunk_size: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
let step_start = std::time::Instant::now();
|
||||
let tokens_len = tokens.len();
|
||||
@@ -738,6 +745,7 @@ impl WorkerPool {
|
||||
tokens = tokens_len,
|
||||
offset,
|
||||
images = image_data_uris.len(),
|
||||
chunk_size,
|
||||
"WorkerPool::generate_step_with_images: fan-out"
|
||||
);
|
||||
|
||||
@@ -749,6 +757,7 @@ impl WorkerPool {
|
||||
offset,
|
||||
image_token_id,
|
||||
image_data_uris: image_data_uris.clone(),
|
||||
chunk_size,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
@@ -766,6 +775,7 @@ impl WorkerPool {
|
||||
offset,
|
||||
image_token_id,
|
||||
image_data_uris,
|
||||
chunk_size,
|
||||
)
|
||||
.await;
|
||||
let leader_ok = leader_result.is_ok();
|
||||
|
||||
@@ -109,6 +109,10 @@ pub enum WorkerRequest {
|
||||
/// image in prompt order. Each rank decodes + preprocesses these
|
||||
/// identically; tens of KB each, so cheap over the stdin pipe.
|
||||
image_data_uris: Vec<String>,
|
||||
/// Prefill chunk size (tokens). Sent explicitly so every rank
|
||||
/// walks the prompt in identical windows and the per-chunk
|
||||
/// row-parallel collectives stay paired across ranks.
|
||||
chunk_size: usize,
|
||||
},
|
||||
|
||||
/// Reset the KV cache for this model on this rank. Sent at the
|
||||
@@ -222,6 +226,7 @@ mod tests {
|
||||
offset: 0,
|
||||
image_token_id: 248056,
|
||||
image_data_uris: vec!["data:image/png;base64,AAA=".into()],
|
||||
chunk_size: 512,
|
||||
};
|
||||
let wire = serde_json::to_string(&req).unwrap();
|
||||
assert!(wire.contains(r#""op":"generate_step_with_images""#));
|
||||
|
||||
@@ -1200,19 +1200,10 @@ impl TpQwen3_5ForCausalLM {
|
||||
/// identical encode → splice → forward and keeps the replicated
|
||||
/// hidden state in lockstep. Returns last-position logits
|
||||
/// `(B, 1, vocab)`, same contract as `forward`.
|
||||
pub fn forward_with_images(
|
||||
&mut self,
|
||||
input: &Tensor,
|
||||
offset: usize,
|
||||
image_pixels: &[Tensor],
|
||||
image_token_id: u32,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
if image_pixels.is_empty() {
|
||||
candle_core::bail!("forward_with_images: called with zero images");
|
||||
}
|
||||
// Encode each image (immutable borrows of the tower) before the
|
||||
// mutable forward below; the borrows end as each owned embedding
|
||||
// is pushed.
|
||||
/// Encode every preprocessed `(C,H,W)` image once through this
|
||||
/// rank's replicated tower and concatenate along the patch axis →
|
||||
/// `(sum_patches, hidden)`. Done once per prefill, not per chunk.
|
||||
fn encode_images_concat(&self, image_pixels: &[Tensor]) -> candle_core::Result<Tensor> {
|
||||
let mut per_image = Vec::with_capacity(image_pixels.len());
|
||||
for (idx, img) in image_pixels.iter().enumerate() {
|
||||
let embed = self
|
||||
@@ -1220,8 +1211,66 @@ impl TpQwen3_5ForCausalLM {
|
||||
.map_err(|e| candle_core::Error::Msg(format!("encode image[{idx}]: {e:#}")))?;
|
||||
per_image.push(embed);
|
||||
}
|
||||
let image_embeds = Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)?;
|
||||
self.forward_with_vision(input, offset, &image_embeds, image_token_id)
|
||||
Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)
|
||||
}
|
||||
|
||||
/// Chunked image prefill on one rank. Encodes the image(s) once,
|
||||
/// then walks the (pre-expanded) prompt in `chunk_size`-token
|
||||
/// windows — exactly like the text `chunked_prefill_tp` — splicing
|
||||
/// the patch embeddings into whichever chunk(s) carry `<|image_pad|>`
|
||||
/// positions. Activation memory is bounded by the chunk, not the
|
||||
/// full prompt, so a long vision context no longer single-shot-OOMs.
|
||||
///
|
||||
/// Every rank runs the identical chunk sequence (same `tokens.len()`
|
||||
/// and `chunk_size`), so the row-parallel `AllReduce`s pair up
|
||||
/// chunk-by-chunk across ranks with no extra synchronisation. The KV
|
||||
/// cache accumulates across chunks via the growing offset; only the
|
||||
/// final chunk's last-position logits are returned (intermediate
|
||||
/// chunks just populate the cache, same as the text path).
|
||||
pub fn prefill_with_images_chunked(
|
||||
&mut self,
|
||||
tokens: &[u32],
|
||||
base_offset: usize,
|
||||
image_pixels: &[Tensor],
|
||||
image_token_id: u32,
|
||||
chunk_size: usize,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
if image_pixels.is_empty() {
|
||||
candle_core::bail!("prefill_with_images_chunked: called with zero images");
|
||||
}
|
||||
if tokens.is_empty() {
|
||||
candle_core::bail!("prefill_with_images_chunked: empty prompt");
|
||||
}
|
||||
let chunk_size = chunk_size.max(1);
|
||||
let device = self.device().clone();
|
||||
let image_embeds = self.encode_images_concat(image_pixels)?;
|
||||
|
||||
let mut last_logits: Option<Tensor> = None;
|
||||
// Rows of `image_embeds` already spliced by earlier chunks. The
|
||||
// `<|image_pad|>` run is contiguous, so chunks consume embedding
|
||||
// rows in order.
|
||||
let mut img_off = 0usize;
|
||||
let mut start = 0usize;
|
||||
while start < tokens.len() {
|
||||
let end = (start + chunk_size).min(tokens.len());
|
||||
let chunk = &tokens[start..end];
|
||||
let input = Tensor::new(chunk, &device)?.unsqueeze(0)?;
|
||||
let n_here = chunk.iter().filter(|&&t| t == image_token_id).count();
|
||||
let logits = if n_here == 0 {
|
||||
// Pure-text chunk — same forward the text prefill runs.
|
||||
self.forward(&input, base_offset + start)?
|
||||
} else {
|
||||
// Splice the next `n_here` patch rows at this chunk's
|
||||
// local image-pad positions.
|
||||
let rows = image_embeds.narrow(0, img_off, n_here)?;
|
||||
img_off += n_here;
|
||||
self.forward_with_vision(&input, base_offset + start, &rows, image_token_id)?
|
||||
};
|
||||
last_logits = Some(logits);
|
||||
start = end;
|
||||
}
|
||||
last_logits
|
||||
.ok_or_else(|| candle_core::Error::Msg("prefill_with_images_chunked: no chunks".into()))
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
|
||||
@@ -47,24 +47,30 @@ impl WorkerModel {
|
||||
}
|
||||
}
|
||||
|
||||
/// Image-bearing forward on this rank. Only the vision-capable
|
||||
/// Chunked image prefill on this rank. Only the vision-capable
|
||||
/// `qwen3_5` arch has a replicated tower; the dense `qwen3` arch
|
||||
/// errors. The returned logits are discarded by the caller (the
|
||||
/// leader samples from its own rank-0 copy) — the value is the NCCL
|
||||
/// collectives the forward issues.
|
||||
fn forward_with_images(
|
||||
/// collectives the forward issues, chunk by chunk in lockstep with
|
||||
/// the leader.
|
||||
fn prefill_with_images_chunked(
|
||||
&mut self,
|
||||
input: &candle_core::Tensor,
|
||||
offset: usize,
|
||||
tokens: &[u32],
|
||||
base_offset: usize,
|
||||
image_pixels: &[candle_core::Tensor],
|
||||
image_token_id: u32,
|
||||
chunk_size: usize,
|
||||
) -> candle_core::Result<candle_core::Tensor> {
|
||||
match self {
|
||||
WorkerModel::Qwen3_5(m) => {
|
||||
m.forward_with_images(input, offset, image_pixels, image_token_id)
|
||||
}
|
||||
WorkerModel::Qwen3_5(m) => m.prefill_with_images_chunked(
|
||||
tokens,
|
||||
base_offset,
|
||||
image_pixels,
|
||||
image_token_id,
|
||||
chunk_size,
|
||||
),
|
||||
WorkerModel::Qwen3(_) => {
|
||||
candle_core::bail!("forward_with_images: qwen3 (dense) has no vision tower")
|
||||
candle_core::bail!("prefill_with_images_chunked: qwen3 (dense) has no vision tower")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -195,12 +201,14 @@ impl WorkerState {
|
||||
offset,
|
||||
image_token_id,
|
||||
image_data_uris,
|
||||
chunk_size,
|
||||
} => self.handle_generate_step_with_images(
|
||||
&model_id,
|
||||
tokens,
|
||||
offset,
|
||||
image_token_id,
|
||||
image_data_uris,
|
||||
chunk_size,
|
||||
),
|
||||
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
|
||||
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
|
||||
@@ -466,6 +474,7 @@ impl WorkerState {
|
||||
offset: usize,
|
||||
image_token_id: u32,
|
||||
image_data_uris: Vec<String>,
|
||||
chunk_size: usize,
|
||||
) -> WorkerResponse {
|
||||
use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri};
|
||||
use candle_core::Tensor;
|
||||
@@ -514,16 +523,6 @@ impl WorkerState {
|
||||
}
|
||||
}
|
||||
|
||||
let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "forward_failed".into(),
|
||||
message: format!("build input tensor: {e}"),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
tracing::debug!(
|
||||
rank = self.config.rank,
|
||||
@@ -531,10 +530,14 @@ impl WorkerState {
|
||||
tokens = tokens.len(),
|
||||
offset,
|
||||
images = pixels.len(),
|
||||
"worker GenerateStepWithImages: forward starting"
|
||||
chunk_size,
|
||||
"worker GenerateStepWithImages: chunked prefill starting"
|
||||
);
|
||||
// Drop the logits — the leader samples from its own rank-0 copy.
|
||||
if let Err(e) = model.forward_with_images(&input, offset, &pixels, image_token_id) {
|
||||
// The chunked prefill builds its own per-chunk input tensors.
|
||||
if let Err(e) =
|
||||
model.prefill_with_images_chunked(&tokens, offset, &pixels, image_token_id, chunk_size)
|
||||
{
|
||||
tracing::warn!(
|
||||
rank = self.config.rank,
|
||||
model = %model_id,
|
||||
@@ -564,6 +567,7 @@ impl WorkerState {
|
||||
_offset: usize,
|
||||
_image_token_id: u32,
|
||||
_image_data_uris: Vec<String>,
|
||||
_chunk_size: usize,
|
||||
) -> WorkerResponse {
|
||||
WorkerResponse::Error {
|
||||
kind: "cuda_feature_not_enabled".into(),
|
||||
|
||||
Reference in New Issue
Block a user