From 34f9b77d9d69858f1ae9851775f00a4137a7b90e Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Thu, 21 May 2026 21:15:32 +0300 Subject: [PATCH] feat(stage-8e-2d): route quantized matmul by M (prefill vs decode) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MaybeQuantLinear::forward picks between two QMatMul paths: - M > 8 (prefill): QMatMul::forward_via_f16 dequantises the weight once into f16 and runs a real cuBLAS-backed GEMM. The dequant cost is fixed per call, so it's amortised across the M tokens. - M <= 8 (decode): QMatMul::forward uses candle's GGUF GEMV kernel on the quantized blocks directly. Requires f32 inputs so we still cast in/out at the boundary in that arm. Earlier 8e-2c sent everything through the GGUF GEMV kernel, which is excellent at GEMV (decode) but doesn't have a real batched GEMM path — prefill regressed ~4x. This restores prefill to roughly the bf16 cuBLAS GEMM throughput while keeping the decode gain. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/neuron/src/harness/tp/tp_linear.rs | 40 +++++++++++++++++++---- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/crates/neuron/src/harness/tp/tp_linear.rs b/crates/neuron/src/harness/tp/tp_linear.rs index acffc2c..663ba7a 100644 --- a/crates/neuron/src/harness/tp/tp_linear.rs +++ b/crates/neuron/src/harness/tp/tp_linear.rs @@ -71,17 +71,45 @@ impl MaybeQuantLinear { } } +/// Above this M (the product of all input dims except the last) +/// dispatch the quantized matmul through `QMatMul::forward_via_f16`, +/// which dequantizes the weight to f16 once and runs cuBLAS GEMM. +/// At or below this M the GGUF GEMV kernel inside +/// `QMatMul::forward` wins (it operates on quantized blocks directly +/// and accumulates in registers). +/// +/// 8 is conservative: candle's f16 GEMM beats the GGUF GEMV anywhere +/// the M dim gets non-trivial (>=4 typically), but the dequantize +/// cost is fixed per call so the crossover is a small constant. +const QUANT_PREFILL_M_THRESHOLD: usize = 8; + impl Module for MaybeQuantLinear { fn forward(&self, x: &Tensor) -> candle_core::Result { match self { Self::Plain(l) => l.forward(x), Self::Quant(qm) => { - // candle's `QTensor::cuda_fwd` requires f32 inputs (the - // GGUF kernels dequantize on the fly and accumulate in - // f32). Our model dtype is bf16 (or f16) so we cast in - // and out at the matmul boundary. The cast itself is a - // single launch on the activation tensor — cheap vs the - // weight loads the matmul saves. + // Decode vs prefill split. `M` is the "rows of x" the + // matmul will iterate over — every dim except the last + // (which is in_features). For decode (`seq_len == 1` + // with batch 1) M is 1; for prefill with L>>1 it's L + // (or B*L). + let dims = x.dims(); + let m: usize = dims.iter().take(dims.len() - 1).product(); + + if m > QUANT_PREFILL_M_THRESHOLD { + // Prefill: dequantize the weight once into f16, + // then run a real cuBLAS-backed GEMM. The cost of + // the dequant is amortised across all M tokens. + // `forward_via_f16` handles the dtype round-trip + // internally (output matches input dtype). + return qm.forward_via_f16(x); + } + + // Decode (M <= threshold): use the on-the-fly GGUF + // GEMV kernel via `QMatMul::forward`. That kernel + // requires f32 inputs (it accumulates in f32 from the + // dequantized quant blocks); cast in/out at the + // boundary. let in_dtype = x.dtype(); let x_f32 = if in_dtype == candle_core::DType::F32 { x.clone()