feat(stage-8d-3): wire causal_conv1d_update/full CUDA kernels
Some checks failed
CI / Clippy (push) Waiting to run
CI / Test (push) Waiting to run
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 38s
build-prerelease / Build cortex binary (push) Has started running
build-prerelease / Build neuron-blackwell (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled

Replaces the per-layer conv1d + silu sequence in both single-GPU and
TP linear-attention forward paths with a shared run_causal_conv1d
helper that dispatches to:

- causal_conv1d_update for decode (seq_len=1 with existing conv_state)
- causal_conv1d_full for prefill / fresh start (zero-pads internally)

Both kernels fuse the depthwise conv + SiLU into a single launch — 4×
fewer cuda launches per linear-attention layer vs the candle conv1d +
candle_nn::ops::silu combo. Falls back to the original Rust path on
cpu.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 11:49:41 +03:00
parent 10c151efa5
commit 05dc0bad18
2 changed files with 141 additions and 64 deletions

View File

@@ -209,45 +209,18 @@ impl GatedDeltaNet {
let a = self.in_proj_a.forward(x)?;
// ----- Depthwise causal Conv1d + SiLU (with state continuation). -----
// If the previous step left a `conv_state`, prepend it so the
// causal kernel window sees the correct left-context.
let prepended = match &self.state.conv_state {
Some(prev) => Tensor::cat(&[prev, &mixed_qkv_chw], 2)?,
None => mixed_qkv_chw.clone(),
};
let prep_len = prepended.dims()[2];
// Update conv_state: keep the last `conv_kernel_size` columns
// of the (possibly prepended) sequence. If the sequence is
// shorter than `conv_kernel_size` (very-short prefill or first
// decode step before warmup), left-pad with zeros.
let new_state = if prep_len >= self.conv_kernel_size {
prepended.narrow(2, prep_len - self.conv_kernel_size, self.conv_kernel_size)?
} else {
let pad = Tensor::zeros(
(batch_size, self.conv_dim, self.conv_kernel_size - prep_len),
dtype,
&device,
)?;
Tensor::cat(&[&pad, &prepended], 2)?
};
self.state.conv_state = Some(new_state);
// Apply the depthwise conv with padding=kernel-1 (so output
// length = input + kernel - 1), then trim back to `prep_len`.
// Matches the reference Python which calls the same nn.Conv1d
// with its baked-in padding and slices `[..., :input_len]`.
let conv_out = prepended.conv1d(
// Dispatches to a cuda kernel that fuses conv1d + silu when
// available; falls back to candle's `conv1d` + `silu` on cpu.
let (conv_out, new_state) = run_causal_conv1d(
&mixed_qkv_chw,
&self.conv1d_weight,
self.conv_kernel_size - 1,
1,
1,
self.state.conv_state.take(),
batch_size,
self.conv_dim,
seq_len,
self.conv_kernel_size,
)?;
let conv_out = conv_out.narrow(2, 0, prep_len)?;
let conv_out = candle_nn::ops::silu(&conv_out)?;
// Keep only the last L outputs (drop the prepended-state contribution).
let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?;
self.state.conv_state = Some(new_state);
// Back to (B, L, conv_dim).
let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?;
@@ -484,6 +457,130 @@ fn run_delta_rule_rust(
Ok((core_attn_out, state))
}
/// Depthwise causal conv1d + SiLU, with rolling `conv_state`.
///
/// `x`: `(B, conv_dim, L)` model dtype (f16/bf16 on cuda, anything on cpu).
/// `weight`: `(conv_dim, 1, kernel_size)` model dtype.
/// `conv_state`: `Some((B, conv_dim, kernel_size))` for decode continuation,
/// or `None` for fresh prefill.
///
/// Returns `(conv_out: (B, conv_dim, L), new_conv_state: (B, conv_dim, kernel_size))`.
/// SiLU is baked in.
///
/// Cuda path: dispatches to `causal_conv1d_update` (decode, seq_len=1 with
/// existing state) or `causal_conv1d_full` (prefill / first call), both
/// ported from mistralrs `gdn.cu`. Each kernel fuses the depthwise conv
/// and SiLU activation in one launch — that's ~4× fewer cuda launches per
/// linear-attention layer than the candle `conv1d` + `silu` combo.
///
/// CPU path: the original prepend-narrow-conv1d-silu sequence.
pub(crate) fn run_causal_conv1d(
x: &Tensor,
weight: &Tensor,
conv_state: Option<Tensor>,
batch_size: usize,
conv_dim: usize,
seq_len: usize,
conv_kernel_size: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
#[cfg(feature = "cuda")]
{
if x.device().is_cuda() {
return run_causal_conv1d_cuda(
x,
weight,
conv_state,
batch_size,
conv_dim,
seq_len,
conv_kernel_size,
);
}
}
run_causal_conv1d_rust(
x,
weight,
conv_state,
batch_size,
conv_dim,
seq_len,
conv_kernel_size,
)
}
#[cfg(feature = "cuda")]
fn run_causal_conv1d_cuda(
x: &Tensor,
weight: &Tensor,
conv_state: Option<Tensor>,
batch_size: usize,
conv_dim: usize,
seq_len: usize,
conv_kernel_size: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
// Kernel expects weight as (conv_dim, kernel_size) — squeeze the
// depthwise channel-multiplier dim.
let w = weight.squeeze(1)?.to_dtype(x.dtype())?.contiguous()?;
// Decode path: seq_len == 1 AND we have an existing conv_state.
// Otherwise (prefill or fresh-start decode), use the full path which
// zero-pads on the left internally.
if let Some(cs) = conv_state
&& seq_len == 1
{
let cs = cs.contiguous()?;
let (output, new_conv_state) =
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &cs, conv_kernel_size, true)?;
return Ok((output, new_conv_state));
}
// Prefill / fresh-start: the kernel ignores any prior conv_state and
// zero-pads. If we had a non-zero prior state and >1 input tokens
// (multi-turn continuation), we'd need to fall back to Rust. Match
// mistralrs's behaviour: fresh prefill always.
let device = x.device().clone();
let zeros_cs = Tensor::zeros((batch_size, conv_dim, conv_kernel_size), x.dtype(), &device)?;
let (output, new_conv_state) =
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &zeros_cs, conv_kernel_size, false)?;
Ok((output, new_conv_state))
}
fn run_causal_conv1d_rust(
x: &Tensor,
weight: &Tensor,
conv_state: Option<Tensor>,
batch_size: usize,
conv_dim: usize,
seq_len: usize,
conv_kernel_size: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let dtype = x.dtype();
let device = x.device().clone();
let prepended = match &conv_state {
Some(prev) => Tensor::cat(&[prev, x], 2)?,
None => x.clone(),
};
let prep_len = prepended.dims()[2];
let new_state = if prep_len >= conv_kernel_size {
prepended.narrow(2, prep_len - conv_kernel_size, conv_kernel_size)?
} else {
let pad = Tensor::zeros(
(batch_size, conv_dim, conv_kernel_size - prep_len),
dtype,
&device,
)?;
Tensor::cat(&[&pad, &prepended], 2)?
};
let conv_out = prepended.conv1d(weight, conv_kernel_size - 1, 1, 1, conv_dim)?;
let conv_out = conv_out.narrow(2, 0, prep_len)?;
let conv_out = candle_nn::ops::silu(&conv_out)?;
let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?;
Ok((conv_out, new_state))
}
/// Load a no-bias linear from the ShardedVarBuilder. Weight shape is
/// the standard `[out, in]` order.
fn load_linear_no_bias(

View File

@@ -247,37 +247,17 @@ impl TpQwen3_5GatedDeltaNet {
let a = self.in_proj_a.forward(x)?;
// ----- State-aware causal Conv1d + SiLU. -----
let prepended = match &self.state.conv_state {
Some(prev) => Tensor::cat(&[prev, &mixed_qkv_chw], 2)?,
None => mixed_qkv_chw.clone(),
};
let prep_len = prepended.dims()[2];
let new_state = if prep_len >= self.conv_kernel_size {
prepended.narrow(2, prep_len - self.conv_kernel_size, self.conv_kernel_size)?
} else {
let pad = Tensor::zeros(
(
// Same shared helper as single-GPU — cuda kernel when available.
let (conv_out, new_state) = crate::harness::arch::qwen3_5::linear_attn::run_causal_conv1d(
&mixed_qkv_chw,
&self.conv1d_weight,
self.state.conv_state.take(),
batch_size,
self.per_rank_conv_dim,
self.conv_kernel_size - prep_len,
),
dtype,
&device,
seq_len,
self.conv_kernel_size,
)?;
Tensor::cat(&[&pad, &prepended], 2)?
};
self.state.conv_state = Some(new_state);
let conv_out = prepended.conv1d(
&self.conv1d_weight,
self.conv_kernel_size - 1,
1,
1,
self.per_rank_conv_dim,
)?;
let conv_out = conv_out.narrow(2, 0, prep_len)?;
let conv_out = candle_nn::ops::silu(&conv_out)?;
let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?;
let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?;
// ----- Split into q, k, v (per-rank head counts). -----