feat(stage-8e-2d): route quantized matmul by M (prefill vs decode)
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 41s
CI / Clippy (push) Successful in 2m20s
CI / Test (push) Successful in 4m40s
build-prerelease / Build cortex binary (push) Successful in 4m20s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m58s
build-prerelease / Build neuron-ampere (push) Successful in 5m14s
build-prerelease / Package cortex RPM (push) Successful in 9m25s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m55s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m45s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 41s
CI / Clippy (push) Successful in 2m20s
CI / Test (push) Successful in 4m40s
build-prerelease / Build cortex binary (push) Successful in 4m20s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m58s
build-prerelease / Build neuron-ampere (push) Successful in 5m14s
build-prerelease / Package cortex RPM (push) Successful in 9m25s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m55s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m45s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
MaybeQuantLinear::forward picks between two QMatMul paths: - M > 8 (prefill): QMatMul::forward_via_f16 dequantises the weight once into f16 and runs a real cuBLAS-backed GEMM. The dequant cost is fixed per call, so it's amortised across the M tokens. - M <= 8 (decode): QMatMul::forward uses candle's GGUF GEMV kernel on the quantized blocks directly. Requires f32 inputs so we still cast in/out at the boundary in that arm. Earlier 8e-2c sent everything through the GGUF GEMV kernel, which is excellent at GEMV (decode) but doesn't have a real batched GEMM path — prefill regressed ~4x. This restores prefill to roughly the bf16 cuBLAS GEMM throughput while keeping the decode gain. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -71,17 +71,45 @@ impl MaybeQuantLinear {
|
||||
}
|
||||
}
|
||||
|
||||
/// Above this M (the product of all input dims except the last)
|
||||
/// dispatch the quantized matmul through `QMatMul::forward_via_f16`,
|
||||
/// which dequantizes the weight to f16 once and runs cuBLAS GEMM.
|
||||
/// At or below this M the GGUF GEMV kernel inside
|
||||
/// `QMatMul::forward` wins (it operates on quantized blocks directly
|
||||
/// and accumulates in registers).
|
||||
///
|
||||
/// 8 is conservative: candle's f16 GEMM beats the GGUF GEMV anywhere
|
||||
/// the M dim gets non-trivial (>=4 typically), but the dequantize
|
||||
/// cost is fixed per call so the crossover is a small constant.
|
||||
const QUANT_PREFILL_M_THRESHOLD: usize = 8;
|
||||
|
||||
impl Module for MaybeQuantLinear {
|
||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
match self {
|
||||
Self::Plain(l) => l.forward(x),
|
||||
Self::Quant(qm) => {
|
||||
// candle's `QTensor::cuda_fwd` requires f32 inputs (the
|
||||
// GGUF kernels dequantize on the fly and accumulate in
|
||||
// f32). Our model dtype is bf16 (or f16) so we cast in
|
||||
// and out at the matmul boundary. The cast itself is a
|
||||
// single launch on the activation tensor — cheap vs the
|
||||
// weight loads the matmul saves.
|
||||
// Decode vs prefill split. `M` is the "rows of x" the
|
||||
// matmul will iterate over — every dim except the last
|
||||
// (which is in_features). For decode (`seq_len == 1`
|
||||
// with batch 1) M is 1; for prefill with L>>1 it's L
|
||||
// (or B*L).
|
||||
let dims = x.dims();
|
||||
let m: usize = dims.iter().take(dims.len() - 1).product();
|
||||
|
||||
if m > QUANT_PREFILL_M_THRESHOLD {
|
||||
// Prefill: dequantize the weight once into f16,
|
||||
// then run a real cuBLAS-backed GEMM. The cost of
|
||||
// the dequant is amortised across all M tokens.
|
||||
// `forward_via_f16` handles the dtype round-trip
|
||||
// internally (output matches input dtype).
|
||||
return qm.forward_via_f16(x);
|
||||
}
|
||||
|
||||
// Decode (M <= threshold): use the on-the-fly GGUF
|
||||
// GEMV kernel via `QMatMul::forward`. That kernel
|
||||
// requires f32 inputs (it accumulates in f32 from the
|
||||
// dequantized quant blocks); cast in/out at the
|
||||
// boundary.
|
||||
let in_dtype = x.dtype();
|
||||
let x_f32 = if in_dtype == candle_core::DType::F32 {
|
||||
x.clone()
|
||||
|
||||
Reference in New Issue
Block a user