diff --git a/Cargo.lock b/Cargo.lock index 215be42..f8965d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2113,6 +2113,7 @@ dependencies = [ "candle-transformers", "clap", "cortex-core", + "cudaforge", "cudarc 0.19.7", "figment", "futures", diff --git a/crates/neuron/Cargo.toml b/crates/neuron/Cargo.toml index 1a4475a..b653623 100644 --- a/crates/neuron/Cargo.toml +++ b/crates/neuron/Cargo.toml @@ -25,6 +25,7 @@ cuda = [ "candle-transformers/cuda", "dep:cudarc", "dep:half", + "dep:cudaforge", ] # Use cuDNN for convolution / attention kernels. Requires CUDA. cudnn = [ @@ -79,3 +80,13 @@ hf-hub = { version = "0.4", features = ["tokio"] } [dev-dependencies] tokio = { workspace = true, features = ["test-util"] } reqwest.workspace = true + +[build-dependencies] +# Used by `build.rs` to compile `src/cuda/*.cu` into `libneuroncuda.a` +# under the `cuda` feature. Matches mistralrs's upstream build setup +# (their `mistralrs-core/build.rs` uses the same constructor). +cudaforge = { version = "0.1", optional = true } + +[package.metadata.docs.rs] +# Skip the CUDA path on docs.rs (it lacks nvcc). +no-default-features = true diff --git a/crates/neuron/build.rs b/crates/neuron/build.rs new file mode 100644 index 0000000..d57070e --- /dev/null +++ b/crates/neuron/build.rs @@ -0,0 +1,66 @@ +//! Build script: compile the CUDA kernels in `src/cuda/*.cu` into a +//! static library and link it under the `cuda` feature. +//! +//! Patterned on `EricLBuehler/mistral.rs::mistralrs-core/build.rs` — +//! same `cudaforge::KernelBuilder` invocation, same NVCC flag set. + +fn main() { + #[cfg(feature = "cuda")] + { + use std::path::PathBuf; + println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-changed=src/cuda/"); + + let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); + + let mut builder = cudaforge::KernelBuilder::new() + .source_glob("src/cuda/*.cu") + .out_dir(&build_dir) + .arg("-std=c++17") + .arg("-O3") + .arg("-U__CUDA_NO_HALF_OPERATORS__") + .arg("-U__CUDA_NO_HALF_CONVERSIONS__") + .arg("-U__CUDA_NO_HALF2_OPERATORS__") + .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + .arg("--expt-relaxed-constexpr") + .arg("--expt-extended-lambda") + .arg("--use_fast_math") + .arg("--compiler-options") + .arg("-fPIC"); + + // sm_<80 doesn't have bf16 intrinsics for WMMA — gate the + // bf16-only kernels off in that case. (Mirrors upstream.) + if let Some(compute_cap) = builder.get_compute_cap() + && compute_cap < 80 + { + builder = builder.arg("-DNO_BF16_KERNEL"); + } + + let target = std::env::var("TARGET").unwrap(); + let out_file = if target.contains("msvc") { + build_dir.join("neuroncuda.lib") + } else { + build_dir.join("libneuroncuda.a") + }; + + builder + .build_lib(out_file) + .expect("neuron cuda build failed"); + println!("cargo:rustc-link-search={}", build_dir.display()); + println!("cargo:rustc-link-lib=neuroncuda"); + println!("cargo:rustc-link-lib=dylib=cudart"); + + if target.contains("msvc") { + // No extra runtime library needed. + } else if target.contains("apple") + || target.contains("freebsd") + || target.contains("openbsd") + { + println!("cargo:rustc-link-lib=dylib=c++"); + } else if target.contains("android") { + println!("cargo:rustc-link-lib=dylib=c++_shared"); + } else { + println!("cargo:rustc-link-lib=dylib=stdc++"); + } + } +} diff --git a/crates/neuron/src/cuda/ffi.rs b/crates/neuron/src/cuda/ffi.rs new file mode 100644 index 0000000..eab76a9 --- /dev/null +++ b/crates/neuron/src/cuda/ffi.rs @@ -0,0 +1,84 @@ +//! FFI declarations for the CUDA kernels in `gdn.cu`. +//! +//! Subset of `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/ffi.rs` +//! covering only the Gated DeltaNet kernels we currently use. Other +//! kernels in the upstream file (MoE GEMM, top-k, Mamba selective +//! scan, etc.) would land here too as we absorb them. +//! +//! All function declarations are MIT-licensed from upstream and +//! unchanged apart from this header. + +use std::ffi::c_void; + +#[allow(dead_code)] +unsafe extern "C" { + // GDN (Gated Delta Net) kernels for qwen3_5 / Qwen3-Next. + pub(crate) fn gated_delta_rule_recurrence( + q: *const f32, + k: *const f32, + v: *const f32, + g: *const f32, + beta: *const f32, + state: *mut f32, + output: *mut f32, + bh: i32, + seq_len: i32, + k_dim: i32, + v_dim: i32, + stream: i64, + ); + + /// Chunked GDN recurrence for prefill (processes tokens in BT=64 chunks). + pub(crate) fn chunked_gated_delta_rule_recurrence( + q: *const f32, + k: *const f32, + v: *const f32, + g: *const f32, + beta: *const f32, + state: *mut f32, + output: *mut f32, + bh: i32, + seq_len: i32, + k_dim: i32, + v_dim: i32, + stream: i64, + ); + + pub(crate) fn causal_conv1d_update( + x: *const c_void, + weight: *const c_void, + conv_state: *mut c_void, + output: *mut c_void, + batch_size: i32, + conv_dim: i32, + kernel_size: i32, + dtype: i32, + stream: i64, + ); + + pub(crate) fn causal_conv1d_full( + x: *const c_void, + weight: *const c_void, + conv_state_out: *mut c_void, + output: *mut c_void, + batch_size: i32, + conv_dim: i32, + seq_len: i32, + kernel_size: i32, + dtype: i32, + stream: i64, + ); + + pub(crate) fn fused_gdn_gating( + b: *const c_void, + a: *const c_void, + a_log: *const f32, + dt_bias: *const f32, + beta_out: *mut c_void, + g_out: *mut c_void, + total_elements: i32, + num_heads: i32, + dtype: i32, + stream: i64, + ); +} diff --git a/crates/neuron/src/cuda/gdn.cu b/crates/neuron/src/cuda/gdn.cu new file mode 100644 index 0000000..f929cf1 --- /dev/null +++ b/crates/neuron/src/cuda/gdn.cu @@ -0,0 +1,711 @@ +// Gated DeltaNet CUDA kernels for Qwen3-Next (`model_type = "qwen3_5"`). +// +// Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms. +// Upstream path: mistralrs-core/src/cuda/gdn.cu. Local edits in this +// file are limited to this banner; the kernels are unchanged so a +// diff against upstream stays minimal. +// +// Five kernels exposed via `extern "C"` shims at the bottom: +// - gated_delta_rule_recurrence (per-token decode) +// - chunked_gated_delta_rule_recurrence (BT=64 chunked prefill) +// - causal_conv1d_update (single-token conv decode) +// - causal_conv1d_full (multi-token conv prefill) +// - fused_gdn_gating (beta = sigmoid(b); +// g = -exp(A_log) * softplus(a + dt_bias)) + +#include "cuda_bf16.h" +#include "cuda_fp16.h" +#include +#include +#include + +// ============================================================================ +// Kernel 1: gated_delta_rule_recurrence (optimized) +// +// V-tiled recurrence with compile-time K dimension for register residency. +// Grid: (ceil(V/BV), B*H), Block: (BV,). Each thread owns BK registers of +// state. Shared memory holds k_buf and q_buf (2*BK floats). +// +// Optimizations over naive version: +// - Template BK -> float s[BK] lives in true registers (1 cycle vs ~30) +// - #pragma unroll on all k-loops -> full ILP +// - Fused decay+kv_mem pass and fused state_update+output pass +// - __fmaf_rn intrinsics for guaranteed fused multiply-add +// - BV=64 threads -> 2 warps, 6 blocks/SM on Ampere +// +// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S] +// state: [BH, K, V] (in/out) output: [BH, S, V] +// ============================================================================ + +// Optimized kernel: BK known at compile time -> registers + full unrolling +template +__global__ void gated_delta_rule_recurrence_kernel_tiled( + const float *__restrict__ q, // [BH, S, K] + const float *__restrict__ k, // [BH, S, K] + const float *__restrict__ v, // [BH, S, V] + const float *__restrict__ g, // [BH, S] + const float *__restrict__ beta, // [BH, S] + float *__restrict__ state, // [BH, K, V] + float *__restrict__ output, // [BH, S, V] + int seq_len, int v_dim) { + + const int v_tile = blockIdx.x; // which V-tile + const int bh = blockIdx.y; // batch*head index + const int tid = threadIdx.x; // thread within tile [0, BV) + const int v_idx = v_tile * BV + tid; // global V index + + if (v_idx >= v_dim) + return; + + // Pointers for this (batch, head) + const float *q_bh = q + bh * seq_len * BK; + const float *k_bh = k + bh * seq_len * BK; + const float *v_bh = v + bh * seq_len * v_dim; + const float *g_bh = g + bh * seq_len; + const float *beta_bh = beta + bh * seq_len; + float *state_bh = state + bh * BK * v_dim; + float *out_bh = output + bh * seq_len * v_dim; + + // Shared memory: k_buf[BK] + q_buf[BK] + __shared__ float k_buf[BK]; + __shared__ float q_buf[BK]; + + // Load state column into registers — BK is compile-time, so this is + // a true register array (not spilled to local memory) + float s[BK]; +#pragma unroll + for (int j = 0; j < BK; j++) { + s[j] = state_bh[j * v_dim + v_idx]; + } + + for (int t = 0; t < seq_len; t++) { +// Collaboratively load k_t into shared memory +// BK / BV loads per thread (e.g. 128/64 = 2) +#pragma unroll + for (int j = tid; j < BK; j += BV) { + k_buf[j] = k_bh[t * BK + j]; + } + __syncthreads(); + + // Load scalars for this timestep + float decay = expf(g_bh[t]); + float beta_t = beta_bh[t]; + float v_t = v_bh[t * v_dim + v_idx]; + + // Fused pass 1: decay state + compute kv_mem + float kv_mem = 0.0f; +#pragma unroll + for (int j = 0; j < BK; j++) { + s[j] *= decay; + kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem); + } + + // Delta rule + float delta = (v_t - kv_mem) * beta_t; + +// Collaboratively load q_t into shared memory +#pragma unroll + for (int j = tid; j < BK; j += BV) { + q_buf[j] = q_bh[t * BK + j]; + } + __syncthreads(); + + // Fused pass 2: update state + compute output + float y_t = 0.0f; +#pragma unroll + for (int j = 0; j < BK; j++) { + s[j] = __fmaf_rn(k_buf[j], delta, s[j]); + y_t = __fmaf_rn(s[j], q_buf[j], y_t); + } + + out_bh[t * v_dim + v_idx] = y_t; + + __syncthreads(); + } + +// Write state back +#pragma unroll + for (int j = 0; j < BK; j++) { + state_bh[j * v_dim + v_idx] = s[j]; + } +} + +// Fallback kernel: runtime k_dim, still V-tiled for occupancy +template +__global__ void gated_delta_rule_recurrence_kernel_fallback( + const float *__restrict__ q, const float *__restrict__ k, + const float *__restrict__ v, const float *__restrict__ g, + const float *__restrict__ beta, float *__restrict__ state, + float *__restrict__ output, int seq_len, int k_dim, int v_dim) { + + const int v_tile = blockIdx.x; + const int bh = blockIdx.y; + const int tid = threadIdx.x; + const int v_idx = v_tile * BV + tid; + + if (v_idx >= v_dim) + return; + + const float *q_bh = q + bh * seq_len * k_dim; + const float *k_bh = k + bh * seq_len * k_dim; + const float *v_bh = v + bh * seq_len * v_dim; + const float *g_bh = g + bh * seq_len; + const float *beta_bh = beta + bh * seq_len; + float *state_bh = state + bh * k_dim * v_dim; + float *out_bh = output + bh * seq_len * v_dim; + + extern __shared__ float shared[]; + float *k_buf = shared; + float *q_buf = shared + k_dim; + + float s[MAX_K]; + for (int j = 0; j < k_dim; j++) { + s[j] = state_bh[j * v_dim + v_idx]; + } + + for (int t = 0; t < seq_len; t++) { + for (int j = tid; j < k_dim; j += BV) { + k_buf[j] = k_bh[t * k_dim + j]; + } + __syncthreads(); + + float decay = expf(g_bh[t]); + float beta_t = beta_bh[t]; + float v_t = v_bh[t * v_dim + v_idx]; + + float kv_mem = 0.0f; + for (int j = 0; j < k_dim; j++) { + s[j] *= decay; + kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem); + } + + float delta = (v_t - kv_mem) * beta_t; + + for (int j = tid; j < k_dim; j += BV) { + q_buf[j] = q_bh[t * k_dim + j]; + } + __syncthreads(); + + float y_t = 0.0f; + for (int j = 0; j < k_dim; j++) { + s[j] = __fmaf_rn(k_buf[j], delta, s[j]); + y_t = __fmaf_rn(s[j], q_buf[j], y_t); + } + + out_bh[t * v_dim + v_idx] = y_t; + + __syncthreads(); + } + + for (int j = 0; j < k_dim; j++) { + state_bh[j * v_dim + v_idx] = s[j]; + } +} + +extern "C" void gated_delta_rule_recurrence(const float *q, const float *k, + const float *v, const float *g, + const float *beta, float *state, + float *output, int bh, int seq_len, + int k_dim, int v_dim, + int64_t stream) { + + const cudaStream_t custream = (cudaStream_t)stream; + + if (k_dim == 128) { + // Fast path for Qwen3-Next (k_dim=128) + constexpr int BK = 128; + constexpr int BV = 64; + dim3 grid((v_dim + BV - 1) / BV, bh); + dim3 block(BV); + gated_delta_rule_recurrence_kernel_tiled + <<>>(q, k, v, g, beta, state, output, seq_len, + v_dim); + } else if (k_dim == 64) { + // Fast path for models with k_dim=64 + constexpr int BK = 64; + constexpr int BV = 64; + dim3 grid((v_dim + BV - 1) / BV, bh); + dim3 block(BV); + gated_delta_rule_recurrence_kernel_tiled + <<>>(q, k, v, g, beta, state, output, seq_len, + v_dim); + } else { + // Fallback for other k_dim values (runtime loop, still V-tiled) + constexpr int BV = 64; + constexpr int MAX_K = 256; + dim3 grid((v_dim + BV - 1) / BV, bh); + dim3 block(BV); + size_t smem = 2 * k_dim * sizeof(float); + gated_delta_rule_recurrence_kernel_fallback + <<>>(q, k, v, g, beta, state, output, + seq_len, k_dim, v_dim); + } +} + +// ============================================================================ +// Kernel 1b: chunked_gated_delta_rule_recurrence (prefill optimization) +// +// Processes prefill tokens in BT-token chunks instead of one at a time. +// Within each chunk: parallel prefix sum of g, cooperative kk_dot computation, +// forward substitution (triangular solve), output computation, and state +// update. +// +// Same thread model as Kernel 1: one block per (v_tile, batch*head), +// one thread per V-column. Each thread owns BK registers of state. +// +// Shared memory holds: +// k_chunk[BT * BK] -- key vectors for current chunk +// kk_dot[BT * BT] -- dot(k[i], k[j]) lower-triangular matrix +// gcum[BT] -- cumulative sum of g within chunk +// beta_s[BT] -- beta values for chunk +// q_buf[BK] -- q vector (loaded one row at a time) +// +// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S] +// state: [BH, K, V] (in/out) output: [BH, S, V] +// ============================================================================ + +template +__global__ void +chunked_gated_delta_rule_kernel(const float *__restrict__ q, // [BH, S, K] + const float *__restrict__ k, // [BH, S, K] + const float *__restrict__ v, // [BH, S, V] + const float *__restrict__ g, // [BH, S] + const float *__restrict__ beta, // [BH, S] + float *__restrict__ state, // [BH, K, V] + float *__restrict__ output, // [BH, S, V] + int seq_len, int v_dim) { + + const int v_tile = blockIdx.x; + const int bh = blockIdx.y; + const int tid = threadIdx.x; + const int v_idx = v_tile * BV + tid; + + if (v_idx >= v_dim) + return; + + const int num_chunks = (seq_len + BT - 1) / BT; + + // Pointers for this (batch, head) + const float *q_bh = q + bh * seq_len * BK; + const float *k_bh = k + bh * seq_len * BK; + const float *v_bh = v + bh * seq_len * v_dim; + const float *g_bh = g + bh * seq_len; + const float *beta_bh = beta + bh * seq_len; + float *state_bh = state + bh * BK * v_dim; + float *out_bh = output + bh * seq_len * v_dim; + + // Dynamic shared memory layout + extern __shared__ float smem[]; + float *k_chunk = smem; // [BT * BK] + float *kk_dot = smem + BT * BK; // [BT * BT] + float *gcum = smem + BT * BK + BT * BT; // [BT] + float *beta_s = gcum + BT; // [BT] + float *q_buf = beta_s + BT; // [BK] + + // Load state column into registers + float s[BK]; +#pragma unroll + for (int j = 0; j < BK; j++) { + s[j] = state_bh[j * v_dim + v_idx]; + } + + // Per-thread register array for corrected deltas + float delta[BT]; + + for (int c = 0; c < num_chunks; c++) { + const int chunk_start = c * BT; + const int chunk_len = min(BT, seq_len - chunk_start); + + // === Phase 1: Cooperative load of k, beta, g into shared memory === + for (int t = 0; t < chunk_len; t++) { + for (int j = tid; j < BK; j += BV) { + k_chunk[t * BK + j] = k_bh[(chunk_start + t) * BK + j]; + } + } + if (tid < chunk_len) { + beta_s[tid] = beta_bh[chunk_start + tid]; + gcum[tid] = g_bh[chunk_start + tid]; + } + __syncthreads(); + + // === Phase 1b: Parallel prefix sum of g (Hillis-Steele) === + for (int stride = 1; stride < BT; stride <<= 1) { + float prev = 0.0f; + if (tid < chunk_len && (int)tid >= stride) + prev = gcum[tid - stride]; + __syncthreads(); + if (tid < chunk_len && (int)tid >= stride) + gcum[tid] += prev; + __syncthreads(); + } + + // === Phase 2: Compute kk_dot[i][j] = dot(k[i], k[j]) for j < i === + // Only lower-triangular entries needed (strictly lower) + for (int idx = tid; idx < chunk_len * chunk_len; idx += BV) { + int i = idx / chunk_len; + int j = idx % chunk_len; + if (j < i) { + float dot = 0.0f; + for (int d = 0; d < BK; d++) { + dot = __fmaf_rn(k_chunk[i * BK + d], k_chunk[j * BK + d], dot); + } + kk_dot[i * BT + j] = dot; + } + } + __syncthreads(); + + // === Phase 3: Forward substitution (per V-column, in registers) === + // Computes corrected delta values via triangular solve + for (int i = 0; i < chunk_len; i++) { + float v_i = v_bh[(chunk_start + i) * v_dim + v_idx]; + float decay_i = expf(gcum[i]); + float beta_i = beta_s[i]; + + // Inter-chunk contribution: state @ k[i] with decay + float kv_mem = 0.0f; +#pragma unroll + for (int d = 0; d < BK; d++) { + kv_mem = __fmaf_rn(s[d] * decay_i, k_chunk[i * BK + d], kv_mem); + } + + float rhs = beta_i * (v_i - kv_mem); + + // Subtract lower-triangular contributions (intra-chunk) + for (int j = 0; j < i; j++) { + float a_ij = beta_i * kk_dot[i * BT + j] * expf(gcum[i] - gcum[j]); + rhs -= a_ij * delta[j]; + } + delta[i] = rhs; + } + + // === Phase 4: Output computation (per V-column) === + for (int i = 0; i < chunk_len; i++) { + // Cooperatively load q[i] into shared + for (int j = tid; j < BK; j += BV) { + q_buf[j] = q_bh[(chunk_start + i) * BK + j]; + } + __syncthreads(); + + float decay_i = expf(gcum[i]); + + // Inter-chunk contribution: q[i] @ (state * decay) + float o_val = 0.0f; +#pragma unroll + for (int d = 0; d < BK; d++) { + o_val = __fmaf_rn(q_buf[d], s[d] * decay_i, o_val); + } + + // Intra-chunk contribution: sum_{j<=i} dot(q[i], k[j]) * delta[j] * + // exp(gcum[i] - gcum[j]) + for (int j = 0; j <= i; j++) { + float qk_dot = 0.0f; + for (int d = 0; d < BK; d++) { + qk_dot = __fmaf_rn(q_buf[d], k_chunk[j * BK + d], qk_dot); + } + o_val += qk_dot * delta[j] * expf(gcum[i] - gcum[j]); + } + + out_bh[(chunk_start + i) * v_dim + v_idx] = o_val; + __syncthreads(); + } + + // === Phase 5: State update for next chunk === + float g_total = gcum[chunk_len - 1]; +#pragma unroll + for (int d = 0; d < BK; d++) { + float s_new = s[d] * expf(g_total); + for (int t = 0; t < chunk_len; t++) { + s_new += k_chunk[t * BK + d] * delta[t] * expf(g_total - gcum[t]); + } + s[d] = s_new; + } + + __syncthreads(); + } + + // Write final state back +#pragma unroll + for (int j = 0; j < BK; j++) { + state_bh[j * v_dim + v_idx] = s[j]; + } +} + +extern "C" void chunked_gated_delta_rule_recurrence( + const float *q, const float *k, const float *v, const float *g, + const float *beta, float *state, float *output, int bh, int seq_len, + int k_dim, int v_dim, int64_t stream) { + + const cudaStream_t custream = (cudaStream_t)stream; + + if (k_dim == 128) { + constexpr int BT = 64; + constexpr int BK = 128; + constexpr int BV = 64; + // Shared memory: BT*BK + BT*BT + BT + BT + BK floats + size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float); + + // Request extended shared memory + auto kernel = chunked_gated_delta_rule_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem); + + dim3 grid((v_dim + BV - 1) / BV, bh); + dim3 block(BV); + kernel<<>>(q, k, v, g, beta, state, output, + seq_len, v_dim); + } else if (k_dim == 64) { + constexpr int BT = 64; + constexpr int BK = 64; + constexpr int BV = 64; + size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float); + + auto kernel = chunked_gated_delta_rule_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem); + + dim3 grid((v_dim + BV - 1) / BV, bh); + dim3 block(BV); + kernel<<>>(q, k, v, g, beta, state, output, + seq_len, v_dim); + } else { + // Fallback: use the sequential kernel for unsupported k_dim + gated_delta_rule_recurrence(q, k, v, g, beta, state, output, bh, seq_len, + k_dim, v_dim, stream); + } +} + +// ============================================================================ +// Kernel 2a: causal_conv1d_update (decode path, single step) +// +// Each thread handles one channel: shift conv_state left by 1, +// insert new value, dot product with weight, apply SiLU. +// +// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size] +// conv_state: [B, conv_dim, kernel_size] (in/out) +// output: [B, conv_dim, 1] +// ============================================================================ + +template +__global__ void causal_conv1d_update_kernel( + const T *__restrict__ x, // [B, conv_dim, 1] + const T *__restrict__ weight, // [conv_dim, kernel_size] + T *__restrict__ conv_state, // [B, conv_dim, kernel_size] + T *__restrict__ output, // [B, conv_dim, 1] + int batch_size, int conv_dim, int kernel_size) { + + const int ch = blockIdx.x * blockDim.x + threadIdx.x; + const int b = blockIdx.y; + + if (ch >= conv_dim || b >= batch_size) + return; + + // Pointer to this batch/channel's conv state + T *cs = conv_state + (b * conv_dim + ch) * kernel_size; + const T *w = weight + ch * kernel_size; + + // Shift state left by 1 + for (int i = 0; i < kernel_size - 1; i++) { + cs[i] = cs[i + 1]; + } + // Insert new value + cs[kernel_size - 1] = x[b * conv_dim + ch]; + + // Dot product with weight + float acc = 0.0f; + for (int i = 0; i < kernel_size; i++) { + acc += (float)cs[i] * (float)w[i]; + } + + // SiLU activation: x * sigmoid(x) + float sig = 1.0f / (1.0f + expf(-acc)); + float result = acc * sig; + + output[b * conv_dim + ch] = (T)result; +} + +extern "C" void causal_conv1d_update(const void *x, const void *weight, + void *conv_state, void *output, + int batch_size, int conv_dim, + int kernel_size, int dtype, + int64_t stream) { + const cudaStream_t custream = (cudaStream_t)stream; + dim3 block(256); + dim3 grid((conv_dim + 255) / 256, batch_size); + + if (dtype == 0) { + // f16 + causal_conv1d_update_kernel<__half><<>>( + (const __half *)x, (const __half *)weight, (__half *)conv_state, + (__half *)output, batch_size, conv_dim, kernel_size); + } else { + // bf16 + causal_conv1d_update_kernel<__nv_bfloat16><<>>( + (const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight, + (__nv_bfloat16 *)conv_state, (__nv_bfloat16 *)output, batch_size, + conv_dim, kernel_size); + } +} + +// ============================================================================ +// Kernel 2b: causal_conv1d_full (prefill path) +// +// Each thread handles one (channel, position): causal window with +// zero-padding, dot product with weight, SiLU. +// A second pass writes the conv_state from the last kernel_size positions. +// +// x: [B, conv_dim, S] weight: [conv_dim, kernel_size] +// conv_state_out: [B, conv_dim, kernel_size] output: [B, conv_dim, S] +// ============================================================================ + +template +__global__ void causal_conv1d_full_kernel( + const T *__restrict__ x, // [B, conv_dim, S] + const T *__restrict__ weight, // [conv_dim, kernel_size] + T *__restrict__ output, // [B, conv_dim, S] + int batch_size, int conv_dim, int seq_len, int kernel_size) { + + const int ch = blockIdx.x * blockDim.x + threadIdx.x; + const int pos = blockIdx.y; + const int b = blockIdx.z; + + if (ch >= conv_dim || pos >= seq_len || b >= batch_size) + return; + + const T *x_bch = x + (b * conv_dim + ch) * seq_len; + const T *w = weight + ch * kernel_size; + + // Causal convolution: sum over kernel_size window ending at pos + float acc = 0.0f; + for (int i = 0; i < kernel_size; i++) { + int src_pos = pos - (kernel_size - 1) + i; + float x_val = (src_pos >= 0) ? (float)x_bch[src_pos] : 0.0f; + acc += x_val * (float)w[i]; + } + + // SiLU + float sig = 1.0f / (1.0f + expf(-acc)); + float result = acc * sig; + + output[(b * conv_dim + ch) * seq_len + pos] = (T)result; +} + +template +__global__ void save_conv_state_kernel( + const T *__restrict__ x, // [B, conv_dim, S] + T *__restrict__ conv_state_out, // [B, conv_dim, kernel_size] + int batch_size, int conv_dim, int seq_len, int kernel_size) { + + const int ch = blockIdx.x * blockDim.x + threadIdx.x; + const int b = blockIdx.y; + + if (ch >= conv_dim || b >= batch_size) + return; + + const T *x_bch = x + (b * conv_dim + ch) * seq_len; + T *cs = conv_state_out + (b * conv_dim + ch) * kernel_size; + + // Save last kernel_size positions (zero-pad if seq_len < kernel_size) + int pad = kernel_size - seq_len; + for (int i = 0; i < kernel_size; i++) { + if (i < pad) { + cs[i] = (T)0.0f; + } else { + cs[i] = x_bch[seq_len - kernel_size + i]; + } + } +} + +extern "C" void causal_conv1d_full(const void *x, const void *weight, + void *conv_state_out, void *output, + int batch_size, int conv_dim, int seq_len, + int kernel_size, int dtype, int64_t stream) { + const cudaStream_t custream = (cudaStream_t)stream; + + // Main convolution kernel + dim3 block(256); + dim3 grid((conv_dim + 255) / 256, seq_len, batch_size); + + if (dtype == 0) { + causal_conv1d_full_kernel<__half><<>>( + (const __half *)x, (const __half *)weight, (__half *)output, batch_size, + conv_dim, seq_len, kernel_size); + // Save conv state + dim3 grid2((conv_dim + 255) / 256, batch_size); + save_conv_state_kernel<__half><<>>( + (const __half *)x, (__half *)conv_state_out, batch_size, conv_dim, + seq_len, kernel_size); + } else { + causal_conv1d_full_kernel<__nv_bfloat16><<>>( + (const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight, + (__nv_bfloat16 *)output, batch_size, conv_dim, seq_len, kernel_size); + dim3 grid2((conv_dim + 255) / 256, batch_size); + save_conv_state_kernel<__nv_bfloat16><<>>( + (const __nv_bfloat16 *)x, (__nv_bfloat16 *)conv_state_out, batch_size, + conv_dim, seq_len, kernel_size); + } +} + +// ============================================================================ +// Kernel 3: fused_gdn_gating +// +// Fuses: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias) +// a_log and dt_bias are per-head (broadcast over batch*seq). +// +// b, a: [total] a_log, dt_bias: [num_heads] +// beta_out, g_out: [total] +// ============================================================================ + +template +__global__ void +fused_gdn_gating_kernel(const T *__restrict__ b, // [total] + const T *__restrict__ a, // [total] + const float *__restrict__ a_log, // [num_heads] + const float *__restrict__ dt_bias, // [num_heads] + T *__restrict__ beta_out, // [total] + T *__restrict__ g_out, // [total] + int total_elements, int num_heads) { + + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_elements) + return; + + // Head index: elements are laid out as [..., num_heads] + int head_idx = idx % num_heads; + + // beta = sigmoid(b) + float b_val = (float)b[idx]; + float beta = 1.0f / (1.0f + expf(-b_val)); + + // g = -exp(a_log) * softplus(a + dt_bias) + float a_val = (float)a[idx]; + float a_log_val = a_log[head_idx]; + float dt_bias_val = dt_bias[head_idx]; + + float sp_input = a_val + dt_bias_val; + float softplus_val = logf(1.0f + expf(sp_input)); + float g_val = -expf(a_log_val) * softplus_val; + + beta_out[idx] = (T)beta; + g_out[idx] = (T)g_val; +} + +extern "C" void fused_gdn_gating(const void *b, const void *a, + const float *a_log, const float *dt_bias, + void *beta_out, void *g_out, + int total_elements, int num_heads, int dtype, + int64_t stream) { + const cudaStream_t custream = (cudaStream_t)stream; + dim3 block(256); + dim3 grid((total_elements + 255) / 256); + + if (dtype == 0) { + fused_gdn_gating_kernel<__half><<>>( + (const __half *)b, (const __half *)a, a_log, dt_bias, + (__half *)beta_out, (__half *)g_out, total_elements, num_heads); + } else { + fused_gdn_gating_kernel<__nv_bfloat16><<>>( + (const __nv_bfloat16 *)b, (const __nv_bfloat16 *)a, a_log, dt_bias, + (__nv_bfloat16 *)beta_out, (__nv_bfloat16 *)g_out, total_elements, + num_heads); + } +} diff --git a/crates/neuron/src/cuda/gdn.rs b/crates/neuron/src/cuda/gdn.rs new file mode 100644 index 0000000..b8c5ce6 --- /dev/null +++ b/crates/neuron/src/cuda/gdn.rs @@ -0,0 +1,486 @@ +//! Rust wrappers around the Gated DeltaNet CUDA kernels in `gdn.cu`. +//! +//! Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms. +//! Upstream path: `mistralrs-core/src/cuda/gdn.rs`. The only edits in +//! this file are this header comment — the FFI path module name is +//! `crate::cuda::ffi`, identical to upstream's layout. + +#![allow(clippy::cast_possible_truncation)] + +use candle_core::{Result, Tensor}; + +#[cfg(feature = "cuda")] +use candle_core::DType; + +/// CUDA-accelerated gated delta rule recurrence. +/// +/// Inputs (all contiguous, f32): +/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S] +/// state: [BH, K, V] (mutated in place) +/// +/// Returns: output [BH, S, V] +#[cfg(feature = "cuda")] +pub fn gated_delta_rule_recurrence_cuda( + q: &Tensor, + k: &Tensor, + v: &Tensor, + g: &Tensor, + beta: &Tensor, + state: &mut Tensor, +) -> Result { + use candle::cuda_backend::cudarc::driver::DevicePtr; + use candle_core as candle; + + let (bh, seq_len, k_dim) = q.dims3()?; + let v_dim = v.dim(2)?; + + let dev = q.device().as_cuda_device()?; + + let (q_s, q_l) = q.storage_and_layout(); + let q_s = match &*q_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("q must be a cuda tensor"), + }; + let q_offset = q_l.start_offset(); + + let (k_s, k_l) = k.storage_and_layout(); + let k_s = match &*k_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("k must be a cuda tensor"), + }; + let k_offset = k_l.start_offset(); + + let (v_s, v_l) = v.storage_and_layout(); + let v_s = match &*v_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("v must be a cuda tensor"), + }; + let v_offset = v_l.start_offset(); + + let (g_s, g_l) = g.storage_and_layout(); + let g_s = match &*g_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("g must be a cuda tensor"), + }; + let g_offset = g_l.start_offset(); + + let (beta_s, beta_l) = beta.storage_and_layout(); + let beta_s = match &*beta_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("beta must be a cuda tensor"), + }; + let beta_offset = beta_l.start_offset(); + + let (state_s, state_l) = state.storage_and_layout(); + let state_s = match &*state_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("state must be a cuda tensor"), + }; + let state_offset = state_l.start_offset(); + + let output_buf = unsafe { dev.alloc::(bh * seq_len * v_dim) }?; + + let stream = dev.cuda_stream().cu_stream() as i64; + + unsafe { + crate::cuda::ffi::gated_delta_rule_recurrence( + q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32, + k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32, + v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32, + g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32, + beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32, + state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32, + output_buf.device_ptr(output_buf.stream()).0 as *mut f32, + bh as i32, + seq_len as i32, + k_dim as i32, + v_dim as i32, + stream, + ); + } + + // The kernel wrote state in-place via the raw pointer; rewrap + // (state tensor's underlying CudaSlice was modified directly) + + let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone()); + Ok(Tensor::from(( + candle::Storage::Cuda(output_storage), + (bh, seq_len, v_dim), + ))) +} + +#[cfg(not(feature = "cuda"))] +#[allow(unused)] +pub fn gated_delta_rule_recurrence_cuda( + _q: &Tensor, + _k: &Tensor, + _v: &Tensor, + _g: &Tensor, + _beta: &Tensor, + _state: &mut Tensor, +) -> Result { + candle_core::bail!("gated_delta_rule_recurrence_cuda requires the cuda feature") +} + +/// CUDA-accelerated chunked gated delta rule recurrence (prefill optimization). +/// +/// Processes prefill tokens in 64-token chunks instead of one at a time. +/// Same interface as `gated_delta_rule_recurrence_cuda`. +/// +/// Inputs (all contiguous, f32): +/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S] +/// state: [BH, K, V] (mutated in place) +/// +/// Returns: output [BH, S, V] +#[cfg(feature = "cuda")] +pub fn chunked_gated_delta_rule_recurrence_cuda( + q: &Tensor, + k: &Tensor, + v: &Tensor, + g: &Tensor, + beta: &Tensor, + state: &mut Tensor, +) -> Result { + use candle::cuda_backend::cudarc::driver::DevicePtr; + use candle_core as candle; + + let (bh, seq_len, k_dim) = q.dims3()?; + let v_dim = v.dim(2)?; + + let dev = q.device().as_cuda_device()?; + + let (q_s, q_l) = q.storage_and_layout(); + let q_s = match &*q_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("q must be a cuda tensor"), + }; + let q_offset = q_l.start_offset(); + + let (k_s, k_l) = k.storage_and_layout(); + let k_s = match &*k_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("k must be a cuda tensor"), + }; + let k_offset = k_l.start_offset(); + + let (v_s, v_l) = v.storage_and_layout(); + let v_s = match &*v_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("v must be a cuda tensor"), + }; + let v_offset = v_l.start_offset(); + + let (g_s, g_l) = g.storage_and_layout(); + let g_s = match &*g_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("g must be a cuda tensor"), + }; + let g_offset = g_l.start_offset(); + + let (beta_s, beta_l) = beta.storage_and_layout(); + let beta_s = match &*beta_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("beta must be a cuda tensor"), + }; + let beta_offset = beta_l.start_offset(); + + let (state_s, state_l) = state.storage_and_layout(); + let state_s = match &*state_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("state must be a cuda tensor"), + }; + let state_offset = state_l.start_offset(); + + let output_buf = unsafe { dev.alloc::(bh * seq_len * v_dim) }?; + + let stream = dev.cuda_stream().cu_stream() as i64; + + unsafe { + crate::cuda::ffi::chunked_gated_delta_rule_recurrence( + q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32, + k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32, + v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32, + g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32, + beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32, + state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32, + output_buf.device_ptr(output_buf.stream()).0 as *mut f32, + bh as i32, + seq_len as i32, + k_dim as i32, + v_dim as i32, + stream, + ); + } + + let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone()); + Ok(Tensor::from(( + candle::Storage::Cuda(output_storage), + (bh, seq_len, v_dim), + ))) +} + +#[cfg(not(feature = "cuda"))] +#[allow(unused)] +pub fn chunked_gated_delta_rule_recurrence_cuda( + _q: &Tensor, + _k: &Tensor, + _v: &Tensor, + _g: &Tensor, + _beta: &Tensor, + _state: &mut Tensor, +) -> Result { + candle_core::bail!("chunked_gated_delta_rule_recurrence_cuda requires the cuda feature") +} + +/// CUDA-accelerated causal conv1d (both update and full paths). +/// +/// For update (is_update=true): +/// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size] +/// conv_state: [B, conv_dim, kernel_size] (mutated in place for update) +/// Returns: (output [B, conv_dim, 1], updated conv_state) +/// +/// For full (is_update=false): +/// x: [B, conv_dim, S] weight: [conv_dim, kernel_size] +/// Returns: (output [B, conv_dim, S], new conv_state [B, conv_dim, kernel_size]) +#[cfg(feature = "cuda")] +pub fn causal_conv1d_cuda( + x: &Tensor, + weight: &Tensor, + conv_state: &Tensor, + kernel_size: usize, + is_update: bool, +) -> Result<(Tensor, Tensor)> { + use candle::cuda_backend::cudarc::driver::DevicePtr; + use candle_core as candle; + use core::ffi::c_void; + fn cuda_fwd< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( + x: &Tensor, + weight: &Tensor, + conv_state: &Tensor, + kernel_size: usize, + is_update: bool, + dtype_code: i32, + ) -> Result<(Tensor, Tensor)> { + let dev = x.device().as_cuda_device()?; + let (batch_size, conv_dim, seq_len) = x.dims3()?; + + let (x_s, x_l) = x.storage_and_layout(); + let x_s = match &*x_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("x must be a cuda tensor"), + }; + let x_offset = x_l.start_offset(); + + let (w_s, w_l) = weight.storage_and_layout(); + let w_s = match &*w_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("weight must be a cuda tensor"), + }; + let w_offset = w_l.start_offset(); + + let stream = dev.cuda_stream().cu_stream() as i64; + + if is_update { + // Clone conv_state so the kernel can mutate it in place + let conv_state_new = conv_state.clone(); + + let output_buf = unsafe { dev.alloc::(batch_size * conv_dim) }?; + + // Scope the borrow of conv_state_new so we can move it later + { + let (cs_s, cs_l) = conv_state_new.storage_and_layout(); + let cs_s = match &*cs_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("conv_state must be a cuda tensor"), + }; + let cs_offset = cs_l.start_offset(); + + unsafe { + crate::cuda::ffi::causal_conv1d_update( + x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void, + w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void, + cs_s.slice(cs_offset..).device_ptr(cs_s.stream()).0 as *mut c_void, + output_buf.device_ptr(output_buf.stream()).0 as *mut c_void, + batch_size as i32, + conv_dim as i32, + kernel_size as i32, + dtype_code, + stream, + ); + } + } + + let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone()); + let output = Tensor::from(( + candle::Storage::Cuda(output_storage), + (batch_size, conv_dim, 1usize), + )); + + Ok((output, conv_state_new)) + } else { + // Full path: allocate new conv_state and output + let output_buf = unsafe { dev.alloc::(batch_size * conv_dim * seq_len) }?; + let cs_buf = unsafe { dev.alloc::(batch_size * conv_dim * kernel_size) }?; + + unsafe { + crate::cuda::ffi::causal_conv1d_full( + x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void, + w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void, + cs_buf.device_ptr(cs_buf.stream()).0 as *mut c_void, + output_buf.device_ptr(output_buf.stream()).0 as *mut c_void, + batch_size as i32, + conv_dim as i32, + seq_len as i32, + kernel_size as i32, + dtype_code, + stream, + ); + } + + let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone()); + let output = Tensor::from(( + candle::Storage::Cuda(output_storage), + (batch_size, conv_dim, seq_len), + )); + + let cs_storage = candle::CudaStorage::wrap_cuda_slice(cs_buf, dev.clone()); + let new_conv_state = Tensor::from(( + candle::Storage::Cuda(cs_storage), + (batch_size, conv_dim, kernel_size), + )); + + Ok((output, new_conv_state)) + } + } + + match x.dtype() { + DType::F16 => cuda_fwd::(x, weight, conv_state, kernel_size, is_update, 0), + DType::BF16 => cuda_fwd::(x, weight, conv_state, kernel_size, is_update, 1), + other => candle_core::bail!("causal_conv1d_cuda only supports f16/bf16, got {:?}", other), + } +} + +#[cfg(not(feature = "cuda"))] +#[allow(unused)] +pub fn causal_conv1d_cuda( + _x: &Tensor, + _weight: &Tensor, + _conv_state: &Tensor, + _kernel_size: usize, + _is_update: bool, +) -> Result<(Tensor, Tensor)> { + candle_core::bail!("causal_conv1d_cuda requires the cuda feature") +} + +/// CUDA-accelerated fused GDN gating computation. +/// +/// Computes: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias) +/// +/// b, a: [total_elements] in f16/bf16 +/// a_log, dt_bias: [num_heads] in f32 +/// +/// Returns: (beta, g) in original dtype +#[cfg(feature = "cuda")] +pub fn fused_gdn_gating_cuda( + b: &Tensor, + a: &Tensor, + a_log: &Tensor, + dt_bias: &Tensor, +) -> Result<(Tensor, Tensor)> { + use candle::cuda_backend::cudarc::driver::DevicePtr; + use candle_core as candle; + use core::ffi::c_void; + + fn cuda_fwd< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( + b: &Tensor, + a: &Tensor, + a_log: &Tensor, + dt_bias: &Tensor, + dtype_code: i32, + ) -> Result<(Tensor, Tensor)> { + let total_elements = b.elem_count(); + let num_heads = a_log.elem_count(); + let shape = b.shape().clone(); + let dev = b.device().as_cuda_device()?; + + let (b_s, b_l) = b.storage_and_layout(); + let b_s = match &*b_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("b must be a cuda tensor"), + }; + let b_offset = b_l.start_offset(); + + let (a_s, a_l) = a.storage_and_layout(); + let a_s = match &*a_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("a must be a cuda tensor"), + }; + let a_offset = a_l.start_offset(); + + let (alog_s, alog_l) = a_log.storage_and_layout(); + let alog_s = match &*alog_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("a_log must be a cuda tensor"), + }; + let alog_offset = alog_l.start_offset(); + + let (dtb_s, dtb_l) = dt_bias.storage_and_layout(); + let dtb_s = match &*dtb_s { + candle::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => candle::bail!("dt_bias must be a cuda tensor"), + }; + let dtb_offset = dtb_l.start_offset(); + + let beta_buf = unsafe { dev.alloc::(total_elements) }?; + let g_buf = unsafe { dev.alloc::(total_elements) }?; + + let stream = dev.cuda_stream().cu_stream() as i64; + + unsafe { + crate::cuda::ffi::fused_gdn_gating( + b_s.slice(b_offset..).device_ptr(b_s.stream()).0 as *const c_void, + a_s.slice(a_offset..).device_ptr(a_s.stream()).0 as *const c_void, + alog_s.slice(alog_offset..).device_ptr(alog_s.stream()).0 as *const f32, + dtb_s.slice(dtb_offset..).device_ptr(dtb_s.stream()).0 as *const f32, + beta_buf.device_ptr(beta_buf.stream()).0 as *mut c_void, + g_buf.device_ptr(g_buf.stream()).0 as *mut c_void, + total_elements as i32, + num_heads as i32, + dtype_code, + stream, + ); + } + + let beta_storage = candle::CudaStorage::wrap_cuda_slice(beta_buf, dev.clone()); + let beta = Tensor::from((candle::Storage::Cuda(beta_storage), shape.clone())); + + let g_storage = candle::CudaStorage::wrap_cuda_slice(g_buf, dev.clone()); + let g = Tensor::from((candle::Storage::Cuda(g_storage), shape)); + + Ok((beta, g)) + } + + match b.dtype() { + DType::F16 => cuda_fwd::(b, a, a_log, dt_bias, 0), + DType::BF16 => cuda_fwd::(b, a, a_log, dt_bias, 1), + other => candle_core::bail!( + "fused_gdn_gating_cuda only supports f16/bf16, got {:?}", + other + ), + } +} + +#[cfg(not(feature = "cuda"))] +#[allow(unused)] +pub fn fused_gdn_gating_cuda( + _b: &Tensor, + _a: &Tensor, + _a_log: &Tensor, + _dt_bias: &Tensor, +) -> Result<(Tensor, Tensor)> { + candle_core::bail!("fused_gdn_gating_cuda requires the cuda feature") +} diff --git a/crates/neuron/src/cuda/mod.rs b/crates/neuron/src/cuda/mod.rs new file mode 100644 index 0000000..c1e77c6 --- /dev/null +++ b/crates/neuron/src/cuda/mod.rs @@ -0,0 +1,15 @@ +//! CUDA kernels and their Rust wrappers. +//! +//! Currently scoped to what we need for Qwen3-Next (`qwen3_5`) +//! inference performance — the Gated DeltaNet kernels ported from +//! `EricLBuehler/mistral.rs` (MIT). Each kernel lives in a `.cu` +//! file alongside this module; `build.rs` compiles them all into a +//! static lib via `cudaforge` and links it under the `cuda` feature. +//! +//! When we absorb more upstream kernels (MoE GEMM, top-k, Mamba SSM, +//! etc.) they land here in their own `.cu` + `.rs` pairs. + +#[cfg(feature = "cuda")] +pub mod ffi; +#[cfg(feature = "cuda")] +pub mod gdn; diff --git a/crates/neuron/src/lib.rs b/crates/neuron/src/lib.rs index 5c72182..ac659ea 100644 --- a/crates/neuron/src/lib.rs +++ b/crates/neuron/src/lib.rs @@ -1,5 +1,6 @@ pub mod api; pub mod config; +pub mod cuda; pub mod discovery; pub mod harness; pub mod health;