feat(stage-8e-2d): route quantized matmul by M (prefill vs decode)
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 41s
CI / Clippy (push) Successful in 2m20s
CI / Test (push) Successful in 4m40s
build-prerelease / Build cortex binary (push) Successful in 4m20s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m58s
build-prerelease / Build neuron-ampere (push) Successful in 5m14s
build-prerelease / Package cortex RPM (push) Successful in 9m25s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m55s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m45s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s

MaybeQuantLinear::forward picks between two QMatMul paths:

- M > 8 (prefill): QMatMul::forward_via_f16 dequantises the weight
  once into f16 and runs a real cuBLAS-backed GEMM. The dequant cost
  is fixed per call, so it's amortised across the M tokens.
- M <= 8 (decode): QMatMul::forward uses candle's GGUF GEMV kernel
  on the quantized blocks directly. Requires f32 inputs so we still
  cast in/out at the boundary in that arm.

Earlier 8e-2c sent everything through the GGUF GEMV kernel, which
is excellent at GEMV (decode) but doesn't have a real batched GEMM
path — prefill regressed ~4x. This restores prefill to roughly the
bf16 cuBLAS GEMM throughput while keeping the decode gain.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 21:15:32 +03:00
parent f084aaab8e
commit 34f9b77d9d

View File

@@ -71,17 +71,45 @@ impl MaybeQuantLinear {
} }
} }
/// Above this M (the product of all input dims except the last)
/// dispatch the quantized matmul through `QMatMul::forward_via_f16`,
/// which dequantizes the weight to f16 once and runs cuBLAS GEMM.
/// At or below this M the GGUF GEMV kernel inside
/// `QMatMul::forward` wins (it operates on quantized blocks directly
/// and accumulates in registers).
///
/// 8 is conservative: candle's f16 GEMM beats the GGUF GEMV anywhere
/// the M dim gets non-trivial (>=4 typically), but the dequantize
/// cost is fixed per call so the crossover is a small constant.
const QUANT_PREFILL_M_THRESHOLD: usize = 8;
impl Module for MaybeQuantLinear { impl Module for MaybeQuantLinear {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> { fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
match self { match self {
Self::Plain(l) => l.forward(x), Self::Plain(l) => l.forward(x),
Self::Quant(qm) => { Self::Quant(qm) => {
// candle's `QTensor::cuda_fwd` requires f32 inputs (the // Decode vs prefill split. `M` is the "rows of x" the
// GGUF kernels dequantize on the fly and accumulate in // matmul will iterate over — every dim except the last
// f32). Our model dtype is bf16 (or f16) so we cast in // (which is in_features). For decode (`seq_len == 1`
// and out at the matmul boundary. The cast itself is a // with batch 1) M is 1; for prefill with L>>1 it's L
// single launch on the activation tensor — cheap vs the // (or B*L).
// weight loads the matmul saves. let dims = x.dims();
let m: usize = dims.iter().take(dims.len() - 1).product();
if m > QUANT_PREFILL_M_THRESHOLD {
// Prefill: dequantize the weight once into f16,
// then run a real cuBLAS-backed GEMM. The cost of
// the dequant is amortised across all M tokens.
// `forward_via_f16` handles the dtype round-trip
// internally (output matches input dtype).
return qm.forward_via_f16(x);
}
// Decode (M <= threshold): use the on-the-fly GGUF
// GEMV kernel via `QMatMul::forward`. That kernel
// requires f32 inputs (it accumulates in f32 from the
// dequantized quant blocks); cast in/out at the
// boundary.
let in_dtype = x.dtype(); let in_dtype = x.dtype();
let x_f32 = if in_dtype == candle_core::DType::F32 { let x_f32 = if in_dtype == candle_core::DType::F32 {
x.clone() x.clone()