fix(neuron): only poison the model on actual device faults
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Clippy (push) Successful in 2m22s
CI / Test (push) Successful in 4m55s
build-prerelease / Build cortex binary (push) Successful in 4m24s
build-prerelease / Build neuron-blackwell (push) Successful in 5m49s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 8m7s
build-prerelease / Build neuron-ada (push) Successful in 5m0s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m48s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m5s
CI / Format (push) Failing after 33s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Clippy (push) Successful in 2m22s
CI / Test (push) Successful in 4m55s
build-prerelease / Build cortex binary (push) Successful in 4m24s
build-prerelease / Build neuron-blackwell (push) Successful in 5m49s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 8m7s
build-prerelease / Build neuron-ada (push) Successful in 5m0s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m48s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m5s
CI / Format (push) Failing after 33s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
Previously every inference Err — shape mismatch, NaN logits, tokenizer
error, missing handle — marked the model poisoned and rejected every
subsequent request until an operator unload+reloaded. The benjy
incident on 2026-05-27 showed how this misfires: a concurrency bug
produced a `broadcast_add: shape mismatch` error that had nothing to
do with CUDA, but the model was taken down anyway.
Add `is_device_fault(err_chain: &str)` — a conservative classifier
that returns false only for errors we know are pre-kernel / CPU-side
(shape mismatches, NaN logits, tokenize/detokenize, missing handle,
DecodeStream, empty prompt). Everything else defaults to true so a
genuine driver fault still poisons.
Applied at all six poisoning sites:
- chat_completion CUDA worker path
- chat_completion CPU spawn_blocking path
- chat_completion_stream CUDA worker path
- chat_completion_stream CPU spawn_blocking path
- chat_completion_tp non-streaming wrapper
- chat_completion_tp_stream spawned task
Each site now logs either "model marked poisoned" (device fault) or
"model NOT marked poisoned" (non-device) so the journal makes the
classification visible. Tests cover the known non-device patterns and
a couple of real CUDA driver messages.
Pairs with the inference_lock commit (c59da83): together they
eliminate both the cause of the spurious-poisoning we just observed
(the shape mismatch) AND the over-reaction to it (the unconditional
poison). Each fix is independently useful but the combination is
what makes the system actually robust to concurrent agent workloads.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -581,6 +581,41 @@ fn logits_health_slice(values: &[f32]) -> LogitsHealth {
|
||||
}
|
||||
}
|
||||
|
||||
/// Classify an inference-failure error string: should we mark the
|
||||
/// model poisoned, or is this a logic / numerical / tokenizer failure
|
||||
/// that leaves the device context healthy? Default is "yes, poison" —
|
||||
/// the cost of failing to poison a genuinely-corrupt context (next
|
||||
/// request hangs or returns garbage) outweighs the cost of
|
||||
/// over-poisoning (operator unload+reloads). The opt-out list covers
|
||||
/// errors we know don't touch device state.
|
||||
///
|
||||
/// Pass the `format!("{err:#}")` rendering of an anyhow::Error (or the
|
||||
/// already-stringified error in paths that stringify failures, like
|
||||
/// the TP streaming task). Matching against the full chain lets the
|
||||
/// classification survive `.context("…")` and `format!("…: {e}")`
|
||||
/// wrappers in the call sites.
|
||||
fn is_device_fault(chain_text: &str) -> bool {
|
||||
let chain = chain_text.to_lowercase();
|
||||
// Non-device patterns: shape errors are pre-kernel and don't touch
|
||||
// GPU state; NaN-logits failures happen on the CPU side after the
|
||||
// forward; tokenize/detokenize is pure CPU; missing-handle lookups
|
||||
// are pre-dispatch. Everything else we treat conservatively as a
|
||||
// potential device fault.
|
||||
let non_device_markers = [
|
||||
"shape mismatch",
|
||||
"broadcast",
|
||||
"cannot broadcast",
|
||||
"logits unhealthy",
|
||||
"tokenize",
|
||||
"detokenize",
|
||||
"decode_stream",
|
||||
"no model for handle",
|
||||
"no tp model for handle",
|
||||
"empty prompt",
|
||||
];
|
||||
!non_device_markers.iter().any(|m| chain.contains(m))
|
||||
}
|
||||
|
||||
/// Build the InferenceError reported to a client when their request
|
||||
/// hits a model that's been marked poisoned by an earlier driver
|
||||
/// failure. The message names the model and the recovery procedure so
|
||||
@@ -1395,7 +1430,19 @@ impl CandleHarness {
|
||||
{
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
let chain = format!("{e:#}");
|
||||
if is_device_fault(&chain) {
|
||||
loaded.poisoned.store(true, Ordering::Release);
|
||||
tracing::warn!(
|
||||
error = %chain,
|
||||
"chat_completion: failed with device fault, model marked poisoned"
|
||||
);
|
||||
} else {
|
||||
tracing::warn!(
|
||||
error = %chain,
|
||||
"chat_completion: failed (non-device fault); model NOT marked poisoned"
|
||||
);
|
||||
}
|
||||
return Err(InferenceError::Other(e));
|
||||
}
|
||||
}
|
||||
@@ -1438,7 +1485,19 @@ impl CandleHarness {
|
||||
match inference_result {
|
||||
Ok(Ok(v)) => v,
|
||||
Ok(Err(e)) => {
|
||||
let chain = format!("{e:#}");
|
||||
if is_device_fault(&chain) {
|
||||
loaded.poisoned.store(true, Ordering::Release);
|
||||
tracing::warn!(
|
||||
error = %chain,
|
||||
"chat_completion: failed with device fault, model marked poisoned"
|
||||
);
|
||||
} else {
|
||||
tracing::warn!(
|
||||
error = %chain,
|
||||
"chat_completion: failed (non-device fault); model NOT marked poisoned"
|
||||
);
|
||||
}
|
||||
return Err(InferenceError::Other(e));
|
||||
}
|
||||
Err(join_err) => {
|
||||
@@ -1682,13 +1741,23 @@ impl CandleHarness {
|
||||
"chat_completion (stream): done"
|
||||
),
|
||||
Err(e) => {
|
||||
let chain = format!("{e:#}");
|
||||
if is_device_fault(&chain) {
|
||||
loaded_for_task.poisoned.store(true, Ordering::Release);
|
||||
tracing::error!(
|
||||
error = %format!("{e:#}"),
|
||||
error = %chain,
|
||||
prompt_tokens = prompt_len,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"chat_completion (stream): failed, model marked poisoned"
|
||||
"chat_completion (stream): failed with device fault, model marked poisoned"
|
||||
);
|
||||
} else {
|
||||
tracing::error!(
|
||||
error = %chain,
|
||||
prompt_tokens = prompt_len,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"chat_completion (stream): failed (non-device fault); model NOT marked poisoned"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1729,13 +1798,23 @@ impl CandleHarness {
|
||||
"chat_completion (stream): done"
|
||||
),
|
||||
Err(e) => {
|
||||
let chain = format!("{e:#}");
|
||||
if is_device_fault(&chain) {
|
||||
loaded_for_task.poisoned.store(true, Ordering::Release);
|
||||
tracing::error!(
|
||||
error = %format!("{e:#}"),
|
||||
error = %chain,
|
||||
prompt_tokens = prompt_len,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"chat_completion (stream): failed, model marked poisoned"
|
||||
"chat_completion (stream): failed with device fault, model marked poisoned"
|
||||
);
|
||||
} else {
|
||||
tracing::error!(
|
||||
error = %chain,
|
||||
prompt_tokens = prompt_len,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"chat_completion (stream): failed (non-device fault); model NOT marked poisoned"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -2103,18 +2182,28 @@ impl CandleHarness {
|
||||
match handle.await {
|
||||
Ok(Ok(resp)) => Ok(resp),
|
||||
Ok(Err(e)) => {
|
||||
// The inner task returned Err — a real inference
|
||||
// failure that propagated through `?`. CUDA / NCCL
|
||||
// driver errors leave the device context unrecoverable,
|
||||
// so poison the model. This is the gate that turned
|
||||
// the 2026-05-26 silent-hang into a clean 5xx.
|
||||
tp_for_marker.poisoned.store(true, Ordering::Release);
|
||||
// The inner task returned Err. Only poison when the
|
||||
// failure indicates a CUDA / NCCL driver fault — shape
|
||||
// mismatches, NaN logits, tokenizer errors etc. don't
|
||||
// touch the device context and shouldn't take the
|
||||
// model down for everyone else.
|
||||
let chain = format!("{e:#}");
|
||||
let _g = span.enter();
|
||||
if matches!(&e, InferenceError::Other(inner) if is_device_fault(&format!("{inner:#}")))
|
||||
{
|
||||
tp_for_marker.poisoned.store(true, Ordering::Release);
|
||||
tracing::error!(
|
||||
error = %format!("{e:#}"),
|
||||
error = %chain,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"TP chat_completion: failed, model marked poisoned"
|
||||
"TP chat_completion: failed with device fault, model marked poisoned"
|
||||
);
|
||||
} else {
|
||||
tracing::error!(
|
||||
error = %chain,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"TP chat_completion: failed (non-device fault); model NOT marked poisoned"
|
||||
);
|
||||
}
|
||||
Err(e)
|
||||
}
|
||||
Err(join_err) => {
|
||||
@@ -2426,13 +2515,22 @@ impl CandleHarness {
|
||||
// chunk went out and the spawned task just ended); now
|
||||
// there's always a log line for the operator.
|
||||
if let Some(err) = &failure {
|
||||
if is_device_fault(err) {
|
||||
tp_for_task.poisoned.store(true, Ordering::Release);
|
||||
tracing::error!(
|
||||
error = %err,
|
||||
completion_tokens = all_tokens.len(),
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"TP chat_completion (stream): failed, model marked poisoned"
|
||||
"TP chat_completion (stream): failed with device fault, model marked poisoned"
|
||||
);
|
||||
} else {
|
||||
tracing::error!(
|
||||
error = %err,
|
||||
completion_tokens = all_tokens.len(),
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"TP chat_completion (stream): failed (non-device fault); model NOT marked poisoned"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
tracing::info!(
|
||||
prompt_tokens = prompt_len,
|
||||
@@ -3278,4 +3376,38 @@ mod tests {
|
||||
"message should mention config.json"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_device_fault_rejects_known_non_device_errors() {
|
||||
// Shape mismatches happen pre-kernel; device is healthy.
|
||||
assert!(!is_device_fault(
|
||||
"prefill chunk 0/9: shape mismatch in broadcast_add, lhs: [1, 32, 512, 1024], rhs: [1, 1, 512, 512]"
|
||||
));
|
||||
// NaN logits are CPU-side numerical, not driver.
|
||||
assert!(!is_device_fault(
|
||||
"prefill sample failed; logits unhealthy nan: 248320/248320"
|
||||
));
|
||||
// Tokenizer/detokenizer errors are pure host.
|
||||
assert!(!is_device_fault("tokenize: invalid utf-8 sequence"));
|
||||
assert!(!is_device_fault("detokenize: byte fallback failed"));
|
||||
// Missing handle is a dispatch-side bug, not a device fault.
|
||||
assert!(!is_device_fault("ForwardLogits: no model for handle 42"));
|
||||
// DecodeStream errors during SSE are not device faults.
|
||||
assert!(!is_device_fault("decode_stream step failed: invalid prefix"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_device_fault_defaults_to_poisoning() {
|
||||
// Unknown errors default to "poison" — better to over-reject
|
||||
// than to keep serving from a corrupted context.
|
||||
assert!(is_device_fault("some unrecognised candle error"));
|
||||
// Real driver faults — these strings come from cudarc's
|
||||
// DriverError Display impl and we want them to poison.
|
||||
assert!(is_device_fault(
|
||||
"leader forward failed: DriverError(CUDA_ERROR_OUT_OF_MEMORY, \"out of memory\")"
|
||||
));
|
||||
assert!(is_device_fault(
|
||||
"DriverError(CUDA_ERROR_ILLEGAL_ADDRESS, \"an illegal memory access was encountered\")"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user