fix(neuron): chunked TP-vision prefill + pre-flight VRAM guard
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 29s
build-prerelease / Build cortex binary (push) Successful in 4m26s
build-prerelease / Package cortex RPM (push) Successful in 1m18s
build-prerelease / Build neuron-blackwell (push) Successful in 6m6s
build-prerelease / Build neuron-ampere (push) Successful in 8m30s
CI / Format (push) Successful in 38s
CI / CUDA type-check (push) Successful in 47s
CI / Clippy (push) Successful in 2m36s
build-prerelease / Build neuron-ada (push) Successful in 5m19s
CI / Test (push) Successful in 6m3s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m32s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m47s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 59s

agent-0 sent a ~13k-token prompt + image; the TP vision prefill was
single-shot, so it tried to materialise activations for all 12,960
positions at once and OOM'd rank 1 mid-forward. Rank 1 died before
issuing its row-parallel AllReduce, stranding rank 0 on the collective
(it hung holding the pool lock). The text path survives the same size
because it chunks the prefill.

Chunk the vision prefill the same way:

- TpQwen3_5ForCausalLM::prefill_with_images_chunked encodes the image(s)
  once, then walks the pre-expanded prompt in prefill_chunk_tokens()
  windows, splicing the patch-embedding rows into whichever chunk(s)
  carry <|image_pad|> positions (pure-text chunks take the plain
  forward). Activation is bounded by the chunk, not the prompt.
- Every rank runs the identical chunk sequence (chunk_size threaded
  through GenerateStepWithImages / TpForwardLogitsWithImages /
  generate_step_with_images), so the per-chunk AllReduces stay paired
  across ranks with no extra sync — the KV cache accumulates via the
  growing offset, only the last chunk's logits are kept.

Pre-flight guard (validate_vision_prefill): even chunked, a long
prompt's KV cache can exhaust VRAM mid-forward, and on TP that hangs
the collective. Reject up front with a clean InsufficientVram when the
estimated footprint exceeds free VRAM, so a doomed request fails fast
instead of hanging the daemon. Heuristic + tunable
(NEURON_VISION_PREFILL_MB_PER_1K_TOKENS / _BASE_MB); default permissive
so the now-working 12,960-token case still passes. Applied to every
vision path (single-GPU + TP); single-GPU vision stays single-shot for
now, so the guard is its protection until it's chunked too.

Tests: pre-flight guard behaviour; RPC round-trip carries chunk_size.
The chunked forward is cuda-gated — CI CUDA type-check validates it.

Refs #16 / TP-vision. Operational note: a TP rank OOM still hangs the
daemon (needs restart); making a worker failure abort the leader's
collective is separate, broader TP hardening.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-04 17:21:36 +03:00
parent c8bcaabc38
commit fa013505d1
8 changed files with 209 additions and 52 deletions

View File

@@ -871,6 +871,45 @@ fn min_free_vram_mb() -> u64 {
/// prefill. Called from every chat_completion entry point right after
/// the VRAM query. A `prompt_len == 0` is accepted (some clients send
/// empty inputs to probe the endpoint); the prefill loop handles it.
/// Rough MiB of VRAM a vision prefill needs per 1000 prompt tokens
/// (accumulating KV cache + per-chunk activation headroom). Tunable;
/// the default is deliberately permissive so the guard rejects only
/// clearly-too-large requests, not ones the chunked prefill handles.
fn vision_prefill_mb_per_1k_tokens() -> u64 {
env_u64("NEURON_VISION_PREFILL_MB_PER_1K_TOKENS", 500)
}
/// Fixed VRAM overhead (MiB) a vision prefill reserves on top of the
/// per-token estimate — image encode buffers + one chunk's activations.
fn vision_prefill_base_mb() -> u64 {
env_u64("NEURON_VISION_PREFILL_BASE_MB", 2000)
}
/// Pre-flight check specific to vision prefills. Even with the chunked
/// prefill bounding per-step activation, the accumulating KV cache for
/// a long prompt can exhaust VRAM mid-forward — and on the TP path a
/// mid-forward OOM strands the NCCL collective (one rank dies, the other
/// hangs on the all-reduce, holding the pool lock). Reject up front with
/// a clean `InsufficientVram` when the estimated footprint exceeds free
/// VRAM, so a doomed request fails fast instead of hanging the daemon.
///
/// Heuristic and tunable (`NEURON_VISION_PREFILL_*`); the default errs
/// permissive. Skipped on the CPU sentinel (`vram_free_mb == 0`).
fn validate_vision_prefill(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> {
if vram_free_mb == 0 {
return Ok(());
}
let required_mb = vision_prefill_base_mb()
+ (prompt_len as u64).saturating_mul(vision_prefill_mb_per_1k_tokens()) / 1000;
if required_mb > vram_free_mb {
return Err(InferenceError::InsufficientVram {
free_mb: vram_free_mb,
required_mb,
});
}
Ok(())
}
fn validate_request(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> {
let max = max_prompt_tokens();
if prompt_len > max {
@@ -1694,6 +1733,12 @@ impl CandleHarness {
);
validate_request(prompt_len, vram_free_mb)?;
if vision_route.is_some() {
validate_vision_prefill(prompt_len, vram_free_mb)?;
}
if vision_route.is_some() {
validate_vision_prefill(prompt_len, vram_free_mb)?;
}
// Routing: CUDA loads go through the per-device worker
// thread (introduced in Phase 1; forward/clear added in
@@ -2107,6 +2152,9 @@ impl CandleHarness {
}
validate_request(prompt_len, vram_free_mb)?;
if vision_route.is_some() {
validate_vision_prefill(prompt_len, vram_free_mb)?;
}
// Routing parallel to the non-streaming chat_completion: CUDA
// goes through the worker (async task), CPU keeps the
@@ -2977,6 +3025,9 @@ impl CandleHarness {
);
validate_request(prompt_len, vram_free_mb)?;
if vision_route.is_some() {
validate_vision_prefill(prompt_len, vram_free_mb)?;
}
let tp_for_task = Arc::clone(&tp);
tokio::spawn(
@@ -3023,9 +3074,10 @@ impl CandleHarness {
// chunk fans out to every rank with a growing
// offset; only the final chunk's logits are kept
// for the first sample.
// Vision requests do a single-shot image prefill;
// text requests chunk it. `vision_route` was moved
// into this task from the synchronous setup above.
// Vision requests do a chunked image prefill (encode
// once, splice per chunk); text requests chunk it the
// same way. `vision_route` was moved into this task
// from the synchronous setup above.
let prefill_result = match &vision_route {
Some((data_uris, image_token_id)) => {
pool.generate_step_with_images(
@@ -3035,6 +3087,7 @@ impl CandleHarness {
0,
*image_token_id,
data_uris.clone(),
prefill_chunk_tokens(),
)
.await
}
@@ -3449,6 +3502,9 @@ async fn chat_completion_tp_inner(
);
validate_request(prompt_len, vram_free_mb)?;
if vision_route.is_some() {
validate_vision_prefill(prompt_len, vram_free_mb)?;
}
// Acquire the pool lock for the duration of the request. After
// Phase 3 the leader's TpLeaderModel lives in the device worker
@@ -3492,8 +3548,9 @@ async fn chat_completion_tp_inner(
// spread across multiple `generate_step` calls with monotonically
// growing offsets.
let prefill_start = std::time::Instant::now();
// Vision requests do a single-shot image prefill (every rank encodes
// + splices its replicated tower); text requests chunk the prefill.
// Vision requests do a chunked image prefill (every rank encodes its
// replicated tower once, then splices per chunk); text requests
// chunk the prefill the same way.
let logits_vec = match &vision_route {
Some((data_uris, image_token_id)) => pool
.generate_step_with_images(
@@ -3503,6 +3560,7 @@ async fn chat_completion_tp_inner(
0,
*image_token_id,
data_uris.clone(),
prefill_chunk_tokens(),
)
.await
.map_err(InferenceError::Other)?,
@@ -4982,4 +5040,27 @@ mod tests {
.unwrap();
assert!(request_has_images(&with_image));
}
/// The vision pre-flight guard rejects a prefill whose estimated
/// footprint exceeds free VRAM (so a doomed request fails clean
/// instead of OOM-hanging the TP collective), passes one that fits,
/// and is skipped on the CPU sentinel.
#[test]
fn vision_prefill_guard_behaviour() {
// CPU sentinel (vram_free_mb == 0) is always allowed.
assert!(validate_vision_prefill(10_000_000, 0).is_ok());
// A clearly-oversized prompt against tiny free VRAM is rejected
// for any non-degenerate config (default: 2000 base + 500/1k).
assert!(matches!(
validate_vision_prefill(10_000_000, 50),
Err(InferenceError::InsufficientVram { .. })
));
// With defaults, the agent-0-sized 12,960-token prompt that
// OOM'd single-shot fits the estimate at ~12 GB free (2000 +
// 12960*500/1000 = 8480 MiB) — the chunked prefill handles it,
// so the guard must NOT reject it.
assert!(validate_vision_prefill(12_960, 12_445).is_ok());
}
}