feat(stage-8e-2): plumb quant config from ModelSpec to TP load path
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m7s
CI / Test (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m47s
build-prerelease / Build neuron-ampere (push) Successful in 5m17s
build-prerelease / Build neuron-ada (push) Successful in 5m14s
build-prerelease / Build cortex binary (push) Successful in 18m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m7s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m7s
CI / Test (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m47s
build-prerelease / Build neuron-ampere (push) Successful in 5m17s
build-prerelease / Build neuron-ada (push) Successful in 5m14s
build-prerelease / Build cortex binary (push) Successful in 18m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m7s
- LoadDenseShard RPC gains an optional `quant` string field. - WorkerPool::load_dense_shard takes a `quant: Option<String>`, passes it via the RPC to workers and via parse_quant_string to the leader's local load. - The Qwen3-Next TP load chain (ForCausalLM → Model → DecoderLayer → Attention / GatedDeltaNet / MLP) takes `quant: Option<GgmlDType>` end-to-end, calling Column/RowParallelLinear::load_with_quant. - The fused in_proj_qkv inside TpQwen3_5GatedDeltaNet is now a MaybeQuantLinear so it also picks up quantization. - parse_quant_string accepts q4_0/q4_1/q5_0/q5_1/q8_0/q8_1, q2k..q8k (with or without underscore), and f16/bf16/f32. Empty / None means no quantization. Callers from candle.rs forward spec.quant through pool.load_dense_shard. This means a `quant = "q5k"` in models.toml now flows end-to-end to a QTensor-backed QMatMul for every per-rank linear in the Qwen3-Next TP path. Leaves lm_head and the small replicated bias/log tensors in their loaded dtype (Stage 8e-3). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -474,6 +474,7 @@ impl WorkerPool {
|
||||
/// `init_nccl` must have completed first. Bails if the leader's
|
||||
/// NCCL comm isn't set up yet.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn load_dense_shard(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
@@ -481,6 +482,7 @@ impl WorkerPool {
|
||||
safetensors_paths: &[std::path::PathBuf],
|
||||
leader_device: &candle_core::Device,
|
||||
dtype: candle_core::DType,
|
||||
quant: Option<String>,
|
||||
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> {
|
||||
use candle_nn::var_builder::ShardedSafeTensors;
|
||||
use std::sync::Arc;
|
||||
@@ -510,6 +512,7 @@ impl WorkerPool {
|
||||
model_id: model_id.to_string(),
|
||||
config_json: config_json.to_string(),
|
||||
safetensors_paths: safetensors_str.clone(),
|
||||
quant: quant.clone(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
@@ -531,6 +534,7 @@ impl WorkerPool {
|
||||
let comm_for_leader = leader_comm;
|
||||
let model_id_for_log = model_id.to_string();
|
||||
let config_json_for_leader = config_json.to_string();
|
||||
let quant_for_leader = quant.clone();
|
||||
|
||||
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
|
||||
// SAFETY: same invariant as the single-GPU dense path —
|
||||
@@ -558,8 +562,16 @@ impl WorkerPool {
|
||||
let cfg: super::tp::tp_qwen3_5::Config =
|
||||
serde_json::from_str(&config_json_for_leader)
|
||||
.context("parse Qwen3-Next Config JSON for leader load")?;
|
||||
let quant_dtype =
|
||||
super::tp::worker::parse_quant_string(quant_for_leader.as_deref())?;
|
||||
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
|
||||
cfg, &vb, &mmap, 0, world_size, comm,
|
||||
cfg,
|
||||
&vb,
|
||||
&mmap,
|
||||
0,
|
||||
world_size,
|
||||
comm,
|
||||
quant_dtype,
|
||||
)?)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
|
||||
Reference in New Issue
Block a user