From cc95fe28d9226a813153ae5b0692bfd2b87163cd Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Thu, 21 May 2026 11:52:38 +0300 Subject: [PATCH] feat(stage-8d-5b): wire fused_gdn_gating CUDA kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit run_fused_gating helper consolidates the per-layer gating math: beta = sigmoid(b) g = -exp(a_log) * softplus(a + dt_bias) CUDA path issues a single launch via fused_gdn_gating_cuda; cpu path falls back to the original per-op Rust sequence. Replaces ~10 candle launches per linear-attention layer (sigmoid + 2× to_dtype + exp + neg + broadcast_add + softplus + 2× unsqueeze + broadcast_mul) across both single-GPU and TP forward paths. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/harness/arch/qwen3_5/linear_attn.rs | 78 ++++++++++++++----- crates/neuron/src/harness/tp/tp_qwen3_5.rs | 20 ++--- 2 files changed, 67 insertions(+), 31 deletions(-) diff --git a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs index 5c3cb2c..58e73b3 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs @@ -234,21 +234,10 @@ impl GatedDeltaNet { let v = v.reshape((batch_size, seq_len, self.num_v_heads, self.head_v_dim))?; // ----- beta + g (per-head, per-token gates). ----- - // beta = sigmoid(b) - let beta = candle_nn::ops::sigmoid(&b)?; - // g = -exp(A_log) * softplus(a + dt_bias) - // Promote everything to f32 — the Python does the same to - // avoid underflow on the -exp path. - let a_log_f32 = self.a_log.to_dtype(candle_core::DType::F32)?; - let neg_a_exp = a_log_f32.exp()?.neg()?; // (num_v_heads,) - let dt_b_f32 = self.dt_bias.to_dtype(candle_core::DType::F32)?; - let a_f32 = a.to_dtype(candle_core::DType::F32)?; - // a is (B, L, num_v_heads); broadcast-add dt_bias. - let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?; - let softplus = softplus(&a_plus_dt)?; - // (1, 1, num_v_heads) × (B, L, num_v_heads). - let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?; - let g = neg_a_exp_b.broadcast_mul(&softplus)?; + // Fused on cuda; per-op Rust on cpu. Both paths produce: + // beta = sigmoid(b) + // g = -exp(A_log) * softplus(a + dt_bias) + let (beta, g) = run_fused_gating(&b, &a, &self.a_log, &self.dt_bias)?; // ----- GQA-style key expansion if num_v_heads > num_k_heads. ----- let (q, k) = if self.num_v_heads > self.num_k_heads { @@ -275,16 +264,16 @@ impl GatedDeltaNet { let g = g.transpose(1, 2)?.contiguous()?; // (B, H, L) let beta = beta.transpose(1, 2)?.contiguous()?; // (B, H, L) - // Pre-scale q by 1/sqrt(D_k) once. + // Pre-scale q by 1/sqrt(D_k) once. Everything goes to f32 here + // since the delta rule mixes broadcast_mul ops that candle won't + // accept across mixed dtypes. On the cuda gating path both beta + // and g come back in model dtype; on the cpu path g is already + // f32 — both casts are cheap idempotent ops. let scale = 1.0_f64 / (self.head_k_dim as f64).sqrt(); let q = (q.to_dtype(candle_core::DType::F32)? * scale)?; let k = k.to_dtype(candle_core::DType::F32)?; let v = v.to_dtype(candle_core::DType::F32)?; - // `g` is already F32 (constructed from A_log/dt_bias in f32 above); - // `beta` came from sigmoid(b) which kept the model dtype, so we - // need to promote it here too — otherwise the per-token - // `(v_t - kv_mem).broadcast_mul(&beta_col)` mixes F32 LHS with - // BF16 RHS and trips candle's dtype-mismatch check. + let g = g.to_dtype(candle_core::DType::F32)?; let beta = beta.to_dtype(candle_core::DType::F32)?; // Initialise the recurrent state from cache or zeros. @@ -560,6 +549,53 @@ fn run_causal_conv1d_cuda( Ok((output, new_conv_state)) } +/// Fused GDN gating: computes `beta = sigmoid(b)` and +/// `g = -exp(a_log) * softplus(a + dt_bias)` together. +/// +/// `b`, `a`: `(B, L, num_heads)` model dtype. +/// `a_log`, `dt_bias`: `(num_heads,)` model dtype (cast to f32 internally). +/// +/// Returns `(beta, g)` both in model dtype on the cuda path, both in f32 +/// on the cpu fallback. The caller casts to f32 before the delta rule. +/// +/// Cuda path: dispatches to `fused_gdn_gating_cuda` — one kernel +/// replaces sigmoid + neg(exp) + softplus + broadcast_mul (≈10 candle +/// launches per layer). +pub(crate) fn run_fused_gating( + b: &Tensor, + a: &Tensor, + a_log: &Tensor, + dt_bias: &Tensor, +) -> candle_core::Result<(Tensor, Tensor)> { + #[cfg(feature = "cuda")] + { + if b.device().is_cuda() { + let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?.contiguous()?; + let dt_bias_f32 = dt_bias.to_dtype(candle_core::DType::F32)?.contiguous()?; + return crate::cuda::gdn::fused_gdn_gating_cuda(b, a, &a_log_f32, &dt_bias_f32); + } + } + run_fused_gating_rust(b, a, a_log, dt_bias) +} + +fn run_fused_gating_rust( + b: &Tensor, + a: &Tensor, + a_log: &Tensor, + dt_bias: &Tensor, +) -> candle_core::Result<(Tensor, Tensor)> { + let beta = candle_nn::ops::sigmoid(b)?; + let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?; + let neg_a_exp = a_log_f32.exp()?.neg()?; + let dt_b_f32 = dt_bias.to_dtype(candle_core::DType::F32)?; + let a_f32 = a.to_dtype(candle_core::DType::F32)?; + let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?; + let softplus_val = softplus(&a_plus_dt)?; + let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?; + let g = neg_a_exp_b.broadcast_mul(&softplus_val)?; + Ok((beta, g)) +} + fn run_causal_conv1d_rust( x: &Tensor, weight: &Tensor, diff --git a/crates/neuron/src/harness/tp/tp_qwen3_5.rs b/crates/neuron/src/harness/tp/tp_qwen3_5.rs index f50f024..90826a0 100644 --- a/crates/neuron/src/harness/tp/tp_qwen3_5.rs +++ b/crates/neuron/src/harness/tp/tp_qwen3_5.rs @@ -41,7 +41,7 @@ use std::sync::Arc; use cudarc::nccl::Comm; use super::tp_linear::{ColumnParallelLinear, RowParallelLinear}; -use crate::harness::arch::qwen3_5::linear_attn::{repeat_interleave, softplus}; +use crate::harness::arch::qwen3_5::linear_attn::repeat_interleave; use crate::harness::arch::qwen3_5::rmsnorm::{Qwen3_5RmsNorm, Qwen3_5RmsNormGated, l2norm}; use crate::harness::arch::qwen3_5::rope::RotaryEmbedding; pub use crate::harness::arch::qwen3_5::{Config, TextConfig}; @@ -285,15 +285,14 @@ impl TpQwen3_5GatedDeltaNet { ))?; // ----- beta + g (per-V-head, per-token). ----- - let beta = candle_nn::ops::sigmoid(&b)?; - let a_log_f32 = self.a_log.to_dtype(DType::F32)?; - let neg_a_exp = a_log_f32.exp()?.neg()?; // (per_rank_num_v_heads,) - let dt_b_f32 = self.dt_bias.to_dtype(DType::F32)?; - let a_f32 = a.to_dtype(DType::F32)?; - let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?; - let softplus_a = softplus(&a_plus_dt)?; - let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?; - let g = neg_a_exp_b.broadcast_mul(&softplus_a)?; // F32 + // Same fused gating helper as single-GPU — cuda kernel when + // available, per-op Rust fallback otherwise. + let (beta, g) = crate::harness::arch::qwen3_5::linear_attn::run_fused_gating( + &b, + &a, + &self.a_log, + &self.dt_bias, + )?; // ----- GQA expansion if per-rank ratio > 1. ----- let (q, k) = if self.per_rank_num_v_heads > self.per_rank_num_k_heads { @@ -321,6 +320,7 @@ impl TpQwen3_5GatedDeltaNet { let q = (q.to_dtype(DType::F32)? * scale)?; let k = k.to_dtype(DType::F32)?; let v = v.to_dtype(DType::F32)?; + let g = g.to_dtype(DType::F32)?; let beta = beta.to_dtype(DType::F32)?; let state_init = match self.state.recurrent_state.take() {