From 4aa71902d081d16176b6a84576e7568ff00d624d Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Thu, 21 May 2026 18:03:36 +0300 Subject: [PATCH] feat(stage-8e-2): plumb quant config from ModelSpec to TP load path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - LoadDenseShard RPC gains an optional `quant` string field. - WorkerPool::load_dense_shard takes a `quant: Option`, passes it via the RPC to workers and via parse_quant_string to the leader's local load. - The Qwen3-Next TP load chain (ForCausalLM → Model → DecoderLayer → Attention / GatedDeltaNet / MLP) takes `quant: Option` end-to-end, calling Column/RowParallelLinear::load_with_quant. - The fused in_proj_qkv inside TpQwen3_5GatedDeltaNet is now a MaybeQuantLinear so it also picks up quantization. - parse_quant_string accepts q4_0/q4_1/q5_0/q5_1/q8_0/q8_1, q2k..q8k (with or without underscore), and f16/bf16/f32. Empty / None means no quantization. Callers from candle.rs forward spec.quant through pool.load_dense_shard. This means a `quant = "q5k"` in models.toml now flows end-to-end to a QTensor-backed QMatMul for every per-rank linear in the Qwen3-Next TP path. Leaves lm_head and the small replicated bias/log tensors in their loaded dtype (Stage 8e-3). Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/neuron/src/harness/candle.rs | 1 + crates/neuron/src/harness/tp/mod.rs | 14 ++- crates/neuron/src/harness/tp/rpc.rs | 7 ++ crates/neuron/src/harness/tp/tp_qwen3_5.rs | 123 ++++++++++++++++----- crates/neuron/src/harness/tp/worker.rs | 58 +++++++++- 5 files changed, 175 insertions(+), 28 deletions(-) diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 24dbec6..0d57f05 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -1141,6 +1141,7 @@ impl CandleHarness { &safetensors_paths, &leader_device, candle_core::DType::BF16, + spec.quant.clone(), ) .await?; diff --git a/crates/neuron/src/harness/tp/mod.rs b/crates/neuron/src/harness/tp/mod.rs index bc0e22d..a637b9a 100644 --- a/crates/neuron/src/harness/tp/mod.rs +++ b/crates/neuron/src/harness/tp/mod.rs @@ -474,6 +474,7 @@ impl WorkerPool { /// `init_nccl` must have completed first. Bails if the leader's /// NCCL comm isn't set up yet. #[cfg(feature = "cuda")] + #[allow(clippy::too_many_arguments)] pub async fn load_dense_shard( &mut self, model_id: &str, @@ -481,6 +482,7 @@ impl WorkerPool { safetensors_paths: &[std::path::PathBuf], leader_device: &candle_core::Device, dtype: candle_core::DType, + quant: Option, ) -> Result>> { use candle_nn::var_builder::ShardedSafeTensors; use std::sync::Arc; @@ -510,6 +512,7 @@ impl WorkerPool { model_id: model_id.to_string(), config_json: config_json.to_string(), safetensors_paths: safetensors_str.clone(), + quant: quant.clone(), }) .await?; } @@ -531,6 +534,7 @@ impl WorkerPool { let comm_for_leader = leader_comm; let model_id_for_log = model_id.to_string(); let config_json_for_leader = config_json.to_string(); + let quant_for_leader = quant.clone(); let leader_model = tokio::task::spawn_blocking(move || -> Result { // SAFETY: same invariant as the single-GPU dense path — @@ -558,8 +562,16 @@ impl WorkerPool { let cfg: super::tp::tp_qwen3_5::Config = serde_json::from_str(&config_json_for_leader) .context("parse Qwen3-Next Config JSON for leader load")?; + let quant_dtype = + super::tp::worker::parse_quant_string(quant_for_leader.as_deref())?; TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load( - cfg, &vb, &mmap, 0, world_size, comm, + cfg, + &vb, + &mmap, + 0, + world_size, + comm, + quant_dtype, )?) } other => anyhow::bail!( diff --git a/crates/neuron/src/harness/tp/rpc.rs b/crates/neuron/src/harness/tp/rpc.rs index 2a444a6..5c0c540 100644 --- a/crates/neuron/src/harness/tp/rpc.rs +++ b/crates/neuron/src/harness/tp/rpc.rs @@ -63,6 +63,13 @@ pub enum WorkerRequest { /// Absolute paths the worker should mmap. The same set on every /// rank; ShardedVarBuilder slices into them per rank. safetensors_paths: Vec, + /// Optional in-situ quantization dtype (e.g. "q5k", "q8_0", + /// "q6k"). When set, each linear-layer weight is quantized + /// at load time to the named ggml format — saves ~3-5x vs + /// bf16/f16 at the cost of some accuracy. `None` keeps the + /// weights in the on-disk dtype (typically bf16). + #[serde(default)] + quant: Option, }, /// Run one forward step on this rank's loaded model. The worker diff --git a/crates/neuron/src/harness/tp/tp_qwen3_5.rs b/crates/neuron/src/harness/tp/tp_qwen3_5.rs index a023f70..4968d62 100644 --- a/crates/neuron/src/harness/tp/tp_qwen3_5.rs +++ b/crates/neuron/src/harness/tp/tp_qwen3_5.rs @@ -31,6 +31,7 @@ //! linear-attention block, lm_head, the rotary table. use anyhow::{Context, Result, bail}; +use candle_core::quantized::GgmlDType; use candle_core::safetensors::MmapedSafetensors; use candle_core::{DType, Device, IndexOp, Module, Tensor}; use candle_nn::var_builder::ShardedVarBuilder; @@ -59,7 +60,7 @@ pub struct TpGatedDeltaNetState { } pub(crate) struct TpQwen3_5GatedDeltaNet { - in_proj_qkv: Linear, + in_proj_qkv: super::tp_linear::MaybeQuantLinear, in_proj_z: ColumnParallelLinear, in_proj_b: ColumnParallelLinear, in_proj_a: ColumnParallelLinear, @@ -92,6 +93,7 @@ pub(crate) struct TpQwen3_5GatedDeltaNet { impl TpQwen3_5GatedDeltaNet { #[cfg(feature = "cuda")] + #[allow(clippy::too_many_arguments)] pub fn load( cfg: &TextConfig, vb: &ShardedVarBuilder, @@ -99,8 +101,9 @@ impl TpQwen3_5GatedDeltaNet { rank: u32, world_size: u32, comm: Arc, + quant: Option, ) -> Result { - Self::load_inner(cfg, vb, mmap, rank, world_size, comm) + Self::load_inner(cfg, vb, mmap, rank, world_size, comm, quant) } #[cfg(not(feature = "cuda"))] @@ -110,10 +113,12 @@ impl TpQwen3_5GatedDeltaNet { mmap: &MmapedSafetensors, rank: u32, world_size: u32, + quant: Option, ) -> Result { - Self::load_inner(cfg, vb, mmap, rank, world_size) + Self::load_inner(cfg, vb, mmap, rank, world_size, quant) } + #[allow(clippy::too_many_arguments)] fn load_inner( cfg: &TextConfig, vb: &ShardedVarBuilder, @@ -121,6 +126,7 @@ impl TpQwen3_5GatedDeltaNet { rank: u32, world_size: u32, #[cfg(feature = "cuda")] comm: Arc, + quant: Option, ) -> Result { let ws = world_size as usize; let num_v_heads = cfg.linear_num_value_heads; @@ -177,7 +183,9 @@ impl TpQwen3_5GatedDeltaNet { dtype, &device, )?; - let in_proj_qkv = Linear::new(in_proj_qkv_weight, None); + let in_proj_qkv = + super::tp_linear::MaybeQuantLinear::from_weight(in_proj_qkv_weight, quant) + .with_context(|| format!("wrap fused in_proj_qkv for '{}'", vb.prefix()))?; let conv1d_name = format!("{}.conv1d.weight", vb.prefix()); let conv1d_weight = super::fused_load::load_fused_qkv_3d( @@ -195,10 +203,13 @@ impl TpQwen3_5GatedDeltaNet { // ----- Uniformly-sharded projections (along output dim 0). ----- // in_proj_z: hidden → value_dim, sharded along value_dim (V-head). - let in_proj_z = ColumnParallelLinear::load(&vb.pp("in_proj_z"), rank, world_size)?; + let in_proj_z = + ColumnParallelLinear::load_with_quant(&vb.pp("in_proj_z"), rank, world_size, quant)?; // in_proj_b, in_proj_a: hidden → num_v_heads, sharded along output. - let in_proj_b = ColumnParallelLinear::load(&vb.pp("in_proj_b"), rank, world_size)?; - let in_proj_a = ColumnParallelLinear::load(&vb.pp("in_proj_a"), rank, world_size)?; + let in_proj_b = + ColumnParallelLinear::load_with_quant(&vb.pp("in_proj_b"), rank, world_size, quant)?; + let in_proj_a = + ColumnParallelLinear::load_with_quant(&vb.pp("in_proj_a"), rank, world_size, quant)?; // ----- Per-V-head 1D params (sharded uniformly). ----- let a_log = vb @@ -213,9 +224,11 @@ impl TpQwen3_5GatedDeltaNet { // ----- Output projection: row-parallel + AllReduce. ----- #[cfg(feature = "cuda")] - let out_proj = RowParallelLinear::load(&vb.pp("out_proj"), rank, world_size, comm)?; + let out_proj = + RowParallelLinear::load_with_quant(&vb.pp("out_proj"), rank, world_size, comm, quant)?; #[cfg(not(feature = "cuda"))] - let out_proj = RowParallelLinear::load(&vb.pp("out_proj"), rank, world_size)?; + let out_proj = + RowParallelLinear::load_with_quant(&vb.pp("out_proj"), rank, world_size, quant)?; Ok(Self { in_proj_qkv, @@ -418,6 +431,7 @@ pub(crate) struct TpQwen3_5Attention { impl TpQwen3_5Attention { #[cfg(feature = "cuda")] + #[allow(clippy::too_many_arguments)] pub fn load( cfg: &TextConfig, rotary: Arc, @@ -425,8 +439,9 @@ impl TpQwen3_5Attention { rank: u32, world_size: u32, comm: Arc, + quant: Option, ) -> Result { - Self::load_inner(cfg, rotary, vb, rank, world_size, comm) + Self::load_inner(cfg, rotary, vb, rank, world_size, comm, quant) } #[cfg(not(feature = "cuda"))] @@ -436,10 +451,12 @@ impl TpQwen3_5Attention { vb: &ShardedVarBuilder, rank: u32, world_size: u32, + quant: Option, ) -> Result { - Self::load_inner(cfg, rotary, vb, rank, world_size) + Self::load_inner(cfg, rotary, vb, rank, world_size, quant) } + #[allow(clippy::too_many_arguments)] fn load_inner( cfg: &TextConfig, rotary: Arc, @@ -447,6 +464,7 @@ impl TpQwen3_5Attention { rank: u32, world_size: u32, #[cfg(feature = "cuda")] comm: Arc, + quant: Option, ) -> Result { let ws = world_size as usize; let num_heads = cfg.num_attention_heads; @@ -468,13 +486,17 @@ impl TpQwen3_5Attention { // consistently — rank R holds heads `[R*per_rank, (R+1)*per_rank)` // for both query AND gate, so the post-attention `gate.sigmoid()` // multiply against the per-rank attention output matches up. - let q_proj = ColumnParallelLinear::load(&vb.pp("q_proj"), rank, world_size)?; - let k_proj = ColumnParallelLinear::load(&vb.pp("k_proj"), rank, world_size)?; - let v_proj = ColumnParallelLinear::load(&vb.pp("v_proj"), rank, world_size)?; + let q_proj = + ColumnParallelLinear::load_with_quant(&vb.pp("q_proj"), rank, world_size, quant)?; + let k_proj = + ColumnParallelLinear::load_with_quant(&vb.pp("k_proj"), rank, world_size, quant)?; + let v_proj = + ColumnParallelLinear::load_with_quant(&vb.pp("v_proj"), rank, world_size, quant)?; #[cfg(feature = "cuda")] - let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size, comm)?; + let o_proj = + RowParallelLinear::load_with_quant(&vb.pp("o_proj"), rank, world_size, comm, quant)?; #[cfg(not(feature = "cuda"))] - let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size)?; + let o_proj = RowParallelLinear::load_with_quant(&vb.pp("o_proj"), rank, world_size, quant)?; let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?; let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?; @@ -572,12 +594,14 @@ pub(crate) struct TpQwen3_5MLP { impl TpQwen3_5MLP { #[cfg(feature = "cuda")] + #[allow(clippy::too_many_arguments)] pub fn load( cfg: &TextConfig, vb: &ShardedVarBuilder, rank: u32, world_size: u32, comm: Arc, + quant: Option, ) -> Result { if !cfg.intermediate_size.is_multiple_of(world_size as usize) { bail!( @@ -587,9 +611,25 @@ impl TpQwen3_5MLP { ); } Ok(Self { - gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?, - up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?, - down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size, comm)?, + gate_proj: ColumnParallelLinear::load_with_quant( + &vb.pp("gate_proj"), + rank, + world_size, + quant, + )?, + up_proj: ColumnParallelLinear::load_with_quant( + &vb.pp("up_proj"), + rank, + world_size, + quant, + )?, + down_proj: RowParallelLinear::load_with_quant( + &vb.pp("down_proj"), + rank, + world_size, + comm, + quant, + )?, }) } @@ -599,6 +639,7 @@ impl TpQwen3_5MLP { vb: &ShardedVarBuilder, rank: u32, world_size: u32, + quant: Option, ) -> Result { if !cfg.intermediate_size.is_multiple_of(world_size as usize) { bail!( @@ -608,9 +649,24 @@ impl TpQwen3_5MLP { ); } Ok(Self { - gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?, - up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?, - down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size)?, + gate_proj: ColumnParallelLinear::load_with_quant( + &vb.pp("gate_proj"), + rank, + world_size, + quant, + )?, + up_proj: ColumnParallelLinear::load_with_quant( + &vb.pp("up_proj"), + rank, + world_size, + quant, + )?, + down_proj: RowParallelLinear::load_with_quant( + &vb.pp("down_proj"), + rank, + world_size, + quant, + )?, }) } } @@ -649,6 +705,7 @@ impl TpQwen3_5DecoderLayer { rank: u32, world_size: u32, comm: Arc, + quant: Option, ) -> Result { let layer_type = cfg .layer_types @@ -663,6 +720,7 @@ impl TpQwen3_5DecoderLayer { rank, world_size, comm.clone(), + quant, )?), "linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load( cfg, @@ -671,10 +729,11 @@ impl TpQwen3_5DecoderLayer { rank, world_size, comm.clone(), + quant, )?), other => bail!("unknown layer_type '{other}' for layer {layer_idx}"), }; - let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm)?; + let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm, quant)?; let input_layernorm = Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?; let post_attention_layernorm = Qwen3_5RmsNorm::load( @@ -691,6 +750,7 @@ impl TpQwen3_5DecoderLayer { } #[cfg(not(feature = "cuda"))] + #[allow(clippy::too_many_arguments)] pub fn load( cfg: &TextConfig, rotary: Arc, @@ -699,6 +759,7 @@ impl TpQwen3_5DecoderLayer { mmap: &MmapedSafetensors, rank: u32, world_size: u32, + quant: Option, ) -> Result { let layer_type = cfg .layer_types @@ -712,6 +773,7 @@ impl TpQwen3_5DecoderLayer { &vb.pp("self_attn"), rank, world_size, + quant, )?), "linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load( cfg, @@ -719,10 +781,11 @@ impl TpQwen3_5DecoderLayer { mmap, rank, world_size, + quant, )?), other => bail!("unknown layer_type '{other}' for layer {layer_idx}"), }; - let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size)?; + let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size, quant)?; let input_layernorm = Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?; let post_attention_layernorm = Qwen3_5RmsNorm::load( @@ -775,6 +838,7 @@ pub struct TpQwen3_5Model { impl TpQwen3_5Model { #[cfg(feature = "cuda")] + #[allow(clippy::too_many_arguments)] pub fn load( cfg: &TextConfig, vb: &ShardedVarBuilder, @@ -782,6 +846,7 @@ impl TpQwen3_5Model { rank: u32, world_size: u32, comm: Arc, + quant: Option, ) -> Result { let dtype = vb.dtype(); let device = vb.device().clone(); @@ -817,6 +882,7 @@ impl TpQwen3_5Model { rank, world_size, comm.clone(), + quant, ) .with_context(|| { let (free_mb, total_mb) = cuda_mem_mb(&device); @@ -844,6 +910,7 @@ impl TpQwen3_5Model { mmap: &MmapedSafetensors, rank: u32, world_size: u32, + quant: Option, ) -> Result { let dtype = vb.dtype(); let device = vb.device().clone(); @@ -877,6 +944,7 @@ impl TpQwen3_5Model { mmap, rank, world_size, + quant, )?); } @@ -931,6 +999,7 @@ pub struct TpQwen3_5ForCausalLM { impl TpQwen3_5ForCausalLM { #[cfg(feature = "cuda")] + #[allow(clippy::too_many_arguments)] pub fn load( config: Config, vb: &ShardedVarBuilder, @@ -938,9 +1007,10 @@ impl TpQwen3_5ForCausalLM { rank: u32, world_size: u32, comm: Arc, + quant: Option, ) -> Result { let cfg = &config.text_config; - let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm)?; + let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?; let lm_head = build_lm_head(cfg, vb, &base)?; Ok(Self { base, lm_head }) } @@ -952,9 +1022,10 @@ impl TpQwen3_5ForCausalLM { mmap: &MmapedSafetensors, rank: u32, world_size: u32, + quant: Option, ) -> Result { let cfg = &config.text_config; - let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size)?; + let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?; let lm_head = build_lm_head(cfg, vb, &base)?; Ok(Self { base, lm_head }) } diff --git a/crates/neuron/src/harness/tp/worker.rs b/crates/neuron/src/harness/tp/worker.rs index 98b70a5..6ec9fd8 100644 --- a/crates/neuron/src/harness/tp/worker.rs +++ b/crates/neuron/src/harness/tp/worker.rs @@ -160,7 +160,8 @@ impl WorkerState { model_id, config_json, safetensors_paths, - } => self.handle_load_dense_shard(model_id, config_json, safetensors_paths), + quant, + } => self.handle_load_dense_shard(model_id, config_json, safetensors_paths, quant), WorkerRequest::GenerateStep { model_id, tokens, @@ -178,6 +179,7 @@ impl WorkerState { model_id: String, config_json: String, safetensors_paths: Vec, + quant: Option, ) -> WorkerResponse { use crate::harness::arch::qwen3_5 as qwen3_5_arch; use candle_core::{DType, Device}; @@ -185,6 +187,16 @@ impl WorkerState { use candle_transformers::models::qwen3 as qwen3_dense; use std::path::PathBuf; + let quant_dtype = match parse_quant_string(quant.as_deref()) { + Ok(q) => q, + Err(e) => { + return WorkerResponse::Error { + kind: "bad_request".into(), + message: format!("parse quant: {e}"), + }; + } + }; + if self.models.contains_key(&model_id) { return WorkerResponse::Error { kind: "already_loaded".into(), @@ -290,6 +302,7 @@ impl WorkerState { self.config.rank, self.config.world_size, comm, + quant_dtype, ) { Ok(m) => WorkerModel::Qwen3_5(m), Err(e) => { @@ -326,6 +339,7 @@ impl WorkerState { _model_id: String, _config_json: String, _safetensors_paths: Vec, + _quant: Option, ) -> WorkerResponse { WorkerResponse::Error { kind: "cuda_feature_not_enabled".into(), @@ -444,3 +458,45 @@ impl WorkerState { } } } + +/// Parse a `ModelSpec.quant` string into a `GgmlDType`. Accepts the +/// common ggml format names (case-insensitive). `None` and `Some("")` +/// both map to "no quantization". +/// +/// Supported: `q4_0`, `q4_1`, `q5_0`, `q5_1`, `q8_0`, `q8_1`, +/// `q2k`/`q2_k`, `q3k`/`q3_k`, `q4k`/`q4_k`, `q5k`/`q5_k`, +/// `q6k`/`q6_k`, `q8k`/`q8_k`, `f16`, `bf16`, `f32`. The underscore +/// is optional and the prefix is case-insensitive. +#[cfg(feature = "cuda")] +pub(crate) fn parse_quant_string( + s: Option<&str>, +) -> anyhow::Result> { + use candle_core::quantized::GgmlDType; + let s = match s { + Some(s) if !s.is_empty() => s, + _ => return Ok(None), + }; + let normalised = s.to_ascii_lowercase().replace('_', ""); + let dtype = match normalised.as_str() { + "q40" => GgmlDType::Q4_0, + "q41" => GgmlDType::Q4_1, + "q50" => GgmlDType::Q5_0, + "q51" => GgmlDType::Q5_1, + "q80" => GgmlDType::Q8_0, + "q81" => GgmlDType::Q8_1, + "q2k" => GgmlDType::Q2K, + "q3k" => GgmlDType::Q3K, + "q4k" | "q4km" => GgmlDType::Q4K, + "q5k" | "q5km" => GgmlDType::Q5K, + "q6k" => GgmlDType::Q6K, + "q8k" => GgmlDType::Q8K, + "f16" => GgmlDType::F16, + "bf16" => GgmlDType::BF16, + "f32" => GgmlDType::F32, + other => anyhow::bail!( + "unknown quant '{other}' (expected one of: q4_0, q4_1, q5_0, q5_1, q8_0, \ + q8_1, q2k, q3k, q4k, q5k, q6k, q8k, f16, bf16, f32)" + ), + }; + Ok(Some(dtype)) +}