diff --git a/crates/neuron/src/harness/tp/tp_linear.rs b/crates/neuron/src/harness/tp/tp_linear.rs index f62a07a..acffc2c 100644 --- a/crates/neuron/src/harness/tp/tp_linear.rs +++ b/crates/neuron/src/harness/tp/tp_linear.rs @@ -75,7 +75,26 @@ impl Module for MaybeQuantLinear { fn forward(&self, x: &Tensor) -> candle_core::Result { match self { Self::Plain(l) => l.forward(x), - Self::Quant(qm) => qm.forward(x), + Self::Quant(qm) => { + // candle's `QTensor::cuda_fwd` requires f32 inputs (the + // GGUF kernels dequantize on the fly and accumulate in + // f32). Our model dtype is bf16 (or f16) so we cast in + // and out at the matmul boundary. The cast itself is a + // single launch on the activation tensor — cheap vs the + // weight loads the matmul saves. + let in_dtype = x.dtype(); + let x_f32 = if in_dtype == candle_core::DType::F32 { + x.clone() + } else { + x.to_dtype(candle_core::DType::F32)? + }; + let y = qm.forward(&x_f32)?; + if y.dtype() == in_dtype { + Ok(y) + } else { + y.to_dtype(in_dtype) + } + } } } }