feat(neuron,candle): detect CUDA context poisoning and refuse follow-ups
Once a CUDA driver error has hit a forward or kv-cache call, the device's context is unrecoverable in-process — subsequent kernels can hang (the failure mode seen on beast on 2026-05-26), return garbage, or trip another illegal-address. The harness now marks the model poisoned on any forward / spawn_blocking / TP-task failure, refuses further inference against it with a clear "unload and reload" error, and surfaces `status: "poisoned"` on `/models` so an operator running `curl beast:13131/models` (or cortex polling) can see the bad state. Without this, a single OOM on a too-large prefill quietly turned every subsequent request into a stuck wait on the pool lock; with it, the first request fails fast with the driver error in the journal and the client gets a usable 5xx instead of a hung connection. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -29,6 +29,7 @@ use serde_json::json;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
@@ -74,6 +75,18 @@ impl LoadedHandle {
|
|||||||
LoadedHandle::Tp(m) => m.devices.clone(),
|
LoadedHandle::Tp(m) => m.devices.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// True if an earlier inference left the device context in an
|
||||||
|
/// unrecoverable state. Surfaced in `/models` so cortex (and an
|
||||||
|
/// operator running `curl beast:13131/models`) can see at a glance
|
||||||
|
/// that the model needs unload+reload.
|
||||||
|
pub fn is_poisoned(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
LoadedHandle::Single(m) => m.poisoned.load(Ordering::Acquire),
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
LoadedHandle::Tp(m) => m.poisoned.load(Ordering::Acquire),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A loaded model with its tokenizer, device placement, and architecture-
|
/// A loaded model with its tokenizer, device placement, and architecture-
|
||||||
@@ -86,6 +99,15 @@ pub struct LoadedModel {
|
|||||||
pub device: Device,
|
pub device: Device,
|
||||||
pub quant: Option<String>,
|
pub quant: Option<String>,
|
||||||
pub devices: Vec<u32>,
|
pub devices: Vec<u32>,
|
||||||
|
/// Set to `true` after any forward / kv-cache call fails. A CUDA
|
||||||
|
/// driver error (OOM, illegal address) leaves the device's context
|
||||||
|
/// in an unrecoverable state — subsequent kernels can hang, return
|
||||||
|
/// garbage, or hit another illegal address. The harness refuses
|
||||||
|
/// further inference against a poisoned model and reports a clear
|
||||||
|
/// error so an operator knows to unload+reload to recover. See
|
||||||
|
/// the 2026-05-26 beast incident where a 14k-token prefill OOM
|
||||||
|
/// silently turned every subsequent request into a stuck wait.
|
||||||
|
pub poisoned: AtomicBool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Tensor-parallel loaded model. Holds the leader's rank-0 shard
|
/// Tensor-parallel loaded model. Holds the leader's rank-0 shard
|
||||||
@@ -110,6 +132,11 @@ pub struct TpLoadedModel {
|
|||||||
/// query VRAM without locking the leader (which would contend with
|
/// query VRAM without locking the leader (which would contend with
|
||||||
/// the in-flight forward).
|
/// the in-flight forward).
|
||||||
pub leader_device: Device,
|
pub leader_device: Device,
|
||||||
|
/// Same poisoning gate as [`LoadedModel::poisoned`]. A TP forward
|
||||||
|
/// failure (CUDA OOM on any rank, NCCL desync, illegal address) is
|
||||||
|
/// terminal: the leader's and workers' CUDA contexts cannot be
|
||||||
|
/// reliably reset without restarting the worker subprocesses.
|
||||||
|
pub poisoned: AtomicBool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Architecture-specific weights. Each variant covers one (family,
|
/// Architecture-specific weights. Each variant covers one (family,
|
||||||
@@ -359,6 +386,20 @@ fn resolve_hf_cache(explicit: Option<PathBuf>) -> Option<PathBuf> {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build the InferenceError reported to a client when their request
|
||||||
|
/// hits a model that's been marked poisoned by an earlier driver
|
||||||
|
/// failure. The message names the model and the recovery procedure so
|
||||||
|
/// the operator doesn't have to chase the original failure to know
|
||||||
|
/// what to do.
|
||||||
|
fn poisoned_error(model_id: &str) -> InferenceError {
|
||||||
|
InferenceError::Other(anyhow::anyhow!(
|
||||||
|
"model '{model_id}' is in a poisoned state \
|
||||||
|
(an earlier inference hit a CUDA driver error and the device \
|
||||||
|
context cannot be safely reused); unload and reload the model \
|
||||||
|
to recover"
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
/// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if
|
/// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if
|
||||||
/// the query fails or the device is the CPU fallback so logging never
|
/// the query fails or the device is the CPU fallback so logging never
|
||||||
/// crashes the request path. Mirrors the existing helper in
|
/// crashes the request path. Mirrors the existing helper in
|
||||||
@@ -853,6 +894,15 @@ impl CandleHarness {
|
|||||||
let span = tracing::info_span!("chat", req_id = %req_id, model = %model_id);
|
let span = tracing::info_span!("chat", req_id = %req_id, model = %model_id);
|
||||||
let req_start = std::time::Instant::now();
|
let req_start = std::time::Instant::now();
|
||||||
|
|
||||||
|
// Refuse the request up front if a prior inference poisoned
|
||||||
|
// the device context — otherwise we hand the doomed forward
|
||||||
|
// off to spawn_blocking and stall waiting for CUDA to fail.
|
||||||
|
if loaded.poisoned.load(Ordering::Acquire) {
|
||||||
|
let _g = span.enter();
|
||||||
|
tracing::warn!("chat_completion: refusing request, model poisoned");
|
||||||
|
return Err(poisoned_error(&model_id));
|
||||||
|
}
|
||||||
|
|
||||||
let result = async {
|
let result = async {
|
||||||
let prompt = format_qwen3_prompt(&request.messages);
|
let prompt = format_qwen3_prompt(&request.messages);
|
||||||
|
|
||||||
@@ -888,7 +938,7 @@ impl CandleHarness {
|
|||||||
let arch_arc = Arc::clone(&loaded.arch);
|
let arch_arc = Arc::clone(&loaded.arch);
|
||||||
let device = loaded.device.clone();
|
let device = loaded.device.clone();
|
||||||
|
|
||||||
let (generated_ids, finish_reason) =
|
let inference_result =
|
||||||
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
||||||
let mut guard = arch_arc.blocking_lock();
|
let mut guard = arch_arc.blocking_lock();
|
||||||
run_inference(
|
run_inference(
|
||||||
@@ -902,11 +952,26 @@ impl CandleHarness {
|
|||||||
eos_id,
|
eos_id,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.await
|
.await;
|
||||||
.map_err(|e| {
|
|
||||||
InferenceError::Other(anyhow::anyhow!("inference task panicked: {e}"))
|
// Any failure inside the spawn_blocking touched CUDA via
|
||||||
})?
|
// candle's forward / cache code, so we treat it as a
|
||||||
.map_err(InferenceError::Other)?;
|
// device-poisoning event. The terminal log at the bottom
|
||||||
|
// of the wrapper reports the error; this flag stops the
|
||||||
|
// NEXT request from going down the same path.
|
||||||
|
let (generated_ids, finish_reason) = match inference_result {
|
||||||
|
Ok(Ok(v)) => v,
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
loaded.poisoned.store(true, Ordering::Release);
|
||||||
|
return Err(InferenceError::Other(e));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
loaded.poisoned.store(true, Ordering::Release);
|
||||||
|
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||||
|
"inference task panicked: {e}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let completion_text = loaded
|
let completion_text = loaded
|
||||||
.tokenizer
|
.tokenizer
|
||||||
@@ -1039,6 +1104,12 @@ impl CandleHarness {
|
|||||||
usage: None,
|
usage: None,
|
||||||
extra: serde_json::Value::Object(Default::default()),
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
};
|
};
|
||||||
|
// Refuse if the model is already poisoned. No point opening
|
||||||
|
// an SSE stream just to send the role chunk and then bail.
|
||||||
|
if loaded.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(poisoned_error(&model_id));
|
||||||
|
}
|
||||||
|
|
||||||
// If sending the role chunk fails the receiver is already gone;
|
// If sending the role chunk fails the receiver is already gone;
|
||||||
// bail before kicking off the heavy blocking work.
|
// bail before kicking off the heavy blocking work.
|
||||||
tx.send(role_chunk)
|
tx.send(role_chunk)
|
||||||
@@ -1052,6 +1123,9 @@ impl CandleHarness {
|
|||||||
let span = tracing::info_span!("chat_stream", req_id = %req_id, model = %model_id);
|
let span = tracing::info_span!("chat_stream", req_id = %req_id, model = %model_id);
|
||||||
let prompt_len = prompt_tokens.len();
|
let prompt_len = prompt_tokens.len();
|
||||||
let req_start = std::time::Instant::now();
|
let req_start = std::time::Instant::now();
|
||||||
|
// Cloned `Arc<LoadedModel>` so the spawned task can mark the
|
||||||
|
// model poisoned if its forward fails.
|
||||||
|
let loaded_for_task = Arc::clone(&loaded);
|
||||||
let span_for_starting = span.clone();
|
let span_for_starting = span.clone();
|
||||||
let span_for_task = span.clone();
|
let span_for_task = span.clone();
|
||||||
{
|
{
|
||||||
@@ -1091,12 +1165,15 @@ impl CandleHarness {
|
|||||||
total_ms = req_start.elapsed().as_millis(),
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
"chat_completion (stream): done"
|
"chat_completion (stream): done"
|
||||||
),
|
),
|
||||||
Err(e) => tracing::error!(
|
Err(e) => {
|
||||||
|
loaded_for_task.poisoned.store(true, Ordering::Release);
|
||||||
|
tracing::error!(
|
||||||
error = %format!("{e:#}"),
|
error = %format!("{e:#}"),
|
||||||
prompt_tokens = prompt_len,
|
prompt_tokens = prompt_len,
|
||||||
total_ms = req_start.elapsed().as_millis(),
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
"chat_completion (stream): failed"
|
"chat_completion (stream): failed, model marked poisoned"
|
||||||
),
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -1125,7 +1202,11 @@ impl Harness for CandleHarness {
|
|||||||
.map(|h| ModelInfo {
|
.map(|h| ModelInfo {
|
||||||
id: h.model_id().into(),
|
id: h.model_id().into(),
|
||||||
harness: "candle".into(),
|
harness: "candle".into(),
|
||||||
status: "loaded".into(),
|
status: if h.is_poisoned() {
|
||||||
|
"poisoned".into()
|
||||||
|
} else {
|
||||||
|
"loaded".into()
|
||||||
|
},
|
||||||
devices: h.devices(),
|
devices: h.devices(),
|
||||||
vram_used_mb: None,
|
vram_used_mb: None,
|
||||||
})
|
})
|
||||||
@@ -1183,6 +1264,7 @@ impl Harness for CandleHarness {
|
|||||||
device,
|
device,
|
||||||
quant: spec.quant.clone(),
|
quant: spec.quant.clone(),
|
||||||
devices,
|
devices,
|
||||||
|
poisoned: AtomicBool::new(false),
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut models = self.models.write().await;
|
let mut models = self.models.write().await;
|
||||||
@@ -1332,6 +1414,7 @@ impl CandleHarness {
|
|||||||
pool: TMutex::new(pool),
|
pool: TMutex::new(pool),
|
||||||
leader_model,
|
leader_model,
|
||||||
leader_device: leader_device.clone(),
|
leader_device: leader_device.clone(),
|
||||||
|
poisoned: AtomicBool::new(false),
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut models = self.models.write().await;
|
let mut models = self.models.write().await;
|
||||||
@@ -1374,6 +1457,14 @@ impl CandleHarness {
|
|||||||
let model_id = request.model.clone();
|
let model_id = request.model.clone();
|
||||||
let span = tracing::info_span!("tp_chat", req_id = %req_id, model = %model_id);
|
let span = tracing::info_span!("tp_chat", req_id = %req_id, model = %model_id);
|
||||||
let req_start = std::time::Instant::now();
|
let req_start = std::time::Instant::now();
|
||||||
|
|
||||||
|
if tp.poisoned.load(Ordering::Acquire) {
|
||||||
|
let _g = span.enter();
|
||||||
|
tracing::warn!("TP chat_completion: refusing request, model poisoned");
|
||||||
|
return Err(poisoned_error(&model_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
let tp_for_marker = Arc::clone(&tp);
|
||||||
let handle = tokio::spawn(chat_completion_tp_inner(tp, request).instrument(span.clone()));
|
let handle = tokio::spawn(chat_completion_tp_inner(tp, request).instrument(span.clone()));
|
||||||
let result = match handle.await {
|
let result = match handle.await {
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
@@ -1382,11 +1473,18 @@ impl CandleHarness {
|
|||||||
))),
|
))),
|
||||||
};
|
};
|
||||||
if let Err(ref e) = result {
|
if let Err(ref e) = result {
|
||||||
|
// Mark poisoned: a failure inside the spawned task either
|
||||||
|
// hit a CUDA/NCCL driver error directly or surfaced as a
|
||||||
|
// task panic. Both cases leave the worker subprocesses in
|
||||||
|
// an unknown state — refuse subsequent requests until an
|
||||||
|
// operator unload+reloads. This is the gate that turned
|
||||||
|
// the 2026-05-26 silent-hang into a clean 5xx.
|
||||||
|
tp_for_marker.poisoned.store(true, Ordering::Release);
|
||||||
let _g = span.enter();
|
let _g = span.enter();
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
error = %format!("{e:#}"),
|
error = %format!("{e:#}"),
|
||||||
total_ms = req_start.elapsed().as_millis(),
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
"TP chat_completion: failed"
|
"TP chat_completion: failed, model marked poisoned"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
result
|
result
|
||||||
@@ -1409,6 +1507,10 @@ impl CandleHarness {
|
|||||||
tp: Arc<TpLoadedModel>,
|
tp: Arc<TpLoadedModel>,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
||||||
|
if tp.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(poisoned_error(&request.model));
|
||||||
|
}
|
||||||
|
|
||||||
let prompt = format_qwen3_prompt(&request.messages);
|
let prompt = format_qwen3_prompt(&request.messages);
|
||||||
let encoding = tp
|
let encoding = tp
|
||||||
.tokenizer
|
.tokenizer
|
||||||
@@ -1605,11 +1707,12 @@ impl CandleHarness {
|
|||||||
// chunk went out and the spawned task just ended); now
|
// chunk went out and the spawned task just ended); now
|
||||||
// there's always a log line for the operator.
|
// there's always a log line for the operator.
|
||||||
if let Some(err) = &failure {
|
if let Some(err) = &failure {
|
||||||
|
tp_for_task.poisoned.store(true, Ordering::Release);
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
error = %err,
|
error = %err,
|
||||||
completion_tokens = all_tokens.len(),
|
completion_tokens = all_tokens.len(),
|
||||||
total_ms = req_start.elapsed().as_millis(),
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
"TP chat_completion (stream): failed"
|
"TP chat_completion (stream): failed, model marked poisoned"
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
|
|||||||
Reference in New Issue
Block a user