From f084aaab8e4b0af425cfc022462921ac2a5e647d Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Thu, 21 May 2026 20:05:19 +0300 Subject: [PATCH] fix(stage-8e-2c): cast bf16/f16 activations to f32 around QMatMul MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit candle's QTensor::cuda_fwd requires f32 inputs — its on-the-fly GGUF dequantize accumulates in f32. The model dtype flowing into MaybeQuantLinear::forward is bf16, so QMatMul::forward errored with "unexpected dtype, expected: F32, got: BF16". Wrap the Quant arm to cast the activation to f32 before the matmul and cast the result back to the input dtype. The cast is a single launch on the activation tensor (small relative to weight traffic); it's the price of in-situ GGUF-style quantization, and what mistralrs does inside its own Linear wrapper. The Plain arm is unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/neuron/src/harness/tp/tp_linear.rs | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/crates/neuron/src/harness/tp/tp_linear.rs b/crates/neuron/src/harness/tp/tp_linear.rs index f62a07a..acffc2c 100644 --- a/crates/neuron/src/harness/tp/tp_linear.rs +++ b/crates/neuron/src/harness/tp/tp_linear.rs @@ -75,7 +75,26 @@ impl Module for MaybeQuantLinear { fn forward(&self, x: &Tensor) -> candle_core::Result { match self { Self::Plain(l) => l.forward(x), - Self::Quant(qm) => qm.forward(x), + Self::Quant(qm) => { + // candle's `QTensor::cuda_fwd` requires f32 inputs (the + // GGUF kernels dequantize on the fly and accumulate in + // f32). Our model dtype is bf16 (or f16) so we cast in + // and out at the matmul boundary. The cast itself is a + // single launch on the activation tensor — cheap vs the + // weight loads the matmul saves. + let in_dtype = x.dtype(); + let x_f32 = if in_dtype == candle_core::DType::F32 { + x.clone() + } else { + x.to_dtype(candle_core::DType::F32)? + }; + let y = qm.forward(&x_f32)?; + if y.dtype() == in_dtype { + Ok(y) + } else { + y.to_dtype(in_dtype) + } + } } } }