feat(stage-8e-2d): route quantized matmul by M (prefill vs decode)
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 41s
CI / Clippy (push) Successful in 2m20s
CI / Test (push) Successful in 4m40s
build-prerelease / Build cortex binary (push) Successful in 4m20s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m58s
build-prerelease / Build neuron-ampere (push) Successful in 5m14s
build-prerelease / Package cortex RPM (push) Successful in 9m25s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m55s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m45s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 41s
CI / Clippy (push) Successful in 2m20s
CI / Test (push) Successful in 4m40s
build-prerelease / Build cortex binary (push) Successful in 4m20s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m58s
build-prerelease / Build neuron-ampere (push) Successful in 5m14s
build-prerelease / Package cortex RPM (push) Successful in 9m25s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m55s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m45s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
MaybeQuantLinear::forward picks between two QMatMul paths: - M > 8 (prefill): QMatMul::forward_via_f16 dequantises the weight once into f16 and runs a real cuBLAS-backed GEMM. The dequant cost is fixed per call, so it's amortised across the M tokens. - M <= 8 (decode): QMatMul::forward uses candle's GGUF GEMV kernel on the quantized blocks directly. Requires f32 inputs so we still cast in/out at the boundary in that arm. Earlier 8e-2c sent everything through the GGUF GEMV kernel, which is excellent at GEMV (decode) but doesn't have a real batched GEMM path — prefill regressed ~4x. This restores prefill to roughly the bf16 cuBLAS GEMM throughput while keeping the decode gain. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -71,17 +71,45 @@ impl MaybeQuantLinear {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Above this M (the product of all input dims except the last)
|
||||||
|
/// dispatch the quantized matmul through `QMatMul::forward_via_f16`,
|
||||||
|
/// which dequantizes the weight to f16 once and runs cuBLAS GEMM.
|
||||||
|
/// At or below this M the GGUF GEMV kernel inside
|
||||||
|
/// `QMatMul::forward` wins (it operates on quantized blocks directly
|
||||||
|
/// and accumulates in registers).
|
||||||
|
///
|
||||||
|
/// 8 is conservative: candle's f16 GEMM beats the GGUF GEMV anywhere
|
||||||
|
/// the M dim gets non-trivial (>=4 typically), but the dequantize
|
||||||
|
/// cost is fixed per call so the crossover is a small constant.
|
||||||
|
const QUANT_PREFILL_M_THRESHOLD: usize = 8;
|
||||||
|
|
||||||
impl Module for MaybeQuantLinear {
|
impl Module for MaybeQuantLinear {
|
||||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::Plain(l) => l.forward(x),
|
Self::Plain(l) => l.forward(x),
|
||||||
Self::Quant(qm) => {
|
Self::Quant(qm) => {
|
||||||
// candle's `QTensor::cuda_fwd` requires f32 inputs (the
|
// Decode vs prefill split. `M` is the "rows of x" the
|
||||||
// GGUF kernels dequantize on the fly and accumulate in
|
// matmul will iterate over — every dim except the last
|
||||||
// f32). Our model dtype is bf16 (or f16) so we cast in
|
// (which is in_features). For decode (`seq_len == 1`
|
||||||
// and out at the matmul boundary. The cast itself is a
|
// with batch 1) M is 1; for prefill with L>>1 it's L
|
||||||
// single launch on the activation tensor — cheap vs the
|
// (or B*L).
|
||||||
// weight loads the matmul saves.
|
let dims = x.dims();
|
||||||
|
let m: usize = dims.iter().take(dims.len() - 1).product();
|
||||||
|
|
||||||
|
if m > QUANT_PREFILL_M_THRESHOLD {
|
||||||
|
// Prefill: dequantize the weight once into f16,
|
||||||
|
// then run a real cuBLAS-backed GEMM. The cost of
|
||||||
|
// the dequant is amortised across all M tokens.
|
||||||
|
// `forward_via_f16` handles the dtype round-trip
|
||||||
|
// internally (output matches input dtype).
|
||||||
|
return qm.forward_via_f16(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode (M <= threshold): use the on-the-fly GGUF
|
||||||
|
// GEMV kernel via `QMatMul::forward`. That kernel
|
||||||
|
// requires f32 inputs (it accumulates in f32 from the
|
||||||
|
// dequantized quant blocks); cast in/out at the
|
||||||
|
// boundary.
|
||||||
let in_dtype = x.dtype();
|
let in_dtype = x.dtype();
|
||||||
let x_f32 = if in_dtype == candle_core::DType::F32 {
|
let x_f32 = if in_dtype == candle_core::DType::F32 {
|
||||||
x.clone()
|
x.clone()
|
||||||
|
|||||||
Reference in New Issue
Block a user