feat(stage-8e-2): plumb quant config from ModelSpec to TP load path
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m7s
CI / Test (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m47s
build-prerelease / Build neuron-ampere (push) Successful in 5m17s
build-prerelease / Build neuron-ada (push) Successful in 5m14s
build-prerelease / Build cortex binary (push) Successful in 18m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m7s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m7s
CI / Test (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m47s
build-prerelease / Build neuron-ampere (push) Successful in 5m17s
build-prerelease / Build neuron-ada (push) Successful in 5m14s
build-prerelease / Build cortex binary (push) Successful in 18m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m7s
- LoadDenseShard RPC gains an optional `quant` string field. - WorkerPool::load_dense_shard takes a `quant: Option<String>`, passes it via the RPC to workers and via parse_quant_string to the leader's local load. - The Qwen3-Next TP load chain (ForCausalLM → Model → DecoderLayer → Attention / GatedDeltaNet / MLP) takes `quant: Option<GgmlDType>` end-to-end, calling Column/RowParallelLinear::load_with_quant. - The fused in_proj_qkv inside TpQwen3_5GatedDeltaNet is now a MaybeQuantLinear so it also picks up quantization. - parse_quant_string accepts q4_0/q4_1/q5_0/q5_1/q8_0/q8_1, q2k..q8k (with or without underscore), and f16/bf16/f32. Empty / None means no quantization. Callers from candle.rs forward spec.quant through pool.load_dense_shard. This means a `quant = "q5k"` in models.toml now flows end-to-end to a QTensor-backed QMatMul for every per-rank linear in the Qwen3-Next TP path. Leaves lm_head and the small replicated bias/log tensors in their loaded dtype (Stage 8e-3). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1141,6 +1141,7 @@ impl CandleHarness {
|
|||||||
&safetensors_paths,
|
&safetensors_paths,
|
||||||
&leader_device,
|
&leader_device,
|
||||||
candle_core::DType::BF16,
|
candle_core::DType::BF16,
|
||||||
|
spec.quant.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|||||||
@@ -474,6 +474,7 @@ impl WorkerPool {
|
|||||||
/// `init_nccl` must have completed first. Bails if the leader's
|
/// `init_nccl` must have completed first. Bails if the leader's
|
||||||
/// NCCL comm isn't set up yet.
|
/// NCCL comm isn't set up yet.
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn load_dense_shard(
|
pub async fn load_dense_shard(
|
||||||
&mut self,
|
&mut self,
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
@@ -481,6 +482,7 @@ impl WorkerPool {
|
|||||||
safetensors_paths: &[std::path::PathBuf],
|
safetensors_paths: &[std::path::PathBuf],
|
||||||
leader_device: &candle_core::Device,
|
leader_device: &candle_core::Device,
|
||||||
dtype: candle_core::DType,
|
dtype: candle_core::DType,
|
||||||
|
quant: Option<String>,
|
||||||
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> {
|
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> {
|
||||||
use candle_nn::var_builder::ShardedSafeTensors;
|
use candle_nn::var_builder::ShardedSafeTensors;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -510,6 +512,7 @@ impl WorkerPool {
|
|||||||
model_id: model_id.to_string(),
|
model_id: model_id.to_string(),
|
||||||
config_json: config_json.to_string(),
|
config_json: config_json.to_string(),
|
||||||
safetensors_paths: safetensors_str.clone(),
|
safetensors_paths: safetensors_str.clone(),
|
||||||
|
quant: quant.clone(),
|
||||||
})
|
})
|
||||||
.await?;
|
.await?;
|
||||||
}
|
}
|
||||||
@@ -531,6 +534,7 @@ impl WorkerPool {
|
|||||||
let comm_for_leader = leader_comm;
|
let comm_for_leader = leader_comm;
|
||||||
let model_id_for_log = model_id.to_string();
|
let model_id_for_log = model_id.to_string();
|
||||||
let config_json_for_leader = config_json.to_string();
|
let config_json_for_leader = config_json.to_string();
|
||||||
|
let quant_for_leader = quant.clone();
|
||||||
|
|
||||||
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
|
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
|
||||||
// SAFETY: same invariant as the single-GPU dense path —
|
// SAFETY: same invariant as the single-GPU dense path —
|
||||||
@@ -558,8 +562,16 @@ impl WorkerPool {
|
|||||||
let cfg: super::tp::tp_qwen3_5::Config =
|
let cfg: super::tp::tp_qwen3_5::Config =
|
||||||
serde_json::from_str(&config_json_for_leader)
|
serde_json::from_str(&config_json_for_leader)
|
||||||
.context("parse Qwen3-Next Config JSON for leader load")?;
|
.context("parse Qwen3-Next Config JSON for leader load")?;
|
||||||
|
let quant_dtype =
|
||||||
|
super::tp::worker::parse_quant_string(quant_for_leader.as_deref())?;
|
||||||
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
|
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
|
||||||
cfg, &vb, &mmap, 0, world_size, comm,
|
cfg,
|
||||||
|
&vb,
|
||||||
|
&mmap,
|
||||||
|
0,
|
||||||
|
world_size,
|
||||||
|
comm,
|
||||||
|
quant_dtype,
|
||||||
)?)
|
)?)
|
||||||
}
|
}
|
||||||
other => anyhow::bail!(
|
other => anyhow::bail!(
|
||||||
|
|||||||
@@ -63,6 +63,13 @@ pub enum WorkerRequest {
|
|||||||
/// Absolute paths the worker should mmap. The same set on every
|
/// Absolute paths the worker should mmap. The same set on every
|
||||||
/// rank; ShardedVarBuilder slices into them per rank.
|
/// rank; ShardedVarBuilder slices into them per rank.
|
||||||
safetensors_paths: Vec<String>,
|
safetensors_paths: Vec<String>,
|
||||||
|
/// Optional in-situ quantization dtype (e.g. "q5k", "q8_0",
|
||||||
|
/// "q6k"). When set, each linear-layer weight is quantized
|
||||||
|
/// at load time to the named ggml format — saves ~3-5x vs
|
||||||
|
/// bf16/f16 at the cost of some accuracy. `None` keeps the
|
||||||
|
/// weights in the on-disk dtype (typically bf16).
|
||||||
|
#[serde(default)]
|
||||||
|
quant: Option<String>,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Run one forward step on this rank's loaded model. The worker
|
/// Run one forward step on this rank's loaded model. The worker
|
||||||
|
|||||||
@@ -31,6 +31,7 @@
|
|||||||
//! linear-attention block, lm_head, the rotary table.
|
//! linear-attention block, lm_head, the rotary table.
|
||||||
|
|
||||||
use anyhow::{Context, Result, bail};
|
use anyhow::{Context, Result, bail};
|
||||||
|
use candle_core::quantized::GgmlDType;
|
||||||
use candle_core::safetensors::MmapedSafetensors;
|
use candle_core::safetensors::MmapedSafetensors;
|
||||||
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
||||||
use candle_nn::var_builder::ShardedVarBuilder;
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
@@ -59,7 +60,7 @@ pub struct TpGatedDeltaNetState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct TpQwen3_5GatedDeltaNet {
|
pub(crate) struct TpQwen3_5GatedDeltaNet {
|
||||||
in_proj_qkv: Linear,
|
in_proj_qkv: super::tp_linear::MaybeQuantLinear,
|
||||||
in_proj_z: ColumnParallelLinear,
|
in_proj_z: ColumnParallelLinear,
|
||||||
in_proj_b: ColumnParallelLinear,
|
in_proj_b: ColumnParallelLinear,
|
||||||
in_proj_a: ColumnParallelLinear,
|
in_proj_a: ColumnParallelLinear,
|
||||||
@@ -92,6 +93,7 @@ pub(crate) struct TpQwen3_5GatedDeltaNet {
|
|||||||
|
|
||||||
impl TpQwen3_5GatedDeltaNet {
|
impl TpQwen3_5GatedDeltaNet {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn load(
|
pub fn load(
|
||||||
cfg: &TextConfig,
|
cfg: &TextConfig,
|
||||||
vb: &ShardedVarBuilder,
|
vb: &ShardedVarBuilder,
|
||||||
@@ -99,8 +101,9 @@ impl TpQwen3_5GatedDeltaNet {
|
|||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
comm: Arc<Comm>,
|
comm: Arc<Comm>,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Self::load_inner(cfg, vb, mmap, rank, world_size, comm)
|
Self::load_inner(cfg, vb, mmap, rank, world_size, comm, quant)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
@@ -110,10 +113,12 @@ impl TpQwen3_5GatedDeltaNet {
|
|||||||
mmap: &MmapedSafetensors,
|
mmap: &MmapedSafetensors,
|
||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Self::load_inner(cfg, vb, mmap, rank, world_size)
|
Self::load_inner(cfg, vb, mmap, rank, world_size, quant)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn load_inner(
|
fn load_inner(
|
||||||
cfg: &TextConfig,
|
cfg: &TextConfig,
|
||||||
vb: &ShardedVarBuilder,
|
vb: &ShardedVarBuilder,
|
||||||
@@ -121,6 +126,7 @@ impl TpQwen3_5GatedDeltaNet {
|
|||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
#[cfg(feature = "cuda")] comm: Arc<Comm>,
|
#[cfg(feature = "cuda")] comm: Arc<Comm>,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let ws = world_size as usize;
|
let ws = world_size as usize;
|
||||||
let num_v_heads = cfg.linear_num_value_heads;
|
let num_v_heads = cfg.linear_num_value_heads;
|
||||||
@@ -177,7 +183,9 @@ impl TpQwen3_5GatedDeltaNet {
|
|||||||
dtype,
|
dtype,
|
||||||
&device,
|
&device,
|
||||||
)?;
|
)?;
|
||||||
let in_proj_qkv = Linear::new(in_proj_qkv_weight, None);
|
let in_proj_qkv =
|
||||||
|
super::tp_linear::MaybeQuantLinear::from_weight(in_proj_qkv_weight, quant)
|
||||||
|
.with_context(|| format!("wrap fused in_proj_qkv for '{}'", vb.prefix()))?;
|
||||||
|
|
||||||
let conv1d_name = format!("{}.conv1d.weight", vb.prefix());
|
let conv1d_name = format!("{}.conv1d.weight", vb.prefix());
|
||||||
let conv1d_weight = super::fused_load::load_fused_qkv_3d(
|
let conv1d_weight = super::fused_load::load_fused_qkv_3d(
|
||||||
@@ -195,10 +203,13 @@ impl TpQwen3_5GatedDeltaNet {
|
|||||||
|
|
||||||
// ----- Uniformly-sharded projections (along output dim 0). -----
|
// ----- Uniformly-sharded projections (along output dim 0). -----
|
||||||
// in_proj_z: hidden → value_dim, sharded along value_dim (V-head).
|
// in_proj_z: hidden → value_dim, sharded along value_dim (V-head).
|
||||||
let in_proj_z = ColumnParallelLinear::load(&vb.pp("in_proj_z"), rank, world_size)?;
|
let in_proj_z =
|
||||||
|
ColumnParallelLinear::load_with_quant(&vb.pp("in_proj_z"), rank, world_size, quant)?;
|
||||||
// in_proj_b, in_proj_a: hidden → num_v_heads, sharded along output.
|
// in_proj_b, in_proj_a: hidden → num_v_heads, sharded along output.
|
||||||
let in_proj_b = ColumnParallelLinear::load(&vb.pp("in_proj_b"), rank, world_size)?;
|
let in_proj_b =
|
||||||
let in_proj_a = ColumnParallelLinear::load(&vb.pp("in_proj_a"), rank, world_size)?;
|
ColumnParallelLinear::load_with_quant(&vb.pp("in_proj_b"), rank, world_size, quant)?;
|
||||||
|
let in_proj_a =
|
||||||
|
ColumnParallelLinear::load_with_quant(&vb.pp("in_proj_a"), rank, world_size, quant)?;
|
||||||
|
|
||||||
// ----- Per-V-head 1D params (sharded uniformly). -----
|
// ----- Per-V-head 1D params (sharded uniformly). -----
|
||||||
let a_log = vb
|
let a_log = vb
|
||||||
@@ -213,9 +224,11 @@ impl TpQwen3_5GatedDeltaNet {
|
|||||||
|
|
||||||
// ----- Output projection: row-parallel + AllReduce. -----
|
// ----- Output projection: row-parallel + AllReduce. -----
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
let out_proj = RowParallelLinear::load(&vb.pp("out_proj"), rank, world_size, comm)?;
|
let out_proj =
|
||||||
|
RowParallelLinear::load_with_quant(&vb.pp("out_proj"), rank, world_size, comm, quant)?;
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
let out_proj = RowParallelLinear::load(&vb.pp("out_proj"), rank, world_size)?;
|
let out_proj =
|
||||||
|
RowParallelLinear::load_with_quant(&vb.pp("out_proj"), rank, world_size, quant)?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
in_proj_qkv,
|
in_proj_qkv,
|
||||||
@@ -418,6 +431,7 @@ pub(crate) struct TpQwen3_5Attention {
|
|||||||
|
|
||||||
impl TpQwen3_5Attention {
|
impl TpQwen3_5Attention {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn load(
|
pub fn load(
|
||||||
cfg: &TextConfig,
|
cfg: &TextConfig,
|
||||||
rotary: Arc<RotaryEmbedding>,
|
rotary: Arc<RotaryEmbedding>,
|
||||||
@@ -425,8 +439,9 @@ impl TpQwen3_5Attention {
|
|||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
comm: Arc<Comm>,
|
comm: Arc<Comm>,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Self::load_inner(cfg, rotary, vb, rank, world_size, comm)
|
Self::load_inner(cfg, rotary, vb, rank, world_size, comm, quant)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
@@ -436,10 +451,12 @@ impl TpQwen3_5Attention {
|
|||||||
vb: &ShardedVarBuilder,
|
vb: &ShardedVarBuilder,
|
||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Self::load_inner(cfg, rotary, vb, rank, world_size)
|
Self::load_inner(cfg, rotary, vb, rank, world_size, quant)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn load_inner(
|
fn load_inner(
|
||||||
cfg: &TextConfig,
|
cfg: &TextConfig,
|
||||||
rotary: Arc<RotaryEmbedding>,
|
rotary: Arc<RotaryEmbedding>,
|
||||||
@@ -447,6 +464,7 @@ impl TpQwen3_5Attention {
|
|||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
#[cfg(feature = "cuda")] comm: Arc<Comm>,
|
#[cfg(feature = "cuda")] comm: Arc<Comm>,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let ws = world_size as usize;
|
let ws = world_size as usize;
|
||||||
let num_heads = cfg.num_attention_heads;
|
let num_heads = cfg.num_attention_heads;
|
||||||
@@ -468,13 +486,17 @@ impl TpQwen3_5Attention {
|
|||||||
// consistently — rank R holds heads `[R*per_rank, (R+1)*per_rank)`
|
// consistently — rank R holds heads `[R*per_rank, (R+1)*per_rank)`
|
||||||
// for both query AND gate, so the post-attention `gate.sigmoid()`
|
// for both query AND gate, so the post-attention `gate.sigmoid()`
|
||||||
// multiply against the per-rank attention output matches up.
|
// multiply against the per-rank attention output matches up.
|
||||||
let q_proj = ColumnParallelLinear::load(&vb.pp("q_proj"), rank, world_size)?;
|
let q_proj =
|
||||||
let k_proj = ColumnParallelLinear::load(&vb.pp("k_proj"), rank, world_size)?;
|
ColumnParallelLinear::load_with_quant(&vb.pp("q_proj"), rank, world_size, quant)?;
|
||||||
let v_proj = ColumnParallelLinear::load(&vb.pp("v_proj"), rank, world_size)?;
|
let k_proj =
|
||||||
|
ColumnParallelLinear::load_with_quant(&vb.pp("k_proj"), rank, world_size, quant)?;
|
||||||
|
let v_proj =
|
||||||
|
ColumnParallelLinear::load_with_quant(&vb.pp("v_proj"), rank, world_size, quant)?;
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size, comm)?;
|
let o_proj =
|
||||||
|
RowParallelLinear::load_with_quant(&vb.pp("o_proj"), rank, world_size, comm, quant)?;
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size)?;
|
let o_proj = RowParallelLinear::load_with_quant(&vb.pp("o_proj"), rank, world_size, quant)?;
|
||||||
|
|
||||||
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
|
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
|
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
@@ -572,12 +594,14 @@ pub(crate) struct TpQwen3_5MLP {
|
|||||||
|
|
||||||
impl TpQwen3_5MLP {
|
impl TpQwen3_5MLP {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn load(
|
pub fn load(
|
||||||
cfg: &TextConfig,
|
cfg: &TextConfig,
|
||||||
vb: &ShardedVarBuilder,
|
vb: &ShardedVarBuilder,
|
||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
comm: Arc<Comm>,
|
comm: Arc<Comm>,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
||||||
bail!(
|
bail!(
|
||||||
@@ -587,9 +611,25 @@ impl TpQwen3_5MLP {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
|
gate_proj: ColumnParallelLinear::load_with_quant(
|
||||||
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
|
&vb.pp("gate_proj"),
|
||||||
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size, comm)?,
|
rank,
|
||||||
|
world_size,
|
||||||
|
quant,
|
||||||
|
)?,
|
||||||
|
up_proj: ColumnParallelLinear::load_with_quant(
|
||||||
|
&vb.pp("up_proj"),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
quant,
|
||||||
|
)?,
|
||||||
|
down_proj: RowParallelLinear::load_with_quant(
|
||||||
|
&vb.pp("down_proj"),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
comm,
|
||||||
|
quant,
|
||||||
|
)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -599,6 +639,7 @@ impl TpQwen3_5MLP {
|
|||||||
vb: &ShardedVarBuilder,
|
vb: &ShardedVarBuilder,
|
||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
||||||
bail!(
|
bail!(
|
||||||
@@ -608,9 +649,24 @@ impl TpQwen3_5MLP {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
|
gate_proj: ColumnParallelLinear::load_with_quant(
|
||||||
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
|
&vb.pp("gate_proj"),
|
||||||
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size)?,
|
rank,
|
||||||
|
world_size,
|
||||||
|
quant,
|
||||||
|
)?,
|
||||||
|
up_proj: ColumnParallelLinear::load_with_quant(
|
||||||
|
&vb.pp("up_proj"),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
quant,
|
||||||
|
)?,
|
||||||
|
down_proj: RowParallelLinear::load_with_quant(
|
||||||
|
&vb.pp("down_proj"),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
quant,
|
||||||
|
)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -649,6 +705,7 @@ impl TpQwen3_5DecoderLayer {
|
|||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
comm: Arc<Comm>,
|
comm: Arc<Comm>,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let layer_type = cfg
|
let layer_type = cfg
|
||||||
.layer_types
|
.layer_types
|
||||||
@@ -663,6 +720,7 @@ impl TpQwen3_5DecoderLayer {
|
|||||||
rank,
|
rank,
|
||||||
world_size,
|
world_size,
|
||||||
comm.clone(),
|
comm.clone(),
|
||||||
|
quant,
|
||||||
)?),
|
)?),
|
||||||
"linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load(
|
"linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load(
|
||||||
cfg,
|
cfg,
|
||||||
@@ -671,10 +729,11 @@ impl TpQwen3_5DecoderLayer {
|
|||||||
rank,
|
rank,
|
||||||
world_size,
|
world_size,
|
||||||
comm.clone(),
|
comm.clone(),
|
||||||
|
quant,
|
||||||
)?),
|
)?),
|
||||||
other => bail!("unknown layer_type '{other}' for layer {layer_idx}"),
|
other => bail!("unknown layer_type '{other}' for layer {layer_idx}"),
|
||||||
};
|
};
|
||||||
let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm)?;
|
let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm, quant)?;
|
||||||
let input_layernorm =
|
let input_layernorm =
|
||||||
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
let post_attention_layernorm = Qwen3_5RmsNorm::load(
|
let post_attention_layernorm = Qwen3_5RmsNorm::load(
|
||||||
@@ -691,6 +750,7 @@ impl TpQwen3_5DecoderLayer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn load(
|
pub fn load(
|
||||||
cfg: &TextConfig,
|
cfg: &TextConfig,
|
||||||
rotary: Arc<RotaryEmbedding>,
|
rotary: Arc<RotaryEmbedding>,
|
||||||
@@ -699,6 +759,7 @@ impl TpQwen3_5DecoderLayer {
|
|||||||
mmap: &MmapedSafetensors,
|
mmap: &MmapedSafetensors,
|
||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let layer_type = cfg
|
let layer_type = cfg
|
||||||
.layer_types
|
.layer_types
|
||||||
@@ -712,6 +773,7 @@ impl TpQwen3_5DecoderLayer {
|
|||||||
&vb.pp("self_attn"),
|
&vb.pp("self_attn"),
|
||||||
rank,
|
rank,
|
||||||
world_size,
|
world_size,
|
||||||
|
quant,
|
||||||
)?),
|
)?),
|
||||||
"linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load(
|
"linear_attention" => TpAttentionKind::Linear(TpQwen3_5GatedDeltaNet::load(
|
||||||
cfg,
|
cfg,
|
||||||
@@ -719,10 +781,11 @@ impl TpQwen3_5DecoderLayer {
|
|||||||
mmap,
|
mmap,
|
||||||
rank,
|
rank,
|
||||||
world_size,
|
world_size,
|
||||||
|
quant,
|
||||||
)?),
|
)?),
|
||||||
other => bail!("unknown layer_type '{other}' for layer {layer_idx}"),
|
other => bail!("unknown layer_type '{other}' for layer {layer_idx}"),
|
||||||
};
|
};
|
||||||
let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size)?;
|
let mlp = TpQwen3_5MLP::load(cfg, &vb.pp("mlp"), rank, world_size, quant)?;
|
||||||
let input_layernorm =
|
let input_layernorm =
|
||||||
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
let post_attention_layernorm = Qwen3_5RmsNorm::load(
|
let post_attention_layernorm = Qwen3_5RmsNorm::load(
|
||||||
@@ -775,6 +838,7 @@ pub struct TpQwen3_5Model {
|
|||||||
|
|
||||||
impl TpQwen3_5Model {
|
impl TpQwen3_5Model {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn load(
|
pub fn load(
|
||||||
cfg: &TextConfig,
|
cfg: &TextConfig,
|
||||||
vb: &ShardedVarBuilder,
|
vb: &ShardedVarBuilder,
|
||||||
@@ -782,6 +846,7 @@ impl TpQwen3_5Model {
|
|||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
comm: Arc<Comm>,
|
comm: Arc<Comm>,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let dtype = vb.dtype();
|
let dtype = vb.dtype();
|
||||||
let device = vb.device().clone();
|
let device = vb.device().clone();
|
||||||
@@ -817,6 +882,7 @@ impl TpQwen3_5Model {
|
|||||||
rank,
|
rank,
|
||||||
world_size,
|
world_size,
|
||||||
comm.clone(),
|
comm.clone(),
|
||||||
|
quant,
|
||||||
)
|
)
|
||||||
.with_context(|| {
|
.with_context(|| {
|
||||||
let (free_mb, total_mb) = cuda_mem_mb(&device);
|
let (free_mb, total_mb) = cuda_mem_mb(&device);
|
||||||
@@ -844,6 +910,7 @@ impl TpQwen3_5Model {
|
|||||||
mmap: &MmapedSafetensors,
|
mmap: &MmapedSafetensors,
|
||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let dtype = vb.dtype();
|
let dtype = vb.dtype();
|
||||||
let device = vb.device().clone();
|
let device = vb.device().clone();
|
||||||
@@ -877,6 +944,7 @@ impl TpQwen3_5Model {
|
|||||||
mmap,
|
mmap,
|
||||||
rank,
|
rank,
|
||||||
world_size,
|
world_size,
|
||||||
|
quant,
|
||||||
)?);
|
)?);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -931,6 +999,7 @@ pub struct TpQwen3_5ForCausalLM {
|
|||||||
|
|
||||||
impl TpQwen3_5ForCausalLM {
|
impl TpQwen3_5ForCausalLM {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn load(
|
pub fn load(
|
||||||
config: Config,
|
config: Config,
|
||||||
vb: &ShardedVarBuilder,
|
vb: &ShardedVarBuilder,
|
||||||
@@ -938,9 +1007,10 @@ impl TpQwen3_5ForCausalLM {
|
|||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
comm: Arc<Comm>,
|
comm: Arc<Comm>,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let cfg = &config.text_config;
|
let cfg = &config.text_config;
|
||||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm)?;
|
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?;
|
||||||
let lm_head = build_lm_head(cfg, vb, &base)?;
|
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||||
Ok(Self { base, lm_head })
|
Ok(Self { base, lm_head })
|
||||||
}
|
}
|
||||||
@@ -952,9 +1022,10 @@ impl TpQwen3_5ForCausalLM {
|
|||||||
mmap: &MmapedSafetensors,
|
mmap: &MmapedSafetensors,
|
||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let cfg = &config.text_config;
|
let cfg = &config.text_config;
|
||||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size)?;
|
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?;
|
||||||
let lm_head = build_lm_head(cfg, vb, &base)?;
|
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||||
Ok(Self { base, lm_head })
|
Ok(Self { base, lm_head })
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -160,7 +160,8 @@ impl WorkerState {
|
|||||||
model_id,
|
model_id,
|
||||||
config_json,
|
config_json,
|
||||||
safetensors_paths,
|
safetensors_paths,
|
||||||
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths),
|
quant,
|
||||||
|
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths, quant),
|
||||||
WorkerRequest::GenerateStep {
|
WorkerRequest::GenerateStep {
|
||||||
model_id,
|
model_id,
|
||||||
tokens,
|
tokens,
|
||||||
@@ -178,6 +179,7 @@ impl WorkerState {
|
|||||||
model_id: String,
|
model_id: String,
|
||||||
config_json: String,
|
config_json: String,
|
||||||
safetensors_paths: Vec<String>,
|
safetensors_paths: Vec<String>,
|
||||||
|
quant: Option<String>,
|
||||||
) -> WorkerResponse {
|
) -> WorkerResponse {
|
||||||
use crate::harness::arch::qwen3_5 as qwen3_5_arch;
|
use crate::harness::arch::qwen3_5 as qwen3_5_arch;
|
||||||
use candle_core::{DType, Device};
|
use candle_core::{DType, Device};
|
||||||
@@ -185,6 +187,16 @@ impl WorkerState {
|
|||||||
use candle_transformers::models::qwen3 as qwen3_dense;
|
use candle_transformers::models::qwen3 as qwen3_dense;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
let quant_dtype = match parse_quant_string(quant.as_deref()) {
|
||||||
|
Ok(q) => q,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("parse quant: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if self.models.contains_key(&model_id) {
|
if self.models.contains_key(&model_id) {
|
||||||
return WorkerResponse::Error {
|
return WorkerResponse::Error {
|
||||||
kind: "already_loaded".into(),
|
kind: "already_loaded".into(),
|
||||||
@@ -290,6 +302,7 @@ impl WorkerState {
|
|||||||
self.config.rank,
|
self.config.rank,
|
||||||
self.config.world_size,
|
self.config.world_size,
|
||||||
comm,
|
comm,
|
||||||
|
quant_dtype,
|
||||||
) {
|
) {
|
||||||
Ok(m) => WorkerModel::Qwen3_5(m),
|
Ok(m) => WorkerModel::Qwen3_5(m),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -326,6 +339,7 @@ impl WorkerState {
|
|||||||
_model_id: String,
|
_model_id: String,
|
||||||
_config_json: String,
|
_config_json: String,
|
||||||
_safetensors_paths: Vec<String>,
|
_safetensors_paths: Vec<String>,
|
||||||
|
_quant: Option<String>,
|
||||||
) -> WorkerResponse {
|
) -> WorkerResponse {
|
||||||
WorkerResponse::Error {
|
WorkerResponse::Error {
|
||||||
kind: "cuda_feature_not_enabled".into(),
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
@@ -444,3 +458,45 @@ impl WorkerState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parse a `ModelSpec.quant` string into a `GgmlDType`. Accepts the
|
||||||
|
/// common ggml format names (case-insensitive). `None` and `Some("")`
|
||||||
|
/// both map to "no quantization".
|
||||||
|
///
|
||||||
|
/// Supported: `q4_0`, `q4_1`, `q5_0`, `q5_1`, `q8_0`, `q8_1`,
|
||||||
|
/// `q2k`/`q2_k`, `q3k`/`q3_k`, `q4k`/`q4_k`, `q5k`/`q5_k`,
|
||||||
|
/// `q6k`/`q6_k`, `q8k`/`q8_k`, `f16`, `bf16`, `f32`. The underscore
|
||||||
|
/// is optional and the prefix is case-insensitive.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub(crate) fn parse_quant_string(
|
||||||
|
s: Option<&str>,
|
||||||
|
) -> anyhow::Result<Option<candle_core::quantized::GgmlDType>> {
|
||||||
|
use candle_core::quantized::GgmlDType;
|
||||||
|
let s = match s {
|
||||||
|
Some(s) if !s.is_empty() => s,
|
||||||
|
_ => return Ok(None),
|
||||||
|
};
|
||||||
|
let normalised = s.to_ascii_lowercase().replace('_', "");
|
||||||
|
let dtype = match normalised.as_str() {
|
||||||
|
"q40" => GgmlDType::Q4_0,
|
||||||
|
"q41" => GgmlDType::Q4_1,
|
||||||
|
"q50" => GgmlDType::Q5_0,
|
||||||
|
"q51" => GgmlDType::Q5_1,
|
||||||
|
"q80" => GgmlDType::Q8_0,
|
||||||
|
"q81" => GgmlDType::Q8_1,
|
||||||
|
"q2k" => GgmlDType::Q2K,
|
||||||
|
"q3k" => GgmlDType::Q3K,
|
||||||
|
"q4k" | "q4km" => GgmlDType::Q4K,
|
||||||
|
"q5k" | "q5km" => GgmlDType::Q5K,
|
||||||
|
"q6k" => GgmlDType::Q6K,
|
||||||
|
"q8k" => GgmlDType::Q8K,
|
||||||
|
"f16" => GgmlDType::F16,
|
||||||
|
"bf16" => GgmlDType::BF16,
|
||||||
|
"f32" => GgmlDType::F32,
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"unknown quant '{other}' (expected one of: q4_0, q4_1, q5_0, q5_1, q8_0, \
|
||||||
|
q8_1, q2k, q3k, q4k, q5k, q6k, q8k, f16, bf16, f32)"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
Ok(Some(dtype))
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user