feat(stage-8e-1): MaybeQuantLinear primitive + parallel-linear quant variants
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 37s
build-prerelease / Build cortex binary (push) Successful in 4m36s
build-prerelease / Build neuron-blackwell (push) Successful in 3m31s
build-prerelease / Package cortex RPM (push) Successful in 1m27s
CI / Format (push) Waiting to run
CI / Clippy (push) Waiting to run
CI / Test (push) Waiting to run
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled

Introduces MaybeQuantLinear, which wraps either a plain candle Linear
or a candle QMatMul backed by a freshly-quantized QTensor. Forward
dispatches identically through the Module trait so downstream code
doesn't care which arm is active.

ColumnParallelLinear and RowParallelLinear gain `load_with_quant`
methods. The existing `load` methods stay as backward-compatible
no-quantization wrappers — no churn at the 27 existing call sites.

This is the foundation for in-situ quantization at load time. Wiring
the user-facing quant config and switching call sites to
load_with_quant follow in stages 8e-2 / 8e-3.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 17:55:26 +03:00
parent 8d7b099b36
commit bef159b21c

View File

@@ -24,13 +24,62 @@
//! sum carries it exactly once. //! sum carries it exactly once.
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use candle_core::quantized::{GgmlDType, QMatMul, QTensor};
use candle_core::{Module, Tensor}; use candle_core::{Module, Tensor};
use candle_nn::Linear; use candle_nn::Linear;
use candle_nn::var_builder::{Shard, ShardedVarBuilder}; use candle_nn::var_builder::{Shard, ShardedVarBuilder};
use std::sync::Arc;
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
use super::all_reduce::AllReduce; use super::all_reduce::AllReduce;
/// Linear primitive that holds either a plain `Linear` (bf16/f16/f32)
/// or a quantized `QMatMul` (Q4K/Q5K/Q6K/Q8_0/etc.).
///
/// Constructed via [`MaybeQuantLinear::from_weight`] — pass `None` to
/// keep the weight in its loaded dtype (no quantization), or
/// `Some(dtype)` to quantize at load time.
///
/// On the forward path the two arms dispatch identically: `Module::forward`
/// returns an output in the caller's input dtype (f32 fallback for the
/// quantized matmul). Subsequent ops don't need to know whether the
/// layer was quantized.
pub enum MaybeQuantLinear {
Plain(Linear),
Quant(QMatMul),
}
impl MaybeQuantLinear {
/// Build a linear from a loaded weight tensor. If `quant` is set,
/// the weight is quantized in-situ and stored as a `QMatMul`;
/// otherwise it's wrapped in a plain `Linear`.
pub fn from_weight(weight: Tensor, quant: Option<GgmlDType>) -> Result<Self> {
match quant {
Some(dtype) => {
let qt = QTensor::quantize(&weight, dtype).with_context(|| {
format!(
"QTensor::quantize to {dtype:?} for shape {:?}",
weight.shape()
)
})?;
let qmm = QMatMul::from_arc(Arc::new(qt))
.context("QMatMul::from_arc on freshly quantized weight")?;
Ok(Self::Quant(qmm))
}
None => Ok(Self::Plain(Linear::new(weight, None))),
}
}
}
impl Module for MaybeQuantLinear {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
match self {
Self::Plain(l) => l.forward(x),
Self::Quant(qm) => qm.forward(x),
}
}
}
/// Helper to build a [`Shard`] hint for a given dimension. /// Helper to build a [`Shard`] hint for a given dimension.
pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard { pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
Shard { Shard {
@@ -40,24 +89,38 @@ pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
} }
} }
/// Output-dim sharded linear (column-parallel). Holds a standard /// Output-dim sharded linear (column-parallel). Holds a
/// `candle_nn::Linear` whose `weight` is the rank's slice of the full /// [`MaybeQuantLinear`] whose underlying weight is this rank's slice
/// `[out_features, in_features]` tensor along dim 0. /// of the full `[out_features, in_features]` tensor along dim 0.
pub struct ColumnParallelLinear { pub struct ColumnParallelLinear {
inner: Linear, inner: MaybeQuantLinear,
} }
impl ColumnParallelLinear { impl ColumnParallelLinear {
/// Load this rank's column-parallel slice from a /// Load this rank's column-parallel slice from a
/// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed /// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed
/// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`). /// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`).
///
/// Backward-compatible variant — no in-situ quantization. For
/// quantized loads, use [`Self::load_with_quant`].
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> { pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
Self::load_with_quant(vb, rank, world_size, None)
}
/// Like [`Self::load`] but quantizes the per-rank weight in-situ
/// when `quant` is `Some(dtype)`. Saves ~3-5x vs bf16/f16.
pub fn load_with_quant(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> {
let weight = vb let weight = vb
.get_with_hints((), "weight", shard(0, rank, world_size)) .get_with_hints((), "weight", shard(0, rank, world_size))
.with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?; .with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?;
Ok(Self { let inner = MaybeQuantLinear::from_weight(weight, quant)
inner: Linear::new(weight, None), .with_context(|| format!("wrap column-parallel '{}'", vb.prefix()))?;
}) Ok(Self { inner })
} }
} }
@@ -69,10 +132,10 @@ impl Module for ColumnParallelLinear {
/// Input-dim sharded linear (row-parallel). /// Input-dim sharded linear (row-parallel).
/// ///
/// Holds a sharded `Linear` plus an `AllReduce` op the forward chains /// Holds a sharded [`MaybeQuantLinear`] plus an `AllReduce` op the
/// after the local matmul to recover the full activation. /// forward chains after the local matmul to recover the full activation.
pub struct RowParallelLinear { pub struct RowParallelLinear {
inner: Linear, inner: MaybeQuantLinear,
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
all_reduce: AllReduce, all_reduce: AllReduce,
/// Whether the AllReduce should run. Column-parallel ↔ row-parallel /// Whether the AllReduce should run. Column-parallel ↔ row-parallel
@@ -89,18 +152,35 @@ impl RowParallelLinear {
/// `AllReduce` runs against. On CPU builds the parameter is /// `AllReduce` runs against. On CPU builds the parameter is
/// elided — forward returns the partial sum, which is the *wrong* /// elided — forward returns the partial sum, which is the *wrong*
/// answer for inference but lets us compile-check the model. /// answer for inference but lets us compile-check the model.
///
/// Backward-compatible variant — no in-situ quantization. For
/// quantized loads, use [`Self::load_with_quant`].
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
pub fn load( pub fn load(
vb: &ShardedVarBuilder, vb: &ShardedVarBuilder,
rank: u32, rank: u32,
world_size: u32, world_size: u32,
comm: std::sync::Arc<cudarc::nccl::Comm>, comm: std::sync::Arc<cudarc::nccl::Comm>,
) -> Result<Self> {
Self::load_with_quant(vb, rank, world_size, comm, None)
}
/// Like [`Self::load`] but quantizes the per-rank weight in-situ.
#[cfg(feature = "cuda")]
pub fn load_with_quant(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: std::sync::Arc<cudarc::nccl::Comm>,
quant: Option<GgmlDType>,
) -> Result<Self> { ) -> Result<Self> {
let weight = vb let weight = vb
.get_with_hints((), "weight", shard(1, rank, world_size)) .get_with_hints((), "weight", shard(1, rank, world_size))
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?; .with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
let inner = MaybeQuantLinear::from_weight(weight, quant)
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
Ok(Self { Ok(Self {
inner: Linear::new(weight, None), inner,
all_reduce: AllReduce::new(comm), all_reduce: AllReduce::new(comm),
needs_reduce: world_size > 1, needs_reduce: world_size > 1,
}) })
@@ -108,11 +188,23 @@ impl RowParallelLinear {
#[cfg(not(feature = "cuda"))] #[cfg(not(feature = "cuda"))]
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> { pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
Self::load_with_quant(vb, rank, world_size, None)
}
#[cfg(not(feature = "cuda"))]
pub fn load_with_quant(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> {
let weight = vb let weight = vb
.get_with_hints((), "weight", shard(1, rank, world_size)) .get_with_hints((), "weight", shard(1, rank, world_size))
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?; .with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
let inner = MaybeQuantLinear::from_weight(weight, quant)
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
Ok(Self { Ok(Self {
inner: Linear::new(weight, None), inner,
needs_reduce: world_size > 1, needs_reduce: world_size > 1,
}) })
} }