From 1385979e3d1ed724863fc46eb4df76750a15abbc Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Tue, 26 May 2026 12:26:23 +0300 Subject: [PATCH] feat(neuron,candle): log per-device VRAM at chat_completion start Every "starting" log line now carries vram_free_mb / vram_total_mb for the request's serving device (the leader device on TP). On the 2026-05-26 incident this would have made the 14k-token prefill OOM diagnosable from the first log line: with ~412 MB free, that prompt was never going to fit, and the operator could have caught the imbalance before the CUDA context got poisoned. `device_vram_mb` mirrors the existing helper in tp_qwen3_5.rs and is kept separate to avoid coupling the inference path to the TP module. TpLoadedModel gains a `leader_device: Device` clone so the request path reads the device without locking the leader model (which would contend with an in-flight forward). Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/neuron/src/harness/candle.rs | 48 +++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 654cf47..357a7ca 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -105,6 +105,11 @@ pub struct TpLoadedModel { /// story. pub pool: tokio::sync::Mutex, pub leader_model: Arc>, + /// Candle device for rank 0. Mirrors what `leader_model.device()` + /// would return, but stored separately so the request path can + /// query VRAM without locking the leader (which would contend with + /// the in-flight forward). + pub leader_device: Device, } /// Architecture-specific weights. Each variant covers one (family, @@ -354,6 +359,36 @@ fn resolve_hf_cache(explicit: Option) -> Option { None } +/// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if +/// the query fails or the device is the CPU fallback so logging never +/// crashes the request path. Mirrors the existing helper in +/// `tp_qwen3_5.rs`; kept separate to avoid coupling the inference path +/// to the TP-specific module. +#[cfg(feature = "cuda")] +fn device_vram_mb(device: &Device) -> (u64, u64) { + use candle_core::cuda::cudarc::driver::result; + use candle_core::cuda_backend::WrapErr; + let Device::Cuda(dev) = device else { + return (0, 0); + }; + let Ok(()) = dev.cuda_stream().context().bind_to_thread().w() else { + return (0, 0); + }; + match result::mem_get_info() { + Ok((free, total)) => ( + (free / (1024 * 1024)) as u64, + (total / (1024 * 1024)) as u64, + ), + Err(_) => (0, 0), + } +} + +#[cfg(not(feature = "cuda"))] +#[allow(dead_code)] +fn device_vram_mb(_device: &Device) -> (u64, u64) { + (0, 0) +} + /// A short hex tag used to group every log line emitted on behalf of /// one chat-completion request. Six hex digits is unique enough across /// a 4-hour journal window (24 bits ≈ 16M values, while a busy neuron @@ -838,12 +873,15 @@ impl CandleHarness { .token_to_id("<|im_end|>") .or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>")); + let (vram_free_mb, vram_total_mb) = device_vram_mb(&loaded.device); tracing::info!( prompt_len, max_new, temperature, ?top_p, ?eos_id, + vram_free_mb, + vram_total_mb, "chat_completion: starting" ); @@ -1018,12 +1056,15 @@ impl CandleHarness { let span_for_task = span.clone(); { let _g = span_for_starting.enter(); + let (vram_free_mb, vram_total_mb) = device_vram_mb(&loaded.device); tracing::info!( prompt_len, max_new, temperature, ?top_p, ?eos_id, + vram_free_mb, + vram_total_mb, "chat_completion (stream): starting" ); } @@ -1290,6 +1331,7 @@ impl CandleHarness { devices: devices.clone(), pool: TMutex::new(pool), leader_model, + leader_device: leader_device.clone(), }); let mut models = self.models.write().await; @@ -1428,6 +1470,7 @@ impl CandleHarness { model = %model_id ); let req_start = std::time::Instant::now(); + let (vram_free_mb, vram_total_mb) = device_vram_mb(&tp.leader_device); tracing::info!( parent: &span, prompt_len, @@ -1435,6 +1478,8 @@ impl CandleHarness { temperature, ?top_p, ?eos_id, + vram_free_mb, + vram_total_mb, "TP chat_completion (stream): starting" ); let tp_for_task = Arc::clone(&tp); @@ -1641,6 +1686,7 @@ async fn chat_completion_tp_inner( .token_to_id("<|im_end|>") .or_else(|| tp.tokenizer.token_to_id("<|endoftext|>")); + let (vram_free_mb, vram_total_mb) = device_vram_mb(&tp.leader_device); tracing::info!( model = %model_id, prompt_len, @@ -1648,6 +1694,8 @@ async fn chat_completion_tp_inner( temperature, ?top_p, ?eos_id, + vram_free_mb, + vram_total_mb, "TP chat_completion: starting" );