feat(stage-8d-5b): wire fused_gdn_gating CUDA kernel
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 1m45s
build-prerelease / Build neuron-blackwell (push) Successful in 3m40s
build-prerelease / Build cortex binary (push) Successful in 4m27s
build-prerelease / Package cortex RPM (push) Successful in 1m24s
build-prerelease / Build neuron-ampere (push) Successful in 5m30s
build-prerelease / Build neuron-ada (push) Successful in 5m24s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m49s
CI / Format (push) Successful in 35s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m7s
CI / Clippy (push) Successful in 2m16s
CI / Test (push) Successful in 4m37s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 1m45s
build-prerelease / Build neuron-blackwell (push) Successful in 3m40s
build-prerelease / Build cortex binary (push) Successful in 4m27s
build-prerelease / Package cortex RPM (push) Successful in 1m24s
build-prerelease / Build neuron-ampere (push) Successful in 5m30s
build-prerelease / Build neuron-ada (push) Successful in 5m24s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m49s
CI / Format (push) Successful in 35s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m7s
CI / Clippy (push) Successful in 2m16s
CI / Test (push) Successful in 4m37s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
run_fused_gating helper consolidates the per-layer gating math: beta = sigmoid(b) g = -exp(a_log) * softplus(a + dt_bias) CUDA path issues a single launch via fused_gdn_gating_cuda; cpu path falls back to the original per-op Rust sequence. Replaces ~10 candle launches per linear-attention layer (sigmoid + 2× to_dtype + exp + neg + broadcast_add + softplus + 2× unsqueeze + broadcast_mul) across both single-GPU and TP forward paths. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -234,21 +234,10 @@ impl GatedDeltaNet {
|
|||||||
let v = v.reshape((batch_size, seq_len, self.num_v_heads, self.head_v_dim))?;
|
let v = v.reshape((batch_size, seq_len, self.num_v_heads, self.head_v_dim))?;
|
||||||
|
|
||||||
// ----- beta + g (per-head, per-token gates). -----
|
// ----- beta + g (per-head, per-token gates). -----
|
||||||
|
// Fused on cuda; per-op Rust on cpu. Both paths produce:
|
||||||
// beta = sigmoid(b)
|
// beta = sigmoid(b)
|
||||||
let beta = candle_nn::ops::sigmoid(&b)?;
|
|
||||||
// g = -exp(A_log) * softplus(a + dt_bias)
|
// g = -exp(A_log) * softplus(a + dt_bias)
|
||||||
// Promote everything to f32 — the Python does the same to
|
let (beta, g) = run_fused_gating(&b, &a, &self.a_log, &self.dt_bias)?;
|
||||||
// avoid underflow on the -exp path.
|
|
||||||
let a_log_f32 = self.a_log.to_dtype(candle_core::DType::F32)?;
|
|
||||||
let neg_a_exp = a_log_f32.exp()?.neg()?; // (num_v_heads,)
|
|
||||||
let dt_b_f32 = self.dt_bias.to_dtype(candle_core::DType::F32)?;
|
|
||||||
let a_f32 = a.to_dtype(candle_core::DType::F32)?;
|
|
||||||
// a is (B, L, num_v_heads); broadcast-add dt_bias.
|
|
||||||
let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?;
|
|
||||||
let softplus = softplus(&a_plus_dt)?;
|
|
||||||
// (1, 1, num_v_heads) × (B, L, num_v_heads).
|
|
||||||
let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?;
|
|
||||||
let g = neg_a_exp_b.broadcast_mul(&softplus)?;
|
|
||||||
|
|
||||||
// ----- GQA-style key expansion if num_v_heads > num_k_heads. -----
|
// ----- GQA-style key expansion if num_v_heads > num_k_heads. -----
|
||||||
let (q, k) = if self.num_v_heads > self.num_k_heads {
|
let (q, k) = if self.num_v_heads > self.num_k_heads {
|
||||||
@@ -275,16 +264,16 @@ impl GatedDeltaNet {
|
|||||||
let g = g.transpose(1, 2)?.contiguous()?; // (B, H, L)
|
let g = g.transpose(1, 2)?.contiguous()?; // (B, H, L)
|
||||||
let beta = beta.transpose(1, 2)?.contiguous()?; // (B, H, L)
|
let beta = beta.transpose(1, 2)?.contiguous()?; // (B, H, L)
|
||||||
|
|
||||||
// Pre-scale q by 1/sqrt(D_k) once.
|
// Pre-scale q by 1/sqrt(D_k) once. Everything goes to f32 here
|
||||||
|
// since the delta rule mixes broadcast_mul ops that candle won't
|
||||||
|
// accept across mixed dtypes. On the cuda gating path both beta
|
||||||
|
// and g come back in model dtype; on the cpu path g is already
|
||||||
|
// f32 — both casts are cheap idempotent ops.
|
||||||
let scale = 1.0_f64 / (self.head_k_dim as f64).sqrt();
|
let scale = 1.0_f64 / (self.head_k_dim as f64).sqrt();
|
||||||
let q = (q.to_dtype(candle_core::DType::F32)? * scale)?;
|
let q = (q.to_dtype(candle_core::DType::F32)? * scale)?;
|
||||||
let k = k.to_dtype(candle_core::DType::F32)?;
|
let k = k.to_dtype(candle_core::DType::F32)?;
|
||||||
let v = v.to_dtype(candle_core::DType::F32)?;
|
let v = v.to_dtype(candle_core::DType::F32)?;
|
||||||
// `g` is already F32 (constructed from A_log/dt_bias in f32 above);
|
let g = g.to_dtype(candle_core::DType::F32)?;
|
||||||
// `beta` came from sigmoid(b) which kept the model dtype, so we
|
|
||||||
// need to promote it here too — otherwise the per-token
|
|
||||||
// `(v_t - kv_mem).broadcast_mul(&beta_col)` mixes F32 LHS with
|
|
||||||
// BF16 RHS and trips candle's dtype-mismatch check.
|
|
||||||
let beta = beta.to_dtype(candle_core::DType::F32)?;
|
let beta = beta.to_dtype(candle_core::DType::F32)?;
|
||||||
|
|
||||||
// Initialise the recurrent state from cache or zeros.
|
// Initialise the recurrent state from cache or zeros.
|
||||||
@@ -560,6 +549,53 @@ fn run_causal_conv1d_cuda(
|
|||||||
Ok((output, new_conv_state))
|
Ok((output, new_conv_state))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fused GDN gating: computes `beta = sigmoid(b)` and
|
||||||
|
/// `g = -exp(a_log) * softplus(a + dt_bias)` together.
|
||||||
|
///
|
||||||
|
/// `b`, `a`: `(B, L, num_heads)` model dtype.
|
||||||
|
/// `a_log`, `dt_bias`: `(num_heads,)` model dtype (cast to f32 internally).
|
||||||
|
///
|
||||||
|
/// Returns `(beta, g)` both in model dtype on the cuda path, both in f32
|
||||||
|
/// on the cpu fallback. The caller casts to f32 before the delta rule.
|
||||||
|
///
|
||||||
|
/// Cuda path: dispatches to `fused_gdn_gating_cuda` — one kernel
|
||||||
|
/// replaces sigmoid + neg(exp) + softplus + broadcast_mul (≈10 candle
|
||||||
|
/// launches per layer).
|
||||||
|
pub(crate) fn run_fused_gating(
|
||||||
|
b: &Tensor,
|
||||||
|
a: &Tensor,
|
||||||
|
a_log: &Tensor,
|
||||||
|
dt_bias: &Tensor,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
if b.device().is_cuda() {
|
||||||
|
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?.contiguous()?;
|
||||||
|
let dt_bias_f32 = dt_bias.to_dtype(candle_core::DType::F32)?.contiguous()?;
|
||||||
|
return crate::cuda::gdn::fused_gdn_gating_cuda(b, a, &a_log_f32, &dt_bias_f32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
run_fused_gating_rust(b, a, a_log, dt_bias)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_fused_gating_rust(
|
||||||
|
b: &Tensor,
|
||||||
|
a: &Tensor,
|
||||||
|
a_log: &Tensor,
|
||||||
|
dt_bias: &Tensor,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let beta = candle_nn::ops::sigmoid(b)?;
|
||||||
|
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let neg_a_exp = a_log_f32.exp()?.neg()?;
|
||||||
|
let dt_b_f32 = dt_bias.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let a_f32 = a.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?;
|
||||||
|
let softplus_val = softplus(&a_plus_dt)?;
|
||||||
|
let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?;
|
||||||
|
let g = neg_a_exp_b.broadcast_mul(&softplus_val)?;
|
||||||
|
Ok((beta, g))
|
||||||
|
}
|
||||||
|
|
||||||
fn run_causal_conv1d_rust(
|
fn run_causal_conv1d_rust(
|
||||||
x: &Tensor,
|
x: &Tensor,
|
||||||
weight: &Tensor,
|
weight: &Tensor,
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ use std::sync::Arc;
|
|||||||
use cudarc::nccl::Comm;
|
use cudarc::nccl::Comm;
|
||||||
|
|
||||||
use super::tp_linear::{ColumnParallelLinear, RowParallelLinear};
|
use super::tp_linear::{ColumnParallelLinear, RowParallelLinear};
|
||||||
use crate::harness::arch::qwen3_5::linear_attn::{repeat_interleave, softplus};
|
use crate::harness::arch::qwen3_5::linear_attn::repeat_interleave;
|
||||||
use crate::harness::arch::qwen3_5::rmsnorm::{Qwen3_5RmsNorm, Qwen3_5RmsNormGated, l2norm};
|
use crate::harness::arch::qwen3_5::rmsnorm::{Qwen3_5RmsNorm, Qwen3_5RmsNormGated, l2norm};
|
||||||
use crate::harness::arch::qwen3_5::rope::RotaryEmbedding;
|
use crate::harness::arch::qwen3_5::rope::RotaryEmbedding;
|
||||||
pub use crate::harness::arch::qwen3_5::{Config, TextConfig};
|
pub use crate::harness::arch::qwen3_5::{Config, TextConfig};
|
||||||
@@ -285,15 +285,14 @@ impl TpQwen3_5GatedDeltaNet {
|
|||||||
))?;
|
))?;
|
||||||
|
|
||||||
// ----- beta + g (per-V-head, per-token). -----
|
// ----- beta + g (per-V-head, per-token). -----
|
||||||
let beta = candle_nn::ops::sigmoid(&b)?;
|
// Same fused gating helper as single-GPU — cuda kernel when
|
||||||
let a_log_f32 = self.a_log.to_dtype(DType::F32)?;
|
// available, per-op Rust fallback otherwise.
|
||||||
let neg_a_exp = a_log_f32.exp()?.neg()?; // (per_rank_num_v_heads,)
|
let (beta, g) = crate::harness::arch::qwen3_5::linear_attn::run_fused_gating(
|
||||||
let dt_b_f32 = self.dt_bias.to_dtype(DType::F32)?;
|
&b,
|
||||||
let a_f32 = a.to_dtype(DType::F32)?;
|
&a,
|
||||||
let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?;
|
&self.a_log,
|
||||||
let softplus_a = softplus(&a_plus_dt)?;
|
&self.dt_bias,
|
||||||
let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?;
|
)?;
|
||||||
let g = neg_a_exp_b.broadcast_mul(&softplus_a)?; // F32
|
|
||||||
|
|
||||||
// ----- GQA expansion if per-rank ratio > 1. -----
|
// ----- GQA expansion if per-rank ratio > 1. -----
|
||||||
let (q, k) = if self.per_rank_num_v_heads > self.per_rank_num_k_heads {
|
let (q, k) = if self.per_rank_num_v_heads > self.per_rank_num_k_heads {
|
||||||
@@ -321,6 +320,7 @@ impl TpQwen3_5GatedDeltaNet {
|
|||||||
let q = (q.to_dtype(DType::F32)? * scale)?;
|
let q = (q.to_dtype(DType::F32)? * scale)?;
|
||||||
let k = k.to_dtype(DType::F32)?;
|
let k = k.to_dtype(DType::F32)?;
|
||||||
let v = v.to_dtype(DType::F32)?;
|
let v = v.to_dtype(DType::F32)?;
|
||||||
|
let g = g.to_dtype(DType::F32)?;
|
||||||
let beta = beta.to_dtype(DType::F32)?;
|
let beta = beta.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
let state_init = match self.state.recurrent_state.take() {
|
let state_init = match self.state.recurrent_state.take() {
|
||||||
|
|||||||
Reference in New Issue
Block a user