feat(neuron,candle): log per-device VRAM at chat_completion start

Every "starting" log line now carries vram_free_mb / vram_total_mb for
the request's serving device (the leader device on TP). On the 2026-05-26
incident this would have made the 14k-token prefill OOM diagnosable from
the first log line: with ~412 MB free, that prompt was never going to
fit, and the operator could have caught the imbalance before the CUDA
context got poisoned.

`device_vram_mb` mirrors the existing helper in tp_qwen3_5.rs and is
kept separate to avoid coupling the inference path to the TP module.
TpLoadedModel gains a `leader_device: Device` clone so the request
path reads the device without locking the leader model (which would
contend with an in-flight forward).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-26 12:26:23 +03:00
parent 0a1cfcd4d0
commit 1385979e3d

View File

@@ -105,6 +105,11 @@ pub struct TpLoadedModel {
/// story.
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
pub leader_model: Arc<tokio::sync::Mutex<super::tp::TpLeaderModel>>,
/// Candle device for rank 0. Mirrors what `leader_model.device()`
/// would return, but stored separately so the request path can
/// query VRAM without locking the leader (which would contend with
/// the in-flight forward).
pub leader_device: Device,
}
/// Architecture-specific weights. Each variant covers one (family,
@@ -354,6 +359,36 @@ fn resolve_hf_cache(explicit: Option<PathBuf>) -> Option<PathBuf> {
None
}
/// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if
/// the query fails or the device is the CPU fallback so logging never
/// crashes the request path. Mirrors the existing helper in
/// `tp_qwen3_5.rs`; kept separate to avoid coupling the inference path
/// to the TP-specific module.
#[cfg(feature = "cuda")]
fn device_vram_mb(device: &Device) -> (u64, u64) {
use candle_core::cuda::cudarc::driver::result;
use candle_core::cuda_backend::WrapErr;
let Device::Cuda(dev) = device else {
return (0, 0);
};
let Ok(()) = dev.cuda_stream().context().bind_to_thread().w() else {
return (0, 0);
};
match result::mem_get_info() {
Ok((free, total)) => (
(free / (1024 * 1024)) as u64,
(total / (1024 * 1024)) as u64,
),
Err(_) => (0, 0),
}
}
#[cfg(not(feature = "cuda"))]
#[allow(dead_code)]
fn device_vram_mb(_device: &Device) -> (u64, u64) {
(0, 0)
}
/// A short hex tag used to group every log line emitted on behalf of
/// one chat-completion request. Six hex digits is unique enough across
/// a 4-hour journal window (24 bits ≈ 16M values, while a busy neuron
@@ -838,12 +873,15 @@ impl CandleHarness {
.token_to_id("<|im_end|>")
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
let (vram_free_mb, vram_total_mb) = device_vram_mb(&loaded.device);
tracing::info!(
prompt_len,
max_new,
temperature,
?top_p,
?eos_id,
vram_free_mb,
vram_total_mb,
"chat_completion: starting"
);
@@ -1018,12 +1056,15 @@ impl CandleHarness {
let span_for_task = span.clone();
{
let _g = span_for_starting.enter();
let (vram_free_mb, vram_total_mb) = device_vram_mb(&loaded.device);
tracing::info!(
prompt_len,
max_new,
temperature,
?top_p,
?eos_id,
vram_free_mb,
vram_total_mb,
"chat_completion (stream): starting"
);
}
@@ -1290,6 +1331,7 @@ impl CandleHarness {
devices: devices.clone(),
pool: TMutex::new(pool),
leader_model,
leader_device: leader_device.clone(),
});
let mut models = self.models.write().await;
@@ -1428,6 +1470,7 @@ impl CandleHarness {
model = %model_id
);
let req_start = std::time::Instant::now();
let (vram_free_mb, vram_total_mb) = device_vram_mb(&tp.leader_device);
tracing::info!(
parent: &span,
prompt_len,
@@ -1435,6 +1478,8 @@ impl CandleHarness {
temperature,
?top_p,
?eos_id,
vram_free_mb,
vram_total_mb,
"TP chat_completion (stream): starting"
);
let tp_for_task = Arc::clone(&tp);
@@ -1641,6 +1686,7 @@ async fn chat_completion_tp_inner(
.token_to_id("<|im_end|>")
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
let (vram_free_mb, vram_total_mb) = device_vram_mb(&tp.leader_device);
tracing::info!(
model = %model_id,
prompt_len,
@@ -1648,6 +1694,8 @@ async fn chat_completion_tp_inner(
temperature,
?top_p,
?eos_id,
vram_free_mb,
vram_total_mb,
"TP chat_completion: starting"
);