From fc6ef0ee0f2f63bd02b3d2e05a1fc38f2007bd24 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Tue, 26 May 2026 12:28:42 +0300 Subject: [PATCH] feat(neuron,candle): detect CUDA context poisoning and refuse follow-ups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Once a CUDA driver error has hit a forward or kv-cache call, the device's context is unrecoverable in-process — subsequent kernels can hang (the failure mode seen on beast on 2026-05-26), return garbage, or trip another illegal-address. The harness now marks the model poisoned on any forward / spawn_blocking / TP-task failure, refuses further inference against it with a clear "unload and reload" error, and surfaces `status: "poisoned"` on `/models` so an operator running `curl beast:13131/models` (or cortex polling) can see the bad state. Without this, a single OOM on a too-large prefill quietly turned every subsequent request into a stuck wait on the pool lock; with it, the first request fails fast with the driver error in the journal and the client gets a usable 5xx instead of a hung connection. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/neuron/src/harness/candle.rs | 133 ++++++++++++++++++++++++---- 1 file changed, 118 insertions(+), 15 deletions(-) diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 357a7ca..76cdb18 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -29,6 +29,7 @@ use serde_json::json; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; #[cfg(feature = "cuda")] use std::time::Duration; use std::time::{SystemTime, UNIX_EPOCH}; @@ -74,6 +75,18 @@ impl LoadedHandle { LoadedHandle::Tp(m) => m.devices.clone(), } } + + /// True if an earlier inference left the device context in an + /// unrecoverable state. Surfaced in `/models` so cortex (and an + /// operator running `curl beast:13131/models`) can see at a glance + /// that the model needs unload+reload. + pub fn is_poisoned(&self) -> bool { + match self { + LoadedHandle::Single(m) => m.poisoned.load(Ordering::Acquire), + #[cfg(feature = "cuda")] + LoadedHandle::Tp(m) => m.poisoned.load(Ordering::Acquire), + } + } } /// A loaded model with its tokenizer, device placement, and architecture- @@ -86,6 +99,15 @@ pub struct LoadedModel { pub device: Device, pub quant: Option, pub devices: Vec, + /// Set to `true` after any forward / kv-cache call fails. A CUDA + /// driver error (OOM, illegal address) leaves the device's context + /// in an unrecoverable state — subsequent kernels can hang, return + /// garbage, or hit another illegal address. The harness refuses + /// further inference against a poisoned model and reports a clear + /// error so an operator knows to unload+reload to recover. See + /// the 2026-05-26 beast incident where a 14k-token prefill OOM + /// silently turned every subsequent request into a stuck wait. + pub poisoned: AtomicBool, } /// Tensor-parallel loaded model. Holds the leader's rank-0 shard @@ -110,6 +132,11 @@ pub struct TpLoadedModel { /// query VRAM without locking the leader (which would contend with /// the in-flight forward). pub leader_device: Device, + /// Same poisoning gate as [`LoadedModel::poisoned`]. A TP forward + /// failure (CUDA OOM on any rank, NCCL desync, illegal address) is + /// terminal: the leader's and workers' CUDA contexts cannot be + /// reliably reset without restarting the worker subprocesses. + pub poisoned: AtomicBool, } /// Architecture-specific weights. Each variant covers one (family, @@ -359,6 +386,20 @@ fn resolve_hf_cache(explicit: Option) -> Option { None } +/// Build the InferenceError reported to a client when their request +/// hits a model that's been marked poisoned by an earlier driver +/// failure. The message names the model and the recovery procedure so +/// the operator doesn't have to chase the original failure to know +/// what to do. +fn poisoned_error(model_id: &str) -> InferenceError { + InferenceError::Other(anyhow::anyhow!( + "model '{model_id}' is in a poisoned state \ + (an earlier inference hit a CUDA driver error and the device \ + context cannot be safely reused); unload and reload the model \ + to recover" + )) +} + /// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if /// the query fails or the device is the CPU fallback so logging never /// crashes the request path. Mirrors the existing helper in @@ -853,6 +894,15 @@ impl CandleHarness { let span = tracing::info_span!("chat", req_id = %req_id, model = %model_id); let req_start = std::time::Instant::now(); + // Refuse the request up front if a prior inference poisoned + // the device context — otherwise we hand the doomed forward + // off to spawn_blocking and stall waiting for CUDA to fail. + if loaded.poisoned.load(Ordering::Acquire) { + let _g = span.enter(); + tracing::warn!("chat_completion: refusing request, model poisoned"); + return Err(poisoned_error(&model_id)); + } + let result = async { let prompt = format_qwen3_prompt(&request.messages); @@ -888,7 +938,7 @@ impl CandleHarness { let arch_arc = Arc::clone(&loaded.arch); let device = loaded.device.clone(); - let (generated_ids, finish_reason) = + let inference_result = tokio::task::spawn_blocking(move || -> Result<(Vec, String)> { let mut guard = arch_arc.blocking_lock(); run_inference( @@ -902,11 +952,26 @@ impl CandleHarness { eos_id, ) }) - .await - .map_err(|e| { - InferenceError::Other(anyhow::anyhow!("inference task panicked: {e}")) - })? - .map_err(InferenceError::Other)?; + .await; + + // Any failure inside the spawn_blocking touched CUDA via + // candle's forward / cache code, so we treat it as a + // device-poisoning event. The terminal log at the bottom + // of the wrapper reports the error; this flag stops the + // NEXT request from going down the same path. + let (generated_ids, finish_reason) = match inference_result { + Ok(Ok(v)) => v, + Ok(Err(e)) => { + loaded.poisoned.store(true, Ordering::Release); + return Err(InferenceError::Other(e)); + } + Err(e) => { + loaded.poisoned.store(true, Ordering::Release); + return Err(InferenceError::Other(anyhow::anyhow!( + "inference task panicked: {e}" + ))); + } + }; let completion_text = loaded .tokenizer @@ -1039,6 +1104,12 @@ impl CandleHarness { usage: None, extra: serde_json::Value::Object(Default::default()), }; + // Refuse if the model is already poisoned. No point opening + // an SSE stream just to send the role chunk and then bail. + if loaded.poisoned.load(Ordering::Acquire) { + return Err(poisoned_error(&model_id)); + } + // If sending the role chunk fails the receiver is already gone; // bail before kicking off the heavy blocking work. tx.send(role_chunk) @@ -1052,6 +1123,9 @@ impl CandleHarness { let span = tracing::info_span!("chat_stream", req_id = %req_id, model = %model_id); let prompt_len = prompt_tokens.len(); let req_start = std::time::Instant::now(); + // Cloned `Arc` so the spawned task can mark the + // model poisoned if its forward fails. + let loaded_for_task = Arc::clone(&loaded); let span_for_starting = span.clone(); let span_for_task = span.clone(); { @@ -1091,12 +1165,15 @@ impl CandleHarness { total_ms = req_start.elapsed().as_millis(), "chat_completion (stream): done" ), - Err(e) => tracing::error!( - error = %format!("{e:#}"), - prompt_tokens = prompt_len, - total_ms = req_start.elapsed().as_millis(), - "chat_completion (stream): failed" - ), + Err(e) => { + loaded_for_task.poisoned.store(true, Ordering::Release); + tracing::error!( + error = %format!("{e:#}"), + prompt_tokens = prompt_len, + total_ms = req_start.elapsed().as_millis(), + "chat_completion (stream): failed, model marked poisoned" + ); + } } }); @@ -1125,7 +1202,11 @@ impl Harness for CandleHarness { .map(|h| ModelInfo { id: h.model_id().into(), harness: "candle".into(), - status: "loaded".into(), + status: if h.is_poisoned() { + "poisoned".into() + } else { + "loaded".into() + }, devices: h.devices(), vram_used_mb: None, }) @@ -1183,6 +1264,7 @@ impl Harness for CandleHarness { device, quant: spec.quant.clone(), devices, + poisoned: AtomicBool::new(false), }); let mut models = self.models.write().await; @@ -1332,6 +1414,7 @@ impl CandleHarness { pool: TMutex::new(pool), leader_model, leader_device: leader_device.clone(), + poisoned: AtomicBool::new(false), }); let mut models = self.models.write().await; @@ -1374,6 +1457,14 @@ impl CandleHarness { let model_id = request.model.clone(); let span = tracing::info_span!("tp_chat", req_id = %req_id, model = %model_id); let req_start = std::time::Instant::now(); + + if tp.poisoned.load(Ordering::Acquire) { + let _g = span.enter(); + tracing::warn!("TP chat_completion: refusing request, model poisoned"); + return Err(poisoned_error(&model_id)); + } + + let tp_for_marker = Arc::clone(&tp); let handle = tokio::spawn(chat_completion_tp_inner(tp, request).instrument(span.clone())); let result = match handle.await { Ok(r) => r, @@ -1382,11 +1473,18 @@ impl CandleHarness { ))), }; if let Err(ref e) = result { + // Mark poisoned: a failure inside the spawned task either + // hit a CUDA/NCCL driver error directly or surfaced as a + // task panic. Both cases leave the worker subprocesses in + // an unknown state — refuse subsequent requests until an + // operator unload+reloads. This is the gate that turned + // the 2026-05-26 silent-hang into a clean 5xx. + tp_for_marker.poisoned.store(true, Ordering::Release); let _g = span.enter(); tracing::error!( error = %format!("{e:#}"), total_ms = req_start.elapsed().as_millis(), - "TP chat_completion: failed" + "TP chat_completion: failed, model marked poisoned" ); } result @@ -1409,6 +1507,10 @@ impl CandleHarness { tp: Arc, request: ChatCompletionRequest, ) -> Result, InferenceError> { + if tp.poisoned.load(Ordering::Acquire) { + return Err(poisoned_error(&request.model)); + } + let prompt = format_qwen3_prompt(&request.messages); let encoding = tp .tokenizer @@ -1605,11 +1707,12 @@ impl CandleHarness { // chunk went out and the spawned task just ended); now // there's always a log line for the operator. if let Some(err) = &failure { + tp_for_task.poisoned.store(true, Ordering::Release); tracing::error!( error = %err, completion_tokens = all_tokens.len(), total_ms = req_start.elapsed().as_millis(), - "TP chat_completion (stream): failed" + "TP chat_completion (stream): failed, model marked poisoned" ); } else { tracing::info!(