feat(neuron,candle): detect CUDA context poisoning and refuse follow-ups

Once a CUDA driver error has hit a forward or kv-cache call, the
device's context is unrecoverable in-process — subsequent kernels can
hang (the failure mode seen on beast on 2026-05-26), return garbage,
or trip another illegal-address. The harness now marks the model
poisoned on any forward / spawn_blocking / TP-task failure, refuses
further inference against it with a clear "unload and reload" error,
and surfaces `status: "poisoned"` on `/models` so an operator running
`curl beast:13131/models` (or cortex polling) can see the bad state.

Without this, a single OOM on a too-large prefill quietly turned every
subsequent request into a stuck wait on the pool lock; with it, the
first request fails fast with the driver error in the journal and the
client gets a usable 5xx instead of a hung connection.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-26 12:28:42 +03:00
parent 1385979e3d
commit fc6ef0ee0f

View File

@@ -29,6 +29,7 @@ use serde_json::json;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
#[cfg(feature = "cuda")]
use std::time::Duration;
use std::time::{SystemTime, UNIX_EPOCH};
@@ -74,6 +75,18 @@ impl LoadedHandle {
LoadedHandle::Tp(m) => m.devices.clone(),
}
}
/// True if an earlier inference left the device context in an
/// unrecoverable state. Surfaced in `/models` so cortex (and an
/// operator running `curl beast:13131/models`) can see at a glance
/// that the model needs unload+reload.
pub fn is_poisoned(&self) -> bool {
match self {
LoadedHandle::Single(m) => m.poisoned.load(Ordering::Acquire),
#[cfg(feature = "cuda")]
LoadedHandle::Tp(m) => m.poisoned.load(Ordering::Acquire),
}
}
}
/// A loaded model with its tokenizer, device placement, and architecture-
@@ -86,6 +99,15 @@ pub struct LoadedModel {
pub device: Device,
pub quant: Option<String>,
pub devices: Vec<u32>,
/// Set to `true` after any forward / kv-cache call fails. A CUDA
/// driver error (OOM, illegal address) leaves the device's context
/// in an unrecoverable state — subsequent kernels can hang, return
/// garbage, or hit another illegal address. The harness refuses
/// further inference against a poisoned model and reports a clear
/// error so an operator knows to unload+reload to recover. See
/// the 2026-05-26 beast incident where a 14k-token prefill OOM
/// silently turned every subsequent request into a stuck wait.
pub poisoned: AtomicBool,
}
/// Tensor-parallel loaded model. Holds the leader's rank-0 shard
@@ -110,6 +132,11 @@ pub struct TpLoadedModel {
/// query VRAM without locking the leader (which would contend with
/// the in-flight forward).
pub leader_device: Device,
/// Same poisoning gate as [`LoadedModel::poisoned`]. A TP forward
/// failure (CUDA OOM on any rank, NCCL desync, illegal address) is
/// terminal: the leader's and workers' CUDA contexts cannot be
/// reliably reset without restarting the worker subprocesses.
pub poisoned: AtomicBool,
}
/// Architecture-specific weights. Each variant covers one (family,
@@ -359,6 +386,20 @@ fn resolve_hf_cache(explicit: Option<PathBuf>) -> Option<PathBuf> {
None
}
/// Build the InferenceError reported to a client when their request
/// hits a model that's been marked poisoned by an earlier driver
/// failure. The message names the model and the recovery procedure so
/// the operator doesn't have to chase the original failure to know
/// what to do.
fn poisoned_error(model_id: &str) -> InferenceError {
InferenceError::Other(anyhow::anyhow!(
"model '{model_id}' is in a poisoned state \
(an earlier inference hit a CUDA driver error and the device \
context cannot be safely reused); unload and reload the model \
to recover"
))
}
/// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if
/// the query fails or the device is the CPU fallback so logging never
/// crashes the request path. Mirrors the existing helper in
@@ -853,6 +894,15 @@ impl CandleHarness {
let span = tracing::info_span!("chat", req_id = %req_id, model = %model_id);
let req_start = std::time::Instant::now();
// Refuse the request up front if a prior inference poisoned
// the device context — otherwise we hand the doomed forward
// off to spawn_blocking and stall waiting for CUDA to fail.
if loaded.poisoned.load(Ordering::Acquire) {
let _g = span.enter();
tracing::warn!("chat_completion: refusing request, model poisoned");
return Err(poisoned_error(&model_id));
}
let result = async {
let prompt = format_qwen3_prompt(&request.messages);
@@ -888,7 +938,7 @@ impl CandleHarness {
let arch_arc = Arc::clone(&loaded.arch);
let device = loaded.device.clone();
let (generated_ids, finish_reason) =
let inference_result =
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
let mut guard = arch_arc.blocking_lock();
run_inference(
@@ -902,11 +952,26 @@ impl CandleHarness {
eos_id,
)
})
.await
.map_err(|e| {
InferenceError::Other(anyhow::anyhow!("inference task panicked: {e}"))
})?
.map_err(InferenceError::Other)?;
.await;
// Any failure inside the spawn_blocking touched CUDA via
// candle's forward / cache code, so we treat it as a
// device-poisoning event. The terminal log at the bottom
// of the wrapper reports the error; this flag stops the
// NEXT request from going down the same path.
let (generated_ids, finish_reason) = match inference_result {
Ok(Ok(v)) => v,
Ok(Err(e)) => {
loaded.poisoned.store(true, Ordering::Release);
return Err(InferenceError::Other(e));
}
Err(e) => {
loaded.poisoned.store(true, Ordering::Release);
return Err(InferenceError::Other(anyhow::anyhow!(
"inference task panicked: {e}"
)));
}
};
let completion_text = loaded
.tokenizer
@@ -1039,6 +1104,12 @@ impl CandleHarness {
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
// Refuse if the model is already poisoned. No point opening
// an SSE stream just to send the role chunk and then bail.
if loaded.poisoned.load(Ordering::Acquire) {
return Err(poisoned_error(&model_id));
}
// If sending the role chunk fails the receiver is already gone;
// bail before kicking off the heavy blocking work.
tx.send(role_chunk)
@@ -1052,6 +1123,9 @@ impl CandleHarness {
let span = tracing::info_span!("chat_stream", req_id = %req_id, model = %model_id);
let prompt_len = prompt_tokens.len();
let req_start = std::time::Instant::now();
// Cloned `Arc<LoadedModel>` so the spawned task can mark the
// model poisoned if its forward fails.
let loaded_for_task = Arc::clone(&loaded);
let span_for_starting = span.clone();
let span_for_task = span.clone();
{
@@ -1091,12 +1165,15 @@ impl CandleHarness {
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): done"
),
Err(e) => tracing::error!(
Err(e) => {
loaded_for_task.poisoned.store(true, Ordering::Release);
tracing::error!(
error = %format!("{e:#}"),
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): failed"
),
"chat_completion (stream): failed, model marked poisoned"
);
}
}
});
@@ -1125,7 +1202,11 @@ impl Harness for CandleHarness {
.map(|h| ModelInfo {
id: h.model_id().into(),
harness: "candle".into(),
status: "loaded".into(),
status: if h.is_poisoned() {
"poisoned".into()
} else {
"loaded".into()
},
devices: h.devices(),
vram_used_mb: None,
})
@@ -1183,6 +1264,7 @@ impl Harness for CandleHarness {
device,
quant: spec.quant.clone(),
devices,
poisoned: AtomicBool::new(false),
});
let mut models = self.models.write().await;
@@ -1332,6 +1414,7 @@ impl CandleHarness {
pool: TMutex::new(pool),
leader_model,
leader_device: leader_device.clone(),
poisoned: AtomicBool::new(false),
});
let mut models = self.models.write().await;
@@ -1374,6 +1457,14 @@ impl CandleHarness {
let model_id = request.model.clone();
let span = tracing::info_span!("tp_chat", req_id = %req_id, model = %model_id);
let req_start = std::time::Instant::now();
if tp.poisoned.load(Ordering::Acquire) {
let _g = span.enter();
tracing::warn!("TP chat_completion: refusing request, model poisoned");
return Err(poisoned_error(&model_id));
}
let tp_for_marker = Arc::clone(&tp);
let handle = tokio::spawn(chat_completion_tp_inner(tp, request).instrument(span.clone()));
let result = match handle.await {
Ok(r) => r,
@@ -1382,11 +1473,18 @@ impl CandleHarness {
))),
};
if let Err(ref e) = result {
// Mark poisoned: a failure inside the spawned task either
// hit a CUDA/NCCL driver error directly or surfaced as a
// task panic. Both cases leave the worker subprocesses in
// an unknown state — refuse subsequent requests until an
// operator unload+reloads. This is the gate that turned
// the 2026-05-26 silent-hang into a clean 5xx.
tp_for_marker.poisoned.store(true, Ordering::Release);
let _g = span.enter();
tracing::error!(
error = %format!("{e:#}"),
total_ms = req_start.elapsed().as_millis(),
"TP chat_completion: failed"
"TP chat_completion: failed, model marked poisoned"
);
}
result
@@ -1409,6 +1507,10 @@ impl CandleHarness {
tp: Arc<TpLoadedModel>,
request: ChatCompletionRequest,
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
if tp.poisoned.load(Ordering::Acquire) {
return Err(poisoned_error(&request.model));
}
let prompt = format_qwen3_prompt(&request.messages);
let encoding = tp
.tokenizer
@@ -1605,11 +1707,12 @@ impl CandleHarness {
// chunk went out and the spawned task just ended); now
// there's always a log line for the operator.
if let Some(err) = &failure {
tp_for_task.poisoned.store(true, Ordering::Release);
tracing::error!(
error = %err,
completion_tokens = all_tokens.len(),
total_ms = req_start.elapsed().as_millis(),
"TP chat_completion (stream): failed"
"TP chat_completion (stream): failed, model marked poisoned"
);
} else {
tracing::info!(