From 96d87552452e5991b3cbf5573184db01cf5621d1 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Tue, 19 May 2026 19:11:59 +0300 Subject: [PATCH] fix(tp): add half dep + drop double-wrapped .w() on CudaDevice::alloc MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two follow-up cuda-only fixes surfaced by `cargo build --features cuda` inside the cuda-13.0 runner container: 1. `half::{bf16, f16}` was an undeclared dep. Added `half = "2.5"` (matching candle-core's pinned major) under the cuda feature flag. 2. `dev.alloc::(n)` already returns `candle_core::Result` (it calls `.w()` internally on the cudarc error). Calling `.w()?` on top of that needs `From for CudaError`, which doesn't exist — collapse to `?`. Removed the now-unused `cuda_backend::WrapErr` import. Verified by `cargo build -p neuron --features cuda` and `cargo clippy -p neuron --all-targets --features cuda -- -D warnings` inside `git.lair.cafe/gongfoo/runner-cuda-13.0` with the local glibc/CUDA-13.0 math_functions.h noexcept patch. CPU clippy/tests stay green. Co-Authored-By: Claude Opus 4.7 (1M context) --- .gitignore | 1 + Cargo.lock | 1 + crates/neuron/Cargo.toml | 5 +++++ crates/neuron/src/harness/tp/all_reduce.rs | 7 +++---- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 2dcad04..f1ed483 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ .vscode/ cortex.toml doc/plan/* +/target-cuda/ diff --git a/Cargo.lock b/Cargo.lock index bb537cd..215be42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2116,6 +2116,7 @@ dependencies = [ "cudarc 0.19.7", "figment", "futures", + "half", "hf-hub", "reqwest", "serde", diff --git a/crates/neuron/Cargo.toml b/crates/neuron/Cargo.toml index 98f15d9..1a4475a 100644 --- a/crates/neuron/Cargo.toml +++ b/crates/neuron/Cargo.toml @@ -24,6 +24,7 @@ cuda = [ "candle-nn/cuda", "candle-transformers/cuda", "dep:cudarc", + "dep:half", ] # Use cuDNN for convolution / attention kernels. Requires CUDA. cudnn = [ @@ -68,6 +69,10 @@ candle-transformers = "0.10.2" # TP worker pool can call cudarc::nccl::{Comm, Id} directly. Gated on # the `cuda` feature; same toolchain requirement as candle's CUDA path. cudarc = { version = "0.19", optional = true, default-features = false, features = ["nccl", "cuda-version-from-build-system"] } +# Used by the AllReduce CustomOp1 to type-dispatch on bf16/f16 candle +# storages. Matches candle-core's pinned major version to avoid double- +# compiling the `half` crate at conflicting versions. +half = { version = "2.5", optional = true } tokenizers = { version = "0.22", default-features = false, features = ["onig"] } hf-hub = { version = "0.4", features = ["tokio"] } diff --git a/crates/neuron/src/harness/tp/all_reduce.rs b/crates/neuron/src/harness/tp/all_reduce.rs index 4b1100d..76a7ade 100644 --- a/crates/neuron/src/harness/tp/all_reduce.rs +++ b/crates/neuron/src/harness/tp/all_reduce.rs @@ -21,7 +21,6 @@ #![cfg(feature = "cuda")] use candle_core::backend::BackendStorage; -use candle_core::cuda_backend::WrapErr; use candle_core::{CpuStorage, CudaStorage, CustomOp1, DType, Layout, Result, Shape}; use cudarc::nccl::{Comm, ReduceOp}; use half::{bf16, f16}; @@ -87,7 +86,7 @@ impl CustomOp1 for AllReduce { DType::BF16 => { let src = s.as_cuda_slice::()?; require_contiguous(src, l)?; - let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + let mut dst = unsafe { dev.alloc::(elem_count) }?; self.comm .all_reduce(src, &mut dst, &ReduceOp::Sum) .map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce bf16: {e:?}")))?; @@ -96,7 +95,7 @@ impl CustomOp1 for AllReduce { DType::F16 => { let src = s.as_cuda_slice::()?; require_contiguous(src, l)?; - let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + let mut dst = unsafe { dev.alloc::(elem_count) }?; self.comm .all_reduce(src, &mut dst, &ReduceOp::Sum) .map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f16: {e:?}")))?; @@ -105,7 +104,7 @@ impl CustomOp1 for AllReduce { DType::F32 => { let src = s.as_cuda_slice::()?; require_contiguous(src, l)?; - let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + let mut dst = unsafe { dev.alloc::(elem_count) }?; self.comm .all_reduce(src, &mut dst, &ReduceOp::Sum) .map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f32: {e:?}")))?;