feat(neuron,candle): detect CUDA context poisoning and refuse follow-ups
Once a CUDA driver error has hit a forward or kv-cache call, the device's context is unrecoverable in-process — subsequent kernels can hang (the failure mode seen on beast on 2026-05-26), return garbage, or trip another illegal-address. The harness now marks the model poisoned on any forward / spawn_blocking / TP-task failure, refuses further inference against it with a clear "unload and reload" error, and surfaces `status: "poisoned"` on `/models` so an operator running `curl beast:13131/models` (or cortex polling) can see the bad state. Without this, a single OOM on a too-large prefill quietly turned every subsequent request into a stuck wait on the pool lock; with it, the first request fails fast with the driver error in the journal and the client gets a usable 5xx instead of a hung connection. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -29,6 +29,7 @@ use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
#[cfg(feature = "cuda")]
|
||||
use std::time::Duration;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
@@ -74,6 +75,18 @@ impl LoadedHandle {
|
||||
LoadedHandle::Tp(m) => m.devices.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// True if an earlier inference left the device context in an
|
||||
/// unrecoverable state. Surfaced in `/models` so cortex (and an
|
||||
/// operator running `curl beast:13131/models`) can see at a glance
|
||||
/// that the model needs unload+reload.
|
||||
pub fn is_poisoned(&self) -> bool {
|
||||
match self {
|
||||
LoadedHandle::Single(m) => m.poisoned.load(Ordering::Acquire),
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(m) => m.poisoned.load(Ordering::Acquire),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A loaded model with its tokenizer, device placement, and architecture-
|
||||
@@ -86,6 +99,15 @@ pub struct LoadedModel {
|
||||
pub device: Device,
|
||||
pub quant: Option<String>,
|
||||
pub devices: Vec<u32>,
|
||||
/// Set to `true` after any forward / kv-cache call fails. A CUDA
|
||||
/// driver error (OOM, illegal address) leaves the device's context
|
||||
/// in an unrecoverable state — subsequent kernels can hang, return
|
||||
/// garbage, or hit another illegal address. The harness refuses
|
||||
/// further inference against a poisoned model and reports a clear
|
||||
/// error so an operator knows to unload+reload to recover. See
|
||||
/// the 2026-05-26 beast incident where a 14k-token prefill OOM
|
||||
/// silently turned every subsequent request into a stuck wait.
|
||||
pub poisoned: AtomicBool,
|
||||
}
|
||||
|
||||
/// Tensor-parallel loaded model. Holds the leader's rank-0 shard
|
||||
@@ -110,6 +132,11 @@ pub struct TpLoadedModel {
|
||||
/// query VRAM without locking the leader (which would contend with
|
||||
/// the in-flight forward).
|
||||
pub leader_device: Device,
|
||||
/// Same poisoning gate as [`LoadedModel::poisoned`]. A TP forward
|
||||
/// failure (CUDA OOM on any rank, NCCL desync, illegal address) is
|
||||
/// terminal: the leader's and workers' CUDA contexts cannot be
|
||||
/// reliably reset without restarting the worker subprocesses.
|
||||
pub poisoned: AtomicBool,
|
||||
}
|
||||
|
||||
/// Architecture-specific weights. Each variant covers one (family,
|
||||
@@ -359,6 +386,20 @@ fn resolve_hf_cache(explicit: Option<PathBuf>) -> Option<PathBuf> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Build the InferenceError reported to a client when their request
|
||||
/// hits a model that's been marked poisoned by an earlier driver
|
||||
/// failure. The message names the model and the recovery procedure so
|
||||
/// the operator doesn't have to chase the original failure to know
|
||||
/// what to do.
|
||||
fn poisoned_error(model_id: &str) -> InferenceError {
|
||||
InferenceError::Other(anyhow::anyhow!(
|
||||
"model '{model_id}' is in a poisoned state \
|
||||
(an earlier inference hit a CUDA driver error and the device \
|
||||
context cannot be safely reused); unload and reload the model \
|
||||
to recover"
|
||||
))
|
||||
}
|
||||
|
||||
/// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if
|
||||
/// the query fails or the device is the CPU fallback so logging never
|
||||
/// crashes the request path. Mirrors the existing helper in
|
||||
@@ -853,6 +894,15 @@ impl CandleHarness {
|
||||
let span = tracing::info_span!("chat", req_id = %req_id, model = %model_id);
|
||||
let req_start = std::time::Instant::now();
|
||||
|
||||
// Refuse the request up front if a prior inference poisoned
|
||||
// the device context — otherwise we hand the doomed forward
|
||||
// off to spawn_blocking and stall waiting for CUDA to fail.
|
||||
if loaded.poisoned.load(Ordering::Acquire) {
|
||||
let _g = span.enter();
|
||||
tracing::warn!("chat_completion: refusing request, model poisoned");
|
||||
return Err(poisoned_error(&model_id));
|
||||
}
|
||||
|
||||
let result = async {
|
||||
let prompt = format_qwen3_prompt(&request.messages);
|
||||
|
||||
@@ -888,7 +938,7 @@ impl CandleHarness {
|
||||
let arch_arc = Arc::clone(&loaded.arch);
|
||||
let device = loaded.device.clone();
|
||||
|
||||
let (generated_ids, finish_reason) =
|
||||
let inference_result =
|
||||
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
||||
let mut guard = arch_arc.blocking_lock();
|
||||
run_inference(
|
||||
@@ -902,11 +952,26 @@ impl CandleHarness {
|
||||
eos_id,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| {
|
||||
InferenceError::Other(anyhow::anyhow!("inference task panicked: {e}"))
|
||||
})?
|
||||
.map_err(InferenceError::Other)?;
|
||||
.await;
|
||||
|
||||
// Any failure inside the spawn_blocking touched CUDA via
|
||||
// candle's forward / cache code, so we treat it as a
|
||||
// device-poisoning event. The terminal log at the bottom
|
||||
// of the wrapper reports the error; this flag stops the
|
||||
// NEXT request from going down the same path.
|
||||
let (generated_ids, finish_reason) = match inference_result {
|
||||
Ok(Ok(v)) => v,
|
||||
Ok(Err(e)) => {
|
||||
loaded.poisoned.store(true, Ordering::Release);
|
||||
return Err(InferenceError::Other(e));
|
||||
}
|
||||
Err(e) => {
|
||||
loaded.poisoned.store(true, Ordering::Release);
|
||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||
"inference task panicked: {e}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let completion_text = loaded
|
||||
.tokenizer
|
||||
@@ -1039,6 +1104,12 @@ impl CandleHarness {
|
||||
usage: None,
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
};
|
||||
// Refuse if the model is already poisoned. No point opening
|
||||
// an SSE stream just to send the role chunk and then bail.
|
||||
if loaded.poisoned.load(Ordering::Acquire) {
|
||||
return Err(poisoned_error(&model_id));
|
||||
}
|
||||
|
||||
// If sending the role chunk fails the receiver is already gone;
|
||||
// bail before kicking off the heavy blocking work.
|
||||
tx.send(role_chunk)
|
||||
@@ -1052,6 +1123,9 @@ impl CandleHarness {
|
||||
let span = tracing::info_span!("chat_stream", req_id = %req_id, model = %model_id);
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let req_start = std::time::Instant::now();
|
||||
// Cloned `Arc<LoadedModel>` so the spawned task can mark the
|
||||
// model poisoned if its forward fails.
|
||||
let loaded_for_task = Arc::clone(&loaded);
|
||||
let span_for_starting = span.clone();
|
||||
let span_for_task = span.clone();
|
||||
{
|
||||
@@ -1091,12 +1165,15 @@ impl CandleHarness {
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"chat_completion (stream): done"
|
||||
),
|
||||
Err(e) => tracing::error!(
|
||||
Err(e) => {
|
||||
loaded_for_task.poisoned.store(true, Ordering::Release);
|
||||
tracing::error!(
|
||||
error = %format!("{e:#}"),
|
||||
prompt_tokens = prompt_len,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"chat_completion (stream): failed"
|
||||
),
|
||||
"chat_completion (stream): failed, model marked poisoned"
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1125,7 +1202,11 @@ impl Harness for CandleHarness {
|
||||
.map(|h| ModelInfo {
|
||||
id: h.model_id().into(),
|
||||
harness: "candle".into(),
|
||||
status: "loaded".into(),
|
||||
status: if h.is_poisoned() {
|
||||
"poisoned".into()
|
||||
} else {
|
||||
"loaded".into()
|
||||
},
|
||||
devices: h.devices(),
|
||||
vram_used_mb: None,
|
||||
})
|
||||
@@ -1183,6 +1264,7 @@ impl Harness for CandleHarness {
|
||||
device,
|
||||
quant: spec.quant.clone(),
|
||||
devices,
|
||||
poisoned: AtomicBool::new(false),
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
@@ -1332,6 +1414,7 @@ impl CandleHarness {
|
||||
pool: TMutex::new(pool),
|
||||
leader_model,
|
||||
leader_device: leader_device.clone(),
|
||||
poisoned: AtomicBool::new(false),
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
@@ -1374,6 +1457,14 @@ impl CandleHarness {
|
||||
let model_id = request.model.clone();
|
||||
let span = tracing::info_span!("tp_chat", req_id = %req_id, model = %model_id);
|
||||
let req_start = std::time::Instant::now();
|
||||
|
||||
if tp.poisoned.load(Ordering::Acquire) {
|
||||
let _g = span.enter();
|
||||
tracing::warn!("TP chat_completion: refusing request, model poisoned");
|
||||
return Err(poisoned_error(&model_id));
|
||||
}
|
||||
|
||||
let tp_for_marker = Arc::clone(&tp);
|
||||
let handle = tokio::spawn(chat_completion_tp_inner(tp, request).instrument(span.clone()));
|
||||
let result = match handle.await {
|
||||
Ok(r) => r,
|
||||
@@ -1382,11 +1473,18 @@ impl CandleHarness {
|
||||
))),
|
||||
};
|
||||
if let Err(ref e) = result {
|
||||
// Mark poisoned: a failure inside the spawned task either
|
||||
// hit a CUDA/NCCL driver error directly or surfaced as a
|
||||
// task panic. Both cases leave the worker subprocesses in
|
||||
// an unknown state — refuse subsequent requests until an
|
||||
// operator unload+reloads. This is the gate that turned
|
||||
// the 2026-05-26 silent-hang into a clean 5xx.
|
||||
tp_for_marker.poisoned.store(true, Ordering::Release);
|
||||
let _g = span.enter();
|
||||
tracing::error!(
|
||||
error = %format!("{e:#}"),
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"TP chat_completion: failed"
|
||||
"TP chat_completion: failed, model marked poisoned"
|
||||
);
|
||||
}
|
||||
result
|
||||
@@ -1409,6 +1507,10 @@ impl CandleHarness {
|
||||
tp: Arc<TpLoadedModel>,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
||||
if tp.poisoned.load(Ordering::Acquire) {
|
||||
return Err(poisoned_error(&request.model));
|
||||
}
|
||||
|
||||
let prompt = format_qwen3_prompt(&request.messages);
|
||||
let encoding = tp
|
||||
.tokenizer
|
||||
@@ -1605,11 +1707,12 @@ impl CandleHarness {
|
||||
// chunk went out and the spawned task just ended); now
|
||||
// there's always a log line for the operator.
|
||||
if let Some(err) = &failure {
|
||||
tp_for_task.poisoned.store(true, Ordering::Release);
|
||||
tracing::error!(
|
||||
error = %err,
|
||||
completion_tokens = all_tokens.len(),
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"TP chat_completion (stream): failed"
|
||||
"TP chat_completion (stream): failed, model marked poisoned"
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
|
||||
Reference in New Issue
Block a user