diff --git a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs index 8ddac01..5c3cb2c 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs @@ -406,14 +406,29 @@ fn run_delta_rule_cuda( let g_bh = g.flatten(0, 1)?.contiguous()?; let beta_bh = beta.flatten(0, 1)?.contiguous()?; let mut state_bh = state.flatten(0, 1)?.contiguous()?; - let output_bh = crate::cuda::gdn::gated_delta_rule_recurrence_cuda( - &q_bh, - &k_bh, - &v_bh, - &g_bh, - &beta_bh, - &mut state_bh, - )?; + // For long prefills, the chunked kernel (BT=64) processes a chunk + // of tokens at a time instead of one-by-one — same delta-rule math, + // far fewer block launches. Threshold matches mistralrs. + const CHUNK_THRESHOLD: usize = 64; + let output_bh = if seq_len >= CHUNK_THRESHOLD { + crate::cuda::gdn::chunked_gated_delta_rule_recurrence_cuda( + &q_bh, + &k_bh, + &v_bh, + &g_bh, + &beta_bh, + &mut state_bh, + )? + } else { + crate::cuda::gdn::gated_delta_rule_recurrence_cuda( + &q_bh, + &k_bh, + &v_bh, + &g_bh, + &beta_bh, + &mut state_bh, + )? + }; let core_attn_out = output_bh.reshape((batch_size, num_heads, seq_len, head_v_dim))?; let new_state = state_bh.reshape((batch_size, num_heads, head_k_dim, head_v_dim))?; Ok((core_attn_out, new_state))