fix(neuron): only poison the model on actual device faults
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Clippy (push) Successful in 2m22s
CI / Test (push) Successful in 4m55s
build-prerelease / Build cortex binary (push) Successful in 4m24s
build-prerelease / Build neuron-blackwell (push) Successful in 5m49s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 8m7s
build-prerelease / Build neuron-ada (push) Successful in 5m0s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m48s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m5s
CI / Format (push) Failing after 33s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped

Previously every inference Err — shape mismatch, NaN logits, tokenizer
error, missing handle — marked the model poisoned and rejected every
subsequent request until an operator unload+reloaded. The benjy
incident on 2026-05-27 showed how this misfires: a concurrency bug
produced a `broadcast_add: shape mismatch` error that had nothing to
do with CUDA, but the model was taken down anyway.

Add `is_device_fault(err_chain: &str)` — a conservative classifier
that returns false only for errors we know are pre-kernel / CPU-side
(shape mismatches, NaN logits, tokenize/detokenize, missing handle,
DecodeStream, empty prompt). Everything else defaults to true so a
genuine driver fault still poisons.

Applied at all six poisoning sites:
  - chat_completion CUDA worker path
  - chat_completion CPU spawn_blocking path
  - chat_completion_stream CUDA worker path
  - chat_completion_stream CPU spawn_blocking path
  - chat_completion_tp non-streaming wrapper
  - chat_completion_tp_stream spawned task

Each site now logs either "model marked poisoned" (device fault) or
"model NOT marked poisoned" (non-device) so the journal makes the
classification visible. Tests cover the known non-device patterns and
a couple of real CUDA driver messages.

Pairs with the inference_lock commit (c59da83): together they
eliminate both the cause of the spurious-poisoning we just observed
(the shape mismatch) AND the over-reaction to it (the unconditional
poison). Each fix is independently useful but the combination is
what makes the system actually robust to concurrent agent workloads.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-27 18:57:48 +03:00
parent c59da83636
commit 249b2e5c98

View File

@@ -581,6 +581,41 @@ fn logits_health_slice(values: &[f32]) -> LogitsHealth {
} }
} }
/// Classify an inference-failure error string: should we mark the
/// model poisoned, or is this a logic / numerical / tokenizer failure
/// that leaves the device context healthy? Default is "yes, poison" —
/// the cost of failing to poison a genuinely-corrupt context (next
/// request hangs or returns garbage) outweighs the cost of
/// over-poisoning (operator unload+reloads). The opt-out list covers
/// errors we know don't touch device state.
///
/// Pass the `format!("{err:#}")` rendering of an anyhow::Error (or the
/// already-stringified error in paths that stringify failures, like
/// the TP streaming task). Matching against the full chain lets the
/// classification survive `.context("…")` and `format!("…: {e}")`
/// wrappers in the call sites.
fn is_device_fault(chain_text: &str) -> bool {
let chain = chain_text.to_lowercase();
// Non-device patterns: shape errors are pre-kernel and don't touch
// GPU state; NaN-logits failures happen on the CPU side after the
// forward; tokenize/detokenize is pure CPU; missing-handle lookups
// are pre-dispatch. Everything else we treat conservatively as a
// potential device fault.
let non_device_markers = [
"shape mismatch",
"broadcast",
"cannot broadcast",
"logits unhealthy",
"tokenize",
"detokenize",
"decode_stream",
"no model for handle",
"no tp model for handle",
"empty prompt",
];
!non_device_markers.iter().any(|m| chain.contains(m))
}
/// Build the InferenceError reported to a client when their request /// Build the InferenceError reported to a client when their request
/// hits a model that's been marked poisoned by an earlier driver /// hits a model that's been marked poisoned by an earlier driver
/// failure. The message names the model and the recovery procedure so /// failure. The message names the model and the recovery procedure so
@@ -1395,7 +1430,19 @@ impl CandleHarness {
{ {
Ok(v) => v, Ok(v) => v,
Err(e) => { Err(e) => {
let chain = format!("{e:#}");
if is_device_fault(&chain) {
loaded.poisoned.store(true, Ordering::Release); loaded.poisoned.store(true, Ordering::Release);
tracing::warn!(
error = %chain,
"chat_completion: failed with device fault, model marked poisoned"
);
} else {
tracing::warn!(
error = %chain,
"chat_completion: failed (non-device fault); model NOT marked poisoned"
);
}
return Err(InferenceError::Other(e)); return Err(InferenceError::Other(e));
} }
} }
@@ -1438,7 +1485,19 @@ impl CandleHarness {
match inference_result { match inference_result {
Ok(Ok(v)) => v, Ok(Ok(v)) => v,
Ok(Err(e)) => { Ok(Err(e)) => {
let chain = format!("{e:#}");
if is_device_fault(&chain) {
loaded.poisoned.store(true, Ordering::Release); loaded.poisoned.store(true, Ordering::Release);
tracing::warn!(
error = %chain,
"chat_completion: failed with device fault, model marked poisoned"
);
} else {
tracing::warn!(
error = %chain,
"chat_completion: failed (non-device fault); model NOT marked poisoned"
);
}
return Err(InferenceError::Other(e)); return Err(InferenceError::Other(e));
} }
Err(join_err) => { Err(join_err) => {
@@ -1682,13 +1741,23 @@ impl CandleHarness {
"chat_completion (stream): done" "chat_completion (stream): done"
), ),
Err(e) => { Err(e) => {
let chain = format!("{e:#}");
if is_device_fault(&chain) {
loaded_for_task.poisoned.store(true, Ordering::Release); loaded_for_task.poisoned.store(true, Ordering::Release);
tracing::error!( tracing::error!(
error = %format!("{e:#}"), error = %chain,
prompt_tokens = prompt_len, prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(), total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): failed, model marked poisoned" "chat_completion (stream): failed with device fault, model marked poisoned"
); );
} else {
tracing::error!(
error = %chain,
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): failed (non-device fault); model NOT marked poisoned"
);
}
} }
} }
} }
@@ -1729,13 +1798,23 @@ impl CandleHarness {
"chat_completion (stream): done" "chat_completion (stream): done"
), ),
Err(e) => { Err(e) => {
let chain = format!("{e:#}");
if is_device_fault(&chain) {
loaded_for_task.poisoned.store(true, Ordering::Release); loaded_for_task.poisoned.store(true, Ordering::Release);
tracing::error!( tracing::error!(
error = %format!("{e:#}"), error = %chain,
prompt_tokens = prompt_len, prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(), total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): failed, model marked poisoned" "chat_completion (stream): failed with device fault, model marked poisoned"
); );
} else {
tracing::error!(
error = %chain,
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): failed (non-device fault); model NOT marked poisoned"
);
}
} }
} }
}); });
@@ -2103,18 +2182,28 @@ impl CandleHarness {
match handle.await { match handle.await {
Ok(Ok(resp)) => Ok(resp), Ok(Ok(resp)) => Ok(resp),
Ok(Err(e)) => { Ok(Err(e)) => {
// The inner task returned Err — a real inference // The inner task returned Err. Only poison when the
// failure that propagated through `?`. CUDA / NCCL // failure indicates a CUDA / NCCL driver fault — shape
// driver errors leave the device context unrecoverable, // mismatches, NaN logits, tokenizer errors etc. don't
// so poison the model. This is the gate that turned // touch the device context and shouldn't take the
// the 2026-05-26 silent-hang into a clean 5xx. // model down for everyone else.
tp_for_marker.poisoned.store(true, Ordering::Release); let chain = format!("{e:#}");
let _g = span.enter(); let _g = span.enter();
if matches!(&e, InferenceError::Other(inner) if is_device_fault(&format!("{inner:#}")))
{
tp_for_marker.poisoned.store(true, Ordering::Release);
tracing::error!( tracing::error!(
error = %format!("{e:#}"), error = %chain,
total_ms = req_start.elapsed().as_millis(), total_ms = req_start.elapsed().as_millis(),
"TP chat_completion: failed, model marked poisoned" "TP chat_completion: failed with device fault, model marked poisoned"
); );
} else {
tracing::error!(
error = %chain,
total_ms = req_start.elapsed().as_millis(),
"TP chat_completion: failed (non-device fault); model NOT marked poisoned"
);
}
Err(e) Err(e)
} }
Err(join_err) => { Err(join_err) => {
@@ -2426,13 +2515,22 @@ impl CandleHarness {
// chunk went out and the spawned task just ended); now // chunk went out and the spawned task just ended); now
// there's always a log line for the operator. // there's always a log line for the operator.
if let Some(err) = &failure { if let Some(err) = &failure {
if is_device_fault(err) {
tp_for_task.poisoned.store(true, Ordering::Release); tp_for_task.poisoned.store(true, Ordering::Release);
tracing::error!( tracing::error!(
error = %err, error = %err,
completion_tokens = all_tokens.len(), completion_tokens = all_tokens.len(),
total_ms = req_start.elapsed().as_millis(), total_ms = req_start.elapsed().as_millis(),
"TP chat_completion (stream): failed, model marked poisoned" "TP chat_completion (stream): failed with device fault, model marked poisoned"
); );
} else {
tracing::error!(
error = %err,
completion_tokens = all_tokens.len(),
total_ms = req_start.elapsed().as_millis(),
"TP chat_completion (stream): failed (non-device fault); model NOT marked poisoned"
);
}
} else { } else {
tracing::info!( tracing::info!(
prompt_tokens = prompt_len, prompt_tokens = prompt_len,
@@ -3278,4 +3376,38 @@ mod tests {
"message should mention config.json" "message should mention config.json"
); );
} }
#[test]
fn is_device_fault_rejects_known_non_device_errors() {
// Shape mismatches happen pre-kernel; device is healthy.
assert!(!is_device_fault(
"prefill chunk 0/9: shape mismatch in broadcast_add, lhs: [1, 32, 512, 1024], rhs: [1, 1, 512, 512]"
));
// NaN logits are CPU-side numerical, not driver.
assert!(!is_device_fault(
"prefill sample failed; logits unhealthy nan: 248320/248320"
));
// Tokenizer/detokenizer errors are pure host.
assert!(!is_device_fault("tokenize: invalid utf-8 sequence"));
assert!(!is_device_fault("detokenize: byte fallback failed"));
// Missing handle is a dispatch-side bug, not a device fault.
assert!(!is_device_fault("ForwardLogits: no model for handle 42"));
// DecodeStream errors during SSE are not device faults.
assert!(!is_device_fault("decode_stream step failed: invalid prefix"));
}
#[test]
fn is_device_fault_defaults_to_poisoning() {
// Unknown errors default to "poison" — better to over-reject
// than to keep serving from a corrupted context.
assert!(is_device_fault("some unrecognised candle error"));
// Real driver faults — these strings come from cudarc's
// DriverError Display impl and we want them to poison.
assert!(is_device_fault(
"leader forward failed: DriverError(CUDA_ERROR_OUT_OF_MEMORY, \"out of memory\")"
));
assert!(is_device_fault(
"DriverError(CUDA_ERROR_ILLEGAL_ADDRESS, \"an illegal memory access was encountered\")"
));
}
} }