From 44ae927e3838a7a02845bb7d371ab0bf039de53c Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Thu, 21 May 2026 11:39:30 +0300 Subject: [PATCH] feat(stage-8d-2): wire gated_delta_rule_recurrence kernel into qwen3_5 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the per-token Rust delta-rule loop in `arch/qwen3_5/linear_attn.rs::GatedDeltaNet::forward` with a single dispatch to the `gated_delta_rule_recurrence` kernel imported from mistralrs in 1ebbe87. The kernel is V-tiled with compile-time BK (one block per (V-tile, batch*head), one thread per V-column, BK state floats in registers). For Qwen3.6's per-rank `(B=1, H=24, D_k=128, D_v=128)` shape this collapses ~6 candle tensor-op launches per token per layer (each ~50µs CUDA dispatch overhead, so ~300µs/token/layer × 48 linear- attention layers = 14ms in launch overhead alone) to a single kernel launch with full ILP / register residency. New free function `run_delta_rule`: - cuda branch (when q is on a CUDA device): flattens `(B, H, ...)` → `(BH, ...)`, dispatches the kernel via `crate::cuda::gdn::gated_delta_rule_recurrence_cuda`, reshapes outputs back to `(B, H, L, D_v)` and state to `(B, H, D_k, D_v)`. - cpu fallback: the original per-token Rust loop, unchanged. Keeps cargo test --workspace passing on hosts without cuda. Dispatch decision lives in the wrapper (`q.device().is_cuda()`). Build: `cargo build -p neuron --features cuda` compiles + links; clippy clean on both CPU and cuda paths. 32 lib tests still pass (none of them exercise this code path on cuda; smoke test for the TP variant is the deployed Tbilisi probe). Stage 8d-3 wires the conv1d kernels; 8d-4 the chunked prefill; 8d-5 the same wiring for `tp/tp_qwen3_5.rs`. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/harness/arch/qwen3_5/linear_attn.rs | 180 ++++++++++++++---- 1 file changed, 139 insertions(+), 41 deletions(-) diff --git a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs index 7ced2da..dce17ed 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs @@ -56,7 +56,7 @@ //! prefill; can be added later without changing the surface. use anyhow::{Context, Result}; -use candle_core::{IndexOp, Module, Tensor}; +use candle_core::{Module, Tensor}; use candle_nn::Linear; use candle_nn::var_builder::ShardedVarBuilder; @@ -315,7 +315,7 @@ impl GatedDeltaNet { let beta = beta.to_dtype(candle_core::DType::F32)?; // Initialise the recurrent state from cache or zeros. - let mut state = match self.state.recurrent_state.take() { + let state_init = match self.state.recurrent_state.take() { Some(s) => s.to_dtype(candle_core::DType::F32)?, None => Tensor::zeros( ( @@ -329,48 +329,26 @@ impl GatedDeltaNet { )?, }; - // Per-token delta-rule loop. Slow-but-correct path; chunked - // optimisation is for later. - let mut outputs: Vec = Vec::with_capacity(seq_len); - for t in 0..seq_len { - // (B, H, D_k) and (B, H, D_v) for token t. - let q_t = q.i((.., .., t, ..))?; // (B, H, D_k) - let k_t = k.i((.., .., t, ..))?; - let v_t = v.i((.., .., t, ..))?; - let g_t = g.i((.., .., t))?; // (B, H) - let beta_t = beta.i((.., .., t))?; // (B, H) - - // Decay: state *= exp(g_t). exp(g_t) shape (B, H) → broadcast to (B, H, 1, 1). - let decay = g_t - .exp()? - .unsqueeze(candle_core::D::Minus1)? - .unsqueeze(candle_core::D::Minus1)?; // (B, H, 1, 1) - state = state.broadcast_mul(&decay)?; - - // Memory readout: sum_{d_k} state[d_k, d_v] * k_t[d_k] → (B, H, D_v). - // state: (B, H, D_k, D_v); k_t.unsqueeze(-1): (B, H, D_k, 1). - let k_col = k_t.unsqueeze(candle_core::D::Minus1)?; // (B, H, D_k, 1) - let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?; // sum over D_k → (B, H, D_v) - - // delta = (v_t - kv_mem) * beta_t (broadcast beta on last dim). - let beta_col = beta_t.unsqueeze(candle_core::D::Minus1)?; // (B, H, 1) - let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?; // (B, H, D_v) - - // state += outer(k_t, delta) = k_col * delta_row, broadcast to (B, H, D_k, D_v). - let delta_row = delta.unsqueeze(2)?; // (B, H, 1, D_v) - let outer = k_col.broadcast_mul(&delta_row)?; // (B, H, D_k, D_v) - state = (state + outer)?; - - // out_t = sum_{d_k} state[d_k, d_v] * q_t[d_k] → (B, H, D_v). - let q_col = q_t.unsqueeze(candle_core::D::Minus1)?; // (B, H, D_k, 1) - let out_t = state.broadcast_mul(&q_col)?.sum(2)?; // (B, H, D_v) - outputs.push(out_t.unsqueeze(2)?); // (B, H, 1, D_v) - } + // The delta-rule body: cuda-accelerated `gated_delta_rule_recurrence` + // kernel when we have a cuda device + the kernels are linked in, + // pure-Rust per-token fallback otherwise. + let (core_attn_out, new_state) = run_delta_rule( + &q, + &k, + &v, + &g, + &beta, + state_init, + batch_size, + self.num_v_heads, + seq_len, + self.head_k_dim, + self.head_v_dim, + )?; // Stash the updated recurrent state for the next call. - self.state.recurrent_state = Some(state.to_dtype(dtype)?); + self.state.recurrent_state = Some(new_state.to_dtype(dtype)?); // core_attn_out: (B, H, L, D_v) → (B, L, H, D_v) → (B*L*H, D_v). - let core_attn_out = Tensor::cat(&outputs, 2)?; // (B, H, L, D_v) let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; // (B, L, H, D_v) let core_attn_out = core_attn_out.to_dtype(dtype)?; let core_attn_flat = @@ -386,6 +364,126 @@ impl GatedDeltaNet { } } +/// Run the per-token delta-rule recurrence. +/// +/// `q`, `k`: `(B, H, L, D_k)` (F32). `v`: `(B, H, L, D_v)`. `g`, +/// `beta`: `(B, H, L)`. `state`: `(B, H, D_k, D_v)`. +/// +/// Returns `(core_attn_out: (B, H, L, D_v), state: (B, H, D_k, D_v))`, +/// both F32. Caller is responsible for cast back to model dtype. +/// +/// Cuda path: dispatches to the `gated_delta_rule_recurrence` kernel +/// ported from `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/gdn.cu`. +/// All five inputs must be cuda f32 tensors. The kernel is V-tiled +/// with compile-time BK; one block per (V-tile, batch*head) and one +/// thread per V-column. Each thread holds BK state floats in +/// registers — eliminates the launch-overhead floor we hit with +/// candle's per-op dispatch (was ~12s/token on Qwen3.6-27B). +/// +/// CPU path: pure-Rust per-token loop. Correct, slow. +#[allow(clippy::too_many_arguments)] +fn run_delta_rule( + q: &Tensor, + k: &Tensor, + v: &Tensor, + g: &Tensor, + beta: &Tensor, + state: Tensor, + batch_size: usize, + num_heads: usize, + seq_len: usize, + head_k_dim: usize, + head_v_dim: usize, +) -> candle_core::Result<(Tensor, Tensor)> { + #[cfg(feature = "cuda")] + { + // Only dispatch to the kernel if the inputs are on a CUDA + // device — CPU tests fall back to the Rust loop below. + if q.device().is_cuda() { + return run_delta_rule_cuda( + q, k, v, g, beta, state, batch_size, num_heads, seq_len, head_k_dim, head_v_dim, + ); + } + } + let _ = (batch_size, num_heads, head_k_dim, head_v_dim); + run_delta_rule_rust(q, k, v, g, beta, state, seq_len) +} + +/// CUDA path. Flattens (B, H, ...) → (BH, ...) at the kernel boundary +/// (the kernel uses BH = batch*heads as its outer batch axis) and +/// reshapes the kernel's outputs back to (B, H, ...) for the caller. +#[cfg(feature = "cuda")] +#[allow(clippy::too_many_arguments)] +fn run_delta_rule_cuda( + q: &Tensor, + k: &Tensor, + v: &Tensor, + g: &Tensor, + beta: &Tensor, + state: Tensor, + batch_size: usize, + num_heads: usize, + seq_len: usize, + head_k_dim: usize, + head_v_dim: usize, +) -> candle_core::Result<(Tensor, Tensor)> { + let q_bh = q.flatten(0, 1)?.contiguous()?; + let k_bh = k.flatten(0, 1)?.contiguous()?; + let v_bh = v.flatten(0, 1)?.contiguous()?; + let g_bh = g.flatten(0, 1)?.contiguous()?; + let beta_bh = beta.flatten(0, 1)?.contiguous()?; + let mut state_bh = state.flatten(0, 1)?.contiguous()?; + let output_bh = crate::cuda::gdn::gated_delta_rule_recurrence_cuda( + &q_bh, + &k_bh, + &v_bh, + &g_bh, + &beta_bh, + &mut state_bh, + )?; + let core_attn_out = output_bh.reshape((batch_size, num_heads, seq_len, head_v_dim))?; + let new_state = state_bh.reshape((batch_size, num_heads, head_k_dim, head_v_dim))?; + Ok((core_attn_out, new_state)) +} + +#[allow(clippy::too_many_arguments)] +fn run_delta_rule_rust( + q: &Tensor, + k: &Tensor, + v: &Tensor, + g: &Tensor, + beta: &Tensor, + mut state: Tensor, + seq_len: usize, +) -> candle_core::Result<(Tensor, Tensor)> { + use candle_core::IndexOp; + let mut outputs: Vec = Vec::with_capacity(seq_len); + for t in 0..seq_len { + let q_t = q.i((.., .., t, ..))?; + let k_t = k.i((.., .., t, ..))?; + let v_t = v.i((.., .., t, ..))?; + let g_t = g.i((.., .., t))?; + let beta_t = beta.i((.., .., t))?; + let decay = g_t + .exp()? + .unsqueeze(candle_core::D::Minus1)? + .unsqueeze(candle_core::D::Minus1)?; + state = state.broadcast_mul(&decay)?; + let k_col = k_t.unsqueeze(candle_core::D::Minus1)?; + let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?; + let beta_col = beta_t.unsqueeze(candle_core::D::Minus1)?; + let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?; + let delta_row = delta.unsqueeze(2)?; + let outer = k_col.broadcast_mul(&delta_row)?; + state = (state + outer)?; + let q_col = q_t.unsqueeze(candle_core::D::Minus1)?; + let out_t = state.broadcast_mul(&q_col)?.sum(2)?; + outputs.push(out_t.unsqueeze(2)?); + } + let core_attn_out = Tensor::cat(&outputs, 2)?; // (B, H, L, D_v) + Ok((core_attn_out, state)) +} + /// Load a no-bias linear from the ShardedVarBuilder. Weight shape is /// the standard `[out, in]` order. fn load_linear_no_bias(