diff --git a/crates/neuron/src/harness/tp/tp_qwen3_5.rs b/crates/neuron/src/harness/tp/tp_qwen3_5.rs index 4968d62..f501f4e 100644 --- a/crates/neuron/src/harness/tp/tp_qwen3_5.rs +++ b/crates/neuron/src/harness/tp/tp_qwen3_5.rs @@ -35,7 +35,7 @@ use candle_core::quantized::GgmlDType; use candle_core::safetensors::MmapedSafetensors; use candle_core::{DType, Device, IndexOp, Module, Tensor}; use candle_nn::var_builder::ShardedVarBuilder; -use candle_nn::{Embedding, Linear, kv_cache::ConcatKvCache}; +use candle_nn::{Embedding, kv_cache::ConcatKvCache}; use candle_transformers::utils::repeat_kv; use std::sync::Arc; @@ -994,7 +994,7 @@ impl TpQwen3_5Model { pub struct TpQwen3_5ForCausalLM { base: TpQwen3_5Model, - lm_head: Linear, + lm_head: super::tp_linear::MaybeQuantLinear, } impl TpQwen3_5ForCausalLM { @@ -1011,7 +1011,7 @@ impl TpQwen3_5ForCausalLM { ) -> Result { let cfg = &config.text_config; let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?; - let lm_head = build_lm_head(cfg, vb, &base)?; + let lm_head = build_lm_head(cfg, vb, &base, quant)?; Ok(Self { base, lm_head }) } @@ -1026,7 +1026,7 @@ impl TpQwen3_5ForCausalLM { ) -> Result { let cfg = &config.text_config; let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?; - let lm_head = build_lm_head(cfg, vb, &base)?; + let lm_head = build_lm_head(cfg, vb, &base, quant)?; Ok(Self { base, lm_head }) } @@ -1049,9 +1049,16 @@ fn build_lm_head( cfg: &TextConfig, vb: &ShardedVarBuilder, base: &TpQwen3_5Model, -) -> Result { + quant: Option, +) -> Result { if cfg.tie_word_embeddings { - Ok(Linear::new(base.embed_weight().clone(), None)) + // Tied: lm_head shares the embedding weight. Quantizing the + // shared tensor would corrupt the embedding lookup, so keep + // the lm_head plain even when `quant` is set. The memory win + // is already taken: only one copy of the (vocab, hidden) weight + // lives in VRAM in the tied case. + super::tp_linear::MaybeQuantLinear::from_weight(base.embed_weight().clone(), None) + .context("wrap tied lm_head") } else { // lm_head sits at the top level (sibling of `model.*`), NOT // under `model.language_model`. @@ -1060,7 +1067,7 @@ fn build_lm_head( (cfg.vocab_size, cfg.hidden_size), "weight", )?; - Ok(Linear::new(weight, None)) + super::tp_linear::MaybeQuantLinear::from_weight(weight, quant).context("wrap lm_head") } }