feat(stage-8d-4): dispatch chunked_gated_delta_rule_recurrence at prefill
Some checks failed
build-prerelease / Build cortex binary (push) Blocked by required conditions
CI / Test (push) Waiting to run
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 44s
CI / Clippy (push) Failing after 52s
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
Some checks failed
build-prerelease / Build cortex binary (push) Blocked by required conditions
CI / Test (push) Waiting to run
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 44s
CI / Clippy (push) Failing after 52s
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
run_delta_rule_cuda now picks between the per-token kernel and the BT=64 chunked variant based on seq_len. Threshold = 64 matches mistralrs. Prefill on Qwen3.6-27B (typical seq_len in the hundreds) drops from one block-launch per token to one per 64-token chunk. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -406,14 +406,29 @@ fn run_delta_rule_cuda(
|
||||
let g_bh = g.flatten(0, 1)?.contiguous()?;
|
||||
let beta_bh = beta.flatten(0, 1)?.contiguous()?;
|
||||
let mut state_bh = state.flatten(0, 1)?.contiguous()?;
|
||||
let output_bh = crate::cuda::gdn::gated_delta_rule_recurrence_cuda(
|
||||
// For long prefills, the chunked kernel (BT=64) processes a chunk
|
||||
// of tokens at a time instead of one-by-one — same delta-rule math,
|
||||
// far fewer block launches. Threshold matches mistralrs.
|
||||
const CHUNK_THRESHOLD: usize = 64;
|
||||
let output_bh = if seq_len >= CHUNK_THRESHOLD {
|
||||
crate::cuda::gdn::chunked_gated_delta_rule_recurrence_cuda(
|
||||
&q_bh,
|
||||
&k_bh,
|
||||
&v_bh,
|
||||
&g_bh,
|
||||
&beta_bh,
|
||||
&mut state_bh,
|
||||
)?;
|
||||
)?
|
||||
} else {
|
||||
crate::cuda::gdn::gated_delta_rule_recurrence_cuda(
|
||||
&q_bh,
|
||||
&k_bh,
|
||||
&v_bh,
|
||||
&g_bh,
|
||||
&beta_bh,
|
||||
&mut state_bh,
|
||||
)?
|
||||
};
|
||||
let core_attn_out = output_bh.reshape((batch_size, num_heads, seq_len, head_v_dim))?;
|
||||
let new_state = state_bh.reshape((batch_size, num_heads, head_k_dim, head_v_dim))?;
|
||||
Ok((core_attn_out, new_state))
|
||||
|
||||
Reference in New Issue
Block a user