From 1e1388939205f5bd6a567ebc7d34c0a8fd3491b5 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Wed, 27 May 2026 13:46:54 +0300 Subject: [PATCH] feat(neuron): chunked prefill + VRAM/prompt-length pre-flight checks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prevents the OOM-during-prefill → poisoned-context → 5-minute-reload cycle observed on beast under agent-zero workloads. Three changes, all keyed off env-driven knobs so an operator can tune without a rebuild: 1. Chunked prefill (NEURON_PREFILL_CHUNK_TOKENS, default 512). The initial forward is split into N-token windows, each with a monotonically growing offset. KV cache accumulates across chunks exactly as it would under one big prefill; only the final chunk's logits are kept for sampling. Activation memory now scales with chunk size instead of prompt length, so a 13 k-token prompt stops holding tens of GB of intermediate activations live at once. Wired into all six prefill call sites: - run_inference / run_inference_streaming (CPU path) - run_inference_via_worker / stream_inference_via_worker (CUDA single-GPU through device worker) - chat_completion_tp_inner / chat_completion_tp_stream (TP via WorkerPool) Three helpers — chunked_prefill_local, chunked_prefill_via_worker, chunked_prefill_tp — own the loop shape so the chunking semantics stay identical across paths. Per-chunk debug log shows progress. 2. Max prompt length (NEURON_MAX_PROMPT_TOKENS, default 16384). Requests above the cap return a structured 400 with `code: prompt_too_long` rather than going through the prefill and discovering the limit by OOMing partway through. New InferenceError::PromptTooLong variant. 3. Minimum free VRAM gate (NEURON_MIN_FREE_VRAM_MB, default 1500). If `vram_free_mb` is below the threshold at request start (e.g. another concurrent request is mid-prefill), reject with a clean 503 + `code: insufficient_vram` rather than starting work that will OOM. New InferenceError::InsufficientVram variant. CPU loads (vram=0 sentinel) skip this check. All three gates fire BEFORE any device work, so a rejected request costs ~one tokenisation pass and never touches the worker thread — poison cascades from rejected work are now impossible. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/neuron/src/api.rs | 50 ++++++ crates/neuron/src/harness/candle.rs | 266 +++++++++++++++++++++++++--- 2 files changed, 294 insertions(+), 22 deletions(-) diff --git a/crates/neuron/src/api.rs b/crates/neuron/src/api.rs index bd72836..f400563 100644 --- a/crates/neuron/src/api.rs +++ b/crates/neuron/src/api.rs @@ -174,6 +174,31 @@ async fn chat_completions( Json(json!({"error": format!("model '{id}' not loaded on this neuron")})), ) .into_response(), + Err(InferenceError::PromptTooLong { prompt_len, max }) => ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": format!("prompt has {prompt_len} tokens but max is {max}"), + "code": "prompt_too_long", + "prompt_len": prompt_len, + "max": max, + })), + ) + .into_response(), + Err(InferenceError::InsufficientVram { + free_mb, + required_mb, + }) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ + "error": format!( + "insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB" + ), + "code": "insufficient_vram", + "free_mb": free_mb, + "required_mb": required_mb, + })), + ) + .into_response(), Err(InferenceError::Other(e)) => ( StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": format!("{e:#}")})), @@ -188,6 +213,31 @@ async fn chat_completions( Json(json!({"error": format!("model '{id}' not loaded on this neuron")})), ) .into_response(), + Err(InferenceError::PromptTooLong { prompt_len, max }) => ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": format!("prompt has {prompt_len} tokens but max is {max}"), + "code": "prompt_too_long", + "prompt_len": prompt_len, + "max": max, + })), + ) + .into_response(), + Err(InferenceError::InsufficientVram { + free_mb, + required_mb, + }) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ + "error": format!( + "insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB" + ), + "code": "insufficient_vram", + "free_mb": free_mb, + "required_mb": required_mb, + })), + ) + .into_response(), Err(InferenceError::Other(e)) => ( StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": format!("{e:#}")})), diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index afbcd27..1320349 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -621,6 +621,78 @@ fn new_req_id() -> String { format!("{:06x}", unix_subsec_nanos() & 0xFFFFFF) } +/// Read a positive `usize` from `name` in the process env, falling back +/// to `default` if unset or unparseable. Used for runtime tuning knobs +/// that we want operators to be able to adjust without a recompile. +fn env_usize(name: &str, default: usize) -> usize { + std::env::var(name) + .ok() + .and_then(|s| s.parse().ok()) + .filter(|v: &usize| *v > 0) + .unwrap_or(default) +} + +/// Same as [`env_usize`] but for `u64`. +fn env_u64(name: &str, default: u64) -> u64 { + std::env::var(name) + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(default) +} + +/// Prefill chunk size in tokens. The initial forward over a long prompt +/// is split into windows of this many tokens, each with a monotonically +/// growing offset, so activation memory is bounded by chunk × layers × +/// hidden instead of prompt × layers × hidden. The default (512) keeps +/// activation peaks under ~1 GiB on a 27B Qwen-class model while +/// keeping the per-step overhead negligible vs. one big prefill. +fn prefill_chunk_tokens() -> usize { + env_usize("NEURON_PREFILL_CHUNK_TOKENS", 512) +} + +/// Maximum allowed prompt length, in tokens. Requests above this are +/// rejected with [`InferenceError::PromptTooLong`] before any device +/// work — this is the explicit upper bound on context size, separate +/// from the model's `max_position_embeddings` (which can be much +/// larger than what fits in VRAM in practice). +fn max_prompt_tokens() -> usize { + env_usize("NEURON_MAX_PROMPT_TOKENS", 16384) +} + +/// Minimum free VRAM (MiB) required to even attempt a prefill. Requests +/// below this are rejected with [`InferenceError::InsufficientVram`] +/// before any device work. Acts as a backstop when concurrent requests +/// have eaten the headroom; intentionally conservative — a request +/// that gets past this can still OOM, but the rejection is a clean 503 +/// rather than a poisoned context. +fn min_free_vram_mb() -> u64 { + env_u64("NEURON_MIN_FREE_VRAM_MB", 1500) +} + +/// Pre-flight check: reject the request if the prompt exceeds the +/// configured max, or if there isn't enough free VRAM to safely start a +/// prefill. Called from every chat_completion entry point right after +/// the VRAM query. A `prompt_len == 0` is accepted (some clients send +/// empty inputs to probe the endpoint); the prefill loop handles it. +fn validate_request(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> { + let max = max_prompt_tokens(); + if prompt_len > max { + return Err(InferenceError::PromptTooLong { prompt_len, max }); + } + // VRAM check is skipped on CPU loads (vram_free_mb == 0 sentinel) + // because the (0, 0) reply from `query_vram` is also what a missing + // worker returns. The CPU path has no per-GPU memory limit anyway — + // host RAM is bounded by the OOM killer, not this check. + let min = min_free_vram_mb(); + if vram_free_mb != 0 && vram_free_mb < min { + return Err(InferenceError::InsufficientVram { + free_mb: vram_free_mb, + required_mb: min, + }); + } + Ok(()) +} + /// Threshold above which `pool.lock().await` blocking is interesting /// enough to warn about. Healthy concurrent requests serialise behind /// the pool in single-digit ms — anything past 2 seconds is either a @@ -687,6 +759,129 @@ fn sample_with_penalty( Ok(logits_processor.sample(&penalised)?) } +/// Chunked prefill against an in-process [`ModelArch`]. Splits +/// `prompt_tokens` into [`prefill_chunk_tokens()`]-sized windows, runs +/// each through `arch.forward(chunk, offset)` with a monotonically +/// growing offset, and returns the last chunk's logits ready for +/// sampling. Bounds activation memory to O(chunk × layers × hidden) +/// instead of O(prompt × layers × hidden); the KV cache grows +/// monotonically so the model sees the full prompt at the final chunk. +fn chunked_prefill_local( + arch: &mut ModelArch, + device: &Device, + prompt_tokens: &[u32], +) -> Result { + let prompt_len = prompt_tokens.len(); + if prompt_len == 0 { + anyhow::bail!("chunked_prefill_local: empty prompt"); + } + let chunk_size = prefill_chunk_tokens(); + let mut offset = 0; + let mut last_logits: Option = None; + while offset < prompt_len { + let end = (offset + chunk_size).min(prompt_len); + let chunk = &prompt_tokens[offset..end]; + let input = Tensor::new(chunk, device)?.unsqueeze(0)?; + let logits = arch.forward(&input, offset)?; + if end == prompt_len { + last_logits = Some(logits); + } + offset = end; + } + last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_local: no chunks produced")) +} + +/// Chunked prefill via the per-device worker. Same shape as +/// [`chunked_prefill_local`] but the forward runs on the worker thread +/// and replies with a CPU-side `Vec` of logits at the final +/// chunk's last position. Tensors never escape the worker. +#[cfg(feature = "cuda")] +async fn chunked_prefill_via_worker( + worker: &super::device_worker::DeviceWorkerHandle, + handle: super::device_worker::ArchHandle, + prompt_tokens: &[u32], +) -> Result> { + let prompt_len = prompt_tokens.len(); + if prompt_len == 0 { + anyhow::bail!("chunked_prefill_via_worker: empty prompt"); + } + let chunk_size = prefill_chunk_tokens(); + let mut offset = 0; + let mut last_logits: Option> = None; + let total_chunks = prompt_len.div_ceil(chunk_size); + let mut chunk_idx = 0_usize; + while offset < prompt_len { + let end = (offset + chunk_size).min(prompt_len); + let chunk = prompt_tokens[offset..end].to_vec(); + let chunk_len = chunk.len(); + let step_start = std::time::Instant::now(); + let logits = worker + .forward_logits(handle, chunk, offset) + .await + .map_err(|e| anyhow::anyhow!("prefill chunk {chunk_idx}/{total_chunks}: {e}"))?; + tracing::debug!( + chunk_idx, + total_chunks, + chunk_len, + offset, + elapsed_ms = step_start.elapsed().as_millis(), + "chunked prefill (worker): chunk done" + ); + if end == prompt_len { + last_logits = Some(logits); + } + offset = end; + chunk_idx += 1; + } + last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_via_worker: no chunks produced")) +} + +/// Chunked prefill via the TP `WorkerPool`. Same shape as +/// [`chunked_prefill_via_worker`] but the forward fans out to every +/// rank via `pool.generate_step`. Returns the leader's CPU-side +/// `Vec` of logits at the final chunk's last position. +#[cfg(feature = "cuda")] +async fn chunked_prefill_tp( + pool: &mut super::tp::WorkerPool, + model_id: &str, + leader_handle: super::device_worker::TpHandle, + prompt_tokens: &[u32], +) -> Result> { + let prompt_len = prompt_tokens.len(); + if prompt_len == 0 { + anyhow::bail!("chunked_prefill_tp: empty prompt"); + } + let chunk_size = prefill_chunk_tokens(); + let mut offset = 0; + let mut last_logits: Option> = None; + let total_chunks = prompt_len.div_ceil(chunk_size); + let mut chunk_idx = 0_usize; + while offset < prompt_len { + let end = (offset + chunk_size).min(prompt_len); + let chunk = prompt_tokens[offset..end].to_vec(); + let chunk_len = chunk.len(); + let step_start = std::time::Instant::now(); + let logits = pool + .generate_step(model_id, leader_handle, chunk, offset) + .await + .map_err(|e| anyhow::anyhow!("TP prefill chunk {chunk_idx}/{total_chunks}: {e}"))?; + tracing::debug!( + chunk_idx, + total_chunks, + chunk_len, + offset, + elapsed_ms = step_start.elapsed().as_millis(), + "chunked prefill (TP): chunk done" + ); + if end == prompt_len { + last_logits = Some(logits); + } + offset = end; + chunk_idx += 1; + } + last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_tp: no chunks produced")) +} + impl CandleHarness { pub fn new(bind_url: String, hf_cache: Option) -> Self { let hf_cache = resolve_hf_cache(hf_cache); @@ -1151,6 +1346,8 @@ impl CandleHarness { "chat_completion: starting" ); + validate_request(prompt_len, vram_free_mb)?; + // Routing: CUDA loads go through the per-device worker // thread (introduced in Phase 1; forward/clear added in // Phase 2). CPU loads keep the existing spawn_blocking @@ -1409,6 +1606,9 @@ impl CandleHarness { "chat_completion (stream): starting" ); } + + validate_request(prompt_len, vram_free_mb)?; + // Routing parallel to the non-streaming chat_completion: CUDA // goes through the worker (async task), CPU keeps the // spawn_blocking + Arc> path. @@ -1972,6 +2172,9 @@ impl CandleHarness { vram_total_mb, "TP chat_completion (stream): starting" ); + + validate_request(prompt_len, vram_free_mb)?; + let tp_for_task = Arc::clone(&tp); tokio::spawn( async move { @@ -2001,10 +2204,17 @@ impl CandleHarness { LogitsProcessor::from_sampling(seed, sampling) }; - // Prefill — every rank embeds the prompt, offset = 0. - let logits_vec = match pool - .generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0) - .await + // Chunked prefill — see `chunked_prefill_tp`. Each + // chunk fans out to every rank with a growing + // offset; only the final chunk's logits are kept + // for the first sample. + let logits_vec = match chunked_prefill_tp( + &mut pool, + &model_id, + leader_handle, + &prompt_tokens, + ) + .await { Ok(l) => l, Err(e) => { @@ -2240,6 +2450,8 @@ async fn chat_completion_tp_inner( "TP chat_completion: starting" ); + validate_request(prompt_len, vram_free_mb)?; + // Acquire the pool lock for the duration of the request. After // Phase 3 the leader's TpLeaderModel lives in the device worker // thread, so the pool lock now serialises only subprocess RPC @@ -2276,10 +2488,13 @@ async fn chat_completion_tp_inner( let mut generated: Vec = Vec::new(); let mut finish_reason = "length".to_string(); - // Prefill: every rank embeds the whole prompt, offset = 0. + // Prefill: chunk the prompt through `chunked_prefill_tp` so + // activation memory is bounded by chunk size rather than the full + // prompt length. Every rank still sees the prompt in order, just + // spread across multiple `generate_step` calls with monotonically + // growing offsets. let prefill_start = std::time::Instant::now(); - let logits_vec = pool - .generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0) + let logits_vec = chunked_prefill_tp(&mut pool, &model_id, leader_handle, &prompt_tokens) .await .map_err(InferenceError::Other)?; let (post_prefill_vram_free_mb, _) = tp.query_vram().await; @@ -2454,12 +2669,19 @@ async fn emit_chunk( } /// Errors returned by `CandleHarness::chat_completion`. The -/// `ModelNotLoaded` variant lets the HTTP handler map cleanly to 404 -/// without string-matching on anyhow messages. +/// `ModelNotLoaded`, `PromptTooLong`, and `InsufficientVram` variants +/// let the HTTP handler map cleanly to 404 / 400 / 503 without +/// string-matching on anyhow messages. #[derive(Debug, thiserror::Error)] pub enum InferenceError { #[error("model '{0}' not loaded on this neuron")] ModelNotLoaded(String), + #[error("prompt has {prompt_len} tokens but max is {max}")] + PromptTooLong { prompt_len: usize, max: usize }, + #[error( + "insufficient free VRAM for prefill: {free_mb} MiB free, need at least {required_mb} MiB" + )] + InsufficientVram { free_mb: u64, required_mb: u64 }, #[error(transparent)] Other(#[from] anyhow::Error), } @@ -2540,11 +2762,12 @@ async fn run_inference_via_worker( .await .map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?; - // Prefill — every rank embeds the prompt with offset 0. - let logits_vec = worker - .forward_logits(handle, prompt_tokens.to_vec(), 0) - .await - .map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?; + // Prefill the prompt in `prefill_chunk_tokens()`-sized chunks so + // activation memory is bounded per step rather than scaling with + // prompt length. The KV cache accumulates across chunks; we keep + // only the final chunk's logits for sampling the first generated + // token. + let logits_vec = chunked_prefill_via_worker(worker, handle, prompt_tokens).await?; let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?; let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) { Ok(t) => t, @@ -2634,10 +2857,11 @@ async fn stream_inference_via_worker( .await .map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?; - let logits_vec = worker - .forward_logits(handle, prompt_tokens, 0) - .await - .map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?; + // Chunked prefill (see `chunked_prefill_via_worker`). The owning + // `prompt_tokens: Vec` is borrowed for the loop's duration; + // we still need `prompt_len` (already extracted above) for the + // decode-step offset arithmetic. + let logits_vec = chunked_prefill_via_worker(&*worker, handle, &prompt_tokens).await?; let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?; let mut next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) { Ok(t) => t, @@ -2755,8 +2979,7 @@ fn run_inference( let mut generated: Vec = Vec::new(); arch.clear_kv_cache()?; - let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?; - let logits = arch.forward(&input, 0)?; + let logits = chunked_prefill_local(arch, device, prompt_tokens)?; let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?; if Some(next_token) == eos_id { @@ -2816,8 +3039,7 @@ fn run_inference_streaming( let mut finish_reason = "length".to_string(); arch.clear_kv_cache()?; - let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?; - let logits = arch.forward(&input, 0)?; + let logits = chunked_prefill_local(arch, device, prompt_tokens)?; let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?; let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result {