fix(neuron): chunked TP-vision prefill + pre-flight VRAM guard
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 29s
build-prerelease / Build cortex binary (push) Successful in 4m26s
build-prerelease / Package cortex RPM (push) Successful in 1m18s
build-prerelease / Build neuron-blackwell (push) Successful in 6m6s
build-prerelease / Build neuron-ampere (push) Successful in 8m30s
CI / Format (push) Successful in 38s
CI / CUDA type-check (push) Successful in 47s
CI / Clippy (push) Successful in 2m36s
build-prerelease / Build neuron-ada (push) Successful in 5m19s
CI / Test (push) Successful in 6m3s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m32s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m47s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 59s

agent-0 sent a ~13k-token prompt + image; the TP vision prefill was
single-shot, so it tried to materialise activations for all 12,960
positions at once and OOM'd rank 1 mid-forward. Rank 1 died before
issuing its row-parallel AllReduce, stranding rank 0 on the collective
(it hung holding the pool lock). The text path survives the same size
because it chunks the prefill.

Chunk the vision prefill the same way:

- TpQwen3_5ForCausalLM::prefill_with_images_chunked encodes the image(s)
  once, then walks the pre-expanded prompt in prefill_chunk_tokens()
  windows, splicing the patch-embedding rows into whichever chunk(s)
  carry <|image_pad|> positions (pure-text chunks take the plain
  forward). Activation is bounded by the chunk, not the prompt.
- Every rank runs the identical chunk sequence (chunk_size threaded
  through GenerateStepWithImages / TpForwardLogitsWithImages /
  generate_step_with_images), so the per-chunk AllReduces stay paired
  across ranks with no extra sync — the KV cache accumulates via the
  growing offset, only the last chunk's logits are kept.

Pre-flight guard (validate_vision_prefill): even chunked, a long
prompt's KV cache can exhaust VRAM mid-forward, and on TP that hangs
the collective. Reject up front with a clean InsufficientVram when the
estimated footprint exceeds free VRAM, so a doomed request fails fast
instead of hanging the daemon. Heuristic + tunable
(NEURON_VISION_PREFILL_MB_PER_1K_TOKENS / _BASE_MB); default permissive
so the now-working 12,960-token case still passes. Applied to every
vision path (single-GPU + TP); single-GPU vision stays single-shot for
now, so the guard is its protection until it's chunked too.

Tests: pre-flight guard behaviour; RPC round-trip carries chunk_size.
The chunked forward is cuda-gated — CI CUDA type-check validates it.

Refs #16 / TP-vision. Operational note: a TP rank OOM still hangs the
daemon (needs restart); making a worker failure abort the leader's
collective is separate, broader TP hardening.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-04 17:21:36 +03:00
parent c8bcaabc38
commit fa013505d1
8 changed files with 209 additions and 52 deletions

View File

@@ -109,6 +109,10 @@ pub enum WorkerRequest {
/// image in prompt order. Each rank decodes + preprocesses these
/// identically; tens of KB each, so cheap over the stdin pipe.
image_data_uris: Vec<String>,
/// Prefill chunk size (tokens). Sent explicitly so every rank
/// walks the prompt in identical windows and the per-chunk
/// row-parallel collectives stay paired across ranks.
chunk_size: usize,
},
/// Reset the KV cache for this model on this rank. Sent at the
@@ -222,6 +226,7 @@ mod tests {
offset: 0,
image_token_id: 248056,
image_data_uris: vec!["data:image/png;base64,AAA=".into()],
chunk_size: 512,
};
let wire = serde_json::to_string(&req).unwrap();
assert!(wire.contains(r#""op":"generate_step_with_images""#));