diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 6930fc9..0194b0e 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -581,6 +581,41 @@ fn logits_health_slice(values: &[f32]) -> LogitsHealth { } } +/// Classify an inference-failure error string: should we mark the +/// model poisoned, or is this a logic / numerical / tokenizer failure +/// that leaves the device context healthy? Default is "yes, poison" — +/// the cost of failing to poison a genuinely-corrupt context (next +/// request hangs or returns garbage) outweighs the cost of +/// over-poisoning (operator unload+reloads). The opt-out list covers +/// errors we know don't touch device state. +/// +/// Pass the `format!("{err:#}")` rendering of an anyhow::Error (or the +/// already-stringified error in paths that stringify failures, like +/// the TP streaming task). Matching against the full chain lets the +/// classification survive `.context("…")` and `format!("…: {e}")` +/// wrappers in the call sites. +fn is_device_fault(chain_text: &str) -> bool { + let chain = chain_text.to_lowercase(); + // Non-device patterns: shape errors are pre-kernel and don't touch + // GPU state; NaN-logits failures happen on the CPU side after the + // forward; tokenize/detokenize is pure CPU; missing-handle lookups + // are pre-dispatch. Everything else we treat conservatively as a + // potential device fault. + let non_device_markers = [ + "shape mismatch", + "broadcast", + "cannot broadcast", + "logits unhealthy", + "tokenize", + "detokenize", + "decode_stream", + "no model for handle", + "no tp model for handle", + "empty prompt", + ]; + !non_device_markers.iter().any(|m| chain.contains(m)) +} + /// Build the InferenceError reported to a client when their request /// hits a model that's been marked poisoned by an earlier driver /// failure. The message names the model and the recovery procedure so @@ -1395,7 +1430,19 @@ impl CandleHarness { { Ok(v) => v, Err(e) => { - loaded.poisoned.store(true, Ordering::Release); + let chain = format!("{e:#}"); + if is_device_fault(&chain) { + loaded.poisoned.store(true, Ordering::Release); + tracing::warn!( + error = %chain, + "chat_completion: failed with device fault, model marked poisoned" + ); + } else { + tracing::warn!( + error = %chain, + "chat_completion: failed (non-device fault); model NOT marked poisoned" + ); + } return Err(InferenceError::Other(e)); } } @@ -1438,7 +1485,19 @@ impl CandleHarness { match inference_result { Ok(Ok(v)) => v, Ok(Err(e)) => { - loaded.poisoned.store(true, Ordering::Release); + let chain = format!("{e:#}"); + if is_device_fault(&chain) { + loaded.poisoned.store(true, Ordering::Release); + tracing::warn!( + error = %chain, + "chat_completion: failed with device fault, model marked poisoned" + ); + } else { + tracing::warn!( + error = %chain, + "chat_completion: failed (non-device fault); model NOT marked poisoned" + ); + } return Err(InferenceError::Other(e)); } Err(join_err) => { @@ -1682,13 +1741,23 @@ impl CandleHarness { "chat_completion (stream): done" ), Err(e) => { - loaded_for_task.poisoned.store(true, Ordering::Release); - tracing::error!( - error = %format!("{e:#}"), - prompt_tokens = prompt_len, - total_ms = req_start.elapsed().as_millis(), - "chat_completion (stream): failed, model marked poisoned" - ); + let chain = format!("{e:#}"); + if is_device_fault(&chain) { + loaded_for_task.poisoned.store(true, Ordering::Release); + tracing::error!( + error = %chain, + prompt_tokens = prompt_len, + total_ms = req_start.elapsed().as_millis(), + "chat_completion (stream): failed with device fault, model marked poisoned" + ); + } else { + tracing::error!( + error = %chain, + prompt_tokens = prompt_len, + total_ms = req_start.elapsed().as_millis(), + "chat_completion (stream): failed (non-device fault); model NOT marked poisoned" + ); + } } } } @@ -1729,13 +1798,23 @@ impl CandleHarness { "chat_completion (stream): done" ), Err(e) => { - loaded_for_task.poisoned.store(true, Ordering::Release); - tracing::error!( - error = %format!("{e:#}"), - prompt_tokens = prompt_len, - total_ms = req_start.elapsed().as_millis(), - "chat_completion (stream): failed, model marked poisoned" - ); + let chain = format!("{e:#}"); + if is_device_fault(&chain) { + loaded_for_task.poisoned.store(true, Ordering::Release); + tracing::error!( + error = %chain, + prompt_tokens = prompt_len, + total_ms = req_start.elapsed().as_millis(), + "chat_completion (stream): failed with device fault, model marked poisoned" + ); + } else { + tracing::error!( + error = %chain, + prompt_tokens = prompt_len, + total_ms = req_start.elapsed().as_millis(), + "chat_completion (stream): failed (non-device fault); model NOT marked poisoned" + ); + } } } }); @@ -2103,18 +2182,28 @@ impl CandleHarness { match handle.await { Ok(Ok(resp)) => Ok(resp), Ok(Err(e)) => { - // The inner task returned Err — a real inference - // failure that propagated through `?`. CUDA / NCCL - // driver errors leave the device context unrecoverable, - // so poison the model. This is the gate that turned - // the 2026-05-26 silent-hang into a clean 5xx. - tp_for_marker.poisoned.store(true, Ordering::Release); + // The inner task returned Err. Only poison when the + // failure indicates a CUDA / NCCL driver fault — shape + // mismatches, NaN logits, tokenizer errors etc. don't + // touch the device context and shouldn't take the + // model down for everyone else. + let chain = format!("{e:#}"); let _g = span.enter(); - tracing::error!( - error = %format!("{e:#}"), - total_ms = req_start.elapsed().as_millis(), - "TP chat_completion: failed, model marked poisoned" - ); + if matches!(&e, InferenceError::Other(inner) if is_device_fault(&format!("{inner:#}"))) + { + tp_for_marker.poisoned.store(true, Ordering::Release); + tracing::error!( + error = %chain, + total_ms = req_start.elapsed().as_millis(), + "TP chat_completion: failed with device fault, model marked poisoned" + ); + } else { + tracing::error!( + error = %chain, + total_ms = req_start.elapsed().as_millis(), + "TP chat_completion: failed (non-device fault); model NOT marked poisoned" + ); + } Err(e) } Err(join_err) => { @@ -2426,13 +2515,22 @@ impl CandleHarness { // chunk went out and the spawned task just ended); now // there's always a log line for the operator. if let Some(err) = &failure { - tp_for_task.poisoned.store(true, Ordering::Release); - tracing::error!( - error = %err, - completion_tokens = all_tokens.len(), - total_ms = req_start.elapsed().as_millis(), - "TP chat_completion (stream): failed, model marked poisoned" - ); + if is_device_fault(err) { + tp_for_task.poisoned.store(true, Ordering::Release); + tracing::error!( + error = %err, + completion_tokens = all_tokens.len(), + total_ms = req_start.elapsed().as_millis(), + "TP chat_completion (stream): failed with device fault, model marked poisoned" + ); + } else { + tracing::error!( + error = %err, + completion_tokens = all_tokens.len(), + total_ms = req_start.elapsed().as_millis(), + "TP chat_completion (stream): failed (non-device fault); model NOT marked poisoned" + ); + } } else { tracing::info!( prompt_tokens = prompt_len, @@ -3278,4 +3376,38 @@ mod tests { "message should mention config.json" ); } + + #[test] + fn is_device_fault_rejects_known_non_device_errors() { + // Shape mismatches happen pre-kernel; device is healthy. + assert!(!is_device_fault( + "prefill chunk 0/9: shape mismatch in broadcast_add, lhs: [1, 32, 512, 1024], rhs: [1, 1, 512, 512]" + )); + // NaN logits are CPU-side numerical, not driver. + assert!(!is_device_fault( + "prefill sample failed; logits unhealthy nan: 248320/248320" + )); + // Tokenizer/detokenizer errors are pure host. + assert!(!is_device_fault("tokenize: invalid utf-8 sequence")); + assert!(!is_device_fault("detokenize: byte fallback failed")); + // Missing handle is a dispatch-side bug, not a device fault. + assert!(!is_device_fault("ForwardLogits: no model for handle 42")); + // DecodeStream errors during SSE are not device faults. + assert!(!is_device_fault("decode_stream step failed: invalid prefix")); + } + + #[test] + fn is_device_fault_defaults_to_poisoning() { + // Unknown errors default to "poison" — better to over-reject + // than to keep serving from a corrupted context. + assert!(is_device_fault("some unrecognised candle error")); + // Real driver faults — these strings come from cudarc's + // DriverError Display impl and we want them to poison. + assert!(is_device_fault( + "leader forward failed: DriverError(CUDA_ERROR_OUT_OF_MEMORY, \"out of memory\")" + )); + assert!(is_device_fault( + "DriverError(CUDA_ERROR_ILLEGAL_ADDRESS, \"an illegal memory access was encountered\")" + )); + } }