From e71181499ed948efd7643737b6738239190c032d Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Thu, 21 May 2026 21:53:14 +0300 Subject: [PATCH] feat(stage-8e-3): quantize lm_head in TP Qwen3-Next MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TpQwen3_5ForCausalLM::lm_head is now a MaybeQuantLinear. When the load spec has quant set and tie_word_embeddings is false, lm_head's (vocab_size, hidden_size) weight is quantized in-situ at load time along with all the per-layer linears. The non-tied case on Qwen3.6-27B saves ~1.7 GB per rank vs bf16 (248320 x 5120 x 2 bytes = 2.42 GB -> ~700 MB at Q5K) and shaves a small amount of decode latency from the per-token logits matmul. Tied case (tie_word_embeddings=true) keeps the lm_head plain even when quant is set — quantizing the shared tensor would corrupt the embedding lookup, and the tied case already gets the memory win from only holding one copy. This is the last MaybeQuantLinear hookup in the Qwen3-Next TP path. The dense Qwen3 path (tp_qwen3.rs) is unchanged — defer until it's the bottleneck for a model that actually needs TP at consumer scale. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/neuron/src/harness/tp/tp_qwen3_5.rs | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/crates/neuron/src/harness/tp/tp_qwen3_5.rs b/crates/neuron/src/harness/tp/tp_qwen3_5.rs index 4968d62..f501f4e 100644 --- a/crates/neuron/src/harness/tp/tp_qwen3_5.rs +++ b/crates/neuron/src/harness/tp/tp_qwen3_5.rs @@ -35,7 +35,7 @@ use candle_core::quantized::GgmlDType; use candle_core::safetensors::MmapedSafetensors; use candle_core::{DType, Device, IndexOp, Module, Tensor}; use candle_nn::var_builder::ShardedVarBuilder; -use candle_nn::{Embedding, Linear, kv_cache::ConcatKvCache}; +use candle_nn::{Embedding, kv_cache::ConcatKvCache}; use candle_transformers::utils::repeat_kv; use std::sync::Arc; @@ -994,7 +994,7 @@ impl TpQwen3_5Model { pub struct TpQwen3_5ForCausalLM { base: TpQwen3_5Model, - lm_head: Linear, + lm_head: super::tp_linear::MaybeQuantLinear, } impl TpQwen3_5ForCausalLM { @@ -1011,7 +1011,7 @@ impl TpQwen3_5ForCausalLM { ) -> Result { let cfg = &config.text_config; let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?; - let lm_head = build_lm_head(cfg, vb, &base)?; + let lm_head = build_lm_head(cfg, vb, &base, quant)?; Ok(Self { base, lm_head }) } @@ -1026,7 +1026,7 @@ impl TpQwen3_5ForCausalLM { ) -> Result { let cfg = &config.text_config; let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?; - let lm_head = build_lm_head(cfg, vb, &base)?; + let lm_head = build_lm_head(cfg, vb, &base, quant)?; Ok(Self { base, lm_head }) } @@ -1049,9 +1049,16 @@ fn build_lm_head( cfg: &TextConfig, vb: &ShardedVarBuilder, base: &TpQwen3_5Model, -) -> Result { + quant: Option, +) -> Result { if cfg.tie_word_embeddings { - Ok(Linear::new(base.embed_weight().clone(), None)) + // Tied: lm_head shares the embedding weight. Quantizing the + // shared tensor would corrupt the embedding lookup, so keep + // the lm_head plain even when `quant` is set. The memory win + // is already taken: only one copy of the (vocab, hidden) weight + // lives in VRAM in the tied case. + super::tp_linear::MaybeQuantLinear::from_weight(base.embed_weight().clone(), None) + .context("wrap tied lm_head") } else { // lm_head sits at the top level (sibling of `model.*`), NOT // under `model.language_model`. @@ -1060,7 +1067,7 @@ fn build_lm_head( (cfg.vocab_size, cfg.hidden_size), "weight", )?; - Ok(Linear::new(weight, None)) + super::tp_linear::MaybeQuantLinear::from_weight(weight, quant).context("wrap lm_head") } }