From 4f2957af9ea0ed5e03ca600d4f4c8ac647580023 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Mon, 8 Jun 2026 09:05:02 +0300 Subject: [PATCH] feat(neuron): auto-recover poisoned models (#17 Stage 1c) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When an inference hit a device fault, the model was flagged poisoned and every subsequent request rejected with "unload and reload the model to recover" — until a *human* did exactly that. Now the harness rebuilds the context automatically. - Retain the loading `ModelSpec` on `LoadedModel`/`TpLoadedModel` (+ `LoadedHandle::spec()`) so a poisoned model can be reloaded without an operator reconstructing the spec. - A background recovery task (held via `Weak`, spawned in `new()` when a runtime is present) drains poisoned model ids and runs `unload_model` → `load_model(spec)`. Unload drops the model → cudarc `Comm::drop` aborts NCCL + releases the context; reload re-runs NCCL init + sanity inside the load path, so a successful reload yields a fresh, healthy model. A failed reload leaves it unloaded (next load retries) — never poisoned forever. - The request-entry poison gates now `trigger_recovery` (single-flight per model via a `recovering` set) and return a transient "recovering, retry shortly" error instead of the manual-reload message. Requests that arrive during the brief reload gap (model absent from the registry) also get "recovering" rather than a misleading "not loaded". `new()` now returns `Arc`. Recovery runs only on the background task — never inline on the request path, which holds `inference_lock` and would deadlock on the `models` write lock. Stage 1c of the #17 plan (verified-healthy auto-recovery). Watchdog (1b) + a fault-injection hook for beast verification follow. The in-process rank-0 leader's own context fault still needs a reload that can't rebind it (Stage 3); comm-desync + worker faults recover here. Co-Authored-By: Claude Opus 4.8 (1M context) --- crates/neuron/src/harness/candle.rs | 168 ++++++++++++++++++++++++++-- crates/neuron/src/harness/mod.rs | 6 +- 2 files changed, 162 insertions(+), 12 deletions(-) diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 2dc8da7..d92ab20 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -60,6 +60,17 @@ pub struct CandleHarness { /// can still load on CPU for tests, just without worker threads). #[allow(dead_code)] device_workers: Arc>>>, + /// Auto-recovery (#17): model ids whose poisoned context is being + /// rebuilt via unload+reload. Insert is the single-flight gate (one + /// recovery per model in flight); membership also lets the request + /// path answer "recovering, retry shortly" during the reload gap + /// rather than a bare "not loaded". + recovering: Arc>>, + /// Sender to the background recovery task. The request path enqueues + /// a poisoned model id here; the task (holding a `Weak`) runs + /// the unload→reload→health-gate. Unbounded + tiny (model ids), and + /// the `recovering` set dedupes, so it can't back up. + recovery_tx: tokio::sync::mpsc::UnboundedSender, } /// One entry in the harness's loaded-model registry. Single-GPU loads @@ -86,6 +97,15 @@ impl LoadedHandle { } } + /// The spec this model was loaded from (for auto-recovery #17). + pub fn spec(&self) -> &ModelSpec { + match self { + LoadedHandle::Single(m) => &m.spec, + #[cfg(feature = "cuda")] + LoadedHandle::Tp(m) => &m.spec, + } + } + pub fn devices(&self) -> Vec { match self { LoadedHandle::Single(m) => m.devices.clone(), @@ -215,6 +235,10 @@ pub struct LoadedModel { /// `(h/factor) × (w/factor)` (#14 dynamic resolution). `None` for /// text-only models. Set at load time. pub image_grid_factor: Option, + /// The spec this model was loaded from — retained so auto-recovery + /// (#17) can `unload_model` + `load_model(spec)` a poisoned model + /// without an operator reconstructing it. + pub spec: ModelSpec, } impl LoadedModel { @@ -289,6 +313,9 @@ pub struct TpLoadedModel { /// Pixel→LM-grid divisor — same as /// [`LoadedModel::image_grid_factor`]. pub image_grid_factor: Option, + /// Loading spec, retained for auto-recovery (#17) — see + /// [`LoadedModel::spec`]. + pub spec: ModelSpec, } #[cfg(feature = "cuda")] @@ -792,6 +819,34 @@ fn poisoned_error(model_id: &str) -> InferenceError { )) } +/// Reported while auto-recovery (#17) is rebuilding a poisoned model's +/// context. Unlike [`poisoned_error`] this is a *transient* state — the +/// model is being reloaded automatically; the client should retry. +fn recovering_error(model_id: &str) -> InferenceError { + InferenceError::Other(anyhow::anyhow!( + "model '{model_id}' is recovering (its device context was poisoned \ + by an earlier failure and is being automatically rebuilt); retry \ + shortly" + )) +} + +/// Background auto-recovery task (#17). Drains poisoned model ids and +/// rebuilds each via [`CandleHarness::recover_one`]. Holds a `Weak` so a +/// shutting-down harness lets the task exit; processes one id at a time, +/// which (with the `recovering` set deduping enqueues) keeps recovery +/// single-flight per model. +async fn recovery_loop( + weak: std::sync::Weak, + mut rx: tokio::sync::mpsc::UnboundedReceiver, +) { + while let Some(model_id) = rx.recv().await { + let Some(this) = weak.upgrade() else { + break; + }; + this.recover_one(&model_id).await; + } +} + /// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if /// the query fails or the device is the CPU fallback so logging never /// crashes the request path. Mirrors the existing helper in @@ -1146,7 +1201,7 @@ impl CandleHarness { /// Construct a new harness for `bind_url` using `config`. Resolves /// every configured source's auth env var and cache dir up front so /// the hot load path (`hf_api_for`) is a pure HashMap lookup. - pub fn new(bind_url: String, config: &crate::config::CandleHarnessConfig) -> Self { + pub fn new(bind_url: String, config: &crate::config::CandleHarnessConfig) -> Arc { let raw_sources = config.effective_sources(); let default_source = config.effective_default_source().to_string(); let mut sources = HashMap::with_capacity(raw_sources.len()); @@ -1196,13 +1251,25 @@ impl CandleHarness { bare model ids will fail to resolve until this is fixed" ); } - Self { + let (recovery_tx, recovery_rx) = tokio::sync::mpsc::unbounded_channel::(); + let this = Arc::new(Self { models: Arc::new(RwLock::new(HashMap::new())), sources, default_source, bind_url, device_workers: Arc::new(RwLock::new(HashMap::new())), + recovering: Arc::new(RwLock::new(std::collections::HashSet::new())), + recovery_tx, + }); + // Background auto-recovery task (#17). Holds a `Weak` so it can't + // keep the harness alive. Spawned only when a tokio runtime is + // present — sync unit tests that build a harness without one + // simply skip it (they don't exercise recovery). + if tokio::runtime::Handle::try_current().is_ok() { + let weak = Arc::downgrade(&this); + tokio::spawn(recovery_loop(weak, recovery_rx)); } + this } /// Scheme to substitute for bare `org/name` model ids. Mirrors the @@ -1627,7 +1694,17 @@ impl CandleHarness { let models = self.models.read().await; models.get(&request.model).cloned() }; - let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?; + let handle = match handle { + Some(h) => h, + // Absent from the registry: distinguish a genuinely unloaded + // model from one whose slot is briefly gone mid auto-recovery + // (#17), so the client gets a transient "retry shortly" instead + // of a misleading "not loaded". + None if self.is_recovering(&request.model).await => { + return Err(recovering_error(&request.model)); + } + None => return Err(InferenceError::ModelNotLoaded(request.model.clone())), + }; // The match is technically infallible without `cuda` (only Single // exists), but the cfg-gated Tp arm makes this the right shape // under both feature flags. @@ -1657,7 +1734,7 @@ impl CandleHarness { if loaded.poisoned.load(Ordering::Acquire) { let _g = span.enter(); tracing::warn!("chat_completion: refusing request, model poisoned"); - return Err(poisoned_error(&model_id)); + return Err(self.trigger_recovery(&model_id).await); } // Serialise concurrent requests against this model. Holds for @@ -2036,7 +2113,17 @@ impl CandleHarness { let models = self.models.read().await; models.get(&request.model).cloned() }; - let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?; + let handle = match handle { + Some(h) => h, + // Absent from the registry: distinguish a genuinely unloaded + // model from one whose slot is briefly gone mid auto-recovery + // (#17), so the client gets a transient "retry shortly" instead + // of a misleading "not loaded". + None if self.is_recovering(&request.model).await => { + return Err(recovering_error(&request.model)); + } + None => return Err(InferenceError::ModelNotLoaded(request.model.clone())), + }; // The match is technically infallible without `cuda` (only Single // exists), but the cfg-gated Tp arm makes this the right shape // under both feature flags. @@ -2129,7 +2216,7 @@ impl CandleHarness { // Refuse if the model is already poisoned. No point opening // an SSE stream just to send the Start event and then bail. if loaded.poisoned.load(Ordering::Acquire) { - return Err(poisoned_error(&model_id)); + return Err(self.trigger_recovery(&model_id).await); } // Start event: tells the wire projector to emit its @@ -2347,6 +2434,69 @@ pub struct InferenceStream { pub reasoning_markers: Option, } +/// Auto-recovery (#17) — rebuild a poisoned model's device context +/// automatically instead of leaving it bricked until a human reloads. +impl CandleHarness { + /// True while `model_id` is being auto-recovered (its slot is briefly + /// absent from the registry during the reload). + pub async fn is_recovering(&self, model_id: &str) -> bool { + self.recovering.read().await.contains(model_id) + } + + /// Single-flight trigger from the request path: enqueue a rebuild for a + /// poisoned model (only the first caller per model enqueues) and return + /// the transient "recovering" error to hand back to the client. + async fn trigger_recovery(&self, model_id: &str) -> InferenceError { + let newly = self.recovering.write().await.insert(model_id.to_string()); + if newly { + tracing::warn!(model = %model_id, "auto-recovery: poisoned, enqueueing rebuild"); + if self.recovery_tx.send(model_id.to_string()).is_err() { + // Background task gone (harness shutting down). Drop the + // marker and fall back to the manual-reload message. + self.recovering.write().await.remove(model_id); + tracing::error!(model = %model_id, "auto-recovery: task unavailable"); + return poisoned_error(model_id); + } + } + recovering_error(model_id) + } + + /// Rebuild a poisoned model: `unload_model` (drops it → cudarc aborts + /// NCCL + releases the context) then `load_model` from the retained + /// spec. A successful reload re-runs NCCL init + sanity inside the load + /// path, so it returns a fresh, healthy model; a failed reload leaves + /// the model unloaded (recoverable by the next load), never poisoned + /// forever. Runs on the background task — never inline on the request + /// path (would deadlock on the `models` write lock). + async fn recover_one(&self, model_id: &str) { + let spec = { + let models = self.models.read().await; + models.get(model_id).map(|h| h.spec().clone()) + }; + let Some(spec) = spec else { + self.recovering.write().await.remove(model_id); + return; + }; + tracing::warn!(model = %model_id, "auto-recovery: unload+reload starting"); + if let Err(e) = self.unload_model(model_id).await { + tracing::error!( + model = %model_id, + error = %format!("{e:#}"), + "auto-recovery: unload failed (continuing to reload)" + ); + } + match self.load_model(&spec).await { + Ok(()) => tracing::info!(model = %model_id, "auto-recovery: reloaded; model healthy"), + Err(e) => tracing::error!( + model = %model_id, + error = %format!("{e:#}"), + "auto-recovery: reload failed; model left unloaded" + ), + } + self.recovering.write().await.remove(model_id); + } +} + #[async_trait] impl Harness for CandleHarness { fn name(&self) -> &str { @@ -2550,6 +2700,7 @@ impl Harness for CandleHarness { has_vision: vision_meta.has_vision, image_token_id: vision_meta.image_token_id, image_grid_factor: vision_meta.image_grid_factor, + spec: spec.clone(), }); let mut models = self.models.write().await; @@ -2788,6 +2939,7 @@ impl CandleHarness { has_vision: vision_meta.has_vision, image_token_id: vision_meta.image_token_id, image_grid_factor: vision_meta.image_grid_factor, + spec: spec.clone(), }); let mut models = self.models.write().await; @@ -2834,7 +2986,7 @@ impl CandleHarness { if tp.poisoned.load(Ordering::Acquire) { let _g = span.enter(); tracing::warn!("TP chat_completion: refusing request, model poisoned"); - return Err(poisoned_error(&model_id)); + return Err(self.trigger_recovery(&model_id).await); } // Reject image-bearing requests against a TP model with no @@ -2923,7 +3075,7 @@ impl CandleHarness { request: ChatCompletionRequest, ) -> Result { if tp.poisoned.load(Ordering::Acquire) { - return Err(poisoned_error(&request.model)); + return Err(self.trigger_recovery(&request.model).await); } // Reject image requests against a non-vision TP model before diff --git a/crates/neuron/src/harness/mod.rs b/crates/neuron/src/harness/mod.rs index 4b3b501..d2846ee 100644 --- a/crates/neuron/src/harness/mod.rs +++ b/crates/neuron/src/harness/mod.rs @@ -114,10 +114,8 @@ impl HarnessRegistry { for config in configs { match config.name.as_str() { "candle" => { - let harness = Arc::new(candle::CandleHarness::new( - bind_url.to_string(), - &settings.candle, - )); + let harness = + candle::CandleHarness::new(bind_url.to_string(), &settings.candle); registry.candle = Some(Arc::clone(&harness)); registry.harnesses.insert("candle".into(), harness); }