diff --git a/crates/neuron/src/harness/tp/tp_linear.rs b/crates/neuron/src/harness/tp/tp_linear.rs index 324f2dd..f62a07a 100644 --- a/crates/neuron/src/harness/tp/tp_linear.rs +++ b/crates/neuron/src/harness/tp/tp_linear.rs @@ -24,13 +24,62 @@ //! sum carries it exactly once. use anyhow::{Context, Result}; +use candle_core::quantized::{GgmlDType, QMatMul, QTensor}; use candle_core::{Module, Tensor}; use candle_nn::Linear; use candle_nn::var_builder::{Shard, ShardedVarBuilder}; +use std::sync::Arc; #[cfg(feature = "cuda")] use super::all_reduce::AllReduce; +/// Linear primitive that holds either a plain `Linear` (bf16/f16/f32) +/// or a quantized `QMatMul` (Q4K/Q5K/Q6K/Q8_0/etc.). +/// +/// Constructed via [`MaybeQuantLinear::from_weight`] — pass `None` to +/// keep the weight in its loaded dtype (no quantization), or +/// `Some(dtype)` to quantize at load time. +/// +/// On the forward path the two arms dispatch identically: `Module::forward` +/// returns an output in the caller's input dtype (f32 fallback for the +/// quantized matmul). Subsequent ops don't need to know whether the +/// layer was quantized. +pub enum MaybeQuantLinear { + Plain(Linear), + Quant(QMatMul), +} + +impl MaybeQuantLinear { + /// Build a linear from a loaded weight tensor. If `quant` is set, + /// the weight is quantized in-situ and stored as a `QMatMul`; + /// otherwise it's wrapped in a plain `Linear`. + pub fn from_weight(weight: Tensor, quant: Option) -> Result { + match quant { + Some(dtype) => { + let qt = QTensor::quantize(&weight, dtype).with_context(|| { + format!( + "QTensor::quantize to {dtype:?} for shape {:?}", + weight.shape() + ) + })?; + let qmm = QMatMul::from_arc(Arc::new(qt)) + .context("QMatMul::from_arc on freshly quantized weight")?; + Ok(Self::Quant(qmm)) + } + None => Ok(Self::Plain(Linear::new(weight, None))), + } + } +} + +impl Module for MaybeQuantLinear { + fn forward(&self, x: &Tensor) -> candle_core::Result { + match self { + Self::Plain(l) => l.forward(x), + Self::Quant(qm) => qm.forward(x), + } + } +} + /// Helper to build a [`Shard`] hint for a given dimension. pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard { Shard { @@ -40,24 +89,38 @@ pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard { } } -/// Output-dim sharded linear (column-parallel). Holds a standard -/// `candle_nn::Linear` whose `weight` is the rank's slice of the full -/// `[out_features, in_features]` tensor along dim 0. +/// Output-dim sharded linear (column-parallel). Holds a +/// [`MaybeQuantLinear`] whose underlying weight is this rank's slice +/// of the full `[out_features, in_features]` tensor along dim 0. pub struct ColumnParallelLinear { - inner: Linear, + inner: MaybeQuantLinear, } impl ColumnParallelLinear { /// Load this rank's column-parallel slice from a /// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed /// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`). + /// + /// Backward-compatible variant — no in-situ quantization. For + /// quantized loads, use [`Self::load_with_quant`]. pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result { + Self::load_with_quant(vb, rank, world_size, None) + } + + /// Like [`Self::load`] but quantizes the per-rank weight in-situ + /// when `quant` is `Some(dtype)`. Saves ~3-5x vs bf16/f16. + pub fn load_with_quant( + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + quant: Option, + ) -> Result { let weight = vb .get_with_hints((), "weight", shard(0, rank, world_size)) .with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?; - Ok(Self { - inner: Linear::new(weight, None), - }) + let inner = MaybeQuantLinear::from_weight(weight, quant) + .with_context(|| format!("wrap column-parallel '{}'", vb.prefix()))?; + Ok(Self { inner }) } } @@ -69,10 +132,10 @@ impl Module for ColumnParallelLinear { /// Input-dim sharded linear (row-parallel). /// -/// Holds a sharded `Linear` plus an `AllReduce` op the forward chains -/// after the local matmul to recover the full activation. +/// Holds a sharded [`MaybeQuantLinear`] plus an `AllReduce` op the +/// forward chains after the local matmul to recover the full activation. pub struct RowParallelLinear { - inner: Linear, + inner: MaybeQuantLinear, #[cfg(feature = "cuda")] all_reduce: AllReduce, /// Whether the AllReduce should run. Column-parallel ↔ row-parallel @@ -89,18 +152,35 @@ impl RowParallelLinear { /// `AllReduce` runs against. On CPU builds the parameter is /// elided — forward returns the partial sum, which is the *wrong* /// answer for inference but lets us compile-check the model. + /// + /// Backward-compatible variant — no in-situ quantization. For + /// quantized loads, use [`Self::load_with_quant`]. #[cfg(feature = "cuda")] pub fn load( vb: &ShardedVarBuilder, rank: u32, world_size: u32, comm: std::sync::Arc, + ) -> Result { + Self::load_with_quant(vb, rank, world_size, comm, None) + } + + /// Like [`Self::load`] but quantizes the per-rank weight in-situ. + #[cfg(feature = "cuda")] + pub fn load_with_quant( + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + comm: std::sync::Arc, + quant: Option, ) -> Result { let weight = vb .get_with_hints((), "weight", shard(1, rank, world_size)) .with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?; + let inner = MaybeQuantLinear::from_weight(weight, quant) + .with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?; Ok(Self { - inner: Linear::new(weight, None), + inner, all_reduce: AllReduce::new(comm), needs_reduce: world_size > 1, }) @@ -108,11 +188,23 @@ impl RowParallelLinear { #[cfg(not(feature = "cuda"))] pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result { + Self::load_with_quant(vb, rank, world_size, None) + } + + #[cfg(not(feature = "cuda"))] + pub fn load_with_quant( + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + quant: Option, + ) -> Result { let weight = vb .get_with_hints((), "weight", shard(1, rank, world_size)) .with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?; + let inner = MaybeQuantLinear::from_weight(weight, quant) + .with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?; Ok(Self { - inner: Linear::new(weight, None), + inner, needs_reduce: world_size > 1, }) }