feat(neuron): chunked prefill + VRAM/prompt-length pre-flight checks
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 34s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m15s
CI / Test (push) Successful in 5m9s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 5m1s
build-prerelease / Package cortex RPM (push) Successful in 1m20s
build-prerelease / Build neuron-blackwell (push) Successful in 11m7s
build-prerelease / Build neuron-ampere (push) Successful in 12m16s
build-prerelease / Build neuron-ada (push) Successful in 12m30s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m47s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m3s

Prevents the OOM-during-prefill → poisoned-context → 5-minute-reload
cycle observed on beast under agent-zero workloads. Three changes,
all keyed off env-driven knobs so an operator can tune without a
rebuild:

1. Chunked prefill (NEURON_PREFILL_CHUNK_TOKENS, default 512). The
   initial forward is split into N-token windows, each with a
   monotonically growing offset. KV cache accumulates across chunks
   exactly as it would under one big prefill; only the final chunk's
   logits are kept for sampling. Activation memory now scales with
   chunk size instead of prompt length, so a 13 k-token prompt stops
   holding tens of GB of intermediate activations live at once.

   Wired into all six prefill call sites:
   - run_inference / run_inference_streaming (CPU path)
   - run_inference_via_worker / stream_inference_via_worker (CUDA
     single-GPU through device worker)
   - chat_completion_tp_inner / chat_completion_tp_stream (TP via
     WorkerPool)

   Three helpers — chunked_prefill_local, chunked_prefill_via_worker,
   chunked_prefill_tp — own the loop shape so the chunking semantics
   stay identical across paths. Per-chunk debug log shows progress.

2. Max prompt length (NEURON_MAX_PROMPT_TOKENS, default 16384).
   Requests above the cap return a structured 400 with
   `code: prompt_too_long` rather than going through the prefill and
   discovering the limit by OOMing partway through. New
   InferenceError::PromptTooLong variant.

3. Minimum free VRAM gate (NEURON_MIN_FREE_VRAM_MB, default 1500).
   If `vram_free_mb` is below the threshold at request start (e.g.
   another concurrent request is mid-prefill), reject with a clean
   503 + `code: insufficient_vram` rather than starting work that
   will OOM. New InferenceError::InsufficientVram variant. CPU loads
   (vram=0 sentinel) skip this check.

All three gates fire BEFORE any device work, so a rejected request
costs ~one tokenisation pass and never touches the worker thread —
poison cascades from rejected work are now impossible.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-27 13:46:54 +03:00
parent 6e1c1dd0fc
commit 1e13889392
2 changed files with 294 additions and 22 deletions

View File

@@ -174,6 +174,31 @@ async fn chat_completions(
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
)
.into_response(),
Err(InferenceError::PromptTooLong { prompt_len, max }) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
"code": "prompt_too_long",
"prompt_len": prompt_len,
"max": max,
})),
)
.into_response(),
Err(InferenceError::InsufficientVram {
free_mb,
required_mb,
}) => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": format!(
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
),
"code": "insufficient_vram",
"free_mb": free_mb,
"required_mb": required_mb,
})),
)
.into_response(),
Err(InferenceError::Other(e)) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})),
@@ -188,6 +213,31 @@ async fn chat_completions(
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
)
.into_response(),
Err(InferenceError::PromptTooLong { prompt_len, max }) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
"code": "prompt_too_long",
"prompt_len": prompt_len,
"max": max,
})),
)
.into_response(),
Err(InferenceError::InsufficientVram {
free_mb,
required_mb,
}) => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": format!(
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
),
"code": "insufficient_vram",
"free_mb": free_mb,
"required_mb": required_mb,
})),
)
.into_response(),
Err(InferenceError::Other(e)) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})),

View File

@@ -621,6 +621,78 @@ fn new_req_id() -> String {
format!("{:06x}", unix_subsec_nanos() & 0xFFFFFF)
}
/// Read a positive `usize` from `name` in the process env, falling back
/// to `default` if unset or unparseable. Used for runtime tuning knobs
/// that we want operators to be able to adjust without a recompile.
fn env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.filter(|v: &usize| *v > 0)
.unwrap_or(default)
}
/// Same as [`env_usize`] but for `u64`.
fn env_u64(name: &str, default: u64) -> u64 {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
/// Prefill chunk size in tokens. The initial forward over a long prompt
/// is split into windows of this many tokens, each with a monotonically
/// growing offset, so activation memory is bounded by chunk × layers ×
/// hidden instead of prompt × layers × hidden. The default (512) keeps
/// activation peaks under ~1 GiB on a 27B Qwen-class model while
/// keeping the per-step overhead negligible vs. one big prefill.
fn prefill_chunk_tokens() -> usize {
env_usize("NEURON_PREFILL_CHUNK_TOKENS", 512)
}
/// Maximum allowed prompt length, in tokens. Requests above this are
/// rejected with [`InferenceError::PromptTooLong`] before any device
/// work — this is the explicit upper bound on context size, separate
/// from the model's `max_position_embeddings` (which can be much
/// larger than what fits in VRAM in practice).
fn max_prompt_tokens() -> usize {
env_usize("NEURON_MAX_PROMPT_TOKENS", 16384)
}
/// Minimum free VRAM (MiB) required to even attempt a prefill. Requests
/// below this are rejected with [`InferenceError::InsufficientVram`]
/// before any device work. Acts as a backstop when concurrent requests
/// have eaten the headroom; intentionally conservative — a request
/// that gets past this can still OOM, but the rejection is a clean 503
/// rather than a poisoned context.
fn min_free_vram_mb() -> u64 {
env_u64("NEURON_MIN_FREE_VRAM_MB", 1500)
}
/// Pre-flight check: reject the request if the prompt exceeds the
/// configured max, or if there isn't enough free VRAM to safely start a
/// prefill. Called from every chat_completion entry point right after
/// the VRAM query. A `prompt_len == 0` is accepted (some clients send
/// empty inputs to probe the endpoint); the prefill loop handles it.
fn validate_request(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> {
let max = max_prompt_tokens();
if prompt_len > max {
return Err(InferenceError::PromptTooLong { prompt_len, max });
}
// VRAM check is skipped on CPU loads (vram_free_mb == 0 sentinel)
// because the (0, 0) reply from `query_vram` is also what a missing
// worker returns. The CPU path has no per-GPU memory limit anyway —
// host RAM is bounded by the OOM killer, not this check.
let min = min_free_vram_mb();
if vram_free_mb != 0 && vram_free_mb < min {
return Err(InferenceError::InsufficientVram {
free_mb: vram_free_mb,
required_mb: min,
});
}
Ok(())
}
/// Threshold above which `pool.lock().await` blocking is interesting
/// enough to warn about. Healthy concurrent requests serialise behind
/// the pool in single-digit ms — anything past 2 seconds is either a
@@ -687,6 +759,129 @@ fn sample_with_penalty(
Ok(logits_processor.sample(&penalised)?)
}
/// Chunked prefill against an in-process [`ModelArch`]. Splits
/// `prompt_tokens` into [`prefill_chunk_tokens()`]-sized windows, runs
/// each through `arch.forward(chunk, offset)` with a monotonically
/// growing offset, and returns the last chunk's logits ready for
/// sampling. Bounds activation memory to O(chunk × layers × hidden)
/// instead of O(prompt × layers × hidden); the KV cache grows
/// monotonically so the model sees the full prompt at the final chunk.
fn chunked_prefill_local(
arch: &mut ModelArch,
device: &Device,
prompt_tokens: &[u32],
) -> Result<Tensor> {
let prompt_len = prompt_tokens.len();
if prompt_len == 0 {
anyhow::bail!("chunked_prefill_local: empty prompt");
}
let chunk_size = prefill_chunk_tokens();
let mut offset = 0;
let mut last_logits: Option<Tensor> = None;
while offset < prompt_len {
let end = (offset + chunk_size).min(prompt_len);
let chunk = &prompt_tokens[offset..end];
let input = Tensor::new(chunk, device)?.unsqueeze(0)?;
let logits = arch.forward(&input, offset)?;
if end == prompt_len {
last_logits = Some(logits);
}
offset = end;
}
last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_local: no chunks produced"))
}
/// Chunked prefill via the per-device worker. Same shape as
/// [`chunked_prefill_local`] but the forward runs on the worker thread
/// and replies with a CPU-side `Vec<f32>` of logits at the final
/// chunk's last position. Tensors never escape the worker.
#[cfg(feature = "cuda")]
async fn chunked_prefill_via_worker(
worker: &super::device_worker::DeviceWorkerHandle,
handle: super::device_worker::ArchHandle,
prompt_tokens: &[u32],
) -> Result<Vec<f32>> {
let prompt_len = prompt_tokens.len();
if prompt_len == 0 {
anyhow::bail!("chunked_prefill_via_worker: empty prompt");
}
let chunk_size = prefill_chunk_tokens();
let mut offset = 0;
let mut last_logits: Option<Vec<f32>> = None;
let total_chunks = prompt_len.div_ceil(chunk_size);
let mut chunk_idx = 0_usize;
while offset < prompt_len {
let end = (offset + chunk_size).min(prompt_len);
let chunk = prompt_tokens[offset..end].to_vec();
let chunk_len = chunk.len();
let step_start = std::time::Instant::now();
let logits = worker
.forward_logits(handle, chunk, offset)
.await
.map_err(|e| anyhow::anyhow!("prefill chunk {chunk_idx}/{total_chunks}: {e}"))?;
tracing::debug!(
chunk_idx,
total_chunks,
chunk_len,
offset,
elapsed_ms = step_start.elapsed().as_millis(),
"chunked prefill (worker): chunk done"
);
if end == prompt_len {
last_logits = Some(logits);
}
offset = end;
chunk_idx += 1;
}
last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_via_worker: no chunks produced"))
}
/// Chunked prefill via the TP `WorkerPool`. Same shape as
/// [`chunked_prefill_via_worker`] but the forward fans out to every
/// rank via `pool.generate_step`. Returns the leader's CPU-side
/// `Vec<f32>` of logits at the final chunk's last position.
#[cfg(feature = "cuda")]
async fn chunked_prefill_tp(
pool: &mut super::tp::WorkerPool,
model_id: &str,
leader_handle: super::device_worker::TpHandle,
prompt_tokens: &[u32],
) -> Result<Vec<f32>> {
let prompt_len = prompt_tokens.len();
if prompt_len == 0 {
anyhow::bail!("chunked_prefill_tp: empty prompt");
}
let chunk_size = prefill_chunk_tokens();
let mut offset = 0;
let mut last_logits: Option<Vec<f32>> = None;
let total_chunks = prompt_len.div_ceil(chunk_size);
let mut chunk_idx = 0_usize;
while offset < prompt_len {
let end = (offset + chunk_size).min(prompt_len);
let chunk = prompt_tokens[offset..end].to_vec();
let chunk_len = chunk.len();
let step_start = std::time::Instant::now();
let logits = pool
.generate_step(model_id, leader_handle, chunk, offset)
.await
.map_err(|e| anyhow::anyhow!("TP prefill chunk {chunk_idx}/{total_chunks}: {e}"))?;
tracing::debug!(
chunk_idx,
total_chunks,
chunk_len,
offset,
elapsed_ms = step_start.elapsed().as_millis(),
"chunked prefill (TP): chunk done"
);
if end == prompt_len {
last_logits = Some(logits);
}
offset = end;
chunk_idx += 1;
}
last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_tp: no chunks produced"))
}
impl CandleHarness {
pub fn new(bind_url: String, hf_cache: Option<PathBuf>) -> Self {
let hf_cache = resolve_hf_cache(hf_cache);
@@ -1151,6 +1346,8 @@ impl CandleHarness {
"chat_completion: starting"
);
validate_request(prompt_len, vram_free_mb)?;
// Routing: CUDA loads go through the per-device worker
// thread (introduced in Phase 1; forward/clear added in
// Phase 2). CPU loads keep the existing spawn_blocking
@@ -1409,6 +1606,9 @@ impl CandleHarness {
"chat_completion (stream): starting"
);
}
validate_request(prompt_len, vram_free_mb)?;
// Routing parallel to the non-streaming chat_completion: CUDA
// goes through the worker (async task), CPU keeps the
// spawn_blocking + Arc<Mutex<ModelArch>> path.
@@ -1972,6 +2172,9 @@ impl CandleHarness {
vram_total_mb,
"TP chat_completion (stream): starting"
);
validate_request(prompt_len, vram_free_mb)?;
let tp_for_task = Arc::clone(&tp);
tokio::spawn(
async move {
@@ -2001,9 +2204,16 @@ impl CandleHarness {
LogitsProcessor::from_sampling(seed, sampling)
};
// Prefill — every rank embeds the prompt, offset = 0.
let logits_vec = match pool
.generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0)
// Chunked prefill — see `chunked_prefill_tp`. Each
// chunk fans out to every rank with a growing
// offset; only the final chunk's logits are kept
// for the first sample.
let logits_vec = match chunked_prefill_tp(
&mut pool,
&model_id,
leader_handle,
&prompt_tokens,
)
.await
{
Ok(l) => l,
@@ -2240,6 +2450,8 @@ async fn chat_completion_tp_inner(
"TP chat_completion: starting"
);
validate_request(prompt_len, vram_free_mb)?;
// Acquire the pool lock for the duration of the request. After
// Phase 3 the leader's TpLeaderModel lives in the device worker
// thread, so the pool lock now serialises only subprocess RPC
@@ -2276,10 +2488,13 @@ async fn chat_completion_tp_inner(
let mut generated: Vec<u32> = Vec::new();
let mut finish_reason = "length".to_string();
// Prefill: every rank embeds the whole prompt, offset = 0.
// Prefill: chunk the prompt through `chunked_prefill_tp` so
// activation memory is bounded by chunk size rather than the full
// prompt length. Every rank still sees the prompt in order, just
// spread across multiple `generate_step` calls with monotonically
// growing offsets.
let prefill_start = std::time::Instant::now();
let logits_vec = pool
.generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0)
let logits_vec = chunked_prefill_tp(&mut pool, &model_id, leader_handle, &prompt_tokens)
.await
.map_err(InferenceError::Other)?;
let (post_prefill_vram_free_mb, _) = tp.query_vram().await;
@@ -2454,12 +2669,19 @@ async fn emit_chunk(
}
/// Errors returned by `CandleHarness::chat_completion`. The
/// `ModelNotLoaded` variant lets the HTTP handler map cleanly to 404
/// without string-matching on anyhow messages.
/// `ModelNotLoaded`, `PromptTooLong`, and `InsufficientVram` variants
/// let the HTTP handler map cleanly to 404 / 400 / 503 without
/// string-matching on anyhow messages.
#[derive(Debug, thiserror::Error)]
pub enum InferenceError {
#[error("model '{0}' not loaded on this neuron")]
ModelNotLoaded(String),
#[error("prompt has {prompt_len} tokens but max is {max}")]
PromptTooLong { prompt_len: usize, max: usize },
#[error(
"insufficient free VRAM for prefill: {free_mb} MiB free, need at least {required_mb} MiB"
)]
InsufficientVram { free_mb: u64, required_mb: u64 },
#[error(transparent)]
Other(#[from] anyhow::Error),
}
@@ -2540,11 +2762,12 @@ async fn run_inference_via_worker(
.await
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
// Prefill — every rank embeds the prompt with offset 0.
let logits_vec = worker
.forward_logits(handle, prompt_tokens.to_vec(), 0)
.await
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
// Prefill the prompt in `prefill_chunk_tokens()`-sized chunks so
// activation memory is bounded per step rather than scaling with
// prompt length. The KV cache accumulates across chunks; we keep
// only the final chunk's logits for sampling the first generated
// token.
let logits_vec = chunked_prefill_via_worker(worker, handle, prompt_tokens).await?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
@@ -2634,10 +2857,11 @@ async fn stream_inference_via_worker(
.await
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
let logits_vec = worker
.forward_logits(handle, prompt_tokens, 0)
.await
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
// Chunked prefill (see `chunked_prefill_via_worker`). The owning
// `prompt_tokens: Vec<u32>` is borrowed for the loop's duration;
// we still need `prompt_len` (already extracted above) for the
// decode-step offset arithmetic.
let logits_vec = chunked_prefill_via_worker(&*worker, handle, &prompt_tokens).await?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
let mut next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t,
@@ -2755,8 +2979,7 @@ fn run_inference(
let mut generated: Vec<u32> = Vec::new();
arch.clear_kv_cache()?;
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
let logits = arch.forward(&input, 0)?;
let logits = chunked_prefill_local(arch, device, prompt_tokens)?;
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
if Some(next_token) == eos_id {
@@ -2816,8 +3039,7 @@ fn run_inference_streaming(
let mut finish_reason = "length".to_string();
arch.clear_kv_cache()?;
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
let logits = arch.forward(&input, 0)?;
let logits = chunked_prefill_local(arch, device, prompt_tokens)?;
let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result<bool> {