feat(neuron): chunked prefill + VRAM/prompt-length pre-flight checks
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 34s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m15s
CI / Test (push) Successful in 5m9s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 5m1s
build-prerelease / Package cortex RPM (push) Successful in 1m20s
build-prerelease / Build neuron-blackwell (push) Successful in 11m7s
build-prerelease / Build neuron-ampere (push) Successful in 12m16s
build-prerelease / Build neuron-ada (push) Successful in 12m30s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m47s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m3s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 34s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m15s
CI / Test (push) Successful in 5m9s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 5m1s
build-prerelease / Package cortex RPM (push) Successful in 1m20s
build-prerelease / Build neuron-blackwell (push) Successful in 11m7s
build-prerelease / Build neuron-ampere (push) Successful in 12m16s
build-prerelease / Build neuron-ada (push) Successful in 12m30s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m47s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m3s
Prevents the OOM-during-prefill → poisoned-context → 5-minute-reload
cycle observed on beast under agent-zero workloads. Three changes,
all keyed off env-driven knobs so an operator can tune without a
rebuild:
1. Chunked prefill (NEURON_PREFILL_CHUNK_TOKENS, default 512). The
initial forward is split into N-token windows, each with a
monotonically growing offset. KV cache accumulates across chunks
exactly as it would under one big prefill; only the final chunk's
logits are kept for sampling. Activation memory now scales with
chunk size instead of prompt length, so a 13 k-token prompt stops
holding tens of GB of intermediate activations live at once.
Wired into all six prefill call sites:
- run_inference / run_inference_streaming (CPU path)
- run_inference_via_worker / stream_inference_via_worker (CUDA
single-GPU through device worker)
- chat_completion_tp_inner / chat_completion_tp_stream (TP via
WorkerPool)
Three helpers — chunked_prefill_local, chunked_prefill_via_worker,
chunked_prefill_tp — own the loop shape so the chunking semantics
stay identical across paths. Per-chunk debug log shows progress.
2. Max prompt length (NEURON_MAX_PROMPT_TOKENS, default 16384).
Requests above the cap return a structured 400 with
`code: prompt_too_long` rather than going through the prefill and
discovering the limit by OOMing partway through. New
InferenceError::PromptTooLong variant.
3. Minimum free VRAM gate (NEURON_MIN_FREE_VRAM_MB, default 1500).
If `vram_free_mb` is below the threshold at request start (e.g.
another concurrent request is mid-prefill), reject with a clean
503 + `code: insufficient_vram` rather than starting work that
will OOM. New InferenceError::InsufficientVram variant. CPU loads
(vram=0 sentinel) skip this check.
All three gates fire BEFORE any device work, so a rejected request
costs ~one tokenisation pass and never touches the worker thread —
poison cascades from rejected work are now impossible.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -621,6 +621,78 @@ fn new_req_id() -> String {
|
||||
format!("{:06x}", unix_subsec_nanos() & 0xFFFFFF)
|
||||
}
|
||||
|
||||
/// Read a positive `usize` from `name` in the process env, falling back
|
||||
/// to `default` if unset or unparseable. Used for runtime tuning knobs
|
||||
/// that we want operators to be able to adjust without a recompile.
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.filter(|v: &usize| *v > 0)
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
/// Same as [`env_usize`] but for `u64`.
|
||||
fn env_u64(name: &str, default: u64) -> u64 {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
/// Prefill chunk size in tokens. The initial forward over a long prompt
|
||||
/// is split into windows of this many tokens, each with a monotonically
|
||||
/// growing offset, so activation memory is bounded by chunk × layers ×
|
||||
/// hidden instead of prompt × layers × hidden. The default (512) keeps
|
||||
/// activation peaks under ~1 GiB on a 27B Qwen-class model while
|
||||
/// keeping the per-step overhead negligible vs. one big prefill.
|
||||
fn prefill_chunk_tokens() -> usize {
|
||||
env_usize("NEURON_PREFILL_CHUNK_TOKENS", 512)
|
||||
}
|
||||
|
||||
/// Maximum allowed prompt length, in tokens. Requests above this are
|
||||
/// rejected with [`InferenceError::PromptTooLong`] before any device
|
||||
/// work — this is the explicit upper bound on context size, separate
|
||||
/// from the model's `max_position_embeddings` (which can be much
|
||||
/// larger than what fits in VRAM in practice).
|
||||
fn max_prompt_tokens() -> usize {
|
||||
env_usize("NEURON_MAX_PROMPT_TOKENS", 16384)
|
||||
}
|
||||
|
||||
/// Minimum free VRAM (MiB) required to even attempt a prefill. Requests
|
||||
/// below this are rejected with [`InferenceError::InsufficientVram`]
|
||||
/// before any device work. Acts as a backstop when concurrent requests
|
||||
/// have eaten the headroom; intentionally conservative — a request
|
||||
/// that gets past this can still OOM, but the rejection is a clean 503
|
||||
/// rather than a poisoned context.
|
||||
fn min_free_vram_mb() -> u64 {
|
||||
env_u64("NEURON_MIN_FREE_VRAM_MB", 1500)
|
||||
}
|
||||
|
||||
/// Pre-flight check: reject the request if the prompt exceeds the
|
||||
/// configured max, or if there isn't enough free VRAM to safely start a
|
||||
/// prefill. Called from every chat_completion entry point right after
|
||||
/// the VRAM query. A `prompt_len == 0` is accepted (some clients send
|
||||
/// empty inputs to probe the endpoint); the prefill loop handles it.
|
||||
fn validate_request(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> {
|
||||
let max = max_prompt_tokens();
|
||||
if prompt_len > max {
|
||||
return Err(InferenceError::PromptTooLong { prompt_len, max });
|
||||
}
|
||||
// VRAM check is skipped on CPU loads (vram_free_mb == 0 sentinel)
|
||||
// because the (0, 0) reply from `query_vram` is also what a missing
|
||||
// worker returns. The CPU path has no per-GPU memory limit anyway —
|
||||
// host RAM is bounded by the OOM killer, not this check.
|
||||
let min = min_free_vram_mb();
|
||||
if vram_free_mb != 0 && vram_free_mb < min {
|
||||
return Err(InferenceError::InsufficientVram {
|
||||
free_mb: vram_free_mb,
|
||||
required_mb: min,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Threshold above which `pool.lock().await` blocking is interesting
|
||||
/// enough to warn about. Healthy concurrent requests serialise behind
|
||||
/// the pool in single-digit ms — anything past 2 seconds is either a
|
||||
@@ -687,6 +759,129 @@ fn sample_with_penalty(
|
||||
Ok(logits_processor.sample(&penalised)?)
|
||||
}
|
||||
|
||||
/// Chunked prefill against an in-process [`ModelArch`]. Splits
|
||||
/// `prompt_tokens` into [`prefill_chunk_tokens()`]-sized windows, runs
|
||||
/// each through `arch.forward(chunk, offset)` with a monotonically
|
||||
/// growing offset, and returns the last chunk's logits ready for
|
||||
/// sampling. Bounds activation memory to O(chunk × layers × hidden)
|
||||
/// instead of O(prompt × layers × hidden); the KV cache grows
|
||||
/// monotonically so the model sees the full prompt at the final chunk.
|
||||
fn chunked_prefill_local(
|
||||
arch: &mut ModelArch,
|
||||
device: &Device,
|
||||
prompt_tokens: &[u32],
|
||||
) -> Result<Tensor> {
|
||||
let prompt_len = prompt_tokens.len();
|
||||
if prompt_len == 0 {
|
||||
anyhow::bail!("chunked_prefill_local: empty prompt");
|
||||
}
|
||||
let chunk_size = prefill_chunk_tokens();
|
||||
let mut offset = 0;
|
||||
let mut last_logits: Option<Tensor> = None;
|
||||
while offset < prompt_len {
|
||||
let end = (offset + chunk_size).min(prompt_len);
|
||||
let chunk = &prompt_tokens[offset..end];
|
||||
let input = Tensor::new(chunk, device)?.unsqueeze(0)?;
|
||||
let logits = arch.forward(&input, offset)?;
|
||||
if end == prompt_len {
|
||||
last_logits = Some(logits);
|
||||
}
|
||||
offset = end;
|
||||
}
|
||||
last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_local: no chunks produced"))
|
||||
}
|
||||
|
||||
/// Chunked prefill via the per-device worker. Same shape as
|
||||
/// [`chunked_prefill_local`] but the forward runs on the worker thread
|
||||
/// and replies with a CPU-side `Vec<f32>` of logits at the final
|
||||
/// chunk's last position. Tensors never escape the worker.
|
||||
#[cfg(feature = "cuda")]
|
||||
async fn chunked_prefill_via_worker(
|
||||
worker: &super::device_worker::DeviceWorkerHandle,
|
||||
handle: super::device_worker::ArchHandle,
|
||||
prompt_tokens: &[u32],
|
||||
) -> Result<Vec<f32>> {
|
||||
let prompt_len = prompt_tokens.len();
|
||||
if prompt_len == 0 {
|
||||
anyhow::bail!("chunked_prefill_via_worker: empty prompt");
|
||||
}
|
||||
let chunk_size = prefill_chunk_tokens();
|
||||
let mut offset = 0;
|
||||
let mut last_logits: Option<Vec<f32>> = None;
|
||||
let total_chunks = prompt_len.div_ceil(chunk_size);
|
||||
let mut chunk_idx = 0_usize;
|
||||
while offset < prompt_len {
|
||||
let end = (offset + chunk_size).min(prompt_len);
|
||||
let chunk = prompt_tokens[offset..end].to_vec();
|
||||
let chunk_len = chunk.len();
|
||||
let step_start = std::time::Instant::now();
|
||||
let logits = worker
|
||||
.forward_logits(handle, chunk, offset)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("prefill chunk {chunk_idx}/{total_chunks}: {e}"))?;
|
||||
tracing::debug!(
|
||||
chunk_idx,
|
||||
total_chunks,
|
||||
chunk_len,
|
||||
offset,
|
||||
elapsed_ms = step_start.elapsed().as_millis(),
|
||||
"chunked prefill (worker): chunk done"
|
||||
);
|
||||
if end == prompt_len {
|
||||
last_logits = Some(logits);
|
||||
}
|
||||
offset = end;
|
||||
chunk_idx += 1;
|
||||
}
|
||||
last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_via_worker: no chunks produced"))
|
||||
}
|
||||
|
||||
/// Chunked prefill via the TP `WorkerPool`. Same shape as
|
||||
/// [`chunked_prefill_via_worker`] but the forward fans out to every
|
||||
/// rank via `pool.generate_step`. Returns the leader's CPU-side
|
||||
/// `Vec<f32>` of logits at the final chunk's last position.
|
||||
#[cfg(feature = "cuda")]
|
||||
async fn chunked_prefill_tp(
|
||||
pool: &mut super::tp::WorkerPool,
|
||||
model_id: &str,
|
||||
leader_handle: super::device_worker::TpHandle,
|
||||
prompt_tokens: &[u32],
|
||||
) -> Result<Vec<f32>> {
|
||||
let prompt_len = prompt_tokens.len();
|
||||
if prompt_len == 0 {
|
||||
anyhow::bail!("chunked_prefill_tp: empty prompt");
|
||||
}
|
||||
let chunk_size = prefill_chunk_tokens();
|
||||
let mut offset = 0;
|
||||
let mut last_logits: Option<Vec<f32>> = None;
|
||||
let total_chunks = prompt_len.div_ceil(chunk_size);
|
||||
let mut chunk_idx = 0_usize;
|
||||
while offset < prompt_len {
|
||||
let end = (offset + chunk_size).min(prompt_len);
|
||||
let chunk = prompt_tokens[offset..end].to_vec();
|
||||
let chunk_len = chunk.len();
|
||||
let step_start = std::time::Instant::now();
|
||||
let logits = pool
|
||||
.generate_step(model_id, leader_handle, chunk, offset)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("TP prefill chunk {chunk_idx}/{total_chunks}: {e}"))?;
|
||||
tracing::debug!(
|
||||
chunk_idx,
|
||||
total_chunks,
|
||||
chunk_len,
|
||||
offset,
|
||||
elapsed_ms = step_start.elapsed().as_millis(),
|
||||
"chunked prefill (TP): chunk done"
|
||||
);
|
||||
if end == prompt_len {
|
||||
last_logits = Some(logits);
|
||||
}
|
||||
offset = end;
|
||||
chunk_idx += 1;
|
||||
}
|
||||
last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_tp: no chunks produced"))
|
||||
}
|
||||
|
||||
impl CandleHarness {
|
||||
pub fn new(bind_url: String, hf_cache: Option<PathBuf>) -> Self {
|
||||
let hf_cache = resolve_hf_cache(hf_cache);
|
||||
@@ -1151,6 +1346,8 @@ impl CandleHarness {
|
||||
"chat_completion: starting"
|
||||
);
|
||||
|
||||
validate_request(prompt_len, vram_free_mb)?;
|
||||
|
||||
// Routing: CUDA loads go through the per-device worker
|
||||
// thread (introduced in Phase 1; forward/clear added in
|
||||
// Phase 2). CPU loads keep the existing spawn_blocking
|
||||
@@ -1409,6 +1606,9 @@ impl CandleHarness {
|
||||
"chat_completion (stream): starting"
|
||||
);
|
||||
}
|
||||
|
||||
validate_request(prompt_len, vram_free_mb)?;
|
||||
|
||||
// Routing parallel to the non-streaming chat_completion: CUDA
|
||||
// goes through the worker (async task), CPU keeps the
|
||||
// spawn_blocking + Arc<Mutex<ModelArch>> path.
|
||||
@@ -1972,6 +2172,9 @@ impl CandleHarness {
|
||||
vram_total_mb,
|
||||
"TP chat_completion (stream): starting"
|
||||
);
|
||||
|
||||
validate_request(prompt_len, vram_free_mb)?;
|
||||
|
||||
let tp_for_task = Arc::clone(&tp);
|
||||
tokio::spawn(
|
||||
async move {
|
||||
@@ -2001,10 +2204,17 @@ impl CandleHarness {
|
||||
LogitsProcessor::from_sampling(seed, sampling)
|
||||
};
|
||||
|
||||
// Prefill — every rank embeds the prompt, offset = 0.
|
||||
let logits_vec = match pool
|
||||
.generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0)
|
||||
.await
|
||||
// Chunked prefill — see `chunked_prefill_tp`. Each
|
||||
// chunk fans out to every rank with a growing
|
||||
// offset; only the final chunk's logits are kept
|
||||
// for the first sample.
|
||||
let logits_vec = match chunked_prefill_tp(
|
||||
&mut pool,
|
||||
&model_id,
|
||||
leader_handle,
|
||||
&prompt_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
@@ -2240,6 +2450,8 @@ async fn chat_completion_tp_inner(
|
||||
"TP chat_completion: starting"
|
||||
);
|
||||
|
||||
validate_request(prompt_len, vram_free_mb)?;
|
||||
|
||||
// Acquire the pool lock for the duration of the request. After
|
||||
// Phase 3 the leader's TpLeaderModel lives in the device worker
|
||||
// thread, so the pool lock now serialises only subprocess RPC
|
||||
@@ -2276,10 +2488,13 @@ async fn chat_completion_tp_inner(
|
||||
let mut generated: Vec<u32> = Vec::new();
|
||||
let mut finish_reason = "length".to_string();
|
||||
|
||||
// Prefill: every rank embeds the whole prompt, offset = 0.
|
||||
// Prefill: chunk the prompt through `chunked_prefill_tp` so
|
||||
// activation memory is bounded by chunk size rather than the full
|
||||
// prompt length. Every rank still sees the prompt in order, just
|
||||
// spread across multiple `generate_step` calls with monotonically
|
||||
// growing offsets.
|
||||
let prefill_start = std::time::Instant::now();
|
||||
let logits_vec = pool
|
||||
.generate_step(&model_id, leader_handle, prompt_tokens.clone(), 0)
|
||||
let logits_vec = chunked_prefill_tp(&mut pool, &model_id, leader_handle, &prompt_tokens)
|
||||
.await
|
||||
.map_err(InferenceError::Other)?;
|
||||
let (post_prefill_vram_free_mb, _) = tp.query_vram().await;
|
||||
@@ -2454,12 +2669,19 @@ async fn emit_chunk(
|
||||
}
|
||||
|
||||
/// Errors returned by `CandleHarness::chat_completion`. The
|
||||
/// `ModelNotLoaded` variant lets the HTTP handler map cleanly to 404
|
||||
/// without string-matching on anyhow messages.
|
||||
/// `ModelNotLoaded`, `PromptTooLong`, and `InsufficientVram` variants
|
||||
/// let the HTTP handler map cleanly to 404 / 400 / 503 without
|
||||
/// string-matching on anyhow messages.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum InferenceError {
|
||||
#[error("model '{0}' not loaded on this neuron")]
|
||||
ModelNotLoaded(String),
|
||||
#[error("prompt has {prompt_len} tokens but max is {max}")]
|
||||
PromptTooLong { prompt_len: usize, max: usize },
|
||||
#[error(
|
||||
"insufficient free VRAM for prefill: {free_mb} MiB free, need at least {required_mb} MiB"
|
||||
)]
|
||||
InsufficientVram { free_mb: u64, required_mb: u64 },
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
@@ -2540,11 +2762,12 @@ async fn run_inference_via_worker(
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
|
||||
|
||||
// Prefill — every rank embeds the prompt with offset 0.
|
||||
let logits_vec = worker
|
||||
.forward_logits(handle, prompt_tokens.to_vec(), 0)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
|
||||
// Prefill the prompt in `prefill_chunk_tokens()`-sized chunks so
|
||||
// activation memory is bounded per step rather than scaling with
|
||||
// prompt length. The KV cache accumulates across chunks; we keep
|
||||
// only the final chunk's logits for sampling the first generated
|
||||
// token.
|
||||
let logits_vec = chunked_prefill_via_worker(worker, handle, prompt_tokens).await?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
@@ -2634,10 +2857,11 @@ async fn stream_inference_via_worker(
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
|
||||
|
||||
let logits_vec = worker
|
||||
.forward_logits(handle, prompt_tokens, 0)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
|
||||
// Chunked prefill (see `chunked_prefill_via_worker`). The owning
|
||||
// `prompt_tokens: Vec<u32>` is borrowed for the loop's duration;
|
||||
// we still need `prompt_len` (already extracted above) for the
|
||||
// decode-step offset arithmetic.
|
||||
let logits_vec = chunked_prefill_via_worker(&*worker, handle, &prompt_tokens).await?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
let mut next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
@@ -2755,8 +2979,7 @@ fn run_inference(
|
||||
let mut generated: Vec<u32> = Vec::new();
|
||||
|
||||
arch.clear_kv_cache()?;
|
||||
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
||||
let logits = arch.forward(&input, 0)?;
|
||||
let logits = chunked_prefill_local(arch, device, prompt_tokens)?;
|
||||
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
|
||||
|
||||
if Some(next_token) == eos_id {
|
||||
@@ -2816,8 +3039,7 @@ fn run_inference_streaming(
|
||||
let mut finish_reason = "length".to_string();
|
||||
|
||||
arch.clear_kv_cache()?;
|
||||
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
||||
let logits = arch.forward(&input, 0)?;
|
||||
let logits = chunked_prefill_local(arch, device, prompt_tokens)?;
|
||||
let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
|
||||
|
||||
let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result<bool> {
|
||||
|
||||
Reference in New Issue
Block a user