feat(stage-8e-1): MaybeQuantLinear primitive + parallel-linear quant variants
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 37s
build-prerelease / Build cortex binary (push) Successful in 4m36s
build-prerelease / Build neuron-blackwell (push) Successful in 3m31s
build-prerelease / Package cortex RPM (push) Successful in 1m27s
CI / Format (push) Waiting to run
CI / Clippy (push) Waiting to run
CI / Test (push) Waiting to run
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 37s
build-prerelease / Build cortex binary (push) Successful in 4m36s
build-prerelease / Build neuron-blackwell (push) Successful in 3m31s
build-prerelease / Package cortex RPM (push) Successful in 1m27s
CI / Format (push) Waiting to run
CI / Clippy (push) Waiting to run
CI / Test (push) Waiting to run
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Introduces MaybeQuantLinear, which wraps either a plain candle Linear or a candle QMatMul backed by a freshly-quantized QTensor. Forward dispatches identically through the Module trait so downstream code doesn't care which arm is active. ColumnParallelLinear and RowParallelLinear gain `load_with_quant` methods. The existing `load` methods stay as backward-compatible no-quantization wrappers — no churn at the 27 existing call sites. This is the foundation for in-situ quantization at load time. Wiring the user-facing quant config and switching call sites to load_with_quant follow in stages 8e-2 / 8e-3. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -24,13 +24,62 @@
|
|||||||
//! sum carries it exactly once.
|
//! sum carries it exactly once.
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::quantized::{GgmlDType, QMatMul, QTensor};
|
||||||
use candle_core::{Module, Tensor};
|
use candle_core::{Module, Tensor};
|
||||||
use candle_nn::Linear;
|
use candle_nn::Linear;
|
||||||
use candle_nn::var_builder::{Shard, ShardedVarBuilder};
|
use candle_nn::var_builder::{Shard, ShardedVarBuilder};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
use super::all_reduce::AllReduce;
|
use super::all_reduce::AllReduce;
|
||||||
|
|
||||||
|
/// Linear primitive that holds either a plain `Linear` (bf16/f16/f32)
|
||||||
|
/// or a quantized `QMatMul` (Q4K/Q5K/Q6K/Q8_0/etc.).
|
||||||
|
///
|
||||||
|
/// Constructed via [`MaybeQuantLinear::from_weight`] — pass `None` to
|
||||||
|
/// keep the weight in its loaded dtype (no quantization), or
|
||||||
|
/// `Some(dtype)` to quantize at load time.
|
||||||
|
///
|
||||||
|
/// On the forward path the two arms dispatch identically: `Module::forward`
|
||||||
|
/// returns an output in the caller's input dtype (f32 fallback for the
|
||||||
|
/// quantized matmul). Subsequent ops don't need to know whether the
|
||||||
|
/// layer was quantized.
|
||||||
|
pub enum MaybeQuantLinear {
|
||||||
|
Plain(Linear),
|
||||||
|
Quant(QMatMul),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MaybeQuantLinear {
|
||||||
|
/// Build a linear from a loaded weight tensor. If `quant` is set,
|
||||||
|
/// the weight is quantized in-situ and stored as a `QMatMul`;
|
||||||
|
/// otherwise it's wrapped in a plain `Linear`.
|
||||||
|
pub fn from_weight(weight: Tensor, quant: Option<GgmlDType>) -> Result<Self> {
|
||||||
|
match quant {
|
||||||
|
Some(dtype) => {
|
||||||
|
let qt = QTensor::quantize(&weight, dtype).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"QTensor::quantize to {dtype:?} for shape {:?}",
|
||||||
|
weight.shape()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let qmm = QMatMul::from_arc(Arc::new(qt))
|
||||||
|
.context("QMatMul::from_arc on freshly quantized weight")?;
|
||||||
|
Ok(Self::Quant(qmm))
|
||||||
|
}
|
||||||
|
None => Ok(Self::Plain(Linear::new(weight, None))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MaybeQuantLinear {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Plain(l) => l.forward(x),
|
||||||
|
Self::Quant(qm) => qm.forward(x),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Helper to build a [`Shard`] hint for a given dimension.
|
/// Helper to build a [`Shard`] hint for a given dimension.
|
||||||
pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
|
pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
|
||||||
Shard {
|
Shard {
|
||||||
@@ -40,24 +89,38 @@ pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Output-dim sharded linear (column-parallel). Holds a standard
|
/// Output-dim sharded linear (column-parallel). Holds a
|
||||||
/// `candle_nn::Linear` whose `weight` is the rank's slice of the full
|
/// [`MaybeQuantLinear`] whose underlying weight is this rank's slice
|
||||||
/// `[out_features, in_features]` tensor along dim 0.
|
/// of the full `[out_features, in_features]` tensor along dim 0.
|
||||||
pub struct ColumnParallelLinear {
|
pub struct ColumnParallelLinear {
|
||||||
inner: Linear,
|
inner: MaybeQuantLinear,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ColumnParallelLinear {
|
impl ColumnParallelLinear {
|
||||||
/// Load this rank's column-parallel slice from a
|
/// Load this rank's column-parallel slice from a
|
||||||
/// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed
|
/// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed
|
||||||
/// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`).
|
/// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`).
|
||||||
|
///
|
||||||
|
/// Backward-compatible variant — no in-situ quantization. For
|
||||||
|
/// quantized loads, use [`Self::load_with_quant`].
|
||||||
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
Self::load_with_quant(vb, rank, world_size, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Like [`Self::load`] but quantizes the per-rank weight in-situ
|
||||||
|
/// when `quant` is `Some(dtype)`. Saves ~3-5x vs bf16/f16.
|
||||||
|
pub fn load_with_quant(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
|
) -> Result<Self> {
|
||||||
let weight = vb
|
let weight = vb
|
||||||
.get_with_hints((), "weight", shard(0, rank, world_size))
|
.get_with_hints((), "weight", shard(0, rank, world_size))
|
||||||
.with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?;
|
.with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?;
|
||||||
Ok(Self {
|
let inner = MaybeQuantLinear::from_weight(weight, quant)
|
||||||
inner: Linear::new(weight, None),
|
.with_context(|| format!("wrap column-parallel '{}'", vb.prefix()))?;
|
||||||
})
|
Ok(Self { inner })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,10 +132,10 @@ impl Module for ColumnParallelLinear {
|
|||||||
|
|
||||||
/// Input-dim sharded linear (row-parallel).
|
/// Input-dim sharded linear (row-parallel).
|
||||||
///
|
///
|
||||||
/// Holds a sharded `Linear` plus an `AllReduce` op the forward chains
|
/// Holds a sharded [`MaybeQuantLinear`] plus an `AllReduce` op the
|
||||||
/// after the local matmul to recover the full activation.
|
/// forward chains after the local matmul to recover the full activation.
|
||||||
pub struct RowParallelLinear {
|
pub struct RowParallelLinear {
|
||||||
inner: Linear,
|
inner: MaybeQuantLinear,
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
all_reduce: AllReduce,
|
all_reduce: AllReduce,
|
||||||
/// Whether the AllReduce should run. Column-parallel ↔ row-parallel
|
/// Whether the AllReduce should run. Column-parallel ↔ row-parallel
|
||||||
@@ -89,18 +152,35 @@ impl RowParallelLinear {
|
|||||||
/// `AllReduce` runs against. On CPU builds the parameter is
|
/// `AllReduce` runs against. On CPU builds the parameter is
|
||||||
/// elided — forward returns the partial sum, which is the *wrong*
|
/// elided — forward returns the partial sum, which is the *wrong*
|
||||||
/// answer for inference but lets us compile-check the model.
|
/// answer for inference but lets us compile-check the model.
|
||||||
|
///
|
||||||
|
/// Backward-compatible variant — no in-situ quantization. For
|
||||||
|
/// quantized loads, use [`Self::load_with_quant`].
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
pub fn load(
|
pub fn load(
|
||||||
vb: &ShardedVarBuilder,
|
vb: &ShardedVarBuilder,
|
||||||
rank: u32,
|
rank: u32,
|
||||||
world_size: u32,
|
world_size: u32,
|
||||||
comm: std::sync::Arc<cudarc::nccl::Comm>,
|
comm: std::sync::Arc<cudarc::nccl::Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Self::load_with_quant(vb, rank, world_size, comm, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Like [`Self::load`] but quantizes the per-rank weight in-situ.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load_with_quant(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: std::sync::Arc<cudarc::nccl::Comm>,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let weight = vb
|
let weight = vb
|
||||||
.get_with_hints((), "weight", shard(1, rank, world_size))
|
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||||
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
let inner = MaybeQuantLinear::from_weight(weight, quant)
|
||||||
|
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
inner: Linear::new(weight, None),
|
inner,
|
||||||
all_reduce: AllReduce::new(comm),
|
all_reduce: AllReduce::new(comm),
|
||||||
needs_reduce: world_size > 1,
|
needs_reduce: world_size > 1,
|
||||||
})
|
})
|
||||||
@@ -108,11 +188,23 @@ impl RowParallelLinear {
|
|||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
Self::load_with_quant(vb, rank, world_size, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load_with_quant(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
|
) -> Result<Self> {
|
||||||
let weight = vb
|
let weight = vb
|
||||||
.get_with_hints((), "weight", shard(1, rank, world_size))
|
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||||
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
let inner = MaybeQuantLinear::from_weight(weight, quant)
|
||||||
|
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
inner: Linear::new(weight, None),
|
inner,
|
||||||
needs_reduce: world_size > 1,
|
needs_reduce: world_size > 1,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user