feat(stage-8d-2): wire gated_delta_rule_recurrence kernel into qwen3_5
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 38s
CI / Test (push) Failing after 45s
CI / Clippy (push) Successful in 2m16s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled

Replaces the per-token Rust delta-rule loop in
`arch/qwen3_5/linear_attn.rs::GatedDeltaNet::forward` with a single
dispatch to the `gated_delta_rule_recurrence` kernel imported from
mistralrs in 1ebbe87.

The kernel is V-tiled with compile-time BK (one block per (V-tile,
batch*head), one thread per V-column, BK state floats in registers).
For Qwen3.6's per-rank `(B=1, H=24, D_k=128, D_v=128)` shape this
collapses ~6 candle tensor-op launches per token per layer (each
~50µs CUDA dispatch overhead, so ~300µs/token/layer × 48 linear-
attention layers = 14ms in launch overhead alone) to a single
kernel launch with full ILP / register residency.

New free function `run_delta_rule`:
- cuda branch (when q is on a CUDA device): flattens
  `(B, H, ...)` → `(BH, ...)`, dispatches the kernel via
  `crate::cuda::gdn::gated_delta_rule_recurrence_cuda`, reshapes
  outputs back to `(B, H, L, D_v)` and state to `(B, H, D_k, D_v)`.
- cpu fallback: the original per-token Rust loop, unchanged. Keeps
  cargo test --workspace passing on hosts without cuda.

Dispatch decision lives in the wrapper (`q.device().is_cuda()`).

Build: `cargo build -p neuron --features cuda` compiles + links;
clippy clean on both CPU and cuda paths. 32 lib tests still pass
(none of them exercise this code path on cuda; smoke test for the
TP variant is the deployed Tbilisi probe).

Stage 8d-3 wires the conv1d kernels; 8d-4 the chunked prefill;
8d-5 the same wiring for `tp/tp_qwen3_5.rs`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 11:39:30 +03:00
parent 1ebbe87651
commit 44ae927e38

View File

@@ -56,7 +56,7 @@
//! prefill; can be added later without changing the surface.
use anyhow::{Context, Result};
use candle_core::{IndexOp, Module, Tensor};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
@@ -315,7 +315,7 @@ impl GatedDeltaNet {
let beta = beta.to_dtype(candle_core::DType::F32)?;
// Initialise the recurrent state from cache or zeros.
let mut state = match self.state.recurrent_state.take() {
let state_init = match self.state.recurrent_state.take() {
Some(s) => s.to_dtype(candle_core::DType::F32)?,
None => Tensor::zeros(
(
@@ -329,48 +329,26 @@ impl GatedDeltaNet {
)?,
};
// Per-token delta-rule loop. Slow-but-correct path; chunked
// optimisation is for later.
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
// (B, H, D_k) and (B, H, D_v) for token t.
let q_t = q.i((.., .., t, ..))?; // (B, H, D_k)
let k_t = k.i((.., .., t, ..))?;
let v_t = v.i((.., .., t, ..))?;
let g_t = g.i((.., .., t))?; // (B, H)
let beta_t = beta.i((.., .., t))?; // (B, H)
// Decay: state *= exp(g_t). exp(g_t) shape (B, H) → broadcast to (B, H, 1, 1).
let decay = g_t
.exp()?
.unsqueeze(candle_core::D::Minus1)?
.unsqueeze(candle_core::D::Minus1)?; // (B, H, 1, 1)
state = state.broadcast_mul(&decay)?;
// Memory readout: sum_{d_k} state[d_k, d_v] * k_t[d_k] → (B, H, D_v).
// state: (B, H, D_k, D_v); k_t.unsqueeze(-1): (B, H, D_k, 1).
let k_col = k_t.unsqueeze(candle_core::D::Minus1)?; // (B, H, D_k, 1)
let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?; // sum over D_k → (B, H, D_v)
// delta = (v_t - kv_mem) * beta_t (broadcast beta on last dim).
let beta_col = beta_t.unsqueeze(candle_core::D::Minus1)?; // (B, H, 1)
let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?; // (B, H, D_v)
// state += outer(k_t, delta) = k_col * delta_row, broadcast to (B, H, D_k, D_v).
let delta_row = delta.unsqueeze(2)?; // (B, H, 1, D_v)
let outer = k_col.broadcast_mul(&delta_row)?; // (B, H, D_k, D_v)
state = (state + outer)?;
// out_t = sum_{d_k} state[d_k, d_v] * q_t[d_k] → (B, H, D_v).
let q_col = q_t.unsqueeze(candle_core::D::Minus1)?; // (B, H, D_k, 1)
let out_t = state.broadcast_mul(&q_col)?.sum(2)?; // (B, H, D_v)
outputs.push(out_t.unsqueeze(2)?); // (B, H, 1, D_v)
}
// The delta-rule body: cuda-accelerated `gated_delta_rule_recurrence`
// kernel when we have a cuda device + the kernels are linked in,
// pure-Rust per-token fallback otherwise.
let (core_attn_out, new_state) = run_delta_rule(
&q,
&k,
&v,
&g,
&beta,
state_init,
batch_size,
self.num_v_heads,
seq_len,
self.head_k_dim,
self.head_v_dim,
)?;
// Stash the updated recurrent state for the next call.
self.state.recurrent_state = Some(state.to_dtype(dtype)?);
self.state.recurrent_state = Some(new_state.to_dtype(dtype)?);
// core_attn_out: (B, H, L, D_v) → (B, L, H, D_v) → (B*L*H, D_v).
let core_attn_out = Tensor::cat(&outputs, 2)?; // (B, H, L, D_v)
let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; // (B, L, H, D_v)
let core_attn_out = core_attn_out.to_dtype(dtype)?;
let core_attn_flat =
@@ -386,6 +364,126 @@ impl GatedDeltaNet {
}
}
/// Run the per-token delta-rule recurrence.
///
/// `q`, `k`: `(B, H, L, D_k)` (F32). `v`: `(B, H, L, D_v)`. `g`,
/// `beta`: `(B, H, L)`. `state`: `(B, H, D_k, D_v)`.
///
/// Returns `(core_attn_out: (B, H, L, D_v), state: (B, H, D_k, D_v))`,
/// both F32. Caller is responsible for cast back to model dtype.
///
/// Cuda path: dispatches to the `gated_delta_rule_recurrence` kernel
/// ported from `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/gdn.cu`.
/// All five inputs must be cuda f32 tensors. The kernel is V-tiled
/// with compile-time BK; one block per (V-tile, batch*head) and one
/// thread per V-column. Each thread holds BK state floats in
/// registers — eliminates the launch-overhead floor we hit with
/// candle's per-op dispatch (was ~12s/token on Qwen3.6-27B).
///
/// CPU path: pure-Rust per-token loop. Correct, slow.
#[allow(clippy::too_many_arguments)]
fn run_delta_rule(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: Tensor,
batch_size: usize,
num_heads: usize,
seq_len: usize,
head_k_dim: usize,
head_v_dim: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
#[cfg(feature = "cuda")]
{
// Only dispatch to the kernel if the inputs are on a CUDA
// device — CPU tests fall back to the Rust loop below.
if q.device().is_cuda() {
return run_delta_rule_cuda(
q, k, v, g, beta, state, batch_size, num_heads, seq_len, head_k_dim, head_v_dim,
);
}
}
let _ = (batch_size, num_heads, head_k_dim, head_v_dim);
run_delta_rule_rust(q, k, v, g, beta, state, seq_len)
}
/// CUDA path. Flattens (B, H, ...) → (BH, ...) at the kernel boundary
/// (the kernel uses BH = batch*heads as its outer batch axis) and
/// reshapes the kernel's outputs back to (B, H, ...) for the caller.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
fn run_delta_rule_cuda(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: Tensor,
batch_size: usize,
num_heads: usize,
seq_len: usize,
head_k_dim: usize,
head_v_dim: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let q_bh = q.flatten(0, 1)?.contiguous()?;
let k_bh = k.flatten(0, 1)?.contiguous()?;
let v_bh = v.flatten(0, 1)?.contiguous()?;
let g_bh = g.flatten(0, 1)?.contiguous()?;
let beta_bh = beta.flatten(0, 1)?.contiguous()?;
let mut state_bh = state.flatten(0, 1)?.contiguous()?;
let output_bh = crate::cuda::gdn::gated_delta_rule_recurrence_cuda(
&q_bh,
&k_bh,
&v_bh,
&g_bh,
&beta_bh,
&mut state_bh,
)?;
let core_attn_out = output_bh.reshape((batch_size, num_heads, seq_len, head_v_dim))?;
let new_state = state_bh.reshape((batch_size, num_heads, head_k_dim, head_v_dim))?;
Ok((core_attn_out, new_state))
}
#[allow(clippy::too_many_arguments)]
fn run_delta_rule_rust(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
mut state: Tensor,
seq_len: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
use candle_core::IndexOp;
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let q_t = q.i((.., .., t, ..))?;
let k_t = k.i((.., .., t, ..))?;
let v_t = v.i((.., .., t, ..))?;
let g_t = g.i((.., .., t))?;
let beta_t = beta.i((.., .., t))?;
let decay = g_t
.exp()?
.unsqueeze(candle_core::D::Minus1)?
.unsqueeze(candle_core::D::Minus1)?;
state = state.broadcast_mul(&decay)?;
let k_col = k_t.unsqueeze(candle_core::D::Minus1)?;
let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?;
let beta_col = beta_t.unsqueeze(candle_core::D::Minus1)?;
let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?;
let delta_row = delta.unsqueeze(2)?;
let outer = k_col.broadcast_mul(&delta_row)?;
state = (state + outer)?;
let q_col = q_t.unsqueeze(candle_core::D::Minus1)?;
let out_t = state.broadcast_mul(&q_col)?.sum(2)?;
outputs.push(out_t.unsqueeze(2)?);
}
let core_attn_out = Tensor::cat(&outputs, 2)?; // (B, H, L, D_v)
Ok((core_attn_out, state))
}
/// Load a no-bias linear from the ShardedVarBuilder. Weight shape is
/// the standard `[out, in]` order.
fn load_linear_no_bias(