diff --git a/crates/neuron/src/harness/tp/tp_linear.rs b/crates/neuron/src/harness/tp/tp_linear.rs index acffc2c..663ba7a 100644 --- a/crates/neuron/src/harness/tp/tp_linear.rs +++ b/crates/neuron/src/harness/tp/tp_linear.rs @@ -71,17 +71,45 @@ impl MaybeQuantLinear { } } +/// Above this M (the product of all input dims except the last) +/// dispatch the quantized matmul through `QMatMul::forward_via_f16`, +/// which dequantizes the weight to f16 once and runs cuBLAS GEMM. +/// At or below this M the GGUF GEMV kernel inside +/// `QMatMul::forward` wins (it operates on quantized blocks directly +/// and accumulates in registers). +/// +/// 8 is conservative: candle's f16 GEMM beats the GGUF GEMV anywhere +/// the M dim gets non-trivial (>=4 typically), but the dequantize +/// cost is fixed per call so the crossover is a small constant. +const QUANT_PREFILL_M_THRESHOLD: usize = 8; + impl Module for MaybeQuantLinear { fn forward(&self, x: &Tensor) -> candle_core::Result { match self { Self::Plain(l) => l.forward(x), Self::Quant(qm) => { - // candle's `QTensor::cuda_fwd` requires f32 inputs (the - // GGUF kernels dequantize on the fly and accumulate in - // f32). Our model dtype is bf16 (or f16) so we cast in - // and out at the matmul boundary. The cast itself is a - // single launch on the activation tensor — cheap vs the - // weight loads the matmul saves. + // Decode vs prefill split. `M` is the "rows of x" the + // matmul will iterate over — every dim except the last + // (which is in_features). For decode (`seq_len == 1` + // with batch 1) M is 1; for prefill with L>>1 it's L + // (or B*L). + let dims = x.dims(); + let m: usize = dims.iter().take(dims.len() - 1).product(); + + if m > QUANT_PREFILL_M_THRESHOLD { + // Prefill: dequantize the weight once into f16, + // then run a real cuBLAS-backed GEMM. The cost of + // the dequant is amortised across all M tokens. + // `forward_via_f16` handles the dtype round-trip + // internally (output matches input dtype). + return qm.forward_via_f16(x); + } + + // Decode (M <= threshold): use the on-the-fly GGUF + // GEMV kernel via `QMatMul::forward`. That kernel + // requires f32 inputs (it accumulates in f32 from the + // dequantized quant blocks); cast in/out at the + // boundary. let in_dtype = x.dtype(); let x_f32 = if in_dtype == candle_core::DType::F32 { x.clone()