diff --git a/crates/neuron/src/harness/tp/tp_qwen3_5.rs b/crates/neuron/src/harness/tp/tp_qwen3_5.rs index 90826a0..0ec4942 100644 --- a/crates/neuron/src/harness/tp/tp_qwen3_5.rs +++ b/crates/neuron/src/harness/tp/tp_qwen3_5.rs @@ -782,8 +782,9 @@ impl TpQwen3_5Model { let vb_l = text_vb.pp("layers"); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + log_vram(&device, rank, "before layer 0"); for i in 0..cfg.num_hidden_layers { - layers.push(TpQwen3_5DecoderLayer::load( + let layer = TpQwen3_5DecoderLayer::load( cfg, rotary.clone(), i, @@ -791,7 +792,13 @@ impl TpQwen3_5Model { rank, world_size, comm.clone(), - )?); + ) + .with_context(|| { + let (free_mb, total_mb) = cuda_mem_mb(&device); + format!("load layer {i} (rank {rank}): free={free_mb}MB / total={total_mb}MB") + })?; + layers.push(layer); + log_vram(&device, rank, &format!("after layer {i}")); } let norm = Qwen3_5RmsNorm::load(&text_vb.pp("norm"), cfg.hidden_size, cfg.rms_norm_eps)?; @@ -1053,3 +1060,48 @@ fn load_fused_qkv_slice_3d( .contiguous() .with_context(|| format!("materialise fused conv slice for rank {r}")) } + +/// Query the cuda driver for free/total VRAM on the current device. +/// Returns `(free_mb, total_mb)`. Returns `(0, 0)` if the query fails +/// (so logging never crashes the load path). +#[cfg(feature = "cuda")] +fn cuda_mem_mb(device: &Device) -> (usize, usize) { + use candle_core::cuda::cudarc::driver::result; + use candle_core::cuda_backend::WrapErr; + let Device::Cuda(dev) = device else { + return (0, 0); + }; + let Ok(()) = dev.cuda_stream().context().bind_to_thread().w() else { + return (0, 0); + }; + match result::mem_get_info() { + Ok((free, total)) => (free / (1024 * 1024), total / (1024 * 1024)), + Err(_) => (0, 0), + } +} + +#[cfg(not(feature = "cuda"))] +#[allow(dead_code)] +fn cuda_mem_mb(_device: &Device) -> (usize, usize) { + (0, 0) +} + +/// Info-log the current device's free VRAM with a tag. No-op when the +/// query fails or on cpu. +#[cfg(feature = "cuda")] +fn log_vram(device: &Device, rank: u32, tag: &str) { + let (free_mb, total_mb) = cuda_mem_mb(device); + if total_mb > 0 { + tracing::info!( + target: "neuron::tp::load", + rank, + free_mb, + total_mb, + "{tag}" + ); + } +} + +#[cfg(not(feature = "cuda"))] +#[allow(dead_code)] +fn log_vram(_device: &Device, _rank: u32, _tag: &str) {}