From 05dc0bad18e18ad989b34b84dfa9097e14383b37 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Thu, 21 May 2026 11:49:41 +0300 Subject: [PATCH] feat(stage-8d-3): wire causal_conv1d_update/full CUDA kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the per-layer conv1d + silu sequence in both single-GPU and TP linear-attention forward paths with a shared run_causal_conv1d helper that dispatches to: - causal_conv1d_update for decode (seq_len=1 with existing conv_state) - causal_conv1d_full for prefill / fresh start (zero-pads internally) Both kernels fuse the depthwise conv + SiLU into a single launch — 4× fewer cuda launches per linear-attention layer vs the candle conv1d + candle_nn::ops::silu combo. Falls back to the original Rust path on cpu. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/harness/arch/qwen3_5/linear_attn.rs | 169 ++++++++++++++---- crates/neuron/src/harness/tp/tp_qwen3_5.rs | 36 +--- 2 files changed, 141 insertions(+), 64 deletions(-) diff --git a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs index 6bc1190..8ddac01 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs @@ -209,45 +209,18 @@ impl GatedDeltaNet { let a = self.in_proj_a.forward(x)?; // ----- Depthwise causal Conv1d + SiLU (with state continuation). ----- - // If the previous step left a `conv_state`, prepend it so the - // causal kernel window sees the correct left-context. - let prepended = match &self.state.conv_state { - Some(prev) => Tensor::cat(&[prev, &mixed_qkv_chw], 2)?, - None => mixed_qkv_chw.clone(), - }; - let prep_len = prepended.dims()[2]; - - // Update conv_state: keep the last `conv_kernel_size` columns - // of the (possibly prepended) sequence. If the sequence is - // shorter than `conv_kernel_size` (very-short prefill or first - // decode step before warmup), left-pad with zeros. - let new_state = if prep_len >= self.conv_kernel_size { - prepended.narrow(2, prep_len - self.conv_kernel_size, self.conv_kernel_size)? - } else { - let pad = Tensor::zeros( - (batch_size, self.conv_dim, self.conv_kernel_size - prep_len), - dtype, - &device, - )?; - Tensor::cat(&[&pad, &prepended], 2)? - }; - self.state.conv_state = Some(new_state); - - // Apply the depthwise conv with padding=kernel-1 (so output - // length = input + kernel - 1), then trim back to `prep_len`. - // Matches the reference Python which calls the same nn.Conv1d - // with its baked-in padding and slices `[..., :input_len]`. - let conv_out = prepended.conv1d( + // Dispatches to a cuda kernel that fuses conv1d + silu when + // available; falls back to candle's `conv1d` + `silu` on cpu. + let (conv_out, new_state) = run_causal_conv1d( + &mixed_qkv_chw, &self.conv1d_weight, - self.conv_kernel_size - 1, - 1, - 1, + self.state.conv_state.take(), + batch_size, self.conv_dim, + seq_len, + self.conv_kernel_size, )?; - let conv_out = conv_out.narrow(2, 0, prep_len)?; - let conv_out = candle_nn::ops::silu(&conv_out)?; - // Keep only the last L outputs (drop the prepended-state contribution). - let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?; + self.state.conv_state = Some(new_state); // Back to (B, L, conv_dim). let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?; @@ -484,6 +457,130 @@ fn run_delta_rule_rust( Ok((core_attn_out, state)) } +/// Depthwise causal conv1d + SiLU, with rolling `conv_state`. +/// +/// `x`: `(B, conv_dim, L)` model dtype (f16/bf16 on cuda, anything on cpu). +/// `weight`: `(conv_dim, 1, kernel_size)` model dtype. +/// `conv_state`: `Some((B, conv_dim, kernel_size))` for decode continuation, +/// or `None` for fresh prefill. +/// +/// Returns `(conv_out: (B, conv_dim, L), new_conv_state: (B, conv_dim, kernel_size))`. +/// SiLU is baked in. +/// +/// Cuda path: dispatches to `causal_conv1d_update` (decode, seq_len=1 with +/// existing state) or `causal_conv1d_full` (prefill / first call), both +/// ported from mistralrs `gdn.cu`. Each kernel fuses the depthwise conv +/// and SiLU activation in one launch — that's ~4× fewer cuda launches per +/// linear-attention layer than the candle `conv1d` + `silu` combo. +/// +/// CPU path: the original prepend-narrow-conv1d-silu sequence. +pub(crate) fn run_causal_conv1d( + x: &Tensor, + weight: &Tensor, + conv_state: Option, + batch_size: usize, + conv_dim: usize, + seq_len: usize, + conv_kernel_size: usize, +) -> candle_core::Result<(Tensor, Tensor)> { + #[cfg(feature = "cuda")] + { + if x.device().is_cuda() { + return run_causal_conv1d_cuda( + x, + weight, + conv_state, + batch_size, + conv_dim, + seq_len, + conv_kernel_size, + ); + } + } + run_causal_conv1d_rust( + x, + weight, + conv_state, + batch_size, + conv_dim, + seq_len, + conv_kernel_size, + ) +} + +#[cfg(feature = "cuda")] +fn run_causal_conv1d_cuda( + x: &Tensor, + weight: &Tensor, + conv_state: Option, + batch_size: usize, + conv_dim: usize, + seq_len: usize, + conv_kernel_size: usize, +) -> candle_core::Result<(Tensor, Tensor)> { + // Kernel expects weight as (conv_dim, kernel_size) — squeeze the + // depthwise channel-multiplier dim. + let w = weight.squeeze(1)?.to_dtype(x.dtype())?.contiguous()?; + + // Decode path: seq_len == 1 AND we have an existing conv_state. + // Otherwise (prefill or fresh-start decode), use the full path which + // zero-pads on the left internally. + if let Some(cs) = conv_state + && seq_len == 1 + { + let cs = cs.contiguous()?; + let (output, new_conv_state) = + crate::cuda::gdn::causal_conv1d_cuda(x, &w, &cs, conv_kernel_size, true)?; + return Ok((output, new_conv_state)); + } + + // Prefill / fresh-start: the kernel ignores any prior conv_state and + // zero-pads. If we had a non-zero prior state and >1 input tokens + // (multi-turn continuation), we'd need to fall back to Rust. Match + // mistralrs's behaviour: fresh prefill always. + let device = x.device().clone(); + let zeros_cs = Tensor::zeros((batch_size, conv_dim, conv_kernel_size), x.dtype(), &device)?; + let (output, new_conv_state) = + crate::cuda::gdn::causal_conv1d_cuda(x, &w, &zeros_cs, conv_kernel_size, false)?; + Ok((output, new_conv_state)) +} + +fn run_causal_conv1d_rust( + x: &Tensor, + weight: &Tensor, + conv_state: Option, + batch_size: usize, + conv_dim: usize, + seq_len: usize, + conv_kernel_size: usize, +) -> candle_core::Result<(Tensor, Tensor)> { + let dtype = x.dtype(); + let device = x.device().clone(); + + let prepended = match &conv_state { + Some(prev) => Tensor::cat(&[prev, x], 2)?, + None => x.clone(), + }; + let prep_len = prepended.dims()[2]; + + let new_state = if prep_len >= conv_kernel_size { + prepended.narrow(2, prep_len - conv_kernel_size, conv_kernel_size)? + } else { + let pad = Tensor::zeros( + (batch_size, conv_dim, conv_kernel_size - prep_len), + dtype, + &device, + )?; + Tensor::cat(&[&pad, &prepended], 2)? + }; + + let conv_out = prepended.conv1d(weight, conv_kernel_size - 1, 1, 1, conv_dim)?; + let conv_out = conv_out.narrow(2, 0, prep_len)?; + let conv_out = candle_nn::ops::silu(&conv_out)?; + let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?; + Ok((conv_out, new_state)) +} + /// Load a no-bias linear from the ShardedVarBuilder. Weight shape is /// the standard `[out, in]` order. fn load_linear_no_bias( diff --git a/crates/neuron/src/harness/tp/tp_qwen3_5.rs b/crates/neuron/src/harness/tp/tp_qwen3_5.rs index 4765dd0..f50f024 100644 --- a/crates/neuron/src/harness/tp/tp_qwen3_5.rs +++ b/crates/neuron/src/harness/tp/tp_qwen3_5.rs @@ -247,37 +247,17 @@ impl TpQwen3_5GatedDeltaNet { let a = self.in_proj_a.forward(x)?; // ----- State-aware causal Conv1d + SiLU. ----- - let prepended = match &self.state.conv_state { - Some(prev) => Tensor::cat(&[prev, &mixed_qkv_chw], 2)?, - None => mixed_qkv_chw.clone(), - }; - let prep_len = prepended.dims()[2]; - let new_state = if prep_len >= self.conv_kernel_size { - prepended.narrow(2, prep_len - self.conv_kernel_size, self.conv_kernel_size)? - } else { - let pad = Tensor::zeros( - ( - batch_size, - self.per_rank_conv_dim, - self.conv_kernel_size - prep_len, - ), - dtype, - &device, - )?; - Tensor::cat(&[&pad, &prepended], 2)? - }; - self.state.conv_state = Some(new_state); - - let conv_out = prepended.conv1d( + // Same shared helper as single-GPU — cuda kernel when available. + let (conv_out, new_state) = crate::harness::arch::qwen3_5::linear_attn::run_causal_conv1d( + &mixed_qkv_chw, &self.conv1d_weight, - self.conv_kernel_size - 1, - 1, - 1, + self.state.conv_state.take(), + batch_size, self.per_rank_conv_dim, + seq_len, + self.conv_kernel_size, )?; - let conv_out = conv_out.narrow(2, 0, prep_len)?; - let conv_out = candle_nn::ops::silu(&conv_out)?; - let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?; + self.state.conv_state = Some(new_state); let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?; // ----- Split into q, k, v (per-rank head counts). -----