feat(stage-8e-2): plumb quant config from ModelSpec to TP load path
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m7s
CI / Test (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m47s
build-prerelease / Build neuron-ampere (push) Successful in 5m17s
build-prerelease / Build neuron-ada (push) Successful in 5m14s
build-prerelease / Build cortex binary (push) Successful in 18m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m7s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m7s
CI / Test (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m47s
build-prerelease / Build neuron-ampere (push) Successful in 5m17s
build-prerelease / Build neuron-ada (push) Successful in 5m14s
build-prerelease / Build cortex binary (push) Successful in 18m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m7s
- LoadDenseShard RPC gains an optional `quant` string field. - WorkerPool::load_dense_shard takes a `quant: Option<String>`, passes it via the RPC to workers and via parse_quant_string to the leader's local load. - The Qwen3-Next TP load chain (ForCausalLM → Model → DecoderLayer → Attention / GatedDeltaNet / MLP) takes `quant: Option<GgmlDType>` end-to-end, calling Column/RowParallelLinear::load_with_quant. - The fused in_proj_qkv inside TpQwen3_5GatedDeltaNet is now a MaybeQuantLinear so it also picks up quantization. - parse_quant_string accepts q4_0/q4_1/q5_0/q5_1/q8_0/q8_1, q2k..q8k (with or without underscore), and f16/bf16/f32. Empty / None means no quantization. Callers from candle.rs forward spec.quant through pool.load_dense_shard. This means a `quant = "q5k"` in models.toml now flows end-to-end to a QTensor-backed QMatMul for every per-rank linear in the Qwen3-Next TP path. Leaves lm_head and the small replicated bias/log tensors in their loaded dtype (Stage 8e-3). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1141,6 +1141,7 @@ impl CandleHarness {
|
||||
&safetensors_paths,
|
||||
&leader_device,
|
||||
candle_core::DType::BF16,
|
||||
spec.quant.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -474,6 +474,7 @@ impl WorkerPool {
|
||||
/// `init_nccl` must have completed first. Bails if the leader's
|
||||
/// NCCL comm isn't set up yet.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn load_dense_shard(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
@@ -481,6 +482,7 @@ impl WorkerPool {
|
||||
safetensors_paths: &[std::path::PathBuf],
|
||||
leader_device: &candle_core::Device,
|
||||
dtype: candle_core::DType,
|
||||
quant: Option<String>,
|
||||
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> {
|
||||
use candle_nn::var_builder::ShardedSafeTensors;
|
||||
use std::sync::Arc;
|
||||
@@ -510,6 +512,7 @@ impl WorkerPool {
|
||||
model_id: model_id.to_string(),
|
||||
config_json: config_json.to_string(),
|
||||
safetensors_paths: safetensors_str.clone(),
|
||||
quant: quant.clone(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
@@ -531,6 +534,7 @@ impl WorkerPool {
|
||||
let comm_for_leader = leader_comm;
|
||||
let model_id_for_log = model_id.to_string();
|
||||
let config_json_for_leader = config_json.to_string();
|
||||
let quant_for_leader = quant.clone();
|
||||
|
||||
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
|
||||
// SAFETY: same invariant as the single-GPU dense path —
|
||||
@@ -558,8 +562,16 @@ impl WorkerPool {
|
||||
let cfg: super::tp::tp_qwen3_5::Config =
|
||||
serde_json::from_str(&config_json_for_leader)
|
||||
.context("parse Qwen3-Next Config JSON for leader load")?;
|
||||
let quant_dtype =
|
||||
super::tp::worker::parse_quant_string(quant_for_leader.as_deref())?;
|
||||
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
|
||||
cfg, &vb, &mmap, 0, world_size, comm,
|
||||
cfg,
|
||||
&vb,
|
||||
&mmap,
|
||||
0,
|
||||
world_size,
|
||||
comm,
|
||||
quant_dtype,
|
||||
)?)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
|
||||
@@ -63,6 +63,13 @@ pub enum WorkerRequest {
|
||||
/// Absolute paths the worker should mmap. The same set on every
|
||||
/// rank; ShardedVarBuilder slices into them per rank.
|
||||
safetensors_paths: Vec<String>,
|
||||
/// Optional in-situ quantization dtype (e.g. "q5k", "q8_0",
|
||||
/// "q6k"). When set, each linear-layer weight is quantized
|
||||
/// at load time to the named ggml format — saves ~3-5x vs
|
||||
/// bf16/f16 at the cost of some accuracy. `None` keeps the
|
||||
/// weights in the on-disk dtype (typically bf16).
|
||||
#[serde(default)]
|
||||
quant: Option<String>,
|
||||
},
|
||||
|
||||
/// Run one forward step on this rank's loaded model. The worker
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
//! linear-attention block, lm_head, the rotary table.
|
||||
|
||||
use anyhow::{Context, Result, bail};
|
||||
use candle_core::quantized::GgmlDType;
|
||||
use candle_core::safetensors::MmapedSafetensors;
|
||||
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
@@ -59,7 +60,7 @@ pub struct TpGatedDeltaNetState {
|
||||
}
|
||||
|
||||
pub(crate) struct TpQwen3_5GatedDeltaNet {
|
||||
in_proj_qkv: Linear,
|
||||
in_proj_qkv: super::tp_linear::MaybeQuantLinear,
|
||||
in_proj_z: ColumnParallelLinear,
|
||||
in_proj_b: ColumnParallelLinear,
|
||||
in_proj_a: ColumnParallelLinear,
|
||||
@@ -92,6 +93,7 @@ pub(crate) struct TpQwen3_5GatedDeltaNet {
|
||||
|
||||
impl TpQwen3_5GatedDeltaNet {
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
vb: &ShardedVarBuilder,
|
||||
@@ -99,8 +101,9 @@ impl TpQwen3_5GatedDeltaNet {
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
Self::load_inner(cfg, vb, mmap, rank, world_size, comm)
|
||||
Self::load_inner(cfg, vb, mmap, rank, world_size, comm, quant)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
@@ -110,10 +113,12 @@ impl TpQwen3_5GatedDeltaNet {
|
||||
mmap: &MmapedSafetensors,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
Self::load_inner(cfg, vb, mmap, rank, world_size)
|
||||
Self::load_inner(cfg, vb, mmap, rank, world_size, quant)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn load_inner(
|
||||
cfg: &TextConfig,
|
||||
vb: &ShardedVarBuilder,
|
||||
@@ -121,6 +126,7 @@ impl TpQwen3_5GatedDeltaNet {
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
#[cfg(feature = "cuda")] comm: Arc<Comm>,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
let ws = world_size as usize;
|
||||
let num_v_heads = cfg.linear_num_value_heads;
|
||||
@@ -177,7 +183,9 @@ impl TpQwen3_5GatedDeltaNet {
|
||||
dtype,
|
||||
&device,
|
||||
)?;
|
||||
let in_proj_qkv = Linear::new(in_proj_qkv_weight, None);
|
||||
let in_proj_qkv =
|
||||
super::tp_linear::MaybeQuantLinear::from_weight(in_proj_qkv_weight, quant)
|
||||
.with_context(|| format!("wrap fused in_proj_qkv for '{}'", vb.prefix()))?;
|
||||
|
||||
let conv1d_name = format!("{}.conv1d.weight", vb.prefix());
|
||||
let conv1d_weight = super::fused_load::load_fused_qkv_3d(
|
||||
@@ -195,10 +203,13 @@ impl TpQwen3_5GatedDeltaNet {
|
||||
|
||||
// ----- Uniformly-sharded projections (along output dim 0). -----
|
||||
// in_proj_z: hidden → value_dim, sharded along value_dim (V-head).
|
||||
let in_proj_z = ColumnParallelLinear::load(&vb.pp("in_proj_z"), rank, world_size)?;
|
||||
let in_proj_z =
|
||||
ColumnParallelLinear::load_with_quant(&vb.pp("in_proj_z"), rank, world_size, quant)?;
|
||||
// in_proj_b, in_proj_a: hidden → num_v_heads, sharded along output.
|
||||
let in_proj_b = ColumnParallelLinear::load(&vb.pp("in_proj_b"), rank, world_size)?;
|
||||
let in_proj_a = ColumnParallelLinear::load(&vb.pp("in_proj_a"), rank, world_size)?;
|
||||
let in_proj_b =
|
||||
ColumnParallelLinear::load_with_quant(&vb.pp("in_proj_b"), rank, world_size, quant)?;
|
||||
let in_proj_a =
|
||||
ColumnParallelLinear::load_with_quant(&vb.pp("in_proj_a"), rank, world_size, quant)?;
|
||||
|
||||
// ----- Per-V-head 1D params (sharded uniformly). -----
|
||||
let a_log = vb
|
||||
@@ -213,9 +224,11 @@ impl TpQwen3_5GatedDeltaNet {
|
||||
|
||||
// ----- Output projection: row-parallel + AllReduce. -----
|
||||
#[cfg(feature = "cuda")]
|
||||
let out_proj = RowParallelLinear::load(&vb.pp("out_proj"), rank, world_size, comm)?;
|
||||
let out_proj =
|
||||
RowParallelLinear::load_with_quant(&vb.pp("out_proj"), rank, world_size, comm, quant)?;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
let out_proj = RowParallelLinear::load(&vb.pp("out_proj"), rank, world_size)?;
|
||||
let out_proj =
|
||||
RowParallelLinear::load_with_quant(&vb.pp("out_proj"), rank, world_size, quant)?;
|
||||
|
||||
Ok(Self {
|
||||
in_proj_qkv,
|
||||
@@ -418,6 +431,7 @@ pub(crate) struct TpQwen3_5Attention {
|
||||
|
||||
impl TpQwen3_5Attention {
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
@@ -425,8 +439,9 @@ impl TpQwen3_5Attention {
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
Self::load_inner(cfg, rotary, vb, rank, world_size, comm)
|
||||
Self::load_inner(cfg, rotary, vb, rank, world_size, comm, quant)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
@@ -436,10 +451,12 @@ impl TpQwen3_5Attention {
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
Self::load_inner(cfg, rotary, vb, rank, world_size)
|
||||
Self::load_inner(cfg, rotary, vb, rank, world_size, quant)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn load_inner(
|
||||
cfg: &TextConfig,
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
@@ -447,6 +464,7 @@ impl TpQwen3_5Attention {
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
#[cfg(feature = "cuda")] comm: Arc<Comm>,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
let ws = world_size as usize;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
@@ -468,13 +486,17 @@ impl TpQwen3_5Attention {
|
||||
// consistently — rank R holds heads `[R*per_rank, (R+1)*per_rank)`
|
||||
// for both query AND gate, so the post-attention `gate.sigmoid()`
|
||||
// multiply against the per-rank attention output matches up.
|
||||
let q_proj = ColumnParallelLinear::load(&vb.pp("q_proj"), rank, world_size)?;
|
||||
let k_proj = ColumnParallelLinear::load(&vb.pp("k_proj"), rank, world_size)?;
|
||||
let v_proj = ColumnParallelLinear::load(&vb.pp("v_proj"), rank, world_size)?;
|
||||
let q_proj =
|
||||
ColumnParallelLinear::load_with_quant(&vb.pp("q_proj"), rank, world_size, quant)?;
|
||||
let k_proj =
|
||||
ColumnParallelLinear::load_with_quant(&vb.pp("k_proj"), rank, world_size, quant)?;
|
||||
let v_proj =
|
||||
ColumnParallelLinear::load_with_quant(&vb.pp("v_proj"), rank, world_size, quant)?;
|
||||
#[cfg(feature = "cuda")]
|
||||
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size, comm)?;
|
||||
let o_proj =
|
||||
RowParallelLinear::load_with_quant(&vb.pp("o_proj"), rank, world_size, comm, quant)?;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size)?;
|
||||
let o_proj = RowParallelLinear::load_with_quant(&vb.pp("o_proj"), rank, world_size, quant)?;
|
||||
|
||||
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||
@@ -572,12 +594,14 @@ pub(crate) struct TpQwen3_5MLP {
|
||||
|
||||
impl TpQwen3_5MLP {
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
||||
bail!(
|
||||
@@ -587,9 +611,25 @@ impl TpQwen3_5MLP {
|
||||
);
|
||||
}
|
||||
Ok(Self {
|
||||
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
|
||||
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
|
||||
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size, comm)?,
|
||||
gate_proj: ColumnParallelLinear::load_with_quant(
|
||||
&vb.pp("gate_proj"),
|
||||
rank,
|
||||
world_size,
|
||||
quant,
|
||||
)?,
|
||||
up_proj: ColumnParallelLinear::load_with_quant(
|
||||
&vb.pp("up_proj"),
|
||||
rank,
|
||||
world_size,
|
||||
quant,
|
||||
)?,
|
||||
down_proj: RowParallelLinear::load_with_quant(
|
||||
&vb.pp("down_proj"),
|
||||
rank,
|
||||
world_size,
|
||||
comm,
|
||||
quant,
|
||||
)?,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -599,6 +639,7 @@ impl TpQwen3_5MLP {
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
||||
bail!(
|
||||
@@ -608,9 +649,24 @@ impl TpQwen3_5MLP {
|
||||
);
|
||||
}
|
||||
Ok(Self {
|
||||
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
|
||||
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
|
||||
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size)?,
|
||||
gate_proj: ColumnParallelLinear::load_with_quant(
|
||||
&vb.pp("gate_proj"),
|
||||
rank,
|
||||
world_size,
|
||||
quant,
|
||||
)?,
|
||||
up_proj: ColumnParallelLinear::load_with_quant(
|
||||
&vb.pp("up_proj"),
|
||||
rank,
|
||||
world_size,
|
||||
quant,
|
||||
)?,
|
||||
down_proj: RowParallelLinear::load_with_quant(
|
||||
&vb.pp("down_proj"),
|
||||
rank,
|
||||
world_size,
|
||||
quant,
|
||||
)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -649,6 +705,7 @@ impl TpQwen3_5DecoderLayer {
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
let layer_type = cfg
|
||||
.layer_types
|
||||
@@ -663,6 +720,7 @@ impl TpQwen3_5DecoderLayer {
|
||||
rank,
|
||||
world_size,
|
||||
comm.clone(),
|
||||
quant,
|
||||
)?),
|
||||
"linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load(
|
||||
cfg,
|
||||
@@ -671,10 +729,11 @@ impl TpQwen3_5DecoderLayer {
|
||||
rank,
|
||||
world_size,
|
||||
comm.clone(),
|
||||
quant,
|
||||
)?),
|
||||
other => bail!("unknown layer_type '{other}' for layer {layer_idx}"),
|
||||
};
|
||||
let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm)?;
|
||||
let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm, quant)?;
|
||||
let input_layernorm =
|
||||
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||
let post_attention_layernorm = Qwen3_5RmsNorm::load(
|
||||
@@ -691,6 +750,7 @@ impl TpQwen3_5DecoderLayer {
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
@@ -699,6 +759,7 @@ impl TpQwen3_5DecoderLayer {
|
||||
mmap: &MmapedSafetensors,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
let layer_type = cfg
|
||||
.layer_types
|
||||
@@ -712,6 +773,7 @@ impl TpQwen3_5DecoderLayer {
|
||||
&vb.pp("self_attn"),
|
||||
rank,
|
||||
world_size,
|
||||
quant,
|
||||
)?),
|
||||
"linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load(
|
||||
cfg,
|
||||
@@ -719,10 +781,11 @@ impl TpQwen3_5DecoderLayer {
|
||||
mmap,
|
||||
rank,
|
||||
world_size,
|
||||
quant,
|
||||
)?),
|
||||
other => bail!("unknown layer_type '{other}' for layer {layer_idx}"),
|
||||
};
|
||||
let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size)?;
|
||||
let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size, quant)?;
|
||||
let input_layernorm =
|
||||
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||
let post_attention_layernorm = Qwen3_5RmsNorm::load(
|
||||
@@ -775,6 +838,7 @@ pub struct TpQwen3_5Model {
|
||||
|
||||
impl TpQwen3_5Model {
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
vb: &ShardedVarBuilder,
|
||||
@@ -782,6 +846,7 @@ impl TpQwen3_5Model {
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
let dtype = vb.dtype();
|
||||
let device = vb.device().clone();
|
||||
@@ -817,6 +882,7 @@ impl TpQwen3_5Model {
|
||||
rank,
|
||||
world_size,
|
||||
comm.clone(),
|
||||
quant,
|
||||
)
|
||||
.with_context(|| {
|
||||
let (free_mb, total_mb) = cuda_mem_mb(&device);
|
||||
@@ -844,6 +910,7 @@ impl TpQwen3_5Model {
|
||||
mmap: &MmapedSafetensors,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
let dtype = vb.dtype();
|
||||
let device = vb.device().clone();
|
||||
@@ -877,6 +944,7 @@ impl TpQwen3_5Model {
|
||||
mmap,
|
||||
rank,
|
||||
world_size,
|
||||
quant,
|
||||
)?);
|
||||
}
|
||||
|
||||
@@ -931,6 +999,7 @@ pub struct TpQwen3_5ForCausalLM {
|
||||
|
||||
impl TpQwen3_5ForCausalLM {
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn load(
|
||||
config: Config,
|
||||
vb: &ShardedVarBuilder,
|
||||
@@ -938,9 +1007,10 @@ impl TpQwen3_5ForCausalLM {
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
let cfg = &config.text_config;
|
||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm)?;
|
||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?;
|
||||
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||
Ok(Self { base, lm_head })
|
||||
}
|
||||
@@ -952,9 +1022,10 @@ impl TpQwen3_5ForCausalLM {
|
||||
mmap: &MmapedSafetensors,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
let cfg = &config.text_config;
|
||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size)?;
|
||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?;
|
||||
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||
Ok(Self { base, lm_head })
|
||||
}
|
||||
|
||||
@@ -160,7 +160,8 @@ impl WorkerState {
|
||||
model_id,
|
||||
config_json,
|
||||
safetensors_paths,
|
||||
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths),
|
||||
quant,
|
||||
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths, quant),
|
||||
WorkerRequest::GenerateStep {
|
||||
model_id,
|
||||
tokens,
|
||||
@@ -178,6 +179,7 @@ impl WorkerState {
|
||||
model_id: String,
|
||||
config_json: String,
|
||||
safetensors_paths: Vec<String>,
|
||||
quant: Option<String>,
|
||||
) -> WorkerResponse {
|
||||
use crate::harness::arch::qwen3_5 as qwen3_5_arch;
|
||||
use candle_core::{DType, Device};
|
||||
@@ -185,6 +187,16 @@ impl WorkerState {
|
||||
use candle_transformers::models::qwen3 as qwen3_dense;
|
||||
use std::path::PathBuf;
|
||||
|
||||
let quant_dtype = match parse_quant_string(quant.as_deref()) {
|
||||
Ok(q) => q,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "bad_request".into(),
|
||||
message: format!("parse quant: {e}"),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
if self.models.contains_key(&model_id) {
|
||||
return WorkerResponse::Error {
|
||||
kind: "already_loaded".into(),
|
||||
@@ -290,6 +302,7 @@ impl WorkerState {
|
||||
self.config.rank,
|
||||
self.config.world_size,
|
||||
comm,
|
||||
quant_dtype,
|
||||
) {
|
||||
Ok(m) => WorkerModel::Qwen3_5(m),
|
||||
Err(e) => {
|
||||
@@ -326,6 +339,7 @@ impl WorkerState {
|
||||
_model_id: String,
|
||||
_config_json: String,
|
||||
_safetensors_paths: Vec<String>,
|
||||
_quant: Option<String>,
|
||||
) -> WorkerResponse {
|
||||
WorkerResponse::Error {
|
||||
kind: "cuda_feature_not_enabled".into(),
|
||||
@@ -444,3 +458,45 @@ impl WorkerState {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a `ModelSpec.quant` string into a `GgmlDType`. Accepts the
|
||||
/// common ggml format names (case-insensitive). `None` and `Some("")`
|
||||
/// both map to "no quantization".
|
||||
///
|
||||
/// Supported: `q4_0`, `q4_1`, `q5_0`, `q5_1`, `q8_0`, `q8_1`,
|
||||
/// `q2k`/`q2_k`, `q3k`/`q3_k`, `q4k`/`q4_k`, `q5k`/`q5_k`,
|
||||
/// `q6k`/`q6_k`, `q8k`/`q8_k`, `f16`, `bf16`, `f32`. The underscore
|
||||
/// is optional and the prefix is case-insensitive.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub(crate) fn parse_quant_string(
|
||||
s: Option<&str>,
|
||||
) -> anyhow::Result<Option<candle_core::quantized::GgmlDType>> {
|
||||
use candle_core::quantized::GgmlDType;
|
||||
let s = match s {
|
||||
Some(s) if !s.is_empty() => s,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let normalised = s.to_ascii_lowercase().replace('_', "");
|
||||
let dtype = match normalised.as_str() {
|
||||
"q40" => GgmlDType::Q4_0,
|
||||
"q41" => GgmlDType::Q4_1,
|
||||
"q50" => GgmlDType::Q5_0,
|
||||
"q51" => GgmlDType::Q5_1,
|
||||
"q80" => GgmlDType::Q8_0,
|
||||
"q81" => GgmlDType::Q8_1,
|
||||
"q2k" => GgmlDType::Q2K,
|
||||
"q3k" => GgmlDType::Q3K,
|
||||
"q4k" | "q4km" => GgmlDType::Q4K,
|
||||
"q5k" | "q5km" => GgmlDType::Q5K,
|
||||
"q6k" => GgmlDType::Q6K,
|
||||
"q8k" => GgmlDType::Q8K,
|
||||
"f16" => GgmlDType::F16,
|
||||
"bf16" => GgmlDType::BF16,
|
||||
"f32" => GgmlDType::F32,
|
||||
other => anyhow::bail!(
|
||||
"unknown quant '{other}' (expected one of: q4_0, q4_1, q5_0, q5_1, q8_0, \
|
||||
q8_1, q2k, q3k, q4k, q5k, q6k, q8k, f16, bf16, f32)"
|
||||
),
|
||||
};
|
||||
Ok(Some(dtype))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user