perf(neuron): chunked delta-rule prefill for Gated DeltaNet (#23) #39

Merged
grenade merged 2 commits from perf/23-chunked-gdn-prefill into main 2026-06-12 18:44:23 +00:00
Owner

Prefill (seq_len ≥ 64) now runs the chunk-parallel gated delta rule, a faithful port of HF's torch_chunk_gated_delta_rule (chunk_size = 64) minus the steps our caller already performs (q/k L2-norm, q pre-scaling). Decode steps and short prompts keep the recurrent paths (CUDA kernel / Rust loop) untouched, per the issue's "recurrent path retained for decode".

Design notes

  • One deviation from the reference: its in-place row-by-row UT-transform computes (I − T)⁻¹ − I by forward substitution over 63 sequential row updates. T is strictly lower-triangular, hence nilpotent at chunk size 64, so the same inverse is Π_{j=0..5}(I + T^(2^j)) — six batched matmuls, which suits candle's immutable tensors and parallelises across every chunk at once. Parity tests pin the equivalence.
  • Chunk-local math runs rank-3 over a flattened B·H·N batch dim (candle's matmul supports at most two batch dims); the inter-chunk recurrence is a handful of rank-4 matmuls per 64 tokens.
  • Initial-state continuation is supported (the reference's initial_state path), so chunked prefill composes with #11's restored prefix snapshots — a restored conversation's divergent suffix prefills chunked too.
  • Both single-GPU and TP paths pick this up through the shared run_delta_rule dispatch (the TP shard forward calls the same function). NEURON_GDN_CHUNKED=0 forces the recurrent paths for live A/B.

Parity (the issue's deliverable)

chunked_matches_recurrent_* tests pin chunked vs recurrent to 2e-4 abs across: padding (L=130 = 2×64+2), exact multiples continuing from a non-zero initial state (L=128 after a 50-token recurrent prefix), and a single exact chunk (L=64, B=2, H=3).

Before/after numbers from the #22 harness follow after deploy (the 27B's 8.07 s cold prefill @ ~5k tokens is the before-number; bench cold cells via cache-miss prompts).

🤖 Generated with Claude Code

Prefill (`seq_len ≥ 64`) now runs the chunk-parallel gated delta rule, a faithful port of HF's `torch_chunk_gated_delta_rule` (chunk_size = 64) minus the steps our caller already performs (q/k L2-norm, q pre-scaling). Decode steps and short prompts keep the recurrent paths (CUDA kernel / Rust loop) untouched, per the issue's "recurrent path retained for decode". ## Design notes - **One deviation from the reference**: its in-place row-by-row UT-transform computes `(I − T)⁻¹ − I` by forward substitution over 63 sequential row updates. `T` is strictly lower-triangular, hence nilpotent at chunk size 64, so the same inverse is `Π_{j=0..5}(I + T^(2^j))` — six batched matmuls, which suits candle's immutable tensors and parallelises across every chunk at once. Parity tests pin the equivalence. - Chunk-local math runs rank-3 over a flattened `B·H·N` batch dim (candle's matmul supports at most two batch dims); the inter-chunk recurrence is a handful of rank-4 matmuls per 64 tokens. - **Initial-state continuation** is supported (the reference's `initial_state` path), so chunked prefill composes with #11's restored prefix snapshots — a restored conversation's divergent suffix prefills chunked too. - Both single-GPU and TP paths pick this up through the shared `run_delta_rule` dispatch (the TP shard forward calls the same function). `NEURON_GDN_CHUNKED=0` forces the recurrent paths for live A/B. ## Parity (the issue's deliverable) `chunked_matches_recurrent_*` tests pin chunked vs recurrent to 2e-4 abs across: padding (L=130 = 2×64+2), exact multiples continuing from a non-zero initial state (L=128 after a 50-token recurrent prefix), and a single exact chunk (L=64, B=2, H=3). Before/after numbers from the #22 harness follow after deploy (the 27B's 8.07 s cold prefill @ ~5k tokens is the before-number; bench cold cells via cache-miss prompts). 🤖 Generated with [Claude Code](https://claude.com/claude-code)
grenade added 1 commit 2026-06-12 17:52:16 +00:00
perf(neuron): chunked delta-rule prefill for Gated DeltaNet (#23)
All checks were successful
CI / Format (push) Successful in 32s
CI / Format (pull_request) Successful in 24s
CI / CUDA type-check (push) Successful in 1m38s
CI / CUDA type-check (pull_request) Successful in 2m10s
CI / Clippy (push) Successful in 2m34s
CI / Test (push) Successful in 4m20s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
CI / Clippy (pull_request) Successful in 2m29s
CI / Test (pull_request) Successful in 4m21s
CI / Build cortex SRPM (pull_request) Has been skipped
CI / Build neuron SRPM (pull_request) Has been skipped
CI / Publish cortex to COPR (pull_request) Has been skipped
CI / Publish neuron to COPR (pull_request) Has been skipped
CI / Bump version in source (pull_request) Has been skipped
2a9def6d2d
Prefill (seq_len >= 64) now runs the chunk-parallel gated delta rule
ported from the HF reference torch_chunk_gated_delta_rule
(chunk_size=64): identical math reorganised into per-chunk batched
matmuls (cuBLAS/tensor cores on CUDA, gemm on CPU) instead of the
O(L)-sequential per-token recurrence. Decode steps and short prompts
keep the recurrent paths (CUDA kernel / Rust loop) unchanged.

One deliberate deviation from the reference: its in-place row-by-row
UT-transform computes (I - T)^-1 - I by forward substitution; T is
strictly lower triangular and therefore nilpotent at chunk size 64,
so the same inverse is the product of six squarings
prod_{j=0..5}(I + T^(2^j)) — batched matmuls instead of 63 sequential
row updates, which suits candle's immutable tensors. Chunk-local math
runs rank-3 over a flattened B*H*N batch dim (candle matmul supports
at most two batch dims).

Initial-state continuation is supported, so chunked prefill composes
with #11's restored prefix snapshots. Both single-GPU and TP paths
pick this up through the shared run_delta_rule dispatch.
NEURON_GDN_CHUNKED=0 forces the recurrent paths for A/B measurement.

Parity tests pin chunked against recurrent (2e-4 abs) across padding
(L=130), exact multiples with non-zero initial state (L=128 after a
50-token prefix), and a single exact chunk.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
grenade added 1 commit 2026-06-12 18:18:36 +00:00
fix(neuron): UT transform by forward substitution, not nilpotent squaring
All checks were successful
CI / Format (push) Successful in 32s
CI / Format (pull_request) Successful in 53s
CI / CUDA type-check (push) Successful in 1m52s
CI / CUDA type-check (pull_request) Successful in 2m12s
CI / Clippy (push) Successful in 2m18s
CI / Clippy (pull_request) Successful in 2m36s
CI / Test (push) Successful in 4m18s
CI / Test (pull_request) Successful in 4m22s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
CI / Build cortex SRPM (pull_request) Has been skipped
CI / Publish cortex to COPR (pull_request) Has been skipped
CI / Build neuron SRPM (pull_request) Has been skipped
CI / Publish neuron to COPR (pull_request) Has been skipped
CI / Bump version in source (pull_request) Has been skipped
812d191e50
Live A/B on beast produced NaN logits ("!!!" replies) on real prompts:
the nilpotent-squaring form of (I - T)^-1 computes raw powers of T,
whose entries grow combinatorially (path counts ~ C(62,31)) before
nilpotency collapses them — fine on uncorrelated test data, f32
precision death on real prompts whose repetitive text makes keys
highly correlated. The reference's forward-substitution loop never
forms raw powers; its intermediates are the convergent M entries.

Port the reference loop faithfully (rows accumulate into a fresh
tensor). New adversarial parity test with near-identical keys and
beta ~= 1 diverges to 8e30 under the squaring form and passes under
forward substitution.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
grenade merged commit 128b3818cb into main 2026-06-12 18:44:23 +00:00
grenade deleted branch perf/23-chunked-gdn-prefill 2026-06-12 18:44:23 +00:00
Sign in to join this conversation.