feat(stage-8e-2): plumb quant config from ModelSpec to TP load path
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m7s
CI / Test (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m47s
build-prerelease / Build neuron-ampere (push) Successful in 5m17s
build-prerelease / Build neuron-ada (push) Successful in 5m14s
build-prerelease / Build cortex binary (push) Successful in 18m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m7s

- LoadDenseShard RPC gains an optional `quant` string field.
- WorkerPool::load_dense_shard takes a `quant: Option<String>`,
  passes it via the RPC to workers and via parse_quant_string to
  the leader's local load.
- The Qwen3-Next TP load chain (ForCausalLM → Model → DecoderLayer
  → Attention / GatedDeltaNet / MLP) takes `quant: Option<GgmlDType>`
  end-to-end, calling Column/RowParallelLinear::load_with_quant.
- The fused in_proj_qkv inside TpQwen3_5GatedDeltaNet is now a
  MaybeQuantLinear so it also picks up quantization.
- parse_quant_string accepts q4_0/q4_1/q5_0/q5_1/q8_0/q8_1, q2k..q8k
  (with or without underscore), and f16/bf16/f32. Empty / None means
  no quantization.

Callers from candle.rs forward spec.quant through pool.load_dense_shard.
This means a `quant = "q5k"` in models.toml now flows end-to-end to a
QTensor-backed QMatMul for every per-rank linear in the Qwen3-Next
TP path. Leaves lm_head and the small replicated bias/log tensors in
their loaded dtype (Stage 8e-3).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 18:03:36 +03:00
parent bef159b21c
commit 4aa71902d0
5 changed files with 175 additions and 28 deletions

View File

@@ -1141,6 +1141,7 @@ impl CandleHarness {
&safetensors_paths, &safetensors_paths,
&leader_device, &leader_device,
candle_core::DType::BF16, candle_core::DType::BF16,
spec.quant.clone(),
) )
.await?; .await?;

View File

@@ -474,6 +474,7 @@ impl WorkerPool {
/// `init_nccl` must have completed first. Bails if the leader's /// `init_nccl` must have completed first. Bails if the leader's
/// NCCL comm isn't set up yet. /// NCCL comm isn't set up yet.
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub async fn load_dense_shard( pub async fn load_dense_shard(
&mut self, &mut self,
model_id: &str, model_id: &str,
@@ -481,6 +482,7 @@ impl WorkerPool {
safetensors_paths: &[std::path::PathBuf], safetensors_paths: &[std::path::PathBuf],
leader_device: &candle_core::Device, leader_device: &candle_core::Device,
dtype: candle_core::DType, dtype: candle_core::DType,
quant: Option<String>,
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> { ) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> {
use candle_nn::var_builder::ShardedSafeTensors; use candle_nn::var_builder::ShardedSafeTensors;
use std::sync::Arc; use std::sync::Arc;
@@ -510,6 +512,7 @@ impl WorkerPool {
model_id: model_id.to_string(), model_id: model_id.to_string(),
config_json: config_json.to_string(), config_json: config_json.to_string(),
safetensors_paths: safetensors_str.clone(), safetensors_paths: safetensors_str.clone(),
quant: quant.clone(),
}) })
.await?; .await?;
} }
@@ -531,6 +534,7 @@ impl WorkerPool {
let comm_for_leader = leader_comm; let comm_for_leader = leader_comm;
let model_id_for_log = model_id.to_string(); let model_id_for_log = model_id.to_string();
let config_json_for_leader = config_json.to_string(); let config_json_for_leader = config_json.to_string();
let quant_for_leader = quant.clone();
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> { let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
// SAFETY: same invariant as the single-GPU dense path — // SAFETY: same invariant as the single-GPU dense path —
@@ -558,8 +562,16 @@ impl WorkerPool {
let cfg: super::tp::tp_qwen3_5::Config = let cfg: super::tp::tp_qwen3_5::Config =
serde_json::from_str(&config_json_for_leader) serde_json::from_str(&config_json_for_leader)
.context("parse Qwen3-Next Config JSON for leader load")?; .context("parse Qwen3-Next Config JSON for leader load")?;
let quant_dtype =
super::tp::worker::parse_quant_string(quant_for_leader.as_deref())?;
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load( TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
cfg, &vb, &mmap, 0, world_size, comm, cfg,
&vb,
&mmap,
0,
world_size,
comm,
quant_dtype,
)?) )?)
} }
other => anyhow::bail!( other => anyhow::bail!(

View File

@@ -63,6 +63,13 @@ pub enum WorkerRequest {
/// Absolute paths the worker should mmap. The same set on every /// Absolute paths the worker should mmap. The same set on every
/// rank; ShardedVarBuilder slices into them per rank. /// rank; ShardedVarBuilder slices into them per rank.
safetensors_paths: Vec<String>, safetensors_paths: Vec<String>,
/// Optional in-situ quantization dtype (e.g. "q5k", "q8_0",
/// "q6k"). When set, each linear-layer weight is quantized
/// at load time to the named ggml format — saves ~3-5x vs
/// bf16/f16 at the cost of some accuracy. `None` keeps the
/// weights in the on-disk dtype (typically bf16).
#[serde(default)]
quant: Option<String>,
}, },
/// Run one forward step on this rank's loaded model. The worker /// Run one forward step on this rank's loaded model. The worker

View File

@@ -31,6 +31,7 @@
//! linear-attention block, lm_head, the rotary table. //! linear-attention block, lm_head, the rotary table.
use anyhow::{Context, Result, bail}; use anyhow::{Context, Result, bail};
use candle_core::quantized::GgmlDType;
use candle_core::safetensors::MmapedSafetensors; use candle_core::safetensors::MmapedSafetensors;
use candle_core::{DType, Device, IndexOp, Module, Tensor}; use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder; use candle_nn::var_builder::ShardedVarBuilder;
@@ -59,7 +60,7 @@ pub struct TpGatedDeltaNetState {
} }
pub(crate) struct TpQwen3_5GatedDeltaNet { pub(crate) struct TpQwen3_5GatedDeltaNet {
in_proj_qkv: Linear, in_proj_qkv: super::tp_linear::MaybeQuantLinear,
in_proj_z: ColumnParallelLinear, in_proj_z: ColumnParallelLinear,
in_proj_b: ColumnParallelLinear, in_proj_b: ColumnParallelLinear,
in_proj_a: ColumnParallelLinear, in_proj_a: ColumnParallelLinear,
@@ -92,6 +93,7 @@ pub(crate) struct TpQwen3_5GatedDeltaNet {
impl TpQwen3_5GatedDeltaNet { impl TpQwen3_5GatedDeltaNet {
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn load( pub fn load(
cfg: &TextConfig, cfg: &TextConfig,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
@@ -99,8 +101,9 @@ impl TpQwen3_5GatedDeltaNet {
rank: u32, rank: u32,
world_size: u32, world_size: u32,
comm: Arc<Comm>, comm: Arc<Comm>,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
Self::load_inner(cfg, vb, mmap, rank, world_size, comm) Self::load_inner(cfg, vb, mmap, rank, world_size, comm, quant)
} }
#[cfg(not(feature = "cuda"))] #[cfg(not(feature = "cuda"))]
@@ -110,10 +113,12 @@ impl TpQwen3_5GatedDeltaNet {
mmap: &MmapedSafetensors, mmap: &MmapedSafetensors,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
Self::load_inner(cfg, vb, mmap, rank, world_size) Self::load_inner(cfg, vb, mmap, rank, world_size, quant)
} }
#[allow(clippy::too_many_arguments)]
fn load_inner( fn load_inner(
cfg: &TextConfig, cfg: &TextConfig,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
@@ -121,6 +126,7 @@ impl TpQwen3_5GatedDeltaNet {
rank: u32, rank: u32,
world_size: u32, world_size: u32,
#[cfg(feature = "cuda")] comm: Arc<Comm>, #[cfg(feature = "cuda")] comm: Arc<Comm>,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
let ws = world_size as usize; let ws = world_size as usize;
let num_v_heads = cfg.linear_num_value_heads; let num_v_heads = cfg.linear_num_value_heads;
@@ -177,7 +183,9 @@ impl TpQwen3_5GatedDeltaNet {
dtype, dtype,
&device, &device,
)?; )?;
let in_proj_qkv = Linear::new(in_proj_qkv_weight, None); let in_proj_qkv =
super::tp_linear::MaybeQuantLinear::from_weight(in_proj_qkv_weight, quant)
.with_context(|| format!("wrap fused in_proj_qkv for '{}'", vb.prefix()))?;
let conv1d_name = format!("{}.conv1d.weight", vb.prefix()); let conv1d_name = format!("{}.conv1d.weight", vb.prefix());
let conv1d_weight = super::fused_load::load_fused_qkv_3d( let conv1d_weight = super::fused_load::load_fused_qkv_3d(
@@ -195,10 +203,13 @@ impl TpQwen3_5GatedDeltaNet {
// ----- Uniformly-sharded projections (along output dim 0). ----- // ----- Uniformly-sharded projections (along output dim 0). -----
// in_proj_z: hidden → value_dim, sharded along value_dim (V-head). // in_proj_z: hidden → value_dim, sharded along value_dim (V-head).
let in_proj_z = ColumnParallelLinear::load(&vb.pp("in_proj_z"), rank, world_size)?; let in_proj_z =
ColumnParallelLinear::load_with_quant(&vb.pp("in_proj_z"), rank, world_size, quant)?;
// in_proj_b, in_proj_a: hidden → num_v_heads, sharded along output. // in_proj_b, in_proj_a: hidden → num_v_heads, sharded along output.
let in_proj_b = ColumnParallelLinear::load(&vb.pp("in_proj_b"), rank, world_size)?; let in_proj_b =
let in_proj_a = ColumnParallelLinear::load(&vb.pp("in_proj_a"), rank, world_size)?; ColumnParallelLinear::load_with_quant(&vb.pp("in_proj_b"), rank, world_size, quant)?;
let in_proj_a =
ColumnParallelLinear::load_with_quant(&vb.pp("in_proj_a"), rank, world_size, quant)?;
// ----- Per-V-head 1D params (sharded uniformly). ----- // ----- Per-V-head 1D params (sharded uniformly). -----
let a_log = vb let a_log = vb
@@ -213,9 +224,11 @@ impl TpQwen3_5GatedDeltaNet {
// ----- Output projection: row-parallel + AllReduce. ----- // ----- Output projection: row-parallel + AllReduce. -----
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
let out_proj = RowParallelLinear::load(&vb.pp("out_proj"), rank, world_size, comm)?; let out_proj =
RowParallelLinear::load_with_quant(&vb.pp("out_proj"), rank, world_size, comm, quant)?;
#[cfg(not(feature = "cuda"))] #[cfg(not(feature = "cuda"))]
let out_proj = RowParallelLinear::load(&vb.pp("out_proj"), rank, world_size)?; let out_proj =
RowParallelLinear::load_with_quant(&vb.pp("out_proj"), rank, world_size, quant)?;
Ok(Self { Ok(Self {
in_proj_qkv, in_proj_qkv,
@@ -418,6 +431,7 @@ pub(crate) struct TpQwen3_5Attention {
impl TpQwen3_5Attention { impl TpQwen3_5Attention {
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn load( pub fn load(
cfg: &TextConfig, cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>, rotary: Arc<RotaryEmbedding>,
@@ -425,8 +439,9 @@ impl TpQwen3_5Attention {
rank: u32, rank: u32,
world_size: u32, world_size: u32,
comm: Arc<Comm>, comm: Arc<Comm>,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
Self::load_inner(cfg, rotary, vb, rank, world_size, comm) Self::load_inner(cfg, rotary, vb, rank, world_size, comm, quant)
} }
#[cfg(not(feature = "cuda"))] #[cfg(not(feature = "cuda"))]
@@ -436,10 +451,12 @@ impl TpQwen3_5Attention {
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
Self::load_inner(cfg, rotary, vb, rank, world_size) Self::load_inner(cfg, rotary, vb, rank, world_size, quant)
} }
#[allow(clippy::too_many_arguments)]
fn load_inner( fn load_inner(
cfg: &TextConfig, cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>, rotary: Arc<RotaryEmbedding>,
@@ -447,6 +464,7 @@ impl TpQwen3_5Attention {
rank: u32, rank: u32,
world_size: u32, world_size: u32,
#[cfg(feature = "cuda")] comm: Arc<Comm>, #[cfg(feature = "cuda")] comm: Arc<Comm>,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
let ws = world_size as usize; let ws = world_size as usize;
let num_heads = cfg.num_attention_heads; let num_heads = cfg.num_attention_heads;
@@ -468,13 +486,17 @@ impl TpQwen3_5Attention {
// consistently — rank R holds heads `[R*per_rank, (R+1)*per_rank)` // consistently — rank R holds heads `[R*per_rank, (R+1)*per_rank)`
// for both query AND gate, so the post-attention `gate.sigmoid()` // for both query AND gate, so the post-attention `gate.sigmoid()`
// multiply against the per-rank attention output matches up. // multiply against the per-rank attention output matches up.
let q_proj = ColumnParallelLinear::load(&vb.pp("q_proj"), rank, world_size)?; let q_proj =
let k_proj = ColumnParallelLinear::load(&vb.pp("k_proj"), rank, world_size)?; ColumnParallelLinear::load_with_quant(&vb.pp("q_proj"), rank, world_size, quant)?;
let v_proj = ColumnParallelLinear::load(&vb.pp("v_proj"), rank, world_size)?; let k_proj =
ColumnParallelLinear::load_with_quant(&vb.pp("k_proj"), rank, world_size, quant)?;
let v_proj =
ColumnParallelLinear::load_with_quant(&vb.pp("v_proj"), rank, world_size, quant)?;
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size, comm)?; let o_proj =
RowParallelLinear::load_with_quant(&vb.pp("o_proj"), rank, world_size, comm, quant)?;
#[cfg(not(feature = "cuda"))] #[cfg(not(feature = "cuda"))]
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size)?; let o_proj = RowParallelLinear::load_with_quant(&vb.pp("o_proj"), rank, world_size, quant)?;
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?; let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?; let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
@@ -572,12 +594,14 @@ pub(crate) struct TpQwen3_5MLP {
impl TpQwen3_5MLP { impl TpQwen3_5MLP {
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn load( pub fn load(
cfg: &TextConfig, cfg: &TextConfig,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
comm: Arc<Comm>, comm: Arc<Comm>,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
if !cfg.intermediate_size.is_multiple_of(world_size as usize) { if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
bail!( bail!(
@@ -587,9 +611,25 @@ impl TpQwen3_5MLP {
); );
} }
Ok(Self { Ok(Self {
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?, gate_proj: ColumnParallelLinear::load_with_quant(
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?, &vb.pp("gate_proj"),
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size, comm)?, rank,
world_size,
quant,
)?,
up_proj: ColumnParallelLinear::load_with_quant(
&vb.pp("up_proj"),
rank,
world_size,
quant,
)?,
down_proj: RowParallelLinear::load_with_quant(
&vb.pp("down_proj"),
rank,
world_size,
comm,
quant,
)?,
}) })
} }
@@ -599,6 +639,7 @@ impl TpQwen3_5MLP {
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
if !cfg.intermediate_size.is_multiple_of(world_size as usize) { if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
bail!( bail!(
@@ -608,9 +649,24 @@ impl TpQwen3_5MLP {
); );
} }
Ok(Self { Ok(Self {
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?, gate_proj: ColumnParallelLinear::load_with_quant(
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?, &vb.pp("gate_proj"),
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size)?, rank,
world_size,
quant,
)?,
up_proj: ColumnParallelLinear::load_with_quant(
&vb.pp("up_proj"),
rank,
world_size,
quant,
)?,
down_proj: RowParallelLinear::load_with_quant(
&vb.pp("down_proj"),
rank,
world_size,
quant,
)?,
}) })
} }
} }
@@ -649,6 +705,7 @@ impl TpQwen3_5DecoderLayer {
rank: u32, rank: u32,
world_size: u32, world_size: u32,
comm: Arc<Comm>, comm: Arc<Comm>,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
let layer_type = cfg let layer_type = cfg
.layer_types .layer_types
@@ -663,6 +720,7 @@ impl TpQwen3_5DecoderLayer {
rank, rank,
world_size, world_size,
comm.clone(), comm.clone(),
quant,
)?), )?),
"linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load( "linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load(
cfg, cfg,
@@ -671,10 +729,11 @@ impl TpQwen3_5DecoderLayer {
rank, rank,
world_size, world_size,
comm.clone(), comm.clone(),
quant,
)?), )?),
other => bail!("unknown layer_type '{other}' for layer {layer_idx}"), other => bail!("unknown layer_type '{other}' for layer {layer_idx}"),
}; };
let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm)?; let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm, quant)?;
let input_layernorm = let input_layernorm =
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?; Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
let post_attention_layernorm = Qwen3_5RmsNorm::load( let post_attention_layernorm = Qwen3_5RmsNorm::load(
@@ -691,6 +750,7 @@ impl TpQwen3_5DecoderLayer {
} }
#[cfg(not(feature = "cuda"))] #[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn load( pub fn load(
cfg: &TextConfig, cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>, rotary: Arc<RotaryEmbedding>,
@@ -699,6 +759,7 @@ impl TpQwen3_5DecoderLayer {
mmap: &MmapedSafetensors, mmap: &MmapedSafetensors,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
let layer_type = cfg let layer_type = cfg
.layer_types .layer_types
@@ -712,6 +773,7 @@ impl TpQwen3_5DecoderLayer {
&vb.pp("self_attn"), &vb.pp("self_attn"),
rank, rank,
world_size, world_size,
quant,
)?), )?),
"linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load( "linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load(
cfg, cfg,
@@ -719,10 +781,11 @@ impl TpQwen3_5DecoderLayer {
mmap, mmap,
rank, rank,
world_size, world_size,
quant,
)?), )?),
other => bail!("unknown layer_type '{other}' for layer {layer_idx}"), other => bail!("unknown layer_type '{other}' for layer {layer_idx}"),
}; };
let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size)?; let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size, quant)?;
let input_layernorm = let input_layernorm =
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?; Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
let post_attention_layernorm = Qwen3_5RmsNorm::load( let post_attention_layernorm = Qwen3_5RmsNorm::load(
@@ -775,6 +838,7 @@ pub struct TpQwen3_5Model {
impl TpQwen3_5Model { impl TpQwen3_5Model {
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn load( pub fn load(
cfg: &TextConfig, cfg: &TextConfig,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
@@ -782,6 +846,7 @@ impl TpQwen3_5Model {
rank: u32, rank: u32,
world_size: u32, world_size: u32,
comm: Arc<Comm>, comm: Arc<Comm>,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
let dtype = vb.dtype(); let dtype = vb.dtype();
let device = vb.device().clone(); let device = vb.device().clone();
@@ -817,6 +882,7 @@ impl TpQwen3_5Model {
rank, rank,
world_size, world_size,
comm.clone(), comm.clone(),
quant,
) )
.with_context(|| { .with_context(|| {
let (free_mb, total_mb) = cuda_mem_mb(&device); let (free_mb, total_mb) = cuda_mem_mb(&device);
@@ -844,6 +910,7 @@ impl TpQwen3_5Model {
mmap: &MmapedSafetensors, mmap: &MmapedSafetensors,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
let dtype = vb.dtype(); let dtype = vb.dtype();
let device = vb.device().clone(); let device = vb.device().clone();
@@ -877,6 +944,7 @@ impl TpQwen3_5Model {
mmap, mmap,
rank, rank,
world_size, world_size,
quant,
)?); )?);
} }
@@ -931,6 +999,7 @@ pub struct TpQwen3_5ForCausalLM {
impl TpQwen3_5ForCausalLM { impl TpQwen3_5ForCausalLM {
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn load( pub fn load(
config: Config, config: Config,
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
@@ -938,9 +1007,10 @@ impl TpQwen3_5ForCausalLM {
rank: u32, rank: u32,
world_size: u32, world_size: u32,
comm: Arc<Comm>, comm: Arc<Comm>,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
let cfg = &config.text_config; let cfg = &config.text_config;
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm)?; let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?;
let lm_head = build_lm_head(cfg, vb, &base)?; let lm_head = build_lm_head(cfg, vb, &base)?;
Ok(Self { base, lm_head }) Ok(Self { base, lm_head })
} }
@@ -952,9 +1022,10 @@ impl TpQwen3_5ForCausalLM {
mmap: &MmapedSafetensors, mmap: &MmapedSafetensors,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
let cfg = &config.text_config; let cfg = &config.text_config;
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size)?; let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?;
let lm_head = build_lm_head(cfg, vb, &base)?; let lm_head = build_lm_head(cfg, vb, &base)?;
Ok(Self { base, lm_head }) Ok(Self { base, lm_head })
} }

View File

@@ -160,7 +160,8 @@ impl WorkerState {
model_id, model_id,
config_json, config_json,
safetensors_paths, safetensors_paths,
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths), quant,
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths, quant),
WorkerRequest::GenerateStep { WorkerRequest::GenerateStep {
model_id, model_id,
tokens, tokens,
@@ -178,6 +179,7 @@ impl WorkerState {
model_id: String, model_id: String,
config_json: String, config_json: String,
safetensors_paths: Vec<String>, safetensors_paths: Vec<String>,
quant: Option<String>,
) -> WorkerResponse { ) -> WorkerResponse {
use crate::harness::arch::qwen3_5 as qwen3_5_arch; use crate::harness::arch::qwen3_5 as qwen3_5_arch;
use candle_core::{DType, Device}; use candle_core::{DType, Device};
@@ -185,6 +187,16 @@ impl WorkerState {
use candle_transformers::models::qwen3 as qwen3_dense; use candle_transformers::models::qwen3 as qwen3_dense;
use std::path::PathBuf; use std::path::PathBuf;
let quant_dtype = match parse_quant_string(quant.as_deref()) {
Ok(q) => q,
Err(e) => {
return WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("parse quant: {e}"),
};
}
};
if self.models.contains_key(&model_id) { if self.models.contains_key(&model_id) {
return WorkerResponse::Error { return WorkerResponse::Error {
kind: "already_loaded".into(), kind: "already_loaded".into(),
@@ -290,6 +302,7 @@ impl WorkerState {
self.config.rank, self.config.rank,
self.config.world_size, self.config.world_size,
comm, comm,
quant_dtype,
) { ) {
Ok(m) => WorkerModel::Qwen3_5(m), Ok(m) => WorkerModel::Qwen3_5(m),
Err(e) => { Err(e) => {
@@ -326,6 +339,7 @@ impl WorkerState {
_model_id: String, _model_id: String,
_config_json: String, _config_json: String,
_safetensors_paths: Vec<String>, _safetensors_paths: Vec<String>,
_quant: Option<String>,
) -> WorkerResponse { ) -> WorkerResponse {
WorkerResponse::Error { WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(), kind: "cuda_feature_not_enabled".into(),
@@ -444,3 +458,45 @@ impl WorkerState {
} }
} }
} }
/// Parse a `ModelSpec.quant` string into a `GgmlDType`. Accepts the
/// common ggml format names (case-insensitive). `None` and `Some("")`
/// both map to "no quantization".
///
/// Supported: `q4_0`, `q4_1`, `q5_0`, `q5_1`, `q8_0`, `q8_1`,
/// `q2k`/`q2_k`, `q3k`/`q3_k`, `q4k`/`q4_k`, `q5k`/`q5_k`,
/// `q6k`/`q6_k`, `q8k`/`q8_k`, `f16`, `bf16`, `f32`. The underscore
/// is optional and the prefix is case-insensitive.
#[cfg(feature = "cuda")]
pub(crate) fn parse_quant_string(
s: Option<&str>,
) -> anyhow::Result<Option<candle_core::quantized::GgmlDType>> {
use candle_core::quantized::GgmlDType;
let s = match s {
Some(s) if !s.is_empty() => s,
_ => return Ok(None),
};
let normalised = s.to_ascii_lowercase().replace('_', "");
let dtype = match normalised.as_str() {
"q40" => GgmlDType::Q4_0,
"q41" => GgmlDType::Q4_1,
"q50" => GgmlDType::Q5_0,
"q51" => GgmlDType::Q5_1,
"q80" => GgmlDType::Q8_0,
"q81" => GgmlDType::Q8_1,
"q2k" => GgmlDType::Q2K,
"q3k" => GgmlDType::Q3K,
"q4k" | "q4km" => GgmlDType::Q4K,
"q5k" | "q5km" => GgmlDType::Q5K,
"q6k" => GgmlDType::Q6K,
"q8k" => GgmlDType::Q8K,
"f16" => GgmlDType::F16,
"bf16" => GgmlDType::BF16,
"f32" => GgmlDType::F32,
other => anyhow::bail!(
"unknown quant '{other}' (expected one of: q4_0, q4_1, q5_0, q5_1, q8_0, \
q8_1, q2k, q3k, q4k, q5k, q6k, q8k, f16, bf16, f32)"
),
};
Ok(Some(dtype))
}