feat(stage-8d-5): wire gated_delta_rule_recurrence kernel into tp_qwen3_5
Some checks failed
build-prerelease / Package cortex RPM (push) Blocked by required conditions
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Format (push) Successful in 39s
CI / Clippy (push) Successful in 2m21s
build-prerelease / Build neuron-blackwell (push) Successful in 3m36s
CI / Test (push) Successful in 4m39s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m34s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled

TP per-token Rust loop replaced with shared run_delta_rule dispatch
from arch/qwen3_5/linear_attn.rs. Both single-GPU and TP variants now
use the cuda kernel when available, per-token Rust fallback otherwise.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 11:44:12 +03:00
parent 44ae927e38
commit 10c151efa5
2 changed files with 23 additions and 30 deletions

View File

@@ -382,7 +382,7 @@ impl GatedDeltaNet {
/// ///
/// CPU path: pure-Rust per-token loop. Correct, slow. /// CPU path: pure-Rust per-token loop. Correct, slow.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn run_delta_rule( pub(crate) fn run_delta_rule(
q: &Tensor, q: &Tensor,
k: &Tensor, k: &Tensor,
v: &Tensor, v: &Tensor,

View File

@@ -31,7 +31,7 @@
//! linear-attention block, lm_head, the rotary table. //! linear-attention block, lm_head, the rotary table.
use anyhow::{Context, Result, bail}; use anyhow::{Context, Result, bail};
use candle_core::{D, DType, Device, IndexOp, Module, Tensor}; use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder; use candle_nn::var_builder::ShardedVarBuilder;
use candle_nn::{Embedding, Linear, kv_cache::ConcatKvCache}; use candle_nn::{Embedding, Linear, kv_cache::ConcatKvCache};
use candle_transformers::utils::repeat_kv; use candle_transformers::utils::repeat_kv;
@@ -343,7 +343,7 @@ impl TpQwen3_5GatedDeltaNet {
let v = v.to_dtype(DType::F32)?; let v = v.to_dtype(DType::F32)?;
let beta = beta.to_dtype(DType::F32)?; let beta = beta.to_dtype(DType::F32)?;
let mut state = match self.state.recurrent_state.take() { let state_init = match self.state.recurrent_state.take() {
Some(s) => s.to_dtype(DType::F32)?, Some(s) => s.to_dtype(DType::F32)?,
None => Tensor::zeros( None => Tensor::zeros(
( (
@@ -357,34 +357,27 @@ impl TpQwen3_5GatedDeltaNet {
)?, )?,
}; };
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len); // Hand off to the shared delta-rule runner — same cuda-kernel
for t in 0..seq_len { // dispatch as the single-GPU `arch::qwen3_5::linear_attn`, just
let q_t = q.i((.., .., t, ..))?; // with per-rank head counts. CPU path falls back to a per-token
let k_t = k.i((.., .., t, ..))?; // Rust loop; cuda path is the V-tiled register-resident kernel
let v_t = v.i((.., .., t, ..))?; // imported from mistralrs.
let g_t = g.i((.., .., t))?; let (core_attn_out, new_state) =
let beta_t = beta.i((.., .., t))?; crate::harness::arch::qwen3_5::linear_attn::run_delta_rule(
&q,
&k,
&v,
&g,
&beta,
state_init,
batch_size,
self.per_rank_num_v_heads,
seq_len,
self.head_k_dim,
self.head_v_dim,
)?;
self.state.recurrent_state = Some(new_state.to_dtype(dtype)?);
let decay = g_t.exp()?.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?;
state = state.broadcast_mul(&decay)?;
let k_col = k_t.unsqueeze(D::Minus1)?;
let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?;
let beta_col = beta_t.unsqueeze(D::Minus1)?;
let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?;
let delta_row = delta.unsqueeze(2)?;
let outer = k_col.broadcast_mul(&delta_row)?;
state = (state + outer)?;
let q_col = q_t.unsqueeze(D::Minus1)?;
let out_t = state.broadcast_mul(&q_col)?.sum(2)?;
outputs.push(out_t.unsqueeze(2)?);
}
self.state.recurrent_state = Some(state.to_dtype(dtype)?);
let core_attn_out = Tensor::cat(&outputs, 2)?;
let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?;
let core_attn_out = core_attn_out.to_dtype(dtype)?; let core_attn_out = core_attn_out.to_dtype(dtype)?;
let core_attn_flat = core_attn_out.reshape(( let core_attn_flat = core_attn_out.reshape((