perf(neuron): chunked delta-rule prefill for Gated DeltaNet (#23) #39
Reference in New Issue
Block a user
Delete Branch "perf/23-chunked-gdn-prefill"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
Prefill (
seq_len ≥ 64) now runs the chunk-parallel gated delta rule, a faithful port of HF'storch_chunk_gated_delta_rule(chunk_size = 64) minus the steps our caller already performs (q/k L2-norm, q pre-scaling). Decode steps and short prompts keep the recurrent paths (CUDA kernel / Rust loop) untouched, per the issue's "recurrent path retained for decode".Design notes
(I − T)⁻¹ − Iby forward substitution over 63 sequential row updates.Tis strictly lower-triangular, hence nilpotent at chunk size 64, so the same inverse isΠ_{j=0..5}(I + T^(2^j))— six batched matmuls, which suits candle's immutable tensors and parallelises across every chunk at once. Parity tests pin the equivalence.B·H·Nbatch dim (candle's matmul supports at most two batch dims); the inter-chunk recurrence is a handful of rank-4 matmuls per 64 tokens.initial_statepath), so chunked prefill composes with #11's restored prefix snapshots — a restored conversation's divergent suffix prefills chunked too.run_delta_ruledispatch (the TP shard forward calls the same function).NEURON_GDN_CHUNKED=0forces the recurrent paths for live A/B.Parity (the issue's deliverable)
chunked_matches_recurrent_*tests pin chunked vs recurrent to 2e-4 abs across: padding (L=130 = 2×64+2), exact multiples continuing from a non-zero initial state (L=128 after a 50-token recurrent prefix), and a single exact chunk (L=64, B=2, H=3).Before/after numbers from the #22 harness follow after deploy (the 27B's 8.07 s cold prefill @ ~5k tokens is the before-number; bench cold cells via cache-miss prompts).
🤖 Generated with Claude Code
Prefill (seq_len >= 64) now runs the chunk-parallel gated delta rule ported from the HF reference torch_chunk_gated_delta_rule (chunk_size=64): identical math reorganised into per-chunk batched matmuls (cuBLAS/tensor cores on CUDA, gemm on CPU) instead of the O(L)-sequential per-token recurrence. Decode steps and short prompts keep the recurrent paths (CUDA kernel / Rust loop) unchanged. One deliberate deviation from the reference: its in-place row-by-row UT-transform computes (I - T)^-1 - I by forward substitution; T is strictly lower triangular and therefore nilpotent at chunk size 64, so the same inverse is the product of six squarings prod_{j=0..5}(I + T^(2^j)) — batched matmuls instead of 63 sequential row updates, which suits candle's immutable tensors. Chunk-local math runs rank-3 over a flattened B*H*N batch dim (candle matmul supports at most two batch dims). Initial-state continuation is supported, so chunked prefill composes with #11's restored prefix snapshots. Both single-GPU and TP paths pick this up through the shared run_delta_rule dispatch. NEURON_GDN_CHUNKED=0 forces the recurrent paths for A/B measurement. Parity tests pin chunked against recurrent (2e-4 abs) across padding (L=130), exact multiples with non-zero initial state (L=128 after a 50-token prefix), and a single exact chunk. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>Live A/B on beast produced NaN logits ("!!!" replies) on real prompts: the nilpotent-squaring form of (I - T)^-1 computes raw powers of T, whose entries grow combinatorially (path counts ~ C(62,31)) before nilpotency collapses them — fine on uncorrelated test data, f32 precision death on real prompts whose repetitive text makes keys highly correlated. The reference's forward-substitution loop never forms raw powers; its intermediates are the convergent M entries. Port the reference loop faithfully (rows accumulate into a fresh tensor). New adversarial parity test with near-identical keys and beta ~= 1 diverges to 8e30 under the squaring form and passes under forward substitution. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>