diff --git a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs index dce17ed..6bc1190 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs @@ -382,7 +382,7 @@ impl GatedDeltaNet { /// /// CPU path: pure-Rust per-token loop. Correct, slow. #[allow(clippy::too_many_arguments)] -fn run_delta_rule( +pub(crate) fn run_delta_rule( q: &Tensor, k: &Tensor, v: &Tensor, diff --git a/crates/neuron/src/harness/tp/tp_qwen3_5.rs b/crates/neuron/src/harness/tp/tp_qwen3_5.rs index 1b45235..4765dd0 100644 --- a/crates/neuron/src/harness/tp/tp_qwen3_5.rs +++ b/crates/neuron/src/harness/tp/tp_qwen3_5.rs @@ -31,7 +31,7 @@ //! linear-attention block, lm_head, the rotary table. use anyhow::{Context, Result, bail}; -use candle_core::{D, DType, Device, IndexOp, Module, Tensor}; +use candle_core::{DType, Device, IndexOp, Module, Tensor}; use candle_nn::var_builder::ShardedVarBuilder; use candle_nn::{Embedding, Linear, kv_cache::ConcatKvCache}; use candle_transformers::utils::repeat_kv; @@ -343,7 +343,7 @@ impl TpQwen3_5GatedDeltaNet { let v = v.to_dtype(DType::F32)?; let beta = beta.to_dtype(DType::F32)?; - let mut state = match self.state.recurrent_state.take() { + let state_init = match self.state.recurrent_state.take() { Some(s) => s.to_dtype(DType::F32)?, None => Tensor::zeros( ( @@ -357,34 +357,27 @@ impl TpQwen3_5GatedDeltaNet { )?, }; - let mut outputs: Vec = Vec::with_capacity(seq_len); - for t in 0..seq_len { - let q_t = q.i((.., .., t, ..))?; - let k_t = k.i((.., .., t, ..))?; - let v_t = v.i((.., .., t, ..))?; - let g_t = g.i((.., .., t))?; - let beta_t = beta.i((.., .., t))?; + // Hand off to the shared delta-rule runner — same cuda-kernel + // dispatch as the single-GPU `arch::qwen3_5::linear_attn`, just + // with per-rank head counts. CPU path falls back to a per-token + // Rust loop; cuda path is the V-tiled register-resident kernel + // imported from mistralrs. + let (core_attn_out, new_state) = + crate::harness::arch::qwen3_5::linear_attn::run_delta_rule( + &q, + &k, + &v, + &g, + &beta, + state_init, + batch_size, + self.per_rank_num_v_heads, + seq_len, + self.head_k_dim, + self.head_v_dim, + )?; + self.state.recurrent_state = Some(new_state.to_dtype(dtype)?); - let decay = g_t.exp()?.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?; - state = state.broadcast_mul(&decay)?; - - let k_col = k_t.unsqueeze(D::Minus1)?; - let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?; - - let beta_col = beta_t.unsqueeze(D::Minus1)?; - let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?; - - let delta_row = delta.unsqueeze(2)?; - let outer = k_col.broadcast_mul(&delta_row)?; - state = (state + outer)?; - - let q_col = q_t.unsqueeze(D::Minus1)?; - let out_t = state.broadcast_mul(&q_col)?.sum(2)?; - outputs.push(out_t.unsqueeze(2)?); - } - self.state.recurrent_state = Some(state.to_dtype(dtype)?); - - let core_attn_out = Tensor::cat(&outputs, 2)?; let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; let core_attn_out = core_attn_out.to_dtype(dtype)?; let core_attn_flat = core_attn_out.reshape((