feat(stage-8e-2): plumb quant config from ModelSpec to TP load path
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m7s
CI / Test (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m47s
build-prerelease / Build neuron-ampere (push) Successful in 5m17s
build-prerelease / Build neuron-ada (push) Successful in 5m14s
build-prerelease / Build cortex binary (push) Successful in 18m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m7s

- LoadDenseShard RPC gains an optional `quant` string field.
- WorkerPool::load_dense_shard takes a `quant: Option<String>`,
  passes it via the RPC to workers and via parse_quant_string to
  the leader's local load.
- The Qwen3-Next TP load chain (ForCausalLM → Model → DecoderLayer
  → Attention / GatedDeltaNet / MLP) takes `quant: Option<GgmlDType>`
  end-to-end, calling Column/RowParallelLinear::load_with_quant.
- The fused in_proj_qkv inside TpQwen3_5GatedDeltaNet is now a
  MaybeQuantLinear so it also picks up quantization.
- parse_quant_string accepts q4_0/q4_1/q5_0/q5_1/q8_0/q8_1, q2k..q8k
  (with or without underscore), and f16/bf16/f32. Empty / None means
  no quantization.

Callers from candle.rs forward spec.quant through pool.load_dense_shard.
This means a `quant = "q5k"` in models.toml now flows end-to-end to a
QTensor-backed QMatMul for every per-rank linear in the Qwen3-Next
TP path. Leaves lm_head and the small replicated bias/log tensors in
their loaded dtype (Stage 8e-3).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 18:03:36 +03:00
parent bef159b21c
commit 4aa71902d0
5 changed files with 175 additions and 28 deletions

View File

@@ -474,6 +474,7 @@ impl WorkerPool {
/// `init_nccl` must have completed first. Bails if the leader's
/// NCCL comm isn't set up yet.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub async fn load_dense_shard(
&mut self,
model_id: &str,
@@ -481,6 +482,7 @@ impl WorkerPool {
safetensors_paths: &[std::path::PathBuf],
leader_device: &candle_core::Device,
dtype: candle_core::DType,
quant: Option<String>,
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> {
use candle_nn::var_builder::ShardedSafeTensors;
use std::sync::Arc;
@@ -510,6 +512,7 @@ impl WorkerPool {
model_id: model_id.to_string(),
config_json: config_json.to_string(),
safetensors_paths: safetensors_str.clone(),
quant: quant.clone(),
})
.await?;
}
@@ -531,6 +534,7 @@ impl WorkerPool {
let comm_for_leader = leader_comm;
let model_id_for_log = model_id.to_string();
let config_json_for_leader = config_json.to_string();
let quant_for_leader = quant.clone();
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
// SAFETY: same invariant as the single-GPU dense path —
@@ -558,8 +562,16 @@ impl WorkerPool {
let cfg: super::tp::tp_qwen3_5::Config =
serde_json::from_str(&config_json_for_leader)
.context("parse Qwen3-Next Config JSON for leader load")?;
let quant_dtype =
super::tp::worker::parse_quant_string(quant_for_leader.as_deref())?;
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
cfg, &vb, &mmap, 0, world_size, comm,
cfg,
&vb,
&mmap,
0,
world_size,
comm,
quant_dtype,
)?)
}
other => anyhow::bail!(