feat(neuron): construction-complete vram/config dump + logits health + per-step vram
All checks were successful
CI / Format (push) Successful in 40s
build-prerelease / Resolve version stamps (push) Successful in 45s
CI / Clippy (push) Successful in 2m27s
build-prerelease / Build cortex binary (push) Successful in 4m24s
build-prerelease / Build neuron-blackwell (push) Successful in 4m0s
build-prerelease / Package cortex RPM (push) Successful in 1m18s
build-prerelease / Build neuron-ampere (push) Successful in 5m10s
build-prerelease / Build neuron-ada (push) Successful in 4m56s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m47s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
CI / Test (push) Successful in 4m24s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped

Three additive diagnostics that turn the 2026-05-27 q5k Qwen3.6-27B
incident from "guess at KV cache / quant sizes" into "read the
journal":

1. Construction-complete summary in TpQwen3_5ForCausalLM::load and
   TpQwen3ForCausalLM::load. After the last "after layer N" log fires,
   each rank emits a single info line with: free_mb/total_mb (the
   number that drops by ~9 GB between per-layer and first-request on
   beast, with no inference traffic), every resolved config knob
   (vocab_size, hidden_size, num_layers, head_dim, num_kv_heads,
   max_position_embeddings), and a per-token KV-cache byte estimate.
   For Qwen3-Next also includes the linear/full-attention layer split
   so the hybrid architecture's cache cost is unambiguous.

2. Logits health snapshot on sample failure. Today the failure logs
   "A weight is negative, too large or not a valid number" with no
   context — was it a NaN cascade, an Inf, a negative weight?
   `logits_health(&logits)` computes nan/pos_inf/neg_inf/neg counts
   plus finite_min/max/mean on the failure path (zero cost on the
   success path) and emits a warn line just before the wrapper's
   terminal "failed, model marked poisoned" log. Wired into both the
   prefill and decode sample sites of the non-streaming AND streaming
   TP chat paths.

3. VRAM snapshot at prefill complete + every decode step. The
   "prefill complete" info line now carries vram_free_mb so the
   activations + KV growth from the prefill itself is visible. The
   per-step trace line gets vram_free_mb too, so an operator running
   with RUST_LOG=trace can watch headroom shrink token by token.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-27 09:04:55 +03:00
parent 24e20dcb5c
commit 7c19da9361
3 changed files with 304 additions and 8 deletions

View File

@@ -386,6 +386,95 @@ fn resolve_hf_cache(explicit: Option<PathBuf>) -> Option<PathBuf> {
None
}
/// Summary stats over a 1-D logits tensor, used for the failure log
/// when sampling rejects the distribution. Gathers nan/inf/negative
/// counts and finite min/max/mean — enough to distinguish a NaN
/// cascade (all-NaN, typical of softmax overflow propagating) from
/// an Inf at a single position (numerical edge case) from negative
/// weights (different bug entirely).
///
/// Computed only on the failure path, so the to_vec1 copy cost is
/// paid at most once per poisoned model.
#[derive(Debug)]
#[allow(dead_code)]
struct LogitsHealth {
len: usize,
nan: usize,
pos_inf: usize,
neg_inf: usize,
neg: usize,
finite_min: Option<f32>,
finite_max: Option<f32>,
finite_mean: Option<f32>,
}
#[allow(dead_code)]
fn logits_health(t: &Tensor) -> LogitsHealth {
let values: Vec<f32> = match t
.to_dtype(candle_core::DType::F32)
.and_then(|t| t.flatten_all())
.and_then(|t| t.to_vec1::<f32>())
{
Ok(v) => v,
Err(_) => {
return LogitsHealth {
len: 0,
nan: 0,
pos_inf: 0,
neg_inf: 0,
neg: 0,
finite_min: None,
finite_max: None,
finite_mean: None,
};
}
};
let mut nan = 0usize;
let mut pos_inf = 0usize;
let mut neg_inf = 0usize;
let mut neg = 0usize;
let mut finite_min = f32::INFINITY;
let mut finite_max = f32::NEG_INFINITY;
let mut finite_sum = 0.0_f64;
let mut finite_count = 0usize;
for &v in &values {
if v.is_nan() {
nan += 1;
} else if v == f32::INFINITY {
pos_inf += 1;
} else if v == f32::NEG_INFINITY {
neg_inf += 1;
} else {
if v < 0.0 {
neg += 1;
}
if v < finite_min {
finite_min = v;
}
if v > finite_max {
finite_max = v;
}
finite_sum += v as f64;
finite_count += 1;
}
}
let finite_mean = if finite_count > 0 {
Some((finite_sum / finite_count as f64) as f32)
} else {
None
};
LogitsHealth {
len: values.len(),
nan,
pos_inf,
neg_inf,
neg,
finite_min: (finite_count > 0).then_some(finite_min),
finite_max: (finite_count > 0).then_some(finite_max),
finite_mean,
}
}
/// Build the InferenceError reported to a client when their request
/// hits a model that's been marked poisoned by an earlier driver
/// failure. The message names the model and the recovery procedure so
@@ -1624,10 +1713,24 @@ impl CandleHarness {
break 'work;
}
};
let (post_prefill_vram_free_mb, _) =
device_vram_mb(&tp_for_task.leader_device);
tracing::info!(
model = %model_id,
prompt_len,
vram_free_mb = post_prefill_vram_free_mb,
"TP chat_completion (stream): prefill complete"
);
let mut next_token =
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health(&logits);
tracing::warn!(
model = %model_id,
?health,
"TP chat_completion (stream): prefill sample failed; logits unhealthy"
);
failure = Some(format!("prefill sample: {e:#}"));
break 'work;
}
@@ -1676,10 +1779,24 @@ impl CandleHarness {
) {
Ok(t) => t,
Err(e) => {
let health = logits_health(&logits);
tracing::warn!(
model = %model_id,
step = index,
?health,
"TP chat_completion (stream): decode sample failed; logits unhealthy"
);
failure = Some(format!("decode sample {index}: {e:#}"));
break 'work;
}
};
tracing::trace!(
model = %model_id,
step = index,
next_token,
vram_free_mb = device_vram_mb(&tp_for_task.leader_device).0,
"TP chat_completion (stream): decode step"
);
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
@@ -1845,14 +1962,31 @@ async fn chat_completion_tp_inner(
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
.await
.map_err(InferenceError::Other)?;
let (post_prefill_vram_free_mb, _) = device_vram_mb(&tp.leader_device);
tracing::info!(
model = %model_id,
prompt_len,
elapsed_ms = prefill_start.elapsed().as_millis(),
vram_free_mb = post_prefill_vram_free_mb,
"TP chat_completion: prefill complete"
);
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
.map_err(InferenceError::Other)?;
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
// Logits health snapshot — the surrounding wrapper logs
// "failed, model marked poisoned" with the error chain;
// this WARN sits just above that and carries the actual
// numerical state so an operator can tell at a glance
// whether it was a NaN cascade, an Inf, or something else.
let health = logits_health(&logits);
tracing::warn!(
model = %model_id,
?health,
"TP chat_completion: prefill sample failed; logits unhealthy"
);
return Err(InferenceError::Other(e));
}
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
@@ -1870,13 +2004,26 @@ async fn chat_completion_tp_inner(
)
.await
.map_err(InferenceError::Other)?;
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
.map_err(InferenceError::Other)?;
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health(&logits);
tracing::warn!(
model = %model_id,
step = index,
?health,
"TP chat_completion: decode sample failed; logits unhealthy"
);
return Err(InferenceError::Other(e));
}
};
let step_vram_free_mb = device_vram_mb(&tp.leader_device).0;
tracing::trace!(
model = %model_id,
step = index,
next_token,
step_ms = step_start.elapsed().as_millis(),
vram_free_mb = step_vram_free_mb,
"TP chat_completion: decode step"
);
if Some(next_token) == eos_id {