feat(stage-8d-4): dispatch chunked_gated_delta_rule_recurrence at prefill
Some checks failed
build-prerelease / Build cortex binary (push) Blocked by required conditions
CI / Test (push) Waiting to run
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 44s
CI / Clippy (push) Failing after 52s
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled

run_delta_rule_cuda now picks between the per-token kernel and the
BT=64 chunked variant based on seq_len. Threshold = 64 matches mistralrs.
Prefill on Qwen3.6-27B (typical seq_len in the hundreds) drops from
one block-launch per token to one per 64-token chunk.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 11:50:30 +03:00
parent 05dc0bad18
commit 09c945f81e

View File

@@ -406,14 +406,29 @@ fn run_delta_rule_cuda(
let g_bh = g.flatten(0, 1)?.contiguous()?; let g_bh = g.flatten(0, 1)?.contiguous()?;
let beta_bh = beta.flatten(0, 1)?.contiguous()?; let beta_bh = beta.flatten(0, 1)?.contiguous()?;
let mut state_bh = state.flatten(0, 1)?.contiguous()?; let mut state_bh = state.flatten(0, 1)?.contiguous()?;
let output_bh = crate::cuda::gdn::gated_delta_rule_recurrence_cuda( // For long prefills, the chunked kernel (BT=64) processes a chunk
// of tokens at a time instead of one-by-one — same delta-rule math,
// far fewer block launches. Threshold matches mistralrs.
const CHUNK_THRESHOLD: usize = 64;
let output_bh = if seq_len >= CHUNK_THRESHOLD {
crate::cuda::gdn::chunked_gated_delta_rule_recurrence_cuda(
&q_bh, &q_bh,
&k_bh, &k_bh,
&v_bh, &v_bh,
&g_bh, &g_bh,
&beta_bh, &beta_bh,
&mut state_bh, &mut state_bh,
)?; )?
} else {
crate::cuda::gdn::gated_delta_rule_recurrence_cuda(
&q_bh,
&k_bh,
&v_bh,
&g_bh,
&beta_bh,
&mut state_bh,
)?
};
let core_attn_out = output_bh.reshape((batch_size, num_heads, seq_len, head_v_dim))?; let core_attn_out = output_bh.reshape((batch_size, num_heads, seq_len, head_v_dim))?;
let new_state = state_bh.reshape((batch_size, num_heads, head_k_dim, head_v_dim))?; let new_state = state_bh.reshape((batch_size, num_heads, head_k_dim, head_v_dim))?;
Ok((core_attn_out, new_state)) Ok((core_attn_out, new_state))