feat(stage-8d-3): wire causal_conv1d_update/full CUDA kernels
Some checks failed
CI / Clippy (push) Waiting to run
CI / Test (push) Waiting to run
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 38s
build-prerelease / Build cortex binary (push) Has started running
build-prerelease / Build neuron-blackwell (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
Some checks failed
CI / Clippy (push) Waiting to run
CI / Test (push) Waiting to run
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 38s
build-prerelease / Build cortex binary (push) Has started running
build-prerelease / Build neuron-blackwell (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
Replaces the per-layer conv1d + silu sequence in both single-GPU and TP linear-attention forward paths with a shared run_causal_conv1d helper that dispatches to: - causal_conv1d_update for decode (seq_len=1 with existing conv_state) - causal_conv1d_full for prefill / fresh start (zero-pads internally) Both kernels fuse the depthwise conv + SiLU into a single launch — 4× fewer cuda launches per linear-attention layer vs the candle conv1d + candle_nn::ops::silu combo. Falls back to the original Rust path on cpu. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -209,45 +209,18 @@ impl GatedDeltaNet {
|
|||||||
let a = self.in_proj_a.forward(x)?;
|
let a = self.in_proj_a.forward(x)?;
|
||||||
|
|
||||||
// ----- Depthwise causal Conv1d + SiLU (with state continuation). -----
|
// ----- Depthwise causal Conv1d + SiLU (with state continuation). -----
|
||||||
// If the previous step left a `conv_state`, prepend it so the
|
// Dispatches to a cuda kernel that fuses conv1d + silu when
|
||||||
// causal kernel window sees the correct left-context.
|
// available; falls back to candle's `conv1d` + `silu` on cpu.
|
||||||
let prepended = match &self.state.conv_state {
|
let (conv_out, new_state) = run_causal_conv1d(
|
||||||
Some(prev) => Tensor::cat(&[prev, &mixed_qkv_chw], 2)?,
|
&mixed_qkv_chw,
|
||||||
None => mixed_qkv_chw.clone(),
|
|
||||||
};
|
|
||||||
let prep_len = prepended.dims()[2];
|
|
||||||
|
|
||||||
// Update conv_state: keep the last `conv_kernel_size` columns
|
|
||||||
// of the (possibly prepended) sequence. If the sequence is
|
|
||||||
// shorter than `conv_kernel_size` (very-short prefill or first
|
|
||||||
// decode step before warmup), left-pad with zeros.
|
|
||||||
let new_state = if prep_len >= self.conv_kernel_size {
|
|
||||||
prepended.narrow(2, prep_len - self.conv_kernel_size, self.conv_kernel_size)?
|
|
||||||
} else {
|
|
||||||
let pad = Tensor::zeros(
|
|
||||||
(batch_size, self.conv_dim, self.conv_kernel_size - prep_len),
|
|
||||||
dtype,
|
|
||||||
&device,
|
|
||||||
)?;
|
|
||||||
Tensor::cat(&[&pad, &prepended], 2)?
|
|
||||||
};
|
|
||||||
self.state.conv_state = Some(new_state);
|
|
||||||
|
|
||||||
// Apply the depthwise conv with padding=kernel-1 (so output
|
|
||||||
// length = input + kernel - 1), then trim back to `prep_len`.
|
|
||||||
// Matches the reference Python which calls the same nn.Conv1d
|
|
||||||
// with its baked-in padding and slices `[..., :input_len]`.
|
|
||||||
let conv_out = prepended.conv1d(
|
|
||||||
&self.conv1d_weight,
|
&self.conv1d_weight,
|
||||||
self.conv_kernel_size - 1,
|
self.state.conv_state.take(),
|
||||||
1,
|
batch_size,
|
||||||
1,
|
|
||||||
self.conv_dim,
|
self.conv_dim,
|
||||||
|
seq_len,
|
||||||
|
self.conv_kernel_size,
|
||||||
)?;
|
)?;
|
||||||
let conv_out = conv_out.narrow(2, 0, prep_len)?;
|
self.state.conv_state = Some(new_state);
|
||||||
let conv_out = candle_nn::ops::silu(&conv_out)?;
|
|
||||||
// Keep only the last L outputs (drop the prepended-state contribution).
|
|
||||||
let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?;
|
|
||||||
// Back to (B, L, conv_dim).
|
// Back to (B, L, conv_dim).
|
||||||
let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?;
|
let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
@@ -484,6 +457,130 @@ fn run_delta_rule_rust(
|
|||||||
Ok((core_attn_out, state))
|
Ok((core_attn_out, state))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Depthwise causal conv1d + SiLU, with rolling `conv_state`.
|
||||||
|
///
|
||||||
|
/// `x`: `(B, conv_dim, L)` model dtype (f16/bf16 on cuda, anything on cpu).
|
||||||
|
/// `weight`: `(conv_dim, 1, kernel_size)` model dtype.
|
||||||
|
/// `conv_state`: `Some((B, conv_dim, kernel_size))` for decode continuation,
|
||||||
|
/// or `None` for fresh prefill.
|
||||||
|
///
|
||||||
|
/// Returns `(conv_out: (B, conv_dim, L), new_conv_state: (B, conv_dim, kernel_size))`.
|
||||||
|
/// SiLU is baked in.
|
||||||
|
///
|
||||||
|
/// Cuda path: dispatches to `causal_conv1d_update` (decode, seq_len=1 with
|
||||||
|
/// existing state) or `causal_conv1d_full` (prefill / first call), both
|
||||||
|
/// ported from mistralrs `gdn.cu`. Each kernel fuses the depthwise conv
|
||||||
|
/// and SiLU activation in one launch — that's ~4× fewer cuda launches per
|
||||||
|
/// linear-attention layer than the candle `conv1d` + `silu` combo.
|
||||||
|
///
|
||||||
|
/// CPU path: the original prepend-narrow-conv1d-silu sequence.
|
||||||
|
pub(crate) fn run_causal_conv1d(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: Option<Tensor>,
|
||||||
|
batch_size: usize,
|
||||||
|
conv_dim: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
conv_kernel_size: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
if x.device().is_cuda() {
|
||||||
|
return run_causal_conv1d_cuda(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
conv_state,
|
||||||
|
batch_size,
|
||||||
|
conv_dim,
|
||||||
|
seq_len,
|
||||||
|
conv_kernel_size,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
run_causal_conv1d_rust(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
conv_state,
|
||||||
|
batch_size,
|
||||||
|
conv_dim,
|
||||||
|
seq_len,
|
||||||
|
conv_kernel_size,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn run_causal_conv1d_cuda(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: Option<Tensor>,
|
||||||
|
batch_size: usize,
|
||||||
|
conv_dim: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
conv_kernel_size: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
// Kernel expects weight as (conv_dim, kernel_size) — squeeze the
|
||||||
|
// depthwise channel-multiplier dim.
|
||||||
|
let w = weight.squeeze(1)?.to_dtype(x.dtype())?.contiguous()?;
|
||||||
|
|
||||||
|
// Decode path: seq_len == 1 AND we have an existing conv_state.
|
||||||
|
// Otherwise (prefill or fresh-start decode), use the full path which
|
||||||
|
// zero-pads on the left internally.
|
||||||
|
if let Some(cs) = conv_state
|
||||||
|
&& seq_len == 1
|
||||||
|
{
|
||||||
|
let cs = cs.contiguous()?;
|
||||||
|
let (output, new_conv_state) =
|
||||||
|
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &cs, conv_kernel_size, true)?;
|
||||||
|
return Ok((output, new_conv_state));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefill / fresh-start: the kernel ignores any prior conv_state and
|
||||||
|
// zero-pads. If we had a non-zero prior state and >1 input tokens
|
||||||
|
// (multi-turn continuation), we'd need to fall back to Rust. Match
|
||||||
|
// mistralrs's behaviour: fresh prefill always.
|
||||||
|
let device = x.device().clone();
|
||||||
|
let zeros_cs = Tensor::zeros((batch_size, conv_dim, conv_kernel_size), x.dtype(), &device)?;
|
||||||
|
let (output, new_conv_state) =
|
||||||
|
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &zeros_cs, conv_kernel_size, false)?;
|
||||||
|
Ok((output, new_conv_state))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_causal_conv1d_rust(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: Option<Tensor>,
|
||||||
|
batch_size: usize,
|
||||||
|
conv_dim: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
conv_kernel_size: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let dtype = x.dtype();
|
||||||
|
let device = x.device().clone();
|
||||||
|
|
||||||
|
let prepended = match &conv_state {
|
||||||
|
Some(prev) => Tensor::cat(&[prev, x], 2)?,
|
||||||
|
None => x.clone(),
|
||||||
|
};
|
||||||
|
let prep_len = prepended.dims()[2];
|
||||||
|
|
||||||
|
let new_state = if prep_len >= conv_kernel_size {
|
||||||
|
prepended.narrow(2, prep_len - conv_kernel_size, conv_kernel_size)?
|
||||||
|
} else {
|
||||||
|
let pad = Tensor::zeros(
|
||||||
|
(batch_size, conv_dim, conv_kernel_size - prep_len),
|
||||||
|
dtype,
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
Tensor::cat(&[&pad, &prepended], 2)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let conv_out = prepended.conv1d(weight, conv_kernel_size - 1, 1, 1, conv_dim)?;
|
||||||
|
let conv_out = conv_out.narrow(2, 0, prep_len)?;
|
||||||
|
let conv_out = candle_nn::ops::silu(&conv_out)?;
|
||||||
|
let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?;
|
||||||
|
Ok((conv_out, new_state))
|
||||||
|
}
|
||||||
|
|
||||||
/// Load a no-bias linear from the ShardedVarBuilder. Weight shape is
|
/// Load a no-bias linear from the ShardedVarBuilder. Weight shape is
|
||||||
/// the standard `[out, in]` order.
|
/// the standard `[out, in]` order.
|
||||||
fn load_linear_no_bias(
|
fn load_linear_no_bias(
|
||||||
|
|||||||
@@ -247,37 +247,17 @@ impl TpQwen3_5GatedDeltaNet {
|
|||||||
let a = self.in_proj_a.forward(x)?;
|
let a = self.in_proj_a.forward(x)?;
|
||||||
|
|
||||||
// ----- State-aware causal Conv1d + SiLU. -----
|
// ----- State-aware causal Conv1d + SiLU. -----
|
||||||
let prepended = match &self.state.conv_state {
|
// Same shared helper as single-GPU — cuda kernel when available.
|
||||||
Some(prev) => Tensor::cat(&[prev, &mixed_qkv_chw], 2)?,
|
let (conv_out, new_state) = crate::harness::arch::qwen3_5::linear_attn::run_causal_conv1d(
|
||||||
None => mixed_qkv_chw.clone(),
|
&mixed_qkv_chw,
|
||||||
};
|
&self.conv1d_weight,
|
||||||
let prep_len = prepended.dims()[2];
|
self.state.conv_state.take(),
|
||||||
let new_state = if prep_len >= self.conv_kernel_size {
|
|
||||||
prepended.narrow(2, prep_len - self.conv_kernel_size, self.conv_kernel_size)?
|
|
||||||
} else {
|
|
||||||
let pad = Tensor::zeros(
|
|
||||||
(
|
|
||||||
batch_size,
|
batch_size,
|
||||||
self.per_rank_conv_dim,
|
self.per_rank_conv_dim,
|
||||||
self.conv_kernel_size - prep_len,
|
seq_len,
|
||||||
),
|
self.conv_kernel_size,
|
||||||
dtype,
|
|
||||||
&device,
|
|
||||||
)?;
|
)?;
|
||||||
Tensor::cat(&[&pad, &prepended], 2)?
|
|
||||||
};
|
|
||||||
self.state.conv_state = Some(new_state);
|
self.state.conv_state = Some(new_state);
|
||||||
|
|
||||||
let conv_out = prepended.conv1d(
|
|
||||||
&self.conv1d_weight,
|
|
||||||
self.conv_kernel_size - 1,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
self.per_rank_conv_dim,
|
|
||||||
)?;
|
|
||||||
let conv_out = conv_out.narrow(2, 0, prep_len)?;
|
|
||||||
let conv_out = candle_nn::ops::silu(&conv_out)?;
|
|
||||||
let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?;
|
|
||||||
let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?;
|
let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
// ----- Split into q, k, v (per-rank head counts). -----
|
// ----- Split into q, k, v (per-rank head counts). -----
|
||||||
|
|||||||
Reference in New Issue
Block a user