feat(neuron): construction-complete vram/config dump + logits health + per-step vram
All checks were successful
CI / Format (push) Successful in 40s
build-prerelease / Resolve version stamps (push) Successful in 45s
CI / Clippy (push) Successful in 2m27s
build-prerelease / Build cortex binary (push) Successful in 4m24s
build-prerelease / Build neuron-blackwell (push) Successful in 4m0s
build-prerelease / Package cortex RPM (push) Successful in 1m18s
build-prerelease / Build neuron-ampere (push) Successful in 5m10s
build-prerelease / Build neuron-ada (push) Successful in 4m56s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m47s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
CI / Test (push) Successful in 4m24s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
All checks were successful
CI / Format (push) Successful in 40s
build-prerelease / Resolve version stamps (push) Successful in 45s
CI / Clippy (push) Successful in 2m27s
build-prerelease / Build cortex binary (push) Successful in 4m24s
build-prerelease / Build neuron-blackwell (push) Successful in 4m0s
build-prerelease / Package cortex RPM (push) Successful in 1m18s
build-prerelease / Build neuron-ampere (push) Successful in 5m10s
build-prerelease / Build neuron-ada (push) Successful in 4m56s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m47s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
CI / Test (push) Successful in 4m24s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
Three additive diagnostics that turn the 2026-05-27 q5k Qwen3.6-27B incident from "guess at KV cache / quant sizes" into "read the journal": 1. Construction-complete summary in TpQwen3_5ForCausalLM::load and TpQwen3ForCausalLM::load. After the last "after layer N" log fires, each rank emits a single info line with: free_mb/total_mb (the number that drops by ~9 GB between per-layer and first-request on beast, with no inference traffic), every resolved config knob (vocab_size, hidden_size, num_layers, head_dim, num_kv_heads, max_position_embeddings), and a per-token KV-cache byte estimate. For Qwen3-Next also includes the linear/full-attention layer split so the hybrid architecture's cache cost is unambiguous. 2. Logits health snapshot on sample failure. Today the failure logs "A weight is negative, too large or not a valid number" with no context — was it a NaN cascade, an Inf, a negative weight? `logits_health(&logits)` computes nan/pos_inf/neg_inf/neg counts plus finite_min/max/mean on the failure path (zero cost on the success path) and emits a warn line just before the wrapper's terminal "failed, model marked poisoned" log. Wired into both the prefill and decode sample sites of the non-streaming AND streaming TP chat paths. 3. VRAM snapshot at prefill complete + every decode step. The "prefill complete" info line now carries vram_free_mb so the activations + KV growth from the prefill itself is visible. The per-step trace line gets vram_free_mb too, so an operator running with RUST_LOG=trace can watch headroom shrink token by token. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -386,6 +386,95 @@ fn resolve_hf_cache(explicit: Option<PathBuf>) -> Option<PathBuf> {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Summary stats over a 1-D logits tensor, used for the failure log
|
||||||
|
/// when sampling rejects the distribution. Gathers nan/inf/negative
|
||||||
|
/// counts and finite min/max/mean — enough to distinguish a NaN
|
||||||
|
/// cascade (all-NaN, typical of softmax overflow propagating) from
|
||||||
|
/// an Inf at a single position (numerical edge case) from negative
|
||||||
|
/// weights (different bug entirely).
|
||||||
|
///
|
||||||
|
/// Computed only on the failure path, so the to_vec1 copy cost is
|
||||||
|
/// paid at most once per poisoned model.
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
struct LogitsHealth {
|
||||||
|
len: usize,
|
||||||
|
nan: usize,
|
||||||
|
pos_inf: usize,
|
||||||
|
neg_inf: usize,
|
||||||
|
neg: usize,
|
||||||
|
finite_min: Option<f32>,
|
||||||
|
finite_max: Option<f32>,
|
||||||
|
finite_mean: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn logits_health(t: &Tensor) -> LogitsHealth {
|
||||||
|
let values: Vec<f32> = match t
|
||||||
|
.to_dtype(candle_core::DType::F32)
|
||||||
|
.and_then(|t| t.flatten_all())
|
||||||
|
.and_then(|t| t.to_vec1::<f32>())
|
||||||
|
{
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => {
|
||||||
|
return LogitsHealth {
|
||||||
|
len: 0,
|
||||||
|
nan: 0,
|
||||||
|
pos_inf: 0,
|
||||||
|
neg_inf: 0,
|
||||||
|
neg: 0,
|
||||||
|
finite_min: None,
|
||||||
|
finite_max: None,
|
||||||
|
finite_mean: None,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut nan = 0usize;
|
||||||
|
let mut pos_inf = 0usize;
|
||||||
|
let mut neg_inf = 0usize;
|
||||||
|
let mut neg = 0usize;
|
||||||
|
let mut finite_min = f32::INFINITY;
|
||||||
|
let mut finite_max = f32::NEG_INFINITY;
|
||||||
|
let mut finite_sum = 0.0_f64;
|
||||||
|
let mut finite_count = 0usize;
|
||||||
|
for &v in &values {
|
||||||
|
if v.is_nan() {
|
||||||
|
nan += 1;
|
||||||
|
} else if v == f32::INFINITY {
|
||||||
|
pos_inf += 1;
|
||||||
|
} else if v == f32::NEG_INFINITY {
|
||||||
|
neg_inf += 1;
|
||||||
|
} else {
|
||||||
|
if v < 0.0 {
|
||||||
|
neg += 1;
|
||||||
|
}
|
||||||
|
if v < finite_min {
|
||||||
|
finite_min = v;
|
||||||
|
}
|
||||||
|
if v > finite_max {
|
||||||
|
finite_max = v;
|
||||||
|
}
|
||||||
|
finite_sum += v as f64;
|
||||||
|
finite_count += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let finite_mean = if finite_count > 0 {
|
||||||
|
Some((finite_sum / finite_count as f64) as f32)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
LogitsHealth {
|
||||||
|
len: values.len(),
|
||||||
|
nan,
|
||||||
|
pos_inf,
|
||||||
|
neg_inf,
|
||||||
|
neg,
|
||||||
|
finite_min: (finite_count > 0).then_some(finite_min),
|
||||||
|
finite_max: (finite_count > 0).then_some(finite_max),
|
||||||
|
finite_mean,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Build the InferenceError reported to a client when their request
|
/// Build the InferenceError reported to a client when their request
|
||||||
/// hits a model that's been marked poisoned by an earlier driver
|
/// hits a model that's been marked poisoned by an earlier driver
|
||||||
/// failure. The message names the model and the recovery procedure so
|
/// failure. The message names the model and the recovery procedure so
|
||||||
@@ -1624,10 +1713,24 @@ impl CandleHarness {
|
|||||||
break 'work;
|
break 'work;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let (post_prefill_vram_free_mb, _) =
|
||||||
|
device_vram_mb(&tp_for_task.leader_device);
|
||||||
|
tracing::info!(
|
||||||
|
model = %model_id,
|
||||||
|
prompt_len,
|
||||||
|
vram_free_mb = post_prefill_vram_free_mb,
|
||||||
|
"TP chat_completion (stream): prefill complete"
|
||||||
|
);
|
||||||
let mut next_token =
|
let mut next_token =
|
||||||
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||||
Ok(t) => t,
|
Ok(t) => t,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
let health = logits_health(&logits);
|
||||||
|
tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
?health,
|
||||||
|
"TP chat_completion (stream): prefill sample failed; logits unhealthy"
|
||||||
|
);
|
||||||
failure = Some(format!("prefill sample: {e:#}"));
|
failure = Some(format!("prefill sample: {e:#}"));
|
||||||
break 'work;
|
break 'work;
|
||||||
}
|
}
|
||||||
@@ -1676,10 +1779,24 @@ impl CandleHarness {
|
|||||||
) {
|
) {
|
||||||
Ok(t) => t,
|
Ok(t) => t,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
let health = logits_health(&logits);
|
||||||
|
tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
step = index,
|
||||||
|
?health,
|
||||||
|
"TP chat_completion (stream): decode sample failed; logits unhealthy"
|
||||||
|
);
|
||||||
failure = Some(format!("decode sample {index}: {e:#}"));
|
failure = Some(format!("decode sample {index}: {e:#}"));
|
||||||
break 'work;
|
break 'work;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
tracing::trace!(
|
||||||
|
model = %model_id,
|
||||||
|
step = index,
|
||||||
|
next_token,
|
||||||
|
vram_free_mb = device_vram_mb(&tp_for_task.leader_device).0,
|
||||||
|
"TP chat_completion (stream): decode step"
|
||||||
|
);
|
||||||
if Some(next_token) == eos_id {
|
if Some(next_token) == eos_id {
|
||||||
finish_reason = "stop".into();
|
finish_reason = "stop".into();
|
||||||
break;
|
break;
|
||||||
@@ -1845,14 +1962,31 @@ async fn chat_completion_tp_inner(
|
|||||||
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
||||||
.await
|
.await
|
||||||
.map_err(InferenceError::Other)?;
|
.map_err(InferenceError::Other)?;
|
||||||
|
let (post_prefill_vram_free_mb, _) = device_vram_mb(&tp.leader_device);
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
model = %model_id,
|
model = %model_id,
|
||||||
prompt_len,
|
prompt_len,
|
||||||
elapsed_ms = prefill_start.elapsed().as_millis(),
|
elapsed_ms = prefill_start.elapsed().as_millis(),
|
||||||
|
vram_free_mb = post_prefill_vram_free_mb,
|
||||||
"TP chat_completion: prefill complete"
|
"TP chat_completion: prefill complete"
|
||||||
);
|
);
|
||||||
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||||
.map_err(InferenceError::Other)?;
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
// Logits health snapshot — the surrounding wrapper logs
|
||||||
|
// "failed, model marked poisoned" with the error chain;
|
||||||
|
// this WARN sits just above that and carries the actual
|
||||||
|
// numerical state so an operator can tell at a glance
|
||||||
|
// whether it was a NaN cascade, an Inf, or something else.
|
||||||
|
let health = logits_health(&logits);
|
||||||
|
tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
?health,
|
||||||
|
"TP chat_completion: prefill sample failed; logits unhealthy"
|
||||||
|
);
|
||||||
|
return Err(InferenceError::Other(e));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if Some(next_token) == eos_id {
|
if Some(next_token) == eos_id {
|
||||||
finish_reason = "stop".into();
|
finish_reason = "stop".into();
|
||||||
@@ -1870,13 +2004,26 @@ async fn chat_completion_tp_inner(
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(InferenceError::Other)?;
|
.map_err(InferenceError::Other)?;
|
||||||
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||||
.map_err(InferenceError::Other)?;
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
let health = logits_health(&logits);
|
||||||
|
tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
step = index,
|
||||||
|
?health,
|
||||||
|
"TP chat_completion: decode sample failed; logits unhealthy"
|
||||||
|
);
|
||||||
|
return Err(InferenceError::Other(e));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let step_vram_free_mb = device_vram_mb(&tp.leader_device).0;
|
||||||
tracing::trace!(
|
tracing::trace!(
|
||||||
model = %model_id,
|
model = %model_id,
|
||||||
step = index,
|
step = index,
|
||||||
next_token,
|
next_token,
|
||||||
step_ms = step_start.elapsed().as_millis(),
|
step_ms = step_start.elapsed().as_millis(),
|
||||||
|
vram_free_mb = step_vram_free_mb,
|
||||||
"TP chat_completion: decode step"
|
"TP chat_completion: decode step"
|
||||||
);
|
);
|
||||||
if Some(next_token) == eos_id {
|
if Some(next_token) == eos_id {
|
||||||
|
|||||||
@@ -562,14 +562,18 @@ impl TpQwen3ForCausalLM {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let base = TpQwen3Model::load(cfg, vb, rank, world_size, comm)?;
|
let base = TpQwen3Model::load(cfg, vb, rank, world_size, comm)?;
|
||||||
let lm_head = build_lm_head(cfg, vb, &base)?;
|
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||||
Ok(Self { base, lm_head })
|
let model = Self { base, lm_head };
|
||||||
|
log_construction_complete(cfg, rank, world_size, model.device());
|
||||||
|
Ok(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
let base = TpQwen3Model::load(cfg, vb, rank, world_size)?;
|
let base = TpQwen3Model::load(cfg, vb, rank, world_size)?;
|
||||||
let lm_head = build_lm_head(cfg, vb, &base)?;
|
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||||
Ok(Self { base, lm_head })
|
let model = Self { base, lm_head };
|
||||||
|
log_construction_complete(cfg, rank, world_size, model.device());
|
||||||
|
Ok(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
@@ -603,3 +607,72 @@ fn build_lm_head(cfg: &Config, vb: &ShardedVarBuilder, base: &TpQwen3Model) -> R
|
|||||||
Ok(Linear::new(weight, None))
|
Ok(Linear::new(weight, None))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// VRAM accounting + config dump emitted at the end of
|
||||||
|
/// `TpQwen3ForCausalLM::load`. Same intent as the Qwen3-Next variant
|
||||||
|
/// in tp_qwen3_5.rs — surface the resolved hyperparameters and
|
||||||
|
/// per-rank free VRAM in one line so an operator chasing an OOM or a
|
||||||
|
/// numerical issue doesn't have to grep the per-layer load logs.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn log_construction_complete(cfg: &Config, rank: u32, world_size: u32, device: &Device) {
|
||||||
|
use candle_core::cuda::cudarc::driver::result;
|
||||||
|
use candle_core::cuda_backend::WrapErr;
|
||||||
|
let (free_mb, total_mb) = if let Device::Cuda(dev) = device {
|
||||||
|
if dev.cuda_stream().context().bind_to_thread().w().is_ok() {
|
||||||
|
match result::mem_get_info() {
|
||||||
|
Ok((free, total)) => (free / (1024 * 1024), total / (1024 * 1024)),
|
||||||
|
Err(_) => (0, 0),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(0, 0)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(0, 0)
|
||||||
|
};
|
||||||
|
// Per-rank KV cache cost at one token: K + V × bf16. Vanilla
|
||||||
|
// Qwen3 is dense attention end-to-end, so every layer
|
||||||
|
// contributes. Knowing per-token bytes lets the operator estimate
|
||||||
|
// headroom for a given prompt length before hitting an edge.
|
||||||
|
let per_rank_num_kv_heads = (cfg.num_key_value_heads / world_size as usize).max(1);
|
||||||
|
let kv_bytes_per_token_per_layer = per_rank_num_kv_heads * cfg.head_dim * 2 * 2;
|
||||||
|
let kv_bytes_per_token = kv_bytes_per_token_per_layer * cfg.num_hidden_layers;
|
||||||
|
tracing::info!(
|
||||||
|
target: "neuron::tp::load",
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
free_mb,
|
||||||
|
total_mb,
|
||||||
|
vocab_size = cfg.vocab_size,
|
||||||
|
hidden_size = cfg.hidden_size,
|
||||||
|
num_hidden_layers = cfg.num_hidden_layers,
|
||||||
|
num_attention_heads = cfg.num_attention_heads,
|
||||||
|
num_key_value_heads = cfg.num_key_value_heads,
|
||||||
|
head_dim = cfg.head_dim,
|
||||||
|
max_position_embeddings = cfg.max_position_embeddings,
|
||||||
|
per_rank_num_kv_heads,
|
||||||
|
kv_bytes_per_token,
|
||||||
|
"Qwen3 model construction complete"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn log_construction_complete(cfg: &Config, rank: u32, world_size: u32, _device: &Device) {
|
||||||
|
let per_rank_num_kv_heads = (cfg.num_key_value_heads / world_size as usize).max(1);
|
||||||
|
let kv_bytes_per_token_per_layer = per_rank_num_kv_heads * cfg.head_dim * 2 * 2;
|
||||||
|
let kv_bytes_per_token = kv_bytes_per_token_per_layer * cfg.num_hidden_layers;
|
||||||
|
tracing::info!(
|
||||||
|
target: "neuron::tp::load",
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
vocab_size = cfg.vocab_size,
|
||||||
|
hidden_size = cfg.hidden_size,
|
||||||
|
num_hidden_layers = cfg.num_hidden_layers,
|
||||||
|
num_attention_heads = cfg.num_attention_heads,
|
||||||
|
num_key_value_heads = cfg.num_key_value_heads,
|
||||||
|
head_dim = cfg.head_dim,
|
||||||
|
max_position_embeddings = cfg.max_position_embeddings,
|
||||||
|
per_rank_num_kv_heads,
|
||||||
|
kv_bytes_per_token,
|
||||||
|
"Qwen3 model construction complete"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
@@ -1012,7 +1012,9 @@ impl TpQwen3_5ForCausalLM {
|
|||||||
let cfg = &config.text_config;
|
let cfg = &config.text_config;
|
||||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?;
|
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?;
|
||||||
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
||||||
Ok(Self { base, lm_head })
|
let model = Self { base, lm_head };
|
||||||
|
log_construction_complete(cfg, rank, world_size, quant, model.device());
|
||||||
|
Ok(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
@@ -1027,7 +1029,9 @@ impl TpQwen3_5ForCausalLM {
|
|||||||
let cfg = &config.text_config;
|
let cfg = &config.text_config;
|
||||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?;
|
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?;
|
||||||
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
||||||
Ok(Self { base, lm_head })
|
let model = Self { base, lm_head };
|
||||||
|
log_construction_complete(cfg, rank, world_size, quant, model.device());
|
||||||
|
Ok(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
@@ -1129,3 +1133,75 @@ fn log_vram(device: &Device, rank: u32, tag: &str) {
|
|||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn log_vram(_device: &Device, _rank: u32, _tag: &str) {}
|
fn log_vram(_device: &Device, _rank: u32, _tag: &str) {}
|
||||||
|
|
||||||
|
/// Summary line emitted at end of `TpQwen3_5ForCausalLM::load`, after
|
||||||
|
/// the per-layer load loop AND after the lm_head + any post-construct
|
||||||
|
/// allocations. Logs the resolved config knobs (the ones an operator
|
||||||
|
/// would want to know when chasing a numerical or OOM issue) plus a
|
||||||
|
/// final free/total VRAM snapshot per rank.
|
||||||
|
///
|
||||||
|
/// The free_mb here is the most diagnostic number we have at this
|
||||||
|
/// stage: the gap between the last "after layer N" log and this line
|
||||||
|
/// is everything else the model construction allocated — lm_head,
|
||||||
|
/// embedding (if not tied), per-layer buffers held by candle's
|
||||||
|
/// allocator, the RotaryEmbedding tables, and any working space.
|
||||||
|
///
|
||||||
|
/// `kv_cache_per_layer_per_token_bytes` is a back-of-envelope estimate
|
||||||
|
/// — the actual cache grows as inference proceeds, but knowing the
|
||||||
|
/// per-token cost at this point lets an operator estimate "for a
|
||||||
|
/// 14k-token prompt I need ~X GB extra VRAM" without having to dig
|
||||||
|
/// into the architecture's attention modules.
|
||||||
|
fn log_construction_complete(
|
||||||
|
cfg: &TextConfig,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
|
device: &Device,
|
||||||
|
) {
|
||||||
|
let (free_mb, total_mb) = cuda_mem_mb(device);
|
||||||
|
// Distribution of attention kinds across layers. Qwen3-Next is
|
||||||
|
// hybrid: most layers are linear (Gated DeltaNet), a few are full
|
||||||
|
// softmax attention. Knowing the split at a glance helps when
|
||||||
|
// reasoning about KV cache size — only full-attention layers
|
||||||
|
// contribute to the standard kv cache.
|
||||||
|
let mut full_attn_layers = 0;
|
||||||
|
let mut linear_attn_layers = 0;
|
||||||
|
for kind in &cfg.layer_types {
|
||||||
|
match kind.as_str() {
|
||||||
|
"full_attention" => full_attn_layers += 1,
|
||||||
|
"linear_attention" => linear_attn_layers += 1,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// KV cache per-layer-per-token byte estimate for the per-rank
|
||||||
|
// full-attention layers. bf16 = 2 bytes, K + V doubles it, and
|
||||||
|
// sharded across world_size. Linear-attention layers carry a
|
||||||
|
// fixed-size state instead of a growing cache.
|
||||||
|
let per_rank_num_kv_heads = (cfg.num_key_value_heads / world_size as usize).max(1);
|
||||||
|
let kv_bytes_per_token_per_layer = per_rank_num_kv_heads * cfg.head_dim * 2 /* K+V */ * 2 /* bf16 */;
|
||||||
|
let kv_bytes_per_token = kv_bytes_per_token_per_layer * full_attn_layers;
|
||||||
|
tracing::info!(
|
||||||
|
target: "neuron::tp::load",
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
quant = ?quant,
|
||||||
|
free_mb,
|
||||||
|
total_mb,
|
||||||
|
vocab_size = cfg.vocab_size,
|
||||||
|
hidden_size = cfg.hidden_size,
|
||||||
|
num_hidden_layers = cfg.num_hidden_layers,
|
||||||
|
num_attention_heads = cfg.num_attention_heads,
|
||||||
|
num_key_value_heads = cfg.num_key_value_heads,
|
||||||
|
head_dim = cfg.head_dim,
|
||||||
|
max_position_embeddings = cfg.max_position_embeddings,
|
||||||
|
full_attn_layers,
|
||||||
|
linear_attn_layers,
|
||||||
|
linear_num_value_heads = cfg.linear_num_value_heads,
|
||||||
|
linear_num_key_heads = cfg.linear_num_key_heads,
|
||||||
|
linear_key_head_dim = cfg.linear_key_head_dim,
|
||||||
|
linear_value_head_dim = cfg.linear_value_head_dim,
|
||||||
|
per_rank_num_kv_heads,
|
||||||
|
kv_bytes_per_token,
|
||||||
|
"Qwen3-Next model construction complete"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user