feat(neuron): auto-recover poisoned models (#17 Stage 1c)
When an inference hit a device fault, the model was flagged poisoned and every subsequent request rejected with "unload and reload the model to recover" — until a *human* did exactly that. Now the harness rebuilds the context automatically. - Retain the loading `ModelSpec` on `LoadedModel`/`TpLoadedModel` (+ `LoadedHandle::spec()`) so a poisoned model can be reloaded without an operator reconstructing the spec. - A background recovery task (held via `Weak<CandleHarness>`, spawned in `new()` when a runtime is present) drains poisoned model ids and runs `unload_model` → `load_model(spec)`. Unload drops the model → cudarc `Comm::drop` aborts NCCL + releases the context; reload re-runs NCCL init + sanity inside the load path, so a successful reload yields a fresh, healthy model. A failed reload leaves it unloaded (next load retries) — never poisoned forever. - The request-entry poison gates now `trigger_recovery` (single-flight per model via a `recovering` set) and return a transient "recovering, retry shortly" error instead of the manual-reload message. Requests that arrive during the brief reload gap (model absent from the registry) also get "recovering" rather than a misleading "not loaded". `new()` now returns `Arc<Self>`. Recovery runs only on the background task — never inline on the request path, which holds `inference_lock` and would deadlock on the `models` write lock. Stage 1c of the #17 plan (verified-healthy auto-recovery). Watchdog (1b) + a fault-injection hook for beast verification follow. The in-process rank-0 leader's own context fault still needs a reload that can't rebind it (Stage 3); comm-desync + worker faults recover here. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -60,6 +60,17 @@ pub struct CandleHarness {
|
|||||||
/// can still load on CPU for tests, just without worker threads).
|
/// can still load on CPU for tests, just without worker threads).
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
device_workers: Arc<RwLock<HashMap<u32, Arc<super::device_worker::DeviceWorkerHandle>>>>,
|
device_workers: Arc<RwLock<HashMap<u32, Arc<super::device_worker::DeviceWorkerHandle>>>>,
|
||||||
|
/// Auto-recovery (#17): model ids whose poisoned context is being
|
||||||
|
/// rebuilt via unload+reload. Insert is the single-flight gate (one
|
||||||
|
/// recovery per model in flight); membership also lets the request
|
||||||
|
/// path answer "recovering, retry shortly" during the reload gap
|
||||||
|
/// rather than a bare "not loaded".
|
||||||
|
recovering: Arc<RwLock<std::collections::HashSet<String>>>,
|
||||||
|
/// Sender to the background recovery task. The request path enqueues
|
||||||
|
/// a poisoned model id here; the task (holding a `Weak<Self>`) runs
|
||||||
|
/// the unload→reload→health-gate. Unbounded + tiny (model ids), and
|
||||||
|
/// the `recovering` set dedupes, so it can't back up.
|
||||||
|
recovery_tx: tokio::sync::mpsc::UnboundedSender<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// One entry in the harness's loaded-model registry. Single-GPU loads
|
/// One entry in the harness's loaded-model registry. Single-GPU loads
|
||||||
@@ -86,6 +97,15 @@ impl LoadedHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The spec this model was loaded from (for auto-recovery #17).
|
||||||
|
pub fn spec(&self) -> &ModelSpec {
|
||||||
|
match self {
|
||||||
|
LoadedHandle::Single(m) => &m.spec,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
LoadedHandle::Tp(m) => &m.spec,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn devices(&self) -> Vec<u32> {
|
pub fn devices(&self) -> Vec<u32> {
|
||||||
match self {
|
match self {
|
||||||
LoadedHandle::Single(m) => m.devices.clone(),
|
LoadedHandle::Single(m) => m.devices.clone(),
|
||||||
@@ -215,6 +235,10 @@ pub struct LoadedModel {
|
|||||||
/// `(h/factor) × (w/factor)` (#14 dynamic resolution). `None` for
|
/// `(h/factor) × (w/factor)` (#14 dynamic resolution). `None` for
|
||||||
/// text-only models. Set at load time.
|
/// text-only models. Set at load time.
|
||||||
pub image_grid_factor: Option<usize>,
|
pub image_grid_factor: Option<usize>,
|
||||||
|
/// The spec this model was loaded from — retained so auto-recovery
|
||||||
|
/// (#17) can `unload_model` + `load_model(spec)` a poisoned model
|
||||||
|
/// without an operator reconstructing it.
|
||||||
|
pub spec: ModelSpec,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LoadedModel {
|
impl LoadedModel {
|
||||||
@@ -289,6 +313,9 @@ pub struct TpLoadedModel {
|
|||||||
/// Pixel→LM-grid divisor — same as
|
/// Pixel→LM-grid divisor — same as
|
||||||
/// [`LoadedModel::image_grid_factor`].
|
/// [`LoadedModel::image_grid_factor`].
|
||||||
pub image_grid_factor: Option<usize>,
|
pub image_grid_factor: Option<usize>,
|
||||||
|
/// Loading spec, retained for auto-recovery (#17) — see
|
||||||
|
/// [`LoadedModel::spec`].
|
||||||
|
pub spec: ModelSpec,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
@@ -792,6 +819,34 @@ fn poisoned_error(model_id: &str) -> InferenceError {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Reported while auto-recovery (#17) is rebuilding a poisoned model's
|
||||||
|
/// context. Unlike [`poisoned_error`] this is a *transient* state — the
|
||||||
|
/// model is being reloaded automatically; the client should retry.
|
||||||
|
fn recovering_error(model_id: &str) -> InferenceError {
|
||||||
|
InferenceError::Other(anyhow::anyhow!(
|
||||||
|
"model '{model_id}' is recovering (its device context was poisoned \
|
||||||
|
by an earlier failure and is being automatically rebuilt); retry \
|
||||||
|
shortly"
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Background auto-recovery task (#17). Drains poisoned model ids and
|
||||||
|
/// rebuilds each via [`CandleHarness::recover_one`]. Holds a `Weak` so a
|
||||||
|
/// shutting-down harness lets the task exit; processes one id at a time,
|
||||||
|
/// which (with the `recovering` set deduping enqueues) keeps recovery
|
||||||
|
/// single-flight per model.
|
||||||
|
async fn recovery_loop(
|
||||||
|
weak: std::sync::Weak<CandleHarness>,
|
||||||
|
mut rx: tokio::sync::mpsc::UnboundedReceiver<String>,
|
||||||
|
) {
|
||||||
|
while let Some(model_id) = rx.recv().await {
|
||||||
|
let Some(this) = weak.upgrade() else {
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
this.recover_one(&model_id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if
|
/// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if
|
||||||
/// the query fails or the device is the CPU fallback so logging never
|
/// the query fails or the device is the CPU fallback so logging never
|
||||||
/// crashes the request path. Mirrors the existing helper in
|
/// crashes the request path. Mirrors the existing helper in
|
||||||
@@ -1146,7 +1201,7 @@ impl CandleHarness {
|
|||||||
/// Construct a new harness for `bind_url` using `config`. Resolves
|
/// Construct a new harness for `bind_url` using `config`. Resolves
|
||||||
/// every configured source's auth env var and cache dir up front so
|
/// every configured source's auth env var and cache dir up front so
|
||||||
/// the hot load path (`hf_api_for`) is a pure HashMap lookup.
|
/// the hot load path (`hf_api_for`) is a pure HashMap lookup.
|
||||||
pub fn new(bind_url: String, config: &crate::config::CandleHarnessConfig) -> Self {
|
pub fn new(bind_url: String, config: &crate::config::CandleHarnessConfig) -> Arc<Self> {
|
||||||
let raw_sources = config.effective_sources();
|
let raw_sources = config.effective_sources();
|
||||||
let default_source = config.effective_default_source().to_string();
|
let default_source = config.effective_default_source().to_string();
|
||||||
let mut sources = HashMap::with_capacity(raw_sources.len());
|
let mut sources = HashMap::with_capacity(raw_sources.len());
|
||||||
@@ -1196,13 +1251,25 @@ impl CandleHarness {
|
|||||||
bare model ids will fail to resolve until this is fixed"
|
bare model ids will fail to resolve until this is fixed"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
Self {
|
let (recovery_tx, recovery_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||||
|
let this = Arc::new(Self {
|
||||||
models: Arc::new(RwLock::new(HashMap::new())),
|
models: Arc::new(RwLock::new(HashMap::new())),
|
||||||
sources,
|
sources,
|
||||||
default_source,
|
default_source,
|
||||||
bind_url,
|
bind_url,
|
||||||
device_workers: Arc::new(RwLock::new(HashMap::new())),
|
device_workers: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
recovering: Arc::new(RwLock::new(std::collections::HashSet::new())),
|
||||||
|
recovery_tx,
|
||||||
|
});
|
||||||
|
// Background auto-recovery task (#17). Holds a `Weak` so it can't
|
||||||
|
// keep the harness alive. Spawned only when a tokio runtime is
|
||||||
|
// present — sync unit tests that build a harness without one
|
||||||
|
// simply skip it (they don't exercise recovery).
|
||||||
|
if tokio::runtime::Handle::try_current().is_ok() {
|
||||||
|
let weak = Arc::downgrade(&this);
|
||||||
|
tokio::spawn(recovery_loop(weak, recovery_rx));
|
||||||
}
|
}
|
||||||
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Scheme to substitute for bare `org/name` model ids. Mirrors the
|
/// Scheme to substitute for bare `org/name` model ids. Mirrors the
|
||||||
@@ -1627,7 +1694,17 @@ impl CandleHarness {
|
|||||||
let models = self.models.read().await;
|
let models = self.models.read().await;
|
||||||
models.get(&request.model).cloned()
|
models.get(&request.model).cloned()
|
||||||
};
|
};
|
||||||
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
let handle = match handle {
|
||||||
|
Some(h) => h,
|
||||||
|
// Absent from the registry: distinguish a genuinely unloaded
|
||||||
|
// model from one whose slot is briefly gone mid auto-recovery
|
||||||
|
// (#17), so the client gets a transient "retry shortly" instead
|
||||||
|
// of a misleading "not loaded".
|
||||||
|
None if self.is_recovering(&request.model).await => {
|
||||||
|
return Err(recovering_error(&request.model));
|
||||||
|
}
|
||||||
|
None => return Err(InferenceError::ModelNotLoaded(request.model.clone())),
|
||||||
|
};
|
||||||
// The match is technically infallible without `cuda` (only Single
|
// The match is technically infallible without `cuda` (only Single
|
||||||
// exists), but the cfg-gated Tp arm makes this the right shape
|
// exists), but the cfg-gated Tp arm makes this the right shape
|
||||||
// under both feature flags.
|
// under both feature flags.
|
||||||
@@ -1657,7 +1734,7 @@ impl CandleHarness {
|
|||||||
if loaded.poisoned.load(Ordering::Acquire) {
|
if loaded.poisoned.load(Ordering::Acquire) {
|
||||||
let _g = span.enter();
|
let _g = span.enter();
|
||||||
tracing::warn!("chat_completion: refusing request, model poisoned");
|
tracing::warn!("chat_completion: refusing request, model poisoned");
|
||||||
return Err(poisoned_error(&model_id));
|
return Err(self.trigger_recovery(&model_id).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialise concurrent requests against this model. Holds for
|
// Serialise concurrent requests against this model. Holds for
|
||||||
@@ -2036,7 +2113,17 @@ impl CandleHarness {
|
|||||||
let models = self.models.read().await;
|
let models = self.models.read().await;
|
||||||
models.get(&request.model).cloned()
|
models.get(&request.model).cloned()
|
||||||
};
|
};
|
||||||
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
let handle = match handle {
|
||||||
|
Some(h) => h,
|
||||||
|
// Absent from the registry: distinguish a genuinely unloaded
|
||||||
|
// model from one whose slot is briefly gone mid auto-recovery
|
||||||
|
// (#17), so the client gets a transient "retry shortly" instead
|
||||||
|
// of a misleading "not loaded".
|
||||||
|
None if self.is_recovering(&request.model).await => {
|
||||||
|
return Err(recovering_error(&request.model));
|
||||||
|
}
|
||||||
|
None => return Err(InferenceError::ModelNotLoaded(request.model.clone())),
|
||||||
|
};
|
||||||
// The match is technically infallible without `cuda` (only Single
|
// The match is technically infallible without `cuda` (only Single
|
||||||
// exists), but the cfg-gated Tp arm makes this the right shape
|
// exists), but the cfg-gated Tp arm makes this the right shape
|
||||||
// under both feature flags.
|
// under both feature flags.
|
||||||
@@ -2129,7 +2216,7 @@ impl CandleHarness {
|
|||||||
// Refuse if the model is already poisoned. No point opening
|
// Refuse if the model is already poisoned. No point opening
|
||||||
// an SSE stream just to send the Start event and then bail.
|
// an SSE stream just to send the Start event and then bail.
|
||||||
if loaded.poisoned.load(Ordering::Acquire) {
|
if loaded.poisoned.load(Ordering::Acquire) {
|
||||||
return Err(poisoned_error(&model_id));
|
return Err(self.trigger_recovery(&model_id).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start event: tells the wire projector to emit its
|
// Start event: tells the wire projector to emit its
|
||||||
@@ -2347,6 +2434,69 @@ pub struct InferenceStream {
|
|||||||
pub reasoning_markers: Option<ReasoningTokenPair>,
|
pub reasoning_markers: Option<ReasoningTokenPair>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Auto-recovery (#17) — rebuild a poisoned model's device context
|
||||||
|
/// automatically instead of leaving it bricked until a human reloads.
|
||||||
|
impl CandleHarness {
|
||||||
|
/// True while `model_id` is being auto-recovered (its slot is briefly
|
||||||
|
/// absent from the registry during the reload).
|
||||||
|
pub async fn is_recovering(&self, model_id: &str) -> bool {
|
||||||
|
self.recovering.read().await.contains(model_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Single-flight trigger from the request path: enqueue a rebuild for a
|
||||||
|
/// poisoned model (only the first caller per model enqueues) and return
|
||||||
|
/// the transient "recovering" error to hand back to the client.
|
||||||
|
async fn trigger_recovery(&self, model_id: &str) -> InferenceError {
|
||||||
|
let newly = self.recovering.write().await.insert(model_id.to_string());
|
||||||
|
if newly {
|
||||||
|
tracing::warn!(model = %model_id, "auto-recovery: poisoned, enqueueing rebuild");
|
||||||
|
if self.recovery_tx.send(model_id.to_string()).is_err() {
|
||||||
|
// Background task gone (harness shutting down). Drop the
|
||||||
|
// marker and fall back to the manual-reload message.
|
||||||
|
self.recovering.write().await.remove(model_id);
|
||||||
|
tracing::error!(model = %model_id, "auto-recovery: task unavailable");
|
||||||
|
return poisoned_error(model_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recovering_error(model_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rebuild a poisoned model: `unload_model` (drops it → cudarc aborts
|
||||||
|
/// NCCL + releases the context) then `load_model` from the retained
|
||||||
|
/// spec. A successful reload re-runs NCCL init + sanity inside the load
|
||||||
|
/// path, so it returns a fresh, healthy model; a failed reload leaves
|
||||||
|
/// the model unloaded (recoverable by the next load), never poisoned
|
||||||
|
/// forever. Runs on the background task — never inline on the request
|
||||||
|
/// path (would deadlock on the `models` write lock).
|
||||||
|
async fn recover_one(&self, model_id: &str) {
|
||||||
|
let spec = {
|
||||||
|
let models = self.models.read().await;
|
||||||
|
models.get(model_id).map(|h| h.spec().clone())
|
||||||
|
};
|
||||||
|
let Some(spec) = spec else {
|
||||||
|
self.recovering.write().await.remove(model_id);
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
tracing::warn!(model = %model_id, "auto-recovery: unload+reload starting");
|
||||||
|
if let Err(e) = self.unload_model(model_id).await {
|
||||||
|
tracing::error!(
|
||||||
|
model = %model_id,
|
||||||
|
error = %format!("{e:#}"),
|
||||||
|
"auto-recovery: unload failed (continuing to reload)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
match self.load_model(&spec).await {
|
||||||
|
Ok(()) => tracing::info!(model = %model_id, "auto-recovery: reloaded; model healthy"),
|
||||||
|
Err(e) => tracing::error!(
|
||||||
|
model = %model_id,
|
||||||
|
error = %format!("{e:#}"),
|
||||||
|
"auto-recovery: reload failed; model left unloaded"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
self.recovering.write().await.remove(model_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Harness for CandleHarness {
|
impl Harness for CandleHarness {
|
||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
@@ -2550,6 +2700,7 @@ impl Harness for CandleHarness {
|
|||||||
has_vision: vision_meta.has_vision,
|
has_vision: vision_meta.has_vision,
|
||||||
image_token_id: vision_meta.image_token_id,
|
image_token_id: vision_meta.image_token_id,
|
||||||
image_grid_factor: vision_meta.image_grid_factor,
|
image_grid_factor: vision_meta.image_grid_factor,
|
||||||
|
spec: spec.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut models = self.models.write().await;
|
let mut models = self.models.write().await;
|
||||||
@@ -2788,6 +2939,7 @@ impl CandleHarness {
|
|||||||
has_vision: vision_meta.has_vision,
|
has_vision: vision_meta.has_vision,
|
||||||
image_token_id: vision_meta.image_token_id,
|
image_token_id: vision_meta.image_token_id,
|
||||||
image_grid_factor: vision_meta.image_grid_factor,
|
image_grid_factor: vision_meta.image_grid_factor,
|
||||||
|
spec: spec.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut models = self.models.write().await;
|
let mut models = self.models.write().await;
|
||||||
@@ -2834,7 +2986,7 @@ impl CandleHarness {
|
|||||||
if tp.poisoned.load(Ordering::Acquire) {
|
if tp.poisoned.load(Ordering::Acquire) {
|
||||||
let _g = span.enter();
|
let _g = span.enter();
|
||||||
tracing::warn!("TP chat_completion: refusing request, model poisoned");
|
tracing::warn!("TP chat_completion: refusing request, model poisoned");
|
||||||
return Err(poisoned_error(&model_id));
|
return Err(self.trigger_recovery(&model_id).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject image-bearing requests against a TP model with no
|
// Reject image-bearing requests against a TP model with no
|
||||||
@@ -2923,7 +3075,7 @@ impl CandleHarness {
|
|||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<InferenceStream, InferenceError> {
|
) -> Result<InferenceStream, InferenceError> {
|
||||||
if tp.poisoned.load(Ordering::Acquire) {
|
if tp.poisoned.load(Ordering::Acquire) {
|
||||||
return Err(poisoned_error(&request.model));
|
return Err(self.trigger_recovery(&request.model).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject image requests against a non-vision TP model before
|
// Reject image requests against a non-vision TP model before
|
||||||
|
|||||||
@@ -114,10 +114,8 @@ impl HarnessRegistry {
|
|||||||
for config in configs {
|
for config in configs {
|
||||||
match config.name.as_str() {
|
match config.name.as_str() {
|
||||||
"candle" => {
|
"candle" => {
|
||||||
let harness = Arc::new(candle::CandleHarness::new(
|
let harness =
|
||||||
bind_url.to_string(),
|
candle::CandleHarness::new(bind_url.to_string(), &settings.candle);
|
||||||
&settings.candle,
|
|
||||||
));
|
|
||||||
registry.candle = Some(Arc::clone(&harness));
|
registry.candle = Some(Arc::clone(&harness));
|
||||||
registry.harnesses.insert("candle".into(), harness);
|
registry.harnesses.insert("candle".into(), harness);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user