feat(stage-8d-1): import mistralrs GDN CUDA kernels — build infra only
Some checks failed
build-prerelease / Build cortex binary (push) Blocked by required conditions
CI / Test (push) Waiting to run
build-prerelease / Resolve version stamps (push) Successful in 29s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m23s
build-prerelease / Build neuron-blackwell (push) Has started running
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
Some checks failed
build-prerelease / Build cortex binary (push) Blocked by required conditions
CI / Test (push) Waiting to run
build-prerelease / Resolve version stamps (push) Successful in 29s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m23s
build-prerelease / Build neuron-blackwell (push) Has started running
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
Stage 8d (new): port the Gated DeltaNet CUDA kernels from
EricLBuehler/mistral.rs to close the ~500x decode performance gap
we measured on Qwen3.6-27B TP-2 (~12s/token in our pure-candle path
vs ~37 T/s in mistralrs on the same hardware).
This commit lays the build infrastructure with zero behavioural
change. Subsequent commits (8d-2 .. 8d-5) wire each kernel into the
qwen3_5 architecture and TP variant.
Added:
- `crates/neuron/build.rs` — uses `cudaforge::KernelBuilder` to compile
every `src/cuda/*.cu` file into `libneuroncuda.a` under the `cuda`
feature, then links it + `cudart`. Mirrors mistralrs's
`mistralrs-core/build.rs` setup verbatim (same NVCC flag set, same
sm_<80 bf16 gate).
- `crates/neuron/src/cuda/gdn.cu` — five kernels ported verbatim from
upstream:
* `gated_delta_rule_recurrence` (V-tiled per-token decode)
* `chunked_gated_delta_rule_recurrence` (BT=64 chunked prefill)
* `causal_conv1d_update` (single-token conv decode)
* `causal_conv1d_full` (multi-token conv prefill)
* `fused_gdn_gating` (beta = sigmoid(b); g = -exp(A_log) *
softplus(a + dt_bias))
- `crates/neuron/src/cuda/gdn.rs` — Rust wrappers around the kernels,
cudarc::CudaSlice::device_ptr boilerplate identical to upstream.
- `crates/neuron/src/cuda/ffi.rs` — `extern "C"` decls (subset of
upstream's ffi.rs covering only the five GDN kernels; MoE / SSM /
top-k decls land here when we absorb those too).
- `crates/neuron/src/cuda/mod.rs` — re-exports + module docs.
Cargo wiring: `cudaforge` added as an optional build-dep, activated
by the `cuda` feature. CPU build is unchanged (the `cuda/` module is
fully `#[cfg(feature = "cuda")]`). The cuda feature build inside the
patched container compiles `gdn.cu` (1 of 1 kernels) and links
clean.
Licensing: upstream files preserve their MIT origin via per-file
comment banners pointing to the mistralrs path. No behaviour-relevant
edits to the .cu kernels — local diff against upstream is just the
banner. The `.rs` wrappers and `ffi.rs` subset are also from upstream;
their structure (module path `crate::cuda::ffi::*`) matches identically
so future kernel imports drop in unchanged.
CPU clippy + 32 lib tests pass; `cargo clippy --features cuda` clean
inside the runner container.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2113,6 +2113,7 @@ dependencies = [
|
|||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
"clap",
|
"clap",
|
||||||
"cortex-core",
|
"cortex-core",
|
||||||
|
"cudaforge",
|
||||||
"cudarc 0.19.7",
|
"cudarc 0.19.7",
|
||||||
"figment",
|
"figment",
|
||||||
"futures",
|
"futures",
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ cuda = [
|
|||||||
"candle-transformers/cuda",
|
"candle-transformers/cuda",
|
||||||
"dep:cudarc",
|
"dep:cudarc",
|
||||||
"dep:half",
|
"dep:half",
|
||||||
|
"dep:cudaforge",
|
||||||
]
|
]
|
||||||
# Use cuDNN for convolution / attention kernels. Requires CUDA.
|
# Use cuDNN for convolution / attention kernels. Requires CUDA.
|
||||||
cudnn = [
|
cudnn = [
|
||||||
@@ -79,3 +80,13 @@ hf-hub = { version = "0.4", features = ["tokio"] }
|
|||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { workspace = true, features = ["test-util"] }
|
tokio = { workspace = true, features = ["test-util"] }
|
||||||
reqwest.workspace = true
|
reqwest.workspace = true
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
# Used by `build.rs` to compile `src/cuda/*.cu` into `libneuroncuda.a`
|
||||||
|
# under the `cuda` feature. Matches mistralrs's upstream build setup
|
||||||
|
# (their `mistralrs-core/build.rs` uses the same constructor).
|
||||||
|
cudaforge = { version = "0.1", optional = true }
|
||||||
|
|
||||||
|
[package.metadata.docs.rs]
|
||||||
|
# Skip the CUDA path on docs.rs (it lacks nvcc).
|
||||||
|
no-default-features = true
|
||||||
|
|||||||
66
crates/neuron/build.rs
Normal file
66
crates/neuron/build.rs
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
//! Build script: compile the CUDA kernels in `src/cuda/*.cu` into a
|
||||||
|
//! static library and link it under the `cuda` feature.
|
||||||
|
//!
|
||||||
|
//! Patterned on `EricLBuehler/mistral.rs::mistralrs-core/build.rs` —
|
||||||
|
//! same `cudaforge::KernelBuilder` invocation, same NVCC flag set.
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
use std::path::PathBuf;
|
||||||
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
|
println!("cargo:rerun-if-changed=src/cuda/");
|
||||||
|
|
||||||
|
let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
||||||
|
|
||||||
|
let mut builder = cudaforge::KernelBuilder::new()
|
||||||
|
.source_glob("src/cuda/*.cu")
|
||||||
|
.out_dir(&build_dir)
|
||||||
|
.arg("-std=c++17")
|
||||||
|
.arg("-O3")
|
||||||
|
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||||
|
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||||
|
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||||
|
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||||
|
.arg("--expt-relaxed-constexpr")
|
||||||
|
.arg("--expt-extended-lambda")
|
||||||
|
.arg("--use_fast_math")
|
||||||
|
.arg("--compiler-options")
|
||||||
|
.arg("-fPIC");
|
||||||
|
|
||||||
|
// sm_<80 doesn't have bf16 intrinsics for WMMA — gate the
|
||||||
|
// bf16-only kernels off in that case. (Mirrors upstream.)
|
||||||
|
if let Some(compute_cap) = builder.get_compute_cap()
|
||||||
|
&& compute_cap < 80
|
||||||
|
{
|
||||||
|
builder = builder.arg("-DNO_BF16_KERNEL");
|
||||||
|
}
|
||||||
|
|
||||||
|
let target = std::env::var("TARGET").unwrap();
|
||||||
|
let out_file = if target.contains("msvc") {
|
||||||
|
build_dir.join("neuroncuda.lib")
|
||||||
|
} else {
|
||||||
|
build_dir.join("libneuroncuda.a")
|
||||||
|
};
|
||||||
|
|
||||||
|
builder
|
||||||
|
.build_lib(out_file)
|
||||||
|
.expect("neuron cuda build failed");
|
||||||
|
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||||
|
println!("cargo:rustc-link-lib=neuroncuda");
|
||||||
|
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||||
|
|
||||||
|
if target.contains("msvc") {
|
||||||
|
// No extra runtime library needed.
|
||||||
|
} else if target.contains("apple")
|
||||||
|
|| target.contains("freebsd")
|
||||||
|
|| target.contains("openbsd")
|
||||||
|
{
|
||||||
|
println!("cargo:rustc-link-lib=dylib=c++");
|
||||||
|
} else if target.contains("android") {
|
||||||
|
println!("cargo:rustc-link-lib=dylib=c++_shared");
|
||||||
|
} else {
|
||||||
|
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
84
crates/neuron/src/cuda/ffi.rs
Normal file
84
crates/neuron/src/cuda/ffi.rs
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
//! FFI declarations for the CUDA kernels in `gdn.cu`.
|
||||||
|
//!
|
||||||
|
//! Subset of `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/ffi.rs`
|
||||||
|
//! covering only the Gated DeltaNet kernels we currently use. Other
|
||||||
|
//! kernels in the upstream file (MoE GEMM, top-k, Mamba selective
|
||||||
|
//! scan, etc.) would land here too as we absorb them.
|
||||||
|
//!
|
||||||
|
//! All function declarations are MIT-licensed from upstream and
|
||||||
|
//! unchanged apart from this header.
|
||||||
|
|
||||||
|
use std::ffi::c_void;
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
unsafe extern "C" {
|
||||||
|
// GDN (Gated Delta Net) kernels for qwen3_5 / Qwen3-Next.
|
||||||
|
pub(crate) fn gated_delta_rule_recurrence(
|
||||||
|
q: *const f32,
|
||||||
|
k: *const f32,
|
||||||
|
v: *const f32,
|
||||||
|
g: *const f32,
|
||||||
|
beta: *const f32,
|
||||||
|
state: *mut f32,
|
||||||
|
output: *mut f32,
|
||||||
|
bh: i32,
|
||||||
|
seq_len: i32,
|
||||||
|
k_dim: i32,
|
||||||
|
v_dim: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
|
||||||
|
/// Chunked GDN recurrence for prefill (processes tokens in BT=64 chunks).
|
||||||
|
pub(crate) fn chunked_gated_delta_rule_recurrence(
|
||||||
|
q: *const f32,
|
||||||
|
k: *const f32,
|
||||||
|
v: *const f32,
|
||||||
|
g: *const f32,
|
||||||
|
beta: *const f32,
|
||||||
|
state: *mut f32,
|
||||||
|
output: *mut f32,
|
||||||
|
bh: i32,
|
||||||
|
seq_len: i32,
|
||||||
|
k_dim: i32,
|
||||||
|
v_dim: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
|
||||||
|
pub(crate) fn causal_conv1d_update(
|
||||||
|
x: *const c_void,
|
||||||
|
weight: *const c_void,
|
||||||
|
conv_state: *mut c_void,
|
||||||
|
output: *mut c_void,
|
||||||
|
batch_size: i32,
|
||||||
|
conv_dim: i32,
|
||||||
|
kernel_size: i32,
|
||||||
|
dtype: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
|
||||||
|
pub(crate) fn causal_conv1d_full(
|
||||||
|
x: *const c_void,
|
||||||
|
weight: *const c_void,
|
||||||
|
conv_state_out: *mut c_void,
|
||||||
|
output: *mut c_void,
|
||||||
|
batch_size: i32,
|
||||||
|
conv_dim: i32,
|
||||||
|
seq_len: i32,
|
||||||
|
kernel_size: i32,
|
||||||
|
dtype: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
|
||||||
|
pub(crate) fn fused_gdn_gating(
|
||||||
|
b: *const c_void,
|
||||||
|
a: *const c_void,
|
||||||
|
a_log: *const f32,
|
||||||
|
dt_bias: *const f32,
|
||||||
|
beta_out: *mut c_void,
|
||||||
|
g_out: *mut c_void,
|
||||||
|
total_elements: i32,
|
||||||
|
num_heads: i32,
|
||||||
|
dtype: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
}
|
||||||
711
crates/neuron/src/cuda/gdn.cu
Normal file
711
crates/neuron/src/cuda/gdn.cu
Normal file
@@ -0,0 +1,711 @@
|
|||||||
|
// Gated DeltaNet CUDA kernels for Qwen3-Next (`model_type = "qwen3_5"`).
|
||||||
|
//
|
||||||
|
// Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
|
||||||
|
// Upstream path: mistralrs-core/src/cuda/gdn.cu. Local edits in this
|
||||||
|
// file are limited to this banner; the kernels are unchanged so a
|
||||||
|
// diff against upstream stays minimal.
|
||||||
|
//
|
||||||
|
// Five kernels exposed via `extern "C"` shims at the bottom:
|
||||||
|
// - gated_delta_rule_recurrence (per-token decode)
|
||||||
|
// - chunked_gated_delta_rule_recurrence (BT=64 chunked prefill)
|
||||||
|
// - causal_conv1d_update (single-token conv decode)
|
||||||
|
// - causal_conv1d_full (multi-token conv prefill)
|
||||||
|
// - fused_gdn_gating (beta = sigmoid(b);
|
||||||
|
// g = -exp(A_log) * softplus(a + dt_bias))
|
||||||
|
|
||||||
|
#include "cuda_bf16.h"
|
||||||
|
#include "cuda_fp16.h"
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 1: gated_delta_rule_recurrence (optimized)
|
||||||
|
//
|
||||||
|
// V-tiled recurrence with compile-time K dimension for register residency.
|
||||||
|
// Grid: (ceil(V/BV), B*H), Block: (BV,). Each thread owns BK registers of
|
||||||
|
// state. Shared memory holds k_buf and q_buf (2*BK floats).
|
||||||
|
//
|
||||||
|
// Optimizations over naive version:
|
||||||
|
// - Template BK -> float s[BK] lives in true registers (1 cycle vs ~30)
|
||||||
|
// - #pragma unroll on all k-loops -> full ILP
|
||||||
|
// - Fused decay+kv_mem pass and fused state_update+output pass
|
||||||
|
// - __fmaf_rn intrinsics for guaranteed fused multiply-add
|
||||||
|
// - BV=64 threads -> 2 warps, 6 blocks/SM on Ampere
|
||||||
|
//
|
||||||
|
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
|
||||||
|
// state: [BH, K, V] (in/out) output: [BH, S, V]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
// Optimized kernel: BK known at compile time -> registers + full unrolling
|
||||||
|
template <int BK, int BV>
|
||||||
|
__global__ void gated_delta_rule_recurrence_kernel_tiled(
|
||||||
|
const float *__restrict__ q, // [BH, S, K]
|
||||||
|
const float *__restrict__ k, // [BH, S, K]
|
||||||
|
const float *__restrict__ v, // [BH, S, V]
|
||||||
|
const float *__restrict__ g, // [BH, S]
|
||||||
|
const float *__restrict__ beta, // [BH, S]
|
||||||
|
float *__restrict__ state, // [BH, K, V]
|
||||||
|
float *__restrict__ output, // [BH, S, V]
|
||||||
|
int seq_len, int v_dim) {
|
||||||
|
|
||||||
|
const int v_tile = blockIdx.x; // which V-tile
|
||||||
|
const int bh = blockIdx.y; // batch*head index
|
||||||
|
const int tid = threadIdx.x; // thread within tile [0, BV)
|
||||||
|
const int v_idx = v_tile * BV + tid; // global V index
|
||||||
|
|
||||||
|
if (v_idx >= v_dim)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Pointers for this (batch, head)
|
||||||
|
const float *q_bh = q + bh * seq_len * BK;
|
||||||
|
const float *k_bh = k + bh * seq_len * BK;
|
||||||
|
const float *v_bh = v + bh * seq_len * v_dim;
|
||||||
|
const float *g_bh = g + bh * seq_len;
|
||||||
|
const float *beta_bh = beta + bh * seq_len;
|
||||||
|
float *state_bh = state + bh * BK * v_dim;
|
||||||
|
float *out_bh = output + bh * seq_len * v_dim;
|
||||||
|
|
||||||
|
// Shared memory: k_buf[BK] + q_buf[BK]
|
||||||
|
__shared__ float k_buf[BK];
|
||||||
|
__shared__ float q_buf[BK];
|
||||||
|
|
||||||
|
// Load state column into registers — BK is compile-time, so this is
|
||||||
|
// a true register array (not spilled to local memory)
|
||||||
|
float s[BK];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
s[j] = state_bh[j * v_dim + v_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int t = 0; t < seq_len; t++) {
|
||||||
|
// Collaboratively load k_t into shared memory
|
||||||
|
// BK / BV loads per thread (e.g. 128/64 = 2)
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = tid; j < BK; j += BV) {
|
||||||
|
k_buf[j] = k_bh[t * BK + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Load scalars for this timestep
|
||||||
|
float decay = expf(g_bh[t]);
|
||||||
|
float beta_t = beta_bh[t];
|
||||||
|
float v_t = v_bh[t * v_dim + v_idx];
|
||||||
|
|
||||||
|
// Fused pass 1: decay state + compute kv_mem
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
s[j] *= decay;
|
||||||
|
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delta rule
|
||||||
|
float delta = (v_t - kv_mem) * beta_t;
|
||||||
|
|
||||||
|
// Collaboratively load q_t into shared memory
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = tid; j < BK; j += BV) {
|
||||||
|
q_buf[j] = q_bh[t * BK + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Fused pass 2: update state + compute output
|
||||||
|
float y_t = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
|
||||||
|
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_bh[t * v_dim + v_idx] = y_t;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write state back
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
state_bh[j * v_dim + v_idx] = s[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback kernel: runtime k_dim, still V-tiled for occupancy
|
||||||
|
template <int BV, int MAX_K>
|
||||||
|
__global__ void gated_delta_rule_recurrence_kernel_fallback(
|
||||||
|
const float *__restrict__ q, const float *__restrict__ k,
|
||||||
|
const float *__restrict__ v, const float *__restrict__ g,
|
||||||
|
const float *__restrict__ beta, float *__restrict__ state,
|
||||||
|
float *__restrict__ output, int seq_len, int k_dim, int v_dim) {
|
||||||
|
|
||||||
|
const int v_tile = blockIdx.x;
|
||||||
|
const int bh = blockIdx.y;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int v_idx = v_tile * BV + tid;
|
||||||
|
|
||||||
|
if (v_idx >= v_dim)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const float *q_bh = q + bh * seq_len * k_dim;
|
||||||
|
const float *k_bh = k + bh * seq_len * k_dim;
|
||||||
|
const float *v_bh = v + bh * seq_len * v_dim;
|
||||||
|
const float *g_bh = g + bh * seq_len;
|
||||||
|
const float *beta_bh = beta + bh * seq_len;
|
||||||
|
float *state_bh = state + bh * k_dim * v_dim;
|
||||||
|
float *out_bh = output + bh * seq_len * v_dim;
|
||||||
|
|
||||||
|
extern __shared__ float shared[];
|
||||||
|
float *k_buf = shared;
|
||||||
|
float *q_buf = shared + k_dim;
|
||||||
|
|
||||||
|
float s[MAX_K];
|
||||||
|
for (int j = 0; j < k_dim; j++) {
|
||||||
|
s[j] = state_bh[j * v_dim + v_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int t = 0; t < seq_len; t++) {
|
||||||
|
for (int j = tid; j < k_dim; j += BV) {
|
||||||
|
k_buf[j] = k_bh[t * k_dim + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float decay = expf(g_bh[t]);
|
||||||
|
float beta_t = beta_bh[t];
|
||||||
|
float v_t = v_bh[t * v_dim + v_idx];
|
||||||
|
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
for (int j = 0; j < k_dim; j++) {
|
||||||
|
s[j] *= decay;
|
||||||
|
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
|
||||||
|
}
|
||||||
|
|
||||||
|
float delta = (v_t - kv_mem) * beta_t;
|
||||||
|
|
||||||
|
for (int j = tid; j < k_dim; j += BV) {
|
||||||
|
q_buf[j] = q_bh[t * k_dim + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float y_t = 0.0f;
|
||||||
|
for (int j = 0; j < k_dim; j++) {
|
||||||
|
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
|
||||||
|
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_bh[t * v_dim + v_idx] = y_t;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j = 0; j < k_dim; j++) {
|
||||||
|
state_bh[j * v_dim + v_idx] = s[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void gated_delta_rule_recurrence(const float *q, const float *k,
|
||||||
|
const float *v, const float *g,
|
||||||
|
const float *beta, float *state,
|
||||||
|
float *output, int bh, int seq_len,
|
||||||
|
int k_dim, int v_dim,
|
||||||
|
int64_t stream) {
|
||||||
|
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
|
||||||
|
if (k_dim == 128) {
|
||||||
|
// Fast path for Qwen3-Next (k_dim=128)
|
||||||
|
constexpr int BK = 128;
|
||||||
|
constexpr int BV = 64;
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
|
||||||
|
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
|
||||||
|
v_dim);
|
||||||
|
} else if (k_dim == 64) {
|
||||||
|
// Fast path for models with k_dim=64
|
||||||
|
constexpr int BK = 64;
|
||||||
|
constexpr int BV = 64;
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
|
||||||
|
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
|
||||||
|
v_dim);
|
||||||
|
} else {
|
||||||
|
// Fallback for other k_dim values (runtime loop, still V-tiled)
|
||||||
|
constexpr int BV = 64;
|
||||||
|
constexpr int MAX_K = 256;
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
size_t smem = 2 * k_dim * sizeof(float);
|
||||||
|
gated_delta_rule_recurrence_kernel_fallback<BV, MAX_K>
|
||||||
|
<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||||
|
seq_len, k_dim, v_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 1b: chunked_gated_delta_rule_recurrence (prefill optimization)
|
||||||
|
//
|
||||||
|
// Processes prefill tokens in BT-token chunks instead of one at a time.
|
||||||
|
// Within each chunk: parallel prefix sum of g, cooperative kk_dot computation,
|
||||||
|
// forward substitution (triangular solve), output computation, and state
|
||||||
|
// update.
|
||||||
|
//
|
||||||
|
// Same thread model as Kernel 1: one block per (v_tile, batch*head),
|
||||||
|
// one thread per V-column. Each thread owns BK registers of state.
|
||||||
|
//
|
||||||
|
// Shared memory holds:
|
||||||
|
// k_chunk[BT * BK] -- key vectors for current chunk
|
||||||
|
// kk_dot[BT * BT] -- dot(k[i], k[j]) lower-triangular matrix
|
||||||
|
// gcum[BT] -- cumulative sum of g within chunk
|
||||||
|
// beta_s[BT] -- beta values for chunk
|
||||||
|
// q_buf[BK] -- q vector (loaded one row at a time)
|
||||||
|
//
|
||||||
|
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
|
||||||
|
// state: [BH, K, V] (in/out) output: [BH, S, V]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
template <int BT, int BK, int BV>
|
||||||
|
__global__ void
|
||||||
|
chunked_gated_delta_rule_kernel(const float *__restrict__ q, // [BH, S, K]
|
||||||
|
const float *__restrict__ k, // [BH, S, K]
|
||||||
|
const float *__restrict__ v, // [BH, S, V]
|
||||||
|
const float *__restrict__ g, // [BH, S]
|
||||||
|
const float *__restrict__ beta, // [BH, S]
|
||||||
|
float *__restrict__ state, // [BH, K, V]
|
||||||
|
float *__restrict__ output, // [BH, S, V]
|
||||||
|
int seq_len, int v_dim) {
|
||||||
|
|
||||||
|
const int v_tile = blockIdx.x;
|
||||||
|
const int bh = blockIdx.y;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int v_idx = v_tile * BV + tid;
|
||||||
|
|
||||||
|
if (v_idx >= v_dim)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const int num_chunks = (seq_len + BT - 1) / BT;
|
||||||
|
|
||||||
|
// Pointers for this (batch, head)
|
||||||
|
const float *q_bh = q + bh * seq_len * BK;
|
||||||
|
const float *k_bh = k + bh * seq_len * BK;
|
||||||
|
const float *v_bh = v + bh * seq_len * v_dim;
|
||||||
|
const float *g_bh = g + bh * seq_len;
|
||||||
|
const float *beta_bh = beta + bh * seq_len;
|
||||||
|
float *state_bh = state + bh * BK * v_dim;
|
||||||
|
float *out_bh = output + bh * seq_len * v_dim;
|
||||||
|
|
||||||
|
// Dynamic shared memory layout
|
||||||
|
extern __shared__ float smem[];
|
||||||
|
float *k_chunk = smem; // [BT * BK]
|
||||||
|
float *kk_dot = smem + BT * BK; // [BT * BT]
|
||||||
|
float *gcum = smem + BT * BK + BT * BT; // [BT]
|
||||||
|
float *beta_s = gcum + BT; // [BT]
|
||||||
|
float *q_buf = beta_s + BT; // [BK]
|
||||||
|
|
||||||
|
// Load state column into registers
|
||||||
|
float s[BK];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
s[j] = state_bh[j * v_dim + v_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-thread register array for corrected deltas
|
||||||
|
float delta[BT];
|
||||||
|
|
||||||
|
for (int c = 0; c < num_chunks; c++) {
|
||||||
|
const int chunk_start = c * BT;
|
||||||
|
const int chunk_len = min(BT, seq_len - chunk_start);
|
||||||
|
|
||||||
|
// === Phase 1: Cooperative load of k, beta, g into shared memory ===
|
||||||
|
for (int t = 0; t < chunk_len; t++) {
|
||||||
|
for (int j = tid; j < BK; j += BV) {
|
||||||
|
k_chunk[t * BK + j] = k_bh[(chunk_start + t) * BK + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tid < chunk_len) {
|
||||||
|
beta_s[tid] = beta_bh[chunk_start + tid];
|
||||||
|
gcum[tid] = g_bh[chunk_start + tid];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// === Phase 1b: Parallel prefix sum of g (Hillis-Steele) ===
|
||||||
|
for (int stride = 1; stride < BT; stride <<= 1) {
|
||||||
|
float prev = 0.0f;
|
||||||
|
if (tid < chunk_len && (int)tid >= stride)
|
||||||
|
prev = gcum[tid - stride];
|
||||||
|
__syncthreads();
|
||||||
|
if (tid < chunk_len && (int)tid >= stride)
|
||||||
|
gcum[tid] += prev;
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Phase 2: Compute kk_dot[i][j] = dot(k[i], k[j]) for j < i ===
|
||||||
|
// Only lower-triangular entries needed (strictly lower)
|
||||||
|
for (int idx = tid; idx < chunk_len * chunk_len; idx += BV) {
|
||||||
|
int i = idx / chunk_len;
|
||||||
|
int j = idx % chunk_len;
|
||||||
|
if (j < i) {
|
||||||
|
float dot = 0.0f;
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
dot = __fmaf_rn(k_chunk[i * BK + d], k_chunk[j * BK + d], dot);
|
||||||
|
}
|
||||||
|
kk_dot[i * BT + j] = dot;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// === Phase 3: Forward substitution (per V-column, in registers) ===
|
||||||
|
// Computes corrected delta values via triangular solve
|
||||||
|
for (int i = 0; i < chunk_len; i++) {
|
||||||
|
float v_i = v_bh[(chunk_start + i) * v_dim + v_idx];
|
||||||
|
float decay_i = expf(gcum[i]);
|
||||||
|
float beta_i = beta_s[i];
|
||||||
|
|
||||||
|
// Inter-chunk contribution: state @ k[i] with decay
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
kv_mem = __fmaf_rn(s[d] * decay_i, k_chunk[i * BK + d], kv_mem);
|
||||||
|
}
|
||||||
|
|
||||||
|
float rhs = beta_i * (v_i - kv_mem);
|
||||||
|
|
||||||
|
// Subtract lower-triangular contributions (intra-chunk)
|
||||||
|
for (int j = 0; j < i; j++) {
|
||||||
|
float a_ij = beta_i * kk_dot[i * BT + j] * expf(gcum[i] - gcum[j]);
|
||||||
|
rhs -= a_ij * delta[j];
|
||||||
|
}
|
||||||
|
delta[i] = rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Phase 4: Output computation (per V-column) ===
|
||||||
|
for (int i = 0; i < chunk_len; i++) {
|
||||||
|
// Cooperatively load q[i] into shared
|
||||||
|
for (int j = tid; j < BK; j += BV) {
|
||||||
|
q_buf[j] = q_bh[(chunk_start + i) * BK + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float decay_i = expf(gcum[i]);
|
||||||
|
|
||||||
|
// Inter-chunk contribution: q[i] @ (state * decay)
|
||||||
|
float o_val = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
o_val = __fmaf_rn(q_buf[d], s[d] * decay_i, o_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intra-chunk contribution: sum_{j<=i} dot(q[i], k[j]) * delta[j] *
|
||||||
|
// exp(gcum[i] - gcum[j])
|
||||||
|
for (int j = 0; j <= i; j++) {
|
||||||
|
float qk_dot = 0.0f;
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
qk_dot = __fmaf_rn(q_buf[d], k_chunk[j * BK + d], qk_dot);
|
||||||
|
}
|
||||||
|
o_val += qk_dot * delta[j] * expf(gcum[i] - gcum[j]);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_bh[(chunk_start + i) * v_dim + v_idx] = o_val;
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Phase 5: State update for next chunk ===
|
||||||
|
float g_total = gcum[chunk_len - 1];
|
||||||
|
#pragma unroll
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
float s_new = s[d] * expf(g_total);
|
||||||
|
for (int t = 0; t < chunk_len; t++) {
|
||||||
|
s_new += k_chunk[t * BK + d] * delta[t] * expf(g_total - gcum[t]);
|
||||||
|
}
|
||||||
|
s[d] = s_new;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write final state back
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
state_bh[j * v_dim + v_idx] = s[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void chunked_gated_delta_rule_recurrence(
|
||||||
|
const float *q, const float *k, const float *v, const float *g,
|
||||||
|
const float *beta, float *state, float *output, int bh, int seq_len,
|
||||||
|
int k_dim, int v_dim, int64_t stream) {
|
||||||
|
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
|
||||||
|
if (k_dim == 128) {
|
||||||
|
constexpr int BT = 64;
|
||||||
|
constexpr int BK = 128;
|
||||||
|
constexpr int BV = 64;
|
||||||
|
// Shared memory: BT*BK + BT*BT + BT + BT + BK floats
|
||||||
|
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
|
||||||
|
|
||||||
|
// Request extended shared memory
|
||||||
|
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
|
||||||
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
smem);
|
||||||
|
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||||
|
seq_len, v_dim);
|
||||||
|
} else if (k_dim == 64) {
|
||||||
|
constexpr int BT = 64;
|
||||||
|
constexpr int BK = 64;
|
||||||
|
constexpr int BV = 64;
|
||||||
|
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
|
||||||
|
|
||||||
|
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
|
||||||
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
smem);
|
||||||
|
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||||
|
seq_len, v_dim);
|
||||||
|
} else {
|
||||||
|
// Fallback: use the sequential kernel for unsupported k_dim
|
||||||
|
gated_delta_rule_recurrence(q, k, v, g, beta, state, output, bh, seq_len,
|
||||||
|
k_dim, v_dim, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 2a: causal_conv1d_update (decode path, single step)
|
||||||
|
//
|
||||||
|
// Each thread handles one channel: shift conv_state left by 1,
|
||||||
|
// insert new value, dot product with weight, apply SiLU.
|
||||||
|
//
|
||||||
|
// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
|
||||||
|
// conv_state: [B, conv_dim, kernel_size] (in/out)
|
||||||
|
// output: [B, conv_dim, 1]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void causal_conv1d_update_kernel(
|
||||||
|
const T *__restrict__ x, // [B, conv_dim, 1]
|
||||||
|
const T *__restrict__ weight, // [conv_dim, kernel_size]
|
||||||
|
T *__restrict__ conv_state, // [B, conv_dim, kernel_size]
|
||||||
|
T *__restrict__ output, // [B, conv_dim, 1]
|
||||||
|
int batch_size, int conv_dim, int kernel_size) {
|
||||||
|
|
||||||
|
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int b = blockIdx.y;
|
||||||
|
|
||||||
|
if (ch >= conv_dim || b >= batch_size)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Pointer to this batch/channel's conv state
|
||||||
|
T *cs = conv_state + (b * conv_dim + ch) * kernel_size;
|
||||||
|
const T *w = weight + ch * kernel_size;
|
||||||
|
|
||||||
|
// Shift state left by 1
|
||||||
|
for (int i = 0; i < kernel_size - 1; i++) {
|
||||||
|
cs[i] = cs[i + 1];
|
||||||
|
}
|
||||||
|
// Insert new value
|
||||||
|
cs[kernel_size - 1] = x[b * conv_dim + ch];
|
||||||
|
|
||||||
|
// Dot product with weight
|
||||||
|
float acc = 0.0f;
|
||||||
|
for (int i = 0; i < kernel_size; i++) {
|
||||||
|
acc += (float)cs[i] * (float)w[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// SiLU activation: x * sigmoid(x)
|
||||||
|
float sig = 1.0f / (1.0f + expf(-acc));
|
||||||
|
float result = acc * sig;
|
||||||
|
|
||||||
|
output[b * conv_dim + ch] = (T)result;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void causal_conv1d_update(const void *x, const void *weight,
|
||||||
|
void *conv_state, void *output,
|
||||||
|
int batch_size, int conv_dim,
|
||||||
|
int kernel_size, int dtype,
|
||||||
|
int64_t stream) {
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
dim3 block(256);
|
||||||
|
dim3 grid((conv_dim + 255) / 256, batch_size);
|
||||||
|
|
||||||
|
if (dtype == 0) {
|
||||||
|
// f16
|
||||||
|
causal_conv1d_update_kernel<__half><<<grid, block, 0, custream>>>(
|
||||||
|
(const __half *)x, (const __half *)weight, (__half *)conv_state,
|
||||||
|
(__half *)output, batch_size, conv_dim, kernel_size);
|
||||||
|
} else {
|
||||||
|
// bf16
|
||||||
|
causal_conv1d_update_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||||
|
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
|
||||||
|
(__nv_bfloat16 *)conv_state, (__nv_bfloat16 *)output, batch_size,
|
||||||
|
conv_dim, kernel_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 2b: causal_conv1d_full (prefill path)
|
||||||
|
//
|
||||||
|
// Each thread handles one (channel, position): causal window with
|
||||||
|
// zero-padding, dot product with weight, SiLU.
|
||||||
|
// A second pass writes the conv_state from the last kernel_size positions.
|
||||||
|
//
|
||||||
|
// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
|
||||||
|
// conv_state_out: [B, conv_dim, kernel_size] output: [B, conv_dim, S]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void causal_conv1d_full_kernel(
|
||||||
|
const T *__restrict__ x, // [B, conv_dim, S]
|
||||||
|
const T *__restrict__ weight, // [conv_dim, kernel_size]
|
||||||
|
T *__restrict__ output, // [B, conv_dim, S]
|
||||||
|
int batch_size, int conv_dim, int seq_len, int kernel_size) {
|
||||||
|
|
||||||
|
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int pos = blockIdx.y;
|
||||||
|
const int b = blockIdx.z;
|
||||||
|
|
||||||
|
if (ch >= conv_dim || pos >= seq_len || b >= batch_size)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
|
||||||
|
const T *w = weight + ch * kernel_size;
|
||||||
|
|
||||||
|
// Causal convolution: sum over kernel_size window ending at pos
|
||||||
|
float acc = 0.0f;
|
||||||
|
for (int i = 0; i < kernel_size; i++) {
|
||||||
|
int src_pos = pos - (kernel_size - 1) + i;
|
||||||
|
float x_val = (src_pos >= 0) ? (float)x_bch[src_pos] : 0.0f;
|
||||||
|
acc += x_val * (float)w[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// SiLU
|
||||||
|
float sig = 1.0f / (1.0f + expf(-acc));
|
||||||
|
float result = acc * sig;
|
||||||
|
|
||||||
|
output[(b * conv_dim + ch) * seq_len + pos] = (T)result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void save_conv_state_kernel(
|
||||||
|
const T *__restrict__ x, // [B, conv_dim, S]
|
||||||
|
T *__restrict__ conv_state_out, // [B, conv_dim, kernel_size]
|
||||||
|
int batch_size, int conv_dim, int seq_len, int kernel_size) {
|
||||||
|
|
||||||
|
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int b = blockIdx.y;
|
||||||
|
|
||||||
|
if (ch >= conv_dim || b >= batch_size)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
|
||||||
|
T *cs = conv_state_out + (b * conv_dim + ch) * kernel_size;
|
||||||
|
|
||||||
|
// Save last kernel_size positions (zero-pad if seq_len < kernel_size)
|
||||||
|
int pad = kernel_size - seq_len;
|
||||||
|
for (int i = 0; i < kernel_size; i++) {
|
||||||
|
if (i < pad) {
|
||||||
|
cs[i] = (T)0.0f;
|
||||||
|
} else {
|
||||||
|
cs[i] = x_bch[seq_len - kernel_size + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void causal_conv1d_full(const void *x, const void *weight,
|
||||||
|
void *conv_state_out, void *output,
|
||||||
|
int batch_size, int conv_dim, int seq_len,
|
||||||
|
int kernel_size, int dtype, int64_t stream) {
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
|
||||||
|
// Main convolution kernel
|
||||||
|
dim3 block(256);
|
||||||
|
dim3 grid((conv_dim + 255) / 256, seq_len, batch_size);
|
||||||
|
|
||||||
|
if (dtype == 0) {
|
||||||
|
causal_conv1d_full_kernel<__half><<<grid, block, 0, custream>>>(
|
||||||
|
(const __half *)x, (const __half *)weight, (__half *)output, batch_size,
|
||||||
|
conv_dim, seq_len, kernel_size);
|
||||||
|
// Save conv state
|
||||||
|
dim3 grid2((conv_dim + 255) / 256, batch_size);
|
||||||
|
save_conv_state_kernel<__half><<<grid2, block, 0, custream>>>(
|
||||||
|
(const __half *)x, (__half *)conv_state_out, batch_size, conv_dim,
|
||||||
|
seq_len, kernel_size);
|
||||||
|
} else {
|
||||||
|
causal_conv1d_full_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||||
|
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
|
||||||
|
(__nv_bfloat16 *)output, batch_size, conv_dim, seq_len, kernel_size);
|
||||||
|
dim3 grid2((conv_dim + 255) / 256, batch_size);
|
||||||
|
save_conv_state_kernel<__nv_bfloat16><<<grid2, block, 0, custream>>>(
|
||||||
|
(const __nv_bfloat16 *)x, (__nv_bfloat16 *)conv_state_out, batch_size,
|
||||||
|
conv_dim, seq_len, kernel_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 3: fused_gdn_gating
|
||||||
|
//
|
||||||
|
// Fuses: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
|
||||||
|
// a_log and dt_bias are per-head (broadcast over batch*seq).
|
||||||
|
//
|
||||||
|
// b, a: [total] a_log, dt_bias: [num_heads]
|
||||||
|
// beta_out, g_out: [total]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void
|
||||||
|
fused_gdn_gating_kernel(const T *__restrict__ b, // [total]
|
||||||
|
const T *__restrict__ a, // [total]
|
||||||
|
const float *__restrict__ a_log, // [num_heads]
|
||||||
|
const float *__restrict__ dt_bias, // [num_heads]
|
||||||
|
T *__restrict__ beta_out, // [total]
|
||||||
|
T *__restrict__ g_out, // [total]
|
||||||
|
int total_elements, int num_heads) {
|
||||||
|
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx >= total_elements)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Head index: elements are laid out as [..., num_heads]
|
||||||
|
int head_idx = idx % num_heads;
|
||||||
|
|
||||||
|
// beta = sigmoid(b)
|
||||||
|
float b_val = (float)b[idx];
|
||||||
|
float beta = 1.0f / (1.0f + expf(-b_val));
|
||||||
|
|
||||||
|
// g = -exp(a_log) * softplus(a + dt_bias)
|
||||||
|
float a_val = (float)a[idx];
|
||||||
|
float a_log_val = a_log[head_idx];
|
||||||
|
float dt_bias_val = dt_bias[head_idx];
|
||||||
|
|
||||||
|
float sp_input = a_val + dt_bias_val;
|
||||||
|
float softplus_val = logf(1.0f + expf(sp_input));
|
||||||
|
float g_val = -expf(a_log_val) * softplus_val;
|
||||||
|
|
||||||
|
beta_out[idx] = (T)beta;
|
||||||
|
g_out[idx] = (T)g_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void fused_gdn_gating(const void *b, const void *a,
|
||||||
|
const float *a_log, const float *dt_bias,
|
||||||
|
void *beta_out, void *g_out,
|
||||||
|
int total_elements, int num_heads, int dtype,
|
||||||
|
int64_t stream) {
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
dim3 block(256);
|
||||||
|
dim3 grid((total_elements + 255) / 256);
|
||||||
|
|
||||||
|
if (dtype == 0) {
|
||||||
|
fused_gdn_gating_kernel<__half><<<grid, block, 0, custream>>>(
|
||||||
|
(const __half *)b, (const __half *)a, a_log, dt_bias,
|
||||||
|
(__half *)beta_out, (__half *)g_out, total_elements, num_heads);
|
||||||
|
} else {
|
||||||
|
fused_gdn_gating_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||||
|
(const __nv_bfloat16 *)b, (const __nv_bfloat16 *)a, a_log, dt_bias,
|
||||||
|
(__nv_bfloat16 *)beta_out, (__nv_bfloat16 *)g_out, total_elements,
|
||||||
|
num_heads);
|
||||||
|
}
|
||||||
|
}
|
||||||
486
crates/neuron/src/cuda/gdn.rs
Normal file
486
crates/neuron/src/cuda/gdn.rs
Normal file
@@ -0,0 +1,486 @@
|
|||||||
|
//! Rust wrappers around the Gated DeltaNet CUDA kernels in `gdn.cu`.
|
||||||
|
//!
|
||||||
|
//! Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
|
||||||
|
//! Upstream path: `mistralrs-core/src/cuda/gdn.rs`. The only edits in
|
||||||
|
//! this file are this header comment — the FFI path module name is
|
||||||
|
//! `crate::cuda::ffi`, identical to upstream's layout.
|
||||||
|
|
||||||
|
#![allow(clippy::cast_possible_truncation)]
|
||||||
|
|
||||||
|
use candle_core::{Result, Tensor};
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use candle_core::DType;
|
||||||
|
|
||||||
|
/// CUDA-accelerated gated delta rule recurrence.
|
||||||
|
///
|
||||||
|
/// Inputs (all contiguous, f32):
|
||||||
|
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
|
||||||
|
/// state: [BH, K, V] (mutated in place)
|
||||||
|
///
|
||||||
|
/// Returns: output [BH, S, V]
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn gated_delta_rule_recurrence_cuda(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
state: &mut Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle_core as candle;
|
||||||
|
|
||||||
|
let (bh, seq_len, k_dim) = q.dims3()?;
|
||||||
|
let v_dim = v.dim(2)?;
|
||||||
|
|
||||||
|
let dev = q.device().as_cuda_device()?;
|
||||||
|
|
||||||
|
let (q_s, q_l) = q.storage_and_layout();
|
||||||
|
let q_s = match &*q_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("q must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let q_offset = q_l.start_offset();
|
||||||
|
|
||||||
|
let (k_s, k_l) = k.storage_and_layout();
|
||||||
|
let k_s = match &*k_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("k must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let k_offset = k_l.start_offset();
|
||||||
|
|
||||||
|
let (v_s, v_l) = v.storage_and_layout();
|
||||||
|
let v_s = match &*v_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("v must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let v_offset = v_l.start_offset();
|
||||||
|
|
||||||
|
let (g_s, g_l) = g.storage_and_layout();
|
||||||
|
let g_s = match &*g_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("g must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let g_offset = g_l.start_offset();
|
||||||
|
|
||||||
|
let (beta_s, beta_l) = beta.storage_and_layout();
|
||||||
|
let beta_s = match &*beta_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("beta must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let beta_offset = beta_l.start_offset();
|
||||||
|
|
||||||
|
let (state_s, state_l) = state.storage_and_layout();
|
||||||
|
let state_s = match &*state_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("state must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let state_offset = state_l.start_offset();
|
||||||
|
|
||||||
|
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::gated_delta_rule_recurrence(
|
||||||
|
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
|
||||||
|
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
|
||||||
|
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
|
||||||
|
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
|
||||||
|
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
|
||||||
|
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
|
||||||
|
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
|
||||||
|
bh as i32,
|
||||||
|
seq_len as i32,
|
||||||
|
k_dim as i32,
|
||||||
|
v_dim as i32,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The kernel wrote state in-place via the raw pointer; rewrap
|
||||||
|
// (state tensor's underlying CudaSlice was modified directly)
|
||||||
|
|
||||||
|
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||||
|
Ok(Tensor::from((
|
||||||
|
candle::Storage::Cuda(output_storage),
|
||||||
|
(bh, seq_len, v_dim),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn gated_delta_rule_recurrence_cuda(
|
||||||
|
_q: &Tensor,
|
||||||
|
_k: &Tensor,
|
||||||
|
_v: &Tensor,
|
||||||
|
_g: &Tensor,
|
||||||
|
_beta: &Tensor,
|
||||||
|
_state: &mut Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
candle_core::bail!("gated_delta_rule_recurrence_cuda requires the cuda feature")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CUDA-accelerated chunked gated delta rule recurrence (prefill optimization).
|
||||||
|
///
|
||||||
|
/// Processes prefill tokens in 64-token chunks instead of one at a time.
|
||||||
|
/// Same interface as `gated_delta_rule_recurrence_cuda`.
|
||||||
|
///
|
||||||
|
/// Inputs (all contiguous, f32):
|
||||||
|
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
|
||||||
|
/// state: [BH, K, V] (mutated in place)
|
||||||
|
///
|
||||||
|
/// Returns: output [BH, S, V]
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn chunked_gated_delta_rule_recurrence_cuda(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
state: &mut Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle_core as candle;
|
||||||
|
|
||||||
|
let (bh, seq_len, k_dim) = q.dims3()?;
|
||||||
|
let v_dim = v.dim(2)?;
|
||||||
|
|
||||||
|
let dev = q.device().as_cuda_device()?;
|
||||||
|
|
||||||
|
let (q_s, q_l) = q.storage_and_layout();
|
||||||
|
let q_s = match &*q_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("q must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let q_offset = q_l.start_offset();
|
||||||
|
|
||||||
|
let (k_s, k_l) = k.storage_and_layout();
|
||||||
|
let k_s = match &*k_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("k must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let k_offset = k_l.start_offset();
|
||||||
|
|
||||||
|
let (v_s, v_l) = v.storage_and_layout();
|
||||||
|
let v_s = match &*v_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("v must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let v_offset = v_l.start_offset();
|
||||||
|
|
||||||
|
let (g_s, g_l) = g.storage_and_layout();
|
||||||
|
let g_s = match &*g_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("g must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let g_offset = g_l.start_offset();
|
||||||
|
|
||||||
|
let (beta_s, beta_l) = beta.storage_and_layout();
|
||||||
|
let beta_s = match &*beta_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("beta must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let beta_offset = beta_l.start_offset();
|
||||||
|
|
||||||
|
let (state_s, state_l) = state.storage_and_layout();
|
||||||
|
let state_s = match &*state_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("state must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let state_offset = state_l.start_offset();
|
||||||
|
|
||||||
|
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::chunked_gated_delta_rule_recurrence(
|
||||||
|
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
|
||||||
|
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
|
||||||
|
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
|
||||||
|
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
|
||||||
|
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
|
||||||
|
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
|
||||||
|
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
|
||||||
|
bh as i32,
|
||||||
|
seq_len as i32,
|
||||||
|
k_dim as i32,
|
||||||
|
v_dim as i32,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||||
|
Ok(Tensor::from((
|
||||||
|
candle::Storage::Cuda(output_storage),
|
||||||
|
(bh, seq_len, v_dim),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn chunked_gated_delta_rule_recurrence_cuda(
|
||||||
|
_q: &Tensor,
|
||||||
|
_k: &Tensor,
|
||||||
|
_v: &Tensor,
|
||||||
|
_g: &Tensor,
|
||||||
|
_beta: &Tensor,
|
||||||
|
_state: &mut Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
candle_core::bail!("chunked_gated_delta_rule_recurrence_cuda requires the cuda feature")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CUDA-accelerated causal conv1d (both update and full paths).
|
||||||
|
///
|
||||||
|
/// For update (is_update=true):
|
||||||
|
/// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
|
||||||
|
/// conv_state: [B, conv_dim, kernel_size] (mutated in place for update)
|
||||||
|
/// Returns: (output [B, conv_dim, 1], updated conv_state)
|
||||||
|
///
|
||||||
|
/// For full (is_update=false):
|
||||||
|
/// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
|
||||||
|
/// Returns: (output [B, conv_dim, S], new conv_state [B, conv_dim, kernel_size])
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn causal_conv1d_cuda(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: &Tensor,
|
||||||
|
kernel_size: usize,
|
||||||
|
is_update: bool,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle_core as candle;
|
||||||
|
use core::ffi::c_void;
|
||||||
|
fn cuda_fwd<
|
||||||
|
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||||
|
>(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: &Tensor,
|
||||||
|
kernel_size: usize,
|
||||||
|
is_update: bool,
|
||||||
|
dtype_code: i32,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let dev = x.device().as_cuda_device()?;
|
||||||
|
let (batch_size, conv_dim, seq_len) = x.dims3()?;
|
||||||
|
|
||||||
|
let (x_s, x_l) = x.storage_and_layout();
|
||||||
|
let x_s = match &*x_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("x must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let x_offset = x_l.start_offset();
|
||||||
|
|
||||||
|
let (w_s, w_l) = weight.storage_and_layout();
|
||||||
|
let w_s = match &*w_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("weight must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let w_offset = w_l.start_offset();
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||||
|
|
||||||
|
if is_update {
|
||||||
|
// Clone conv_state so the kernel can mutate it in place
|
||||||
|
let conv_state_new = conv_state.clone();
|
||||||
|
|
||||||
|
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim) }?;
|
||||||
|
|
||||||
|
// Scope the borrow of conv_state_new so we can move it later
|
||||||
|
{
|
||||||
|
let (cs_s, cs_l) = conv_state_new.storage_and_layout();
|
||||||
|
let cs_s = match &*cs_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("conv_state must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let cs_offset = cs_l.start_offset();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::causal_conv1d_update(
|
||||||
|
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
|
||||||
|
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
|
||||||
|
cs_s.slice(cs_offset..).device_ptr(cs_s.stream()).0 as *mut c_void,
|
||||||
|
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
|
||||||
|
batch_size as i32,
|
||||||
|
conv_dim as i32,
|
||||||
|
kernel_size as i32,
|
||||||
|
dtype_code,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||||
|
let output = Tensor::from((
|
||||||
|
candle::Storage::Cuda(output_storage),
|
||||||
|
(batch_size, conv_dim, 1usize),
|
||||||
|
));
|
||||||
|
|
||||||
|
Ok((output, conv_state_new))
|
||||||
|
} else {
|
||||||
|
// Full path: allocate new conv_state and output
|
||||||
|
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * seq_len) }?;
|
||||||
|
let cs_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * kernel_size) }?;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::causal_conv1d_full(
|
||||||
|
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
|
||||||
|
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
|
||||||
|
cs_buf.device_ptr(cs_buf.stream()).0 as *mut c_void,
|
||||||
|
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
|
||||||
|
batch_size as i32,
|
||||||
|
conv_dim as i32,
|
||||||
|
seq_len as i32,
|
||||||
|
kernel_size as i32,
|
||||||
|
dtype_code,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||||
|
let output = Tensor::from((
|
||||||
|
candle::Storage::Cuda(output_storage),
|
||||||
|
(batch_size, conv_dim, seq_len),
|
||||||
|
));
|
||||||
|
|
||||||
|
let cs_storage = candle::CudaStorage::wrap_cuda_slice(cs_buf, dev.clone());
|
||||||
|
let new_conv_state = Tensor::from((
|
||||||
|
candle::Storage::Cuda(cs_storage),
|
||||||
|
(batch_size, conv_dim, kernel_size),
|
||||||
|
));
|
||||||
|
|
||||||
|
Ok((output, new_conv_state))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match x.dtype() {
|
||||||
|
DType::F16 => cuda_fwd::<half::f16>(x, weight, conv_state, kernel_size, is_update, 0),
|
||||||
|
DType::BF16 => cuda_fwd::<half::bf16>(x, weight, conv_state, kernel_size, is_update, 1),
|
||||||
|
other => candle_core::bail!("causal_conv1d_cuda only supports f16/bf16, got {:?}", other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn causal_conv1d_cuda(
|
||||||
|
_x: &Tensor,
|
||||||
|
_weight: &Tensor,
|
||||||
|
_conv_state: &Tensor,
|
||||||
|
_kernel_size: usize,
|
||||||
|
_is_update: bool,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
candle_core::bail!("causal_conv1d_cuda requires the cuda feature")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CUDA-accelerated fused GDN gating computation.
|
||||||
|
///
|
||||||
|
/// Computes: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
|
||||||
|
///
|
||||||
|
/// b, a: [total_elements] in f16/bf16
|
||||||
|
/// a_log, dt_bias: [num_heads] in f32
|
||||||
|
///
|
||||||
|
/// Returns: (beta, g) in original dtype
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn fused_gdn_gating_cuda(
|
||||||
|
b: &Tensor,
|
||||||
|
a: &Tensor,
|
||||||
|
a_log: &Tensor,
|
||||||
|
dt_bias: &Tensor,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle_core as candle;
|
||||||
|
use core::ffi::c_void;
|
||||||
|
|
||||||
|
fn cuda_fwd<
|
||||||
|
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||||
|
>(
|
||||||
|
b: &Tensor,
|
||||||
|
a: &Tensor,
|
||||||
|
a_log: &Tensor,
|
||||||
|
dt_bias: &Tensor,
|
||||||
|
dtype_code: i32,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let total_elements = b.elem_count();
|
||||||
|
let num_heads = a_log.elem_count();
|
||||||
|
let shape = b.shape().clone();
|
||||||
|
let dev = b.device().as_cuda_device()?;
|
||||||
|
|
||||||
|
let (b_s, b_l) = b.storage_and_layout();
|
||||||
|
let b_s = match &*b_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("b must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let b_offset = b_l.start_offset();
|
||||||
|
|
||||||
|
let (a_s, a_l) = a.storage_and_layout();
|
||||||
|
let a_s = match &*a_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("a must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let a_offset = a_l.start_offset();
|
||||||
|
|
||||||
|
let (alog_s, alog_l) = a_log.storage_and_layout();
|
||||||
|
let alog_s = match &*alog_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("a_log must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let alog_offset = alog_l.start_offset();
|
||||||
|
|
||||||
|
let (dtb_s, dtb_l) = dt_bias.storage_and_layout();
|
||||||
|
let dtb_s = match &*dtb_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("dt_bias must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let dtb_offset = dtb_l.start_offset();
|
||||||
|
|
||||||
|
let beta_buf = unsafe { dev.alloc::<T>(total_elements) }?;
|
||||||
|
let g_buf = unsafe { dev.alloc::<T>(total_elements) }?;
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::fused_gdn_gating(
|
||||||
|
b_s.slice(b_offset..).device_ptr(b_s.stream()).0 as *const c_void,
|
||||||
|
a_s.slice(a_offset..).device_ptr(a_s.stream()).0 as *const c_void,
|
||||||
|
alog_s.slice(alog_offset..).device_ptr(alog_s.stream()).0 as *const f32,
|
||||||
|
dtb_s.slice(dtb_offset..).device_ptr(dtb_s.stream()).0 as *const f32,
|
||||||
|
beta_buf.device_ptr(beta_buf.stream()).0 as *mut c_void,
|
||||||
|
g_buf.device_ptr(g_buf.stream()).0 as *mut c_void,
|
||||||
|
total_elements as i32,
|
||||||
|
num_heads as i32,
|
||||||
|
dtype_code,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let beta_storage = candle::CudaStorage::wrap_cuda_slice(beta_buf, dev.clone());
|
||||||
|
let beta = Tensor::from((candle::Storage::Cuda(beta_storage), shape.clone()));
|
||||||
|
|
||||||
|
let g_storage = candle::CudaStorage::wrap_cuda_slice(g_buf, dev.clone());
|
||||||
|
let g = Tensor::from((candle::Storage::Cuda(g_storage), shape));
|
||||||
|
|
||||||
|
Ok((beta, g))
|
||||||
|
}
|
||||||
|
|
||||||
|
match b.dtype() {
|
||||||
|
DType::F16 => cuda_fwd::<half::f16>(b, a, a_log, dt_bias, 0),
|
||||||
|
DType::BF16 => cuda_fwd::<half::bf16>(b, a, a_log, dt_bias, 1),
|
||||||
|
other => candle_core::bail!(
|
||||||
|
"fused_gdn_gating_cuda only supports f16/bf16, got {:?}",
|
||||||
|
other
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn fused_gdn_gating_cuda(
|
||||||
|
_b: &Tensor,
|
||||||
|
_a: &Tensor,
|
||||||
|
_a_log: &Tensor,
|
||||||
|
_dt_bias: &Tensor,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
candle_core::bail!("fused_gdn_gating_cuda requires the cuda feature")
|
||||||
|
}
|
||||||
15
crates/neuron/src/cuda/mod.rs
Normal file
15
crates/neuron/src/cuda/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//! CUDA kernels and their Rust wrappers.
|
||||||
|
//!
|
||||||
|
//! Currently scoped to what we need for Qwen3-Next (`qwen3_5`)
|
||||||
|
//! inference performance — the Gated DeltaNet kernels ported from
|
||||||
|
//! `EricLBuehler/mistral.rs` (MIT). Each kernel lives in a `.cu`
|
||||||
|
//! file alongside this module; `build.rs` compiles them all into a
|
||||||
|
//! static lib via `cudaforge` and links it under the `cuda` feature.
|
||||||
|
//!
|
||||||
|
//! When we absorb more upstream kernels (MoE GEMM, top-k, Mamba SSM,
|
||||||
|
//! etc.) they land here in their own `.cu` + `.rs` pairs.
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub mod ffi;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub mod gdn;
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
pub mod api;
|
pub mod api;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
|
pub mod cuda;
|
||||||
pub mod discovery;
|
pub mod discovery;
|
||||||
pub mod harness;
|
pub mod harness;
|
||||||
pub mod health;
|
pub mod health;
|
||||||
|
|||||||
Reference in New Issue
Block a user