feat(stage-8d-2): wire gated_delta_rule_recurrence kernel into qwen3_5
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 38s
CI / Test (push) Failing after 45s
CI / Clippy (push) Successful in 2m16s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 38s
CI / Test (push) Failing after 45s
CI / Clippy (push) Successful in 2m16s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled
Replaces the per-token Rust delta-rule loop in
`arch/qwen3_5/linear_attn.rs::GatedDeltaNet::forward` with a single
dispatch to the `gated_delta_rule_recurrence` kernel imported from
mistralrs in 1ebbe87.
The kernel is V-tiled with compile-time BK (one block per (V-tile,
batch*head), one thread per V-column, BK state floats in registers).
For Qwen3.6's per-rank `(B=1, H=24, D_k=128, D_v=128)` shape this
collapses ~6 candle tensor-op launches per token per layer (each
~50µs CUDA dispatch overhead, so ~300µs/token/layer × 48 linear-
attention layers = 14ms in launch overhead alone) to a single
kernel launch with full ILP / register residency.
New free function `run_delta_rule`:
- cuda branch (when q is on a CUDA device): flattens
`(B, H, ...)` → `(BH, ...)`, dispatches the kernel via
`crate::cuda::gdn::gated_delta_rule_recurrence_cuda`, reshapes
outputs back to `(B, H, L, D_v)` and state to `(B, H, D_k, D_v)`.
- cpu fallback: the original per-token Rust loop, unchanged. Keeps
cargo test --workspace passing on hosts without cuda.
Dispatch decision lives in the wrapper (`q.device().is_cuda()`).
Build: `cargo build -p neuron --features cuda` compiles + links;
clippy clean on both CPU and cuda paths. 32 lib tests still pass
(none of them exercise this code path on cuda; smoke test for the
TP variant is the deployed Tbilisi probe).
Stage 8d-3 wires the conv1d kernels; 8d-4 the chunked prefill;
8d-5 the same wiring for `tp/tp_qwen3_5.rs`.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -56,7 +56,7 @@
|
|||||||
//! prefill; can be added later without changing the surface.
|
//! prefill; can be added later without changing the surface.
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use candle_core::{IndexOp, Module, Tensor};
|
use candle_core::{Module, Tensor};
|
||||||
use candle_nn::Linear;
|
use candle_nn::Linear;
|
||||||
use candle_nn::var_builder::ShardedVarBuilder;
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
|
||||||
@@ -315,7 +315,7 @@ impl GatedDeltaNet {
|
|||||||
let beta = beta.to_dtype(candle_core::DType::F32)?;
|
let beta = beta.to_dtype(candle_core::DType::F32)?;
|
||||||
|
|
||||||
// Initialise the recurrent state from cache or zeros.
|
// Initialise the recurrent state from cache or zeros.
|
||||||
let mut state = match self.state.recurrent_state.take() {
|
let state_init = match self.state.recurrent_state.take() {
|
||||||
Some(s) => s.to_dtype(candle_core::DType::F32)?,
|
Some(s) => s.to_dtype(candle_core::DType::F32)?,
|
||||||
None => Tensor::zeros(
|
None => Tensor::zeros(
|
||||||
(
|
(
|
||||||
@@ -329,48 +329,26 @@ impl GatedDeltaNet {
|
|||||||
)?,
|
)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Per-token delta-rule loop. Slow-but-correct path; chunked
|
// The delta-rule body: cuda-accelerated `gated_delta_rule_recurrence`
|
||||||
// optimisation is for later.
|
// kernel when we have a cuda device + the kernels are linked in,
|
||||||
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
|
// pure-Rust per-token fallback otherwise.
|
||||||
for t in 0..seq_len {
|
let (core_attn_out, new_state) = run_delta_rule(
|
||||||
// (B, H, D_k) and (B, H, D_v) for token t.
|
&q,
|
||||||
let q_t = q.i((.., .., t, ..))?; // (B, H, D_k)
|
&k,
|
||||||
let k_t = k.i((.., .., t, ..))?;
|
&v,
|
||||||
let v_t = v.i((.., .., t, ..))?;
|
&g,
|
||||||
let g_t = g.i((.., .., t))?; // (B, H)
|
&beta,
|
||||||
let beta_t = beta.i((.., .., t))?; // (B, H)
|
state_init,
|
||||||
|
batch_size,
|
||||||
// Decay: state *= exp(g_t). exp(g_t) shape (B, H) → broadcast to (B, H, 1, 1).
|
self.num_v_heads,
|
||||||
let decay = g_t
|
seq_len,
|
||||||
.exp()?
|
self.head_k_dim,
|
||||||
.unsqueeze(candle_core::D::Minus1)?
|
self.head_v_dim,
|
||||||
.unsqueeze(candle_core::D::Minus1)?; // (B, H, 1, 1)
|
)?;
|
||||||
state = state.broadcast_mul(&decay)?;
|
|
||||||
|
|
||||||
// Memory readout: sum_{d_k} state[d_k, d_v] * k_t[d_k] → (B, H, D_v).
|
|
||||||
// state: (B, H, D_k, D_v); k_t.unsqueeze(-1): (B, H, D_k, 1).
|
|
||||||
let k_col = k_t.unsqueeze(candle_core::D::Minus1)?; // (B, H, D_k, 1)
|
|
||||||
let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?; // sum over D_k → (B, H, D_v)
|
|
||||||
|
|
||||||
// delta = (v_t - kv_mem) * beta_t (broadcast beta on last dim).
|
|
||||||
let beta_col = beta_t.unsqueeze(candle_core::D::Minus1)?; // (B, H, 1)
|
|
||||||
let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?; // (B, H, D_v)
|
|
||||||
|
|
||||||
// state += outer(k_t, delta) = k_col * delta_row, broadcast to (B, H, D_k, D_v).
|
|
||||||
let delta_row = delta.unsqueeze(2)?; // (B, H, 1, D_v)
|
|
||||||
let outer = k_col.broadcast_mul(&delta_row)?; // (B, H, D_k, D_v)
|
|
||||||
state = (state + outer)?;
|
|
||||||
|
|
||||||
// out_t = sum_{d_k} state[d_k, d_v] * q_t[d_k] → (B, H, D_v).
|
|
||||||
let q_col = q_t.unsqueeze(candle_core::D::Minus1)?; // (B, H, D_k, 1)
|
|
||||||
let out_t = state.broadcast_mul(&q_col)?.sum(2)?; // (B, H, D_v)
|
|
||||||
outputs.push(out_t.unsqueeze(2)?); // (B, H, 1, D_v)
|
|
||||||
}
|
|
||||||
// Stash the updated recurrent state for the next call.
|
// Stash the updated recurrent state for the next call.
|
||||||
self.state.recurrent_state = Some(state.to_dtype(dtype)?);
|
self.state.recurrent_state = Some(new_state.to_dtype(dtype)?);
|
||||||
|
|
||||||
// core_attn_out: (B, H, L, D_v) → (B, L, H, D_v) → (B*L*H, D_v).
|
// core_attn_out: (B, H, L, D_v) → (B, L, H, D_v) → (B*L*H, D_v).
|
||||||
let core_attn_out = Tensor::cat(&outputs, 2)?; // (B, H, L, D_v)
|
|
||||||
let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; // (B, L, H, D_v)
|
let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; // (B, L, H, D_v)
|
||||||
let core_attn_out = core_attn_out.to_dtype(dtype)?;
|
let core_attn_out = core_attn_out.to_dtype(dtype)?;
|
||||||
let core_attn_flat =
|
let core_attn_flat =
|
||||||
@@ -386,6 +364,126 @@ impl GatedDeltaNet {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Run the per-token delta-rule recurrence.
|
||||||
|
///
|
||||||
|
/// `q`, `k`: `(B, H, L, D_k)` (F32). `v`: `(B, H, L, D_v)`. `g`,
|
||||||
|
/// `beta`: `(B, H, L)`. `state`: `(B, H, D_k, D_v)`.
|
||||||
|
///
|
||||||
|
/// Returns `(core_attn_out: (B, H, L, D_v), state: (B, H, D_k, D_v))`,
|
||||||
|
/// both F32. Caller is responsible for cast back to model dtype.
|
||||||
|
///
|
||||||
|
/// Cuda path: dispatches to the `gated_delta_rule_recurrence` kernel
|
||||||
|
/// ported from `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/gdn.cu`.
|
||||||
|
/// All five inputs must be cuda f32 tensors. The kernel is V-tiled
|
||||||
|
/// with compile-time BK; one block per (V-tile, batch*head) and one
|
||||||
|
/// thread per V-column. Each thread holds BK state floats in
|
||||||
|
/// registers — eliminates the launch-overhead floor we hit with
|
||||||
|
/// candle's per-op dispatch (was ~12s/token on Qwen3.6-27B).
|
||||||
|
///
|
||||||
|
/// CPU path: pure-Rust per-token loop. Correct, slow.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn run_delta_rule(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
state: Tensor,
|
||||||
|
batch_size: usize,
|
||||||
|
num_heads: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
head_k_dim: usize,
|
||||||
|
head_v_dim: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
// Only dispatch to the kernel if the inputs are on a CUDA
|
||||||
|
// device — CPU tests fall back to the Rust loop below.
|
||||||
|
if q.device().is_cuda() {
|
||||||
|
return run_delta_rule_cuda(
|
||||||
|
q, k, v, g, beta, state, batch_size, num_heads, seq_len, head_k_dim, head_v_dim,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _ = (batch_size, num_heads, head_k_dim, head_v_dim);
|
||||||
|
run_delta_rule_rust(q, k, v, g, beta, state, seq_len)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CUDA path. Flattens (B, H, ...) → (BH, ...) at the kernel boundary
|
||||||
|
/// (the kernel uses BH = batch*heads as its outer batch axis) and
|
||||||
|
/// reshapes the kernel's outputs back to (B, H, ...) for the caller.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn run_delta_rule_cuda(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
state: Tensor,
|
||||||
|
batch_size: usize,
|
||||||
|
num_heads: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
head_k_dim: usize,
|
||||||
|
head_v_dim: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let q_bh = q.flatten(0, 1)?.contiguous()?;
|
||||||
|
let k_bh = k.flatten(0, 1)?.contiguous()?;
|
||||||
|
let v_bh = v.flatten(0, 1)?.contiguous()?;
|
||||||
|
let g_bh = g.flatten(0, 1)?.contiguous()?;
|
||||||
|
let beta_bh = beta.flatten(0, 1)?.contiguous()?;
|
||||||
|
let mut state_bh = state.flatten(0, 1)?.contiguous()?;
|
||||||
|
let output_bh = crate::cuda::gdn::gated_delta_rule_recurrence_cuda(
|
||||||
|
&q_bh,
|
||||||
|
&k_bh,
|
||||||
|
&v_bh,
|
||||||
|
&g_bh,
|
||||||
|
&beta_bh,
|
||||||
|
&mut state_bh,
|
||||||
|
)?;
|
||||||
|
let core_attn_out = output_bh.reshape((batch_size, num_heads, seq_len, head_v_dim))?;
|
||||||
|
let new_state = state_bh.reshape((batch_size, num_heads, head_k_dim, head_v_dim))?;
|
||||||
|
Ok((core_attn_out, new_state))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn run_delta_rule_rust(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
mut state: Tensor,
|
||||||
|
seq_len: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
use candle_core::IndexOp;
|
||||||
|
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
|
||||||
|
for t in 0..seq_len {
|
||||||
|
let q_t = q.i((.., .., t, ..))?;
|
||||||
|
let k_t = k.i((.., .., t, ..))?;
|
||||||
|
let v_t = v.i((.., .., t, ..))?;
|
||||||
|
let g_t = g.i((.., .., t))?;
|
||||||
|
let beta_t = beta.i((.., .., t))?;
|
||||||
|
let decay = g_t
|
||||||
|
.exp()?
|
||||||
|
.unsqueeze(candle_core::D::Minus1)?
|
||||||
|
.unsqueeze(candle_core::D::Minus1)?;
|
||||||
|
state = state.broadcast_mul(&decay)?;
|
||||||
|
let k_col = k_t.unsqueeze(candle_core::D::Minus1)?;
|
||||||
|
let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?;
|
||||||
|
let beta_col = beta_t.unsqueeze(candle_core::D::Minus1)?;
|
||||||
|
let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?;
|
||||||
|
let delta_row = delta.unsqueeze(2)?;
|
||||||
|
let outer = k_col.broadcast_mul(&delta_row)?;
|
||||||
|
state = (state + outer)?;
|
||||||
|
let q_col = q_t.unsqueeze(candle_core::D::Minus1)?;
|
||||||
|
let out_t = state.broadcast_mul(&q_col)?.sum(2)?;
|
||||||
|
outputs.push(out_t.unsqueeze(2)?);
|
||||||
|
}
|
||||||
|
let core_attn_out = Tensor::cat(&outputs, 2)?; // (B, H, L, D_v)
|
||||||
|
Ok((core_attn_out, state))
|
||||||
|
}
|
||||||
|
|
||||||
/// Load a no-bias linear from the ShardedVarBuilder. Weight shape is
|
/// Load a no-bias linear from the ShardedVarBuilder. Weight shape is
|
||||||
/// the standard `[out, in]` order.
|
/// the standard `[out, in]` order.
|
||||||
fn load_linear_no_bias(
|
fn load_linear_no_bias(
|
||||||
|
|||||||
Reference in New Issue
Block a user