feat(stage-8d-5): wire gated_delta_rule_recurrence kernel into tp_qwen3_5
Some checks failed
build-prerelease / Package cortex RPM (push) Blocked by required conditions
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Format (push) Successful in 39s
CI / Clippy (push) Successful in 2m21s
build-prerelease / Build neuron-blackwell (push) Successful in 3m36s
CI / Test (push) Successful in 4m39s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m34s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Some checks failed
build-prerelease / Package cortex RPM (push) Blocked by required conditions
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Format (push) Successful in 39s
CI / Clippy (push) Successful in 2m21s
build-prerelease / Build neuron-blackwell (push) Successful in 3m36s
CI / Test (push) Successful in 4m39s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m34s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
TP per-token Rust loop replaced with shared run_delta_rule dispatch from arch/qwen3_5/linear_attn.rs. Both single-GPU and TP variants now use the cuda kernel when available, per-token Rust fallback otherwise. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -382,7 +382,7 @@ impl GatedDeltaNet {
|
|||||||
///
|
///
|
||||||
/// CPU path: pure-Rust per-token loop. Correct, slow.
|
/// CPU path: pure-Rust per-token loop. Correct, slow.
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn run_delta_rule(
|
pub(crate) fn run_delta_rule(
|
||||||
q: &Tensor,
|
q: &Tensor,
|
||||||
k: &Tensor,
|
k: &Tensor,
|
||||||
v: &Tensor,
|
v: &Tensor,
|
||||||
|
|||||||
@@ -31,7 +31,7 @@
|
|||||||
//! linear-attention block, lm_head, the rotary table.
|
//! linear-attention block, lm_head, the rotary table.
|
||||||
|
|
||||||
use anyhow::{Context, Result, bail};
|
use anyhow::{Context, Result, bail};
|
||||||
use candle_core::{D, DType, Device, IndexOp, Module, Tensor};
|
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
||||||
use candle_nn::var_builder::ShardedVarBuilder;
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
use candle_nn::{Embedding, Linear, kv_cache::ConcatKvCache};
|
use candle_nn::{Embedding, Linear, kv_cache::ConcatKvCache};
|
||||||
use candle_transformers::utils::repeat_kv;
|
use candle_transformers::utils::repeat_kv;
|
||||||
@@ -343,7 +343,7 @@ impl TpQwen3_5GatedDeltaNet {
|
|||||||
let v = v.to_dtype(DType::F32)?;
|
let v = v.to_dtype(DType::F32)?;
|
||||||
let beta = beta.to_dtype(DType::F32)?;
|
let beta = beta.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
let mut state = match self.state.recurrent_state.take() {
|
let state_init = match self.state.recurrent_state.take() {
|
||||||
Some(s) => s.to_dtype(DType::F32)?,
|
Some(s) => s.to_dtype(DType::F32)?,
|
||||||
None => Tensor::zeros(
|
None => Tensor::zeros(
|
||||||
(
|
(
|
||||||
@@ -357,34 +357,27 @@ impl TpQwen3_5GatedDeltaNet {
|
|||||||
)?,
|
)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
|
// Hand off to the shared delta-rule runner — same cuda-kernel
|
||||||
for t in 0..seq_len {
|
// dispatch as the single-GPU `arch::qwen3_5::linear_attn`, just
|
||||||
let q_t = q.i((.., .., t, ..))?;
|
// with per-rank head counts. CPU path falls back to a per-token
|
||||||
let k_t = k.i((.., .., t, ..))?;
|
// Rust loop; cuda path is the V-tiled register-resident kernel
|
||||||
let v_t = v.i((.., .., t, ..))?;
|
// imported from mistralrs.
|
||||||
let g_t = g.i((.., .., t))?;
|
let (core_attn_out, new_state) =
|
||||||
let beta_t = beta.i((.., .., t))?;
|
crate::harness::arch::qwen3_5::linear_attn::run_delta_rule(
|
||||||
|
&q,
|
||||||
|
&k,
|
||||||
|
&v,
|
||||||
|
&g,
|
||||||
|
&beta,
|
||||||
|
state_init,
|
||||||
|
batch_size,
|
||||||
|
self.per_rank_num_v_heads,
|
||||||
|
seq_len,
|
||||||
|
self.head_k_dim,
|
||||||
|
self.head_v_dim,
|
||||||
|
)?;
|
||||||
|
self.state.recurrent_state = Some(new_state.to_dtype(dtype)?);
|
||||||
|
|
||||||
let decay = g_t.exp()?.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?;
|
|
||||||
state = state.broadcast_mul(&decay)?;
|
|
||||||
|
|
||||||
let k_col = k_t.unsqueeze(D::Minus1)?;
|
|
||||||
let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?;
|
|
||||||
|
|
||||||
let beta_col = beta_t.unsqueeze(D::Minus1)?;
|
|
||||||
let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?;
|
|
||||||
|
|
||||||
let delta_row = delta.unsqueeze(2)?;
|
|
||||||
let outer = k_col.broadcast_mul(&delta_row)?;
|
|
||||||
state = (state + outer)?;
|
|
||||||
|
|
||||||
let q_col = q_t.unsqueeze(D::Minus1)?;
|
|
||||||
let out_t = state.broadcast_mul(&q_col)?.sum(2)?;
|
|
||||||
outputs.push(out_t.unsqueeze(2)?);
|
|
||||||
}
|
|
||||||
self.state.recurrent_state = Some(state.to_dtype(dtype)?);
|
|
||||||
|
|
||||||
let core_attn_out = Tensor::cat(&outputs, 2)?;
|
|
||||||
let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?;
|
let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?;
|
||||||
let core_attn_out = core_attn_out.to_dtype(dtype)?;
|
let core_attn_out = core_attn_out.to_dtype(dtype)?;
|
||||||
let core_attn_flat = core_attn_out.reshape((
|
let core_attn_flat = core_attn_out.reshape((
|
||||||
|
|||||||
Reference in New Issue
Block a user