diff --git a/.gitignore b/.gitignore index 2dcad04..f1ed483 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ .vscode/ cortex.toml doc/plan/* +/target-cuda/ diff --git a/Cargo.lock b/Cargo.lock index bb537cd..215be42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2116,6 +2116,7 @@ dependencies = [ "cudarc 0.19.7", "figment", "futures", + "half", "hf-hub", "reqwest", "serde", diff --git a/crates/neuron/Cargo.toml b/crates/neuron/Cargo.toml index 98f15d9..1a4475a 100644 --- a/crates/neuron/Cargo.toml +++ b/crates/neuron/Cargo.toml @@ -24,6 +24,7 @@ cuda = [ "candle-nn/cuda", "candle-transformers/cuda", "dep:cudarc", + "dep:half", ] # Use cuDNN for convolution / attention kernels. Requires CUDA. cudnn = [ @@ -68,6 +69,10 @@ candle-transformers = "0.10.2" # TP worker pool can call cudarc::nccl::{Comm, Id} directly. Gated on # the `cuda` feature; same toolchain requirement as candle's CUDA path. cudarc = { version = "0.19", optional = true, default-features = false, features = ["nccl", "cuda-version-from-build-system"] } +# Used by the AllReduce CustomOp1 to type-dispatch on bf16/f16 candle +# storages. Matches candle-core's pinned major version to avoid double- +# compiling the `half` crate at conflicting versions. +half = { version = "2.5", optional = true } tokenizers = { version = "0.22", default-features = false, features = ["onig"] } hf-hub = { version = "0.4", features = ["tokio"] } diff --git a/crates/neuron/src/harness/tp/all_reduce.rs b/crates/neuron/src/harness/tp/all_reduce.rs index 4b1100d..76a7ade 100644 --- a/crates/neuron/src/harness/tp/all_reduce.rs +++ b/crates/neuron/src/harness/tp/all_reduce.rs @@ -21,7 +21,6 @@ #![cfg(feature = "cuda")] use candle_core::backend::BackendStorage; -use candle_core::cuda_backend::WrapErr; use candle_core::{CpuStorage, CudaStorage, CustomOp1, DType, Layout, Result, Shape}; use cudarc::nccl::{Comm, ReduceOp}; use half::{bf16, f16}; @@ -87,7 +86,7 @@ impl CustomOp1 for AllReduce { DType::BF16 => { let src = s.as_cuda_slice::()?; require_contiguous(src, l)?; - let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + let mut dst = unsafe { dev.alloc::(elem_count) }?; self.comm .all_reduce(src, &mut dst, &ReduceOp::Sum) .map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce bf16: {e:?}")))?; @@ -96,7 +95,7 @@ impl CustomOp1 for AllReduce { DType::F16 => { let src = s.as_cuda_slice::()?; require_contiguous(src, l)?; - let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + let mut dst = unsafe { dev.alloc::(elem_count) }?; self.comm .all_reduce(src, &mut dst, &ReduceOp::Sum) .map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f16: {e:?}")))?; @@ -105,7 +104,7 @@ impl CustomOp1 for AllReduce { DType::F32 => { let src = s.as_cuda_slice::()?; require_contiguous(src, l)?; - let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + let mut dst = unsafe { dev.alloc::(elem_count) }?; self.comm .all_reduce(src, &mut dst, &ReduceOp::Sum) .map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f32: {e:?}")))?;