diff --git a/crates/neuron/src/harness/tp/all_reduce.rs b/crates/neuron/src/harness/tp/all_reduce.rs new file mode 100644 index 0000000..7aedd4d --- /dev/null +++ b/crates/neuron/src/harness/tp/all_reduce.rs @@ -0,0 +1,121 @@ +//! `AllReduce` as a candle `CustomOp1` — the bridge between candle's +//! `Tensor` graph and `cudarc::nccl::Comm::all_reduce`. +//! +//! Ported from the canonical +//! `candle-examples/examples/llama_multiprocess/model.rs` pattern. +//! Row-parallel layers apply this op after their local matmul to sum +//! partial outputs across NCCL ranks. +//! +//! Available only under `--features cuda`; on CPU builds this module +//! is empty and row-parallel layers degenerate to local matmul only +//! (useful for compile-checking the model code; correctness requires +//! cuda). +//! +//! Thread-safety caveat: NCCL communicators are technically only +//! safe to use from a single thread at a time +//! (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html). +//! We hold the `AllReduce` behind an `Arc` and only issue ops +//! against it from the dedicated `spawn_blocking` thread the inference +//! pipeline already uses for candle's forward passes. + +#![cfg(feature = "cuda")] + +use candle_core::cuda_backend::WrapErr; +use candle_core::{CpuStorage, CudaStorage, CustomOp1, DType, Layout, Result, Shape}; +use cudarc::nccl::{Comm, ReduceOp}; +use half::{bf16, f16}; +use std::sync::Arc; + +/// Wraps an NCCL `Comm` so it can be plugged into a candle forward +/// graph as a custom op. Each row-parallel layer holds one of these. +pub struct AllReduce { + comm: Arc, +} + +// SAFETY: `Comm` contains a raw `ncclComm_t` pointer; NCCL's docs note +// that issuing ops against one comm from multiple threads concurrently +// is unsafe. We serialise via the single spawn_blocking thread that +// drives the model's forward pass. The Send/Sync impl is necessary +// because candle's CustomOp1 trait bounds require it; the correctness +// invariant is enforced at the call site, not the type level. +unsafe impl Send for AllReduce {} +unsafe impl Sync for AllReduce {} + +impl AllReduce { + pub fn new(comm: Arc) -> Self { + Self { comm } + } + + pub fn comm(&self) -> &Arc { + &self.comm + } +} + +impl CustomOp1 for AllReduce { + fn name(&self) -> &'static str { + "neuron.tp.all_reduce" + } + + fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("AllReduce custom-op invoked on CPU storage; TP requires CUDA") + } + + fn cuda_fwd(&self, s: &CudaStorage, l: &Layout) -> Result<(CudaStorage, Shape)> { + use cudarc::driver::DeviceSlice; + + // Reject non-contiguous inputs explicitly — copying them + // server-side would mask shape bugs (a TP layer feeding a + // strided activation into all_reduce is almost certainly a + // model construction error). + fn require_contiguous( + slice: &cudarc::driver::CudaSlice, + l: &Layout, + ) -> Result<()> { + match l.contiguous_offsets() { + Some((0, n)) if n == slice.len() => Ok(()), + _ => candle_core::bail!( + "AllReduce input is non-contiguous: layout={:?}, slice_len={}", + l, + slice.len() + ), + } + } + + let elem_count = l.shape().elem_count(); + let dev = s.device().clone(); + + let out = match s.dtype() { + DType::BF16 => { + let src = s.as_cuda_slice::()?; + require_contiguous(src, l)?; + let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + self.comm + .all_reduce(src, &mut dst, &ReduceOp::Sum) + .map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce bf16: {e:?}")))?; + CudaStorage::wrap_cuda_slice(dst, dev) + } + DType::F16 => { + let src = s.as_cuda_slice::()?; + require_contiguous(src, l)?; + let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + self.comm + .all_reduce(src, &mut dst, &ReduceOp::Sum) + .map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f16: {e:?}")))?; + CudaStorage::wrap_cuda_slice(dst, dev) + } + DType::F32 => { + let src = s.as_cuda_slice::()?; + require_contiguous(src, l)?; + let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + self.comm + .all_reduce(src, &mut dst, &ReduceOp::Sum) + .map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f32: {e:?}")))?; + CudaStorage::wrap_cuda_slice(dst, dev) + } + dtype => candle_core::bail!( + "AllReduce: unsupported dtype {dtype:?}; TP path expects bf16/f16/f32" + ), + }; + Ok((out, l.shape().clone())) + } +} diff --git a/crates/neuron/src/harness/tp/mod.rs b/crates/neuron/src/harness/tp/mod.rs index 50eef7d..450825d 100644 --- a/crates/neuron/src/harness/tp/mod.rs +++ b/crates/neuron/src/harness/tp/mod.rs @@ -17,9 +17,10 @@ //! - **7b:** TP-aware Qwen3 inference dispatched through the pool. //! - **7c:** crash detection, streaming SSE, graceful unload. +pub mod all_reduce; pub mod nccl_state; pub mod rpc; -pub mod sharded_linear; +pub mod tp_linear; pub mod worker; use anyhow::{Context, Result}; diff --git a/crates/neuron/src/harness/tp/sharded_linear.rs b/crates/neuron/src/harness/tp/sharded_linear.rs deleted file mode 100644 index 70a1df2..0000000 --- a/crates/neuron/src/harness/tp/sharded_linear.rs +++ /dev/null @@ -1,405 +0,0 @@ -//! Tensor-parallel linear layers over `candle_nn::Linear`. -//! -//! Two sharding strategies, both following the Megatron-LM convention -//! that's also what mistral.rs uses for vanilla Qwen3: -//! -//! - [`ColumnParallelLinear`] — splits the **output** dimension. Each -//! rank holds `out_features / world_size` rows of the weight matrix. -//! The forward pass is a plain local matmul; the output is *sharded* -//! (each rank produces a slice of the output vector). Used for -//! `q_proj` / `k_proj` / `v_proj` (sharding by head) and the FFN's -//! `gate_proj` / `up_proj`. -//! -//! - [`RowParallelLinear`] — splits the **input** dimension. Each -//! rank holds `in_features / world_size` columns of the weight -//! matrix and consumes a sharded input from upstream. Each rank's -//! local matmul produces a *partial* output; an `all_reduce(Sum)` -//! across ranks recovers the full activation. Used for `o_proj` -//! (after attention) and `down_proj` (after the FFN). -//! -//! Stage 7b-ii (this commit): the layers, sharded loading, local -//! forward. The `all_reduce` collective lives in `forward_with_comm` -//! and is wired up in 7b-iii when the full TP-aware Qwen3 model is -//! assembled with an NCCL Comm in scope. Tests here exercise only -//! the local (no-NCCL) math against an unsharded reference. - -use anyhow::{Context, Result}; -use candle_core::{Module, Tensor}; -use candle_nn::{Linear, VarBuilder}; - -/// Direction of the parallelism split — selects which axis of the -/// weight matrix the rank's local slice is taken from. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ShardKind { - /// Split the output dimension: rank `r` holds rows - /// `[r * out/N .. (r+1) * out/N]` of the weight matrix. The - /// downstream consumer either accepts a sharded activation - /// (the next layer is also column-parallel) or merges via - /// all-gather. - Column, - /// Split the input dimension: rank `r` holds columns - /// `[r * in/N .. (r+1) * in/N]`. The forward pass produces a - /// partial output; an `all_reduce(Sum)` across ranks yields the - /// full activation. - Row, -} - -/// A linear layer whose weights have been sharded across NCCL ranks. -/// -/// Holds a standard `candle_nn::Linear` constructed from the local -/// slice. The collective op (only meaningful for `Row`) is invoked -/// by [`forward_with_comm`] — the trait `Module::forward` does just -/// the local matmul, so callers that want correct semantics on a -/// Row-parallel layer must drive the collective themselves. -#[derive(Debug)] -pub struct ShardedLinear { - inner: Linear, - kind: ShardKind, - rank: u32, - world_size: u32, - /// Captured for diagnostics ("rank 3 layer says X but should say Y"). - /// `out_features` reflects the **logical** size (pre-shard) so the - /// caller can validate against the model config without doing the - /// arithmetic itself. - logical_out_features: usize, - logical_in_features: usize, -} - -impl ShardedLinear { - /// Load a column-parallel slice from a `VarBuilder`. Reads the - /// full weight (and bias, if any) from the safetensors and - /// narrows on dim 0 to the rank's slice. The bias is sharded the - /// same way (each rank holds its own bias slice). - /// - /// Bails if `out_features` is not divisible by `world_size` — the - /// same divisibility precondition mistral.rs's PR #2054-era code - /// added an explicit guard for after the first TP shard attempt. - pub fn load_column( - vb: &VarBuilder, - in_features: usize, - out_features: usize, - has_bias: bool, - rank: u32, - world_size: u32, - ) -> Result { - let path = vb.prefix(); - if !out_features.is_multiple_of(world_size as usize) { - anyhow::bail!( - "column-parallel '{path}': out_features={out_features} \ - not divisible by world_size={world_size}" - ); - } - let shard = out_features / world_size as usize; - let start = rank as usize * shard; - - let full_w = vb - .get((out_features, in_features), "weight") - .with_context(|| format!("load weight for column-parallel '{path}'"))?; - let weight = full_w - .narrow(0, start, shard) - .with_context(|| format!("narrow weight rows for column-parallel '{path}'"))? - .contiguous() - .with_context(|| format!("contiguous weight for column-parallel '{path}'"))?; - // Drop the full tensor as soon as we have the shard so peak - // host RAM during load tracks shard-size, not full-size, once - // all narrows complete (Rust's drop semantics handle this - // because `full_w` goes out of scope here). - drop(full_w); - - let bias = if has_bias { - let full_b = vb - .get(out_features, "bias") - .with_context(|| format!("load bias for column-parallel '{path}'"))?; - let b = full_b - .narrow(0, start, shard) - .with_context(|| format!("narrow bias for column-parallel '{path}'"))? - .contiguous() - .with_context(|| format!("contiguous bias for column-parallel '{path}'"))?; - Some(b) - } else { - None - }; - - Ok(Self { - inner: Linear::new(weight, bias), - kind: ShardKind::Column, - rank, - world_size, - logical_out_features: out_features, - logical_in_features: in_features, - }) - } - - /// Load a row-parallel slice from a `VarBuilder`. Reads the full - /// weight and narrows on dim 1 to the rank's column slice. The - /// bias, if any, lives **only on rank 0** — every other rank - /// holds `None`. This keeps the post-`all_reduce` semantics - /// correct: each rank contributes its partial sum without the - /// bias, then rank 0's bias (added in `forward_with_comm`) lands - /// on the result exactly once. - pub fn load_row( - vb: &VarBuilder, - in_features: usize, - out_features: usize, - has_bias: bool, - rank: u32, - world_size: u32, - ) -> Result { - let path = vb.prefix(); - if !in_features.is_multiple_of(world_size as usize) { - anyhow::bail!( - "row-parallel '{path}': in_features={in_features} \ - not divisible by world_size={world_size}" - ); - } - let shard = in_features / world_size as usize; - let start = rank as usize * shard; - - let full_w = vb - .get((out_features, in_features), "weight") - .with_context(|| format!("load weight for row-parallel '{path}'"))?; - let weight = full_w - .narrow(1, start, shard) - .with_context(|| format!("narrow weight cols for row-parallel '{path}'"))? - .contiguous() - .with_context(|| format!("contiguous weight for row-parallel '{path}'"))?; - drop(full_w); - - let bias = if has_bias && rank == 0 { - let b = vb - .get(out_features, "bias") - .with_context(|| format!("load bias for row-parallel '{path}'"))?; - Some(b) - } else { - None - }; - - Ok(Self { - inner: Linear::new(weight, bias), - kind: ShardKind::Row, - rank, - world_size, - logical_out_features: out_features, - logical_in_features: in_features, - }) - } - - pub fn kind(&self) -> ShardKind { - self.kind - } - - pub fn rank(&self) -> u32 { - self.rank - } - - pub fn world_size(&self) -> u32 { - self.world_size - } - - pub fn logical_in_features(&self) -> usize { - self.logical_in_features - } - - pub fn logical_out_features(&self) -> usize { - self.logical_out_features - } -} - -impl Module for ShardedLinear { - /// Local matmul only. For `Row`-parallel layers, the output is a - /// *partial sum* — call [`Self::forward_with_comm`] to get the - /// reduced result. Implementing `Module` lets a `ShardedLinear` - /// be drop-in for any `Module`-shaped consumer that doesn't need - /// the reduce step (column-parallel layers; tests). - fn forward(&self, x: &Tensor) -> candle_core::Result { - self.inner.forward(x) - } -} - -#[cfg(feature = "cuda")] -impl ShardedLinear { - /// Forward pass that issues an `all_reduce(Sum)` for row-parallel - /// layers. Column-parallel layers just delegate to the local - /// matmul (their output is naturally sharded; the next consumer - /// will either gather or accept the shard). - pub fn forward_with_comm(&self, x: &Tensor, comm: &cudarc::nccl::Comm) -> Result { - let local = self - .inner - .forward(x) - .map_err(|e| anyhow::anyhow!("local matmul: {e}"))?; - match self.kind { - ShardKind::Column => Ok(local), - ShardKind::Row => { - // TODO Stage 7b-iii: wrap `local`'s CudaSlice with a - // matching output buffer, call comm.all_reduce(Sum), - // return the result. The cudarc::nccl all_reduce - // signature takes `&S: DevicePtr` + `&mut R: DevicePtrMut`, - // both backed by `CudaSlice`. candle stores its - // Tensor data behind its own slab — extracting the - // underlying CudaSlice safely is a separate piece of - // plumbing best landed alongside the model assembly, - // so this body is a placeholder. - let _ = comm; - anyhow::bail!( - "ShardedLinear::forward_with_comm row-parallel reduce \ - lands in Stage 7b-iii alongside the model assembly; \ - 7b-ii ships only the local matmul" - ); - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use candle_core::{DType, Device, Tensor}; - use candle_nn::var_builder::VarBuilderArgs; - use std::collections::HashMap; - - /// Build a VarBuilder over an in-memory map of tensors. Used by - /// the tests to fake a safetensors source without touching disk. - fn vb_from_map(tensors: HashMap, device: &Device) -> VarBuilder<'static> { - VarBuilderArgs::from_tensors(tensors, DType::F32, device) - } - - /// World_size=2 column-parallel split of a 4x3 weight. Each rank's - /// local matmul on the same input should be 2 rows of the - /// reference (full) matmul. - #[test] - fn column_parallel_shards_output_correctly() { - let device = Device::Cpu; - // weight (out=4, in=3): rows are easy to identify by value. - let w = Tensor::from_slice( - &[ - 1f32, 2., 3., // row 0 - 4., 5., 6., // row 1 - 7., 8., 9., // row 2 - 10., 11., 12., // row 3 - ], - (4, 3), - &device, - ) - .unwrap(); - let mut tensors = HashMap::new(); - tensors.insert("foo.weight".into(), w.clone()); - let vb_root = vb_from_map(tensors, &device); - let vb_foo = vb_root.pp("foo"); - - // rank 0 of world_size 2 gets rows 0..2. - let r0 = ShardedLinear::load_column(&vb_foo, 3, 4, false, 0, 2).unwrap(); - // rank 1 gets rows 2..4. - let r1 = ShardedLinear::load_column(&vb_foo, 3, 4, false, 1, 2).unwrap(); - - let x = Tensor::from_slice(&[1f32, 0., 0.], (1, 3), &device).unwrap(); - let y0 = r0.forward(&x).unwrap().to_vec2::().unwrap(); - let y1 = r1.forward(&x).unwrap().to_vec2::().unwrap(); - // Full reference: x @ w.T → [1, 4, 7, 10]. Rank 0 owns [1, 4], - // rank 1 owns [7, 10]. - assert_eq!(y0, vec![vec![1.0, 4.0]]); - assert_eq!(y1, vec![vec![7.0, 10.0]]); - } - - /// World_size=2 row-parallel split of a 4x4 weight. Each rank's - /// local matmul on its half of the input should be a partial sum; - /// summing the two partials should equal the unsharded reference. - #[test] - fn row_parallel_partials_sum_to_full() { - let device = Device::Cpu; - // weight (out=4, in=4): use distinct values per column so the - // partial sums are obviously different. - let w = Tensor::from_slice( - &[ - 1f32, 2., 3., 4., // row 0 - 5., 6., 7., 8., // row 1 - 9., 10., 11., 12., // row 2 - 13., 14., 15., 16., // row 3 - ], - (4, 4), - &device, - ) - .unwrap(); - let mut tensors = HashMap::new(); - tensors.insert("bar.weight".into(), w.clone()); - let vb_root = vb_from_map(tensors, &device); - let vb_bar = vb_root.pp("bar"); - - let r0 = ShardedLinear::load_row(&vb_bar, 4, 4, false, 0, 2).unwrap(); - let r1 = ShardedLinear::load_row(&vb_bar, 4, 4, false, 1, 2).unwrap(); - - // x split: rank 0 takes x[..2], rank 1 takes x[2..]. - let x_full = Tensor::from_slice(&[1f32, 1., 1., 1.], (1, 4), &device).unwrap(); - let x0 = x_full.narrow(1, 0, 2).unwrap(); - let x1 = x_full.narrow(1, 2, 2).unwrap(); - - let y0 = r0.forward(&x0).unwrap(); - let y1 = r1.forward(&x1).unwrap(); - let summed = (y0 + y1).unwrap().to_vec2::().unwrap(); - - // Reference: x_full @ w.T = [1+2+3+4, 5+6+7+8, 9+10+11+12, 13+14+15+16] - // = [10, 26, 42, 58]. - assert_eq!(summed, vec![vec![10.0, 26.0, 42.0, 58.0]]); - } - - /// Row-parallel bias lives only on rank 0; other ranks have None. - /// (Verifies the rank-0-only bias contract.) - #[test] - fn row_parallel_bias_only_on_rank_zero() { - let device = Device::Cpu; - let w = Tensor::zeros((4, 4), DType::F32, &device).unwrap(); - let b = Tensor::from_slice(&[1f32, 1., 1., 1.], 4, &device).unwrap(); - let mut tensors = HashMap::new(); - tensors.insert("baz.weight".into(), w); - tensors.insert("baz.bias".into(), b); - let vb_root = vb_from_map(tensors, &device); - let vb_baz = vb_root.pp("baz"); - - let r0 = ShardedLinear::load_row(&vb_baz, 4, 4, true, 0, 2).unwrap(); - let r1 = ShardedLinear::load_row(&vb_baz, 4, 4, true, 1, 2).unwrap(); - - // We can't introspect the Linear's bias from the public API, - // but we can run forward of zero-weight rank 1 and confirm - // the output is zero (no bias added on non-zero ranks). - let x = Tensor::ones((1, 2), DType::F32, &device).unwrap(); - let y1 = r1.forward(&x).unwrap().to_vec2::().unwrap(); - assert_eq!(y1, vec![vec![0.0, 0.0, 0.0, 0.0]]); - - let y0 = r0.forward(&x).unwrap().to_vec2::().unwrap(); - // Rank 0 weight is zero but bias is [1,1,1,1] → output should be [1,1,1,1]. - assert_eq!(y0, vec![vec![1.0, 1.0, 1.0, 1.0]]); - } - - #[test] - fn column_parallel_rejects_non_divisible_out_features() { - let device = Device::Cpu; - let w = Tensor::zeros((5, 3), DType::F32, &device).unwrap(); - let mut tensors = HashMap::new(); - tensors.insert("nope.weight".into(), w); - let vb_root = vb_from_map(tensors, &device); - let vb_nope = vb_root.pp("nope"); - - let err = ShardedLinear::load_column(&vb_nope, 3, 5, false, 0, 2).unwrap_err(); - let msg = format!("{err:#}"); - assert!( - msg.contains("not divisible by world_size"), - "expected divisibility error, got: {msg}" - ); - } - - #[test] - fn row_parallel_rejects_non_divisible_in_features() { - let device = Device::Cpu; - let w = Tensor::zeros((4, 5), DType::F32, &device).unwrap(); - let mut tensors = HashMap::new(); - tensors.insert("nope.weight".into(), w); - let vb_root = vb_from_map(tensors, &device); - let vb_nope = vb_root.pp("nope"); - - let err = ShardedLinear::load_row(&vb_nope, 5, 4, false, 0, 2).unwrap_err(); - let msg = format!("{err:#}"); - assert!( - msg.contains("not divisible by world_size"), - "expected divisibility error, got: {msg}" - ); - } -} diff --git a/crates/neuron/src/harness/tp/tp_linear.rs b/crates/neuron/src/harness/tp/tp_linear.rs new file mode 100644 index 0000000..324f2dd --- /dev/null +++ b/crates/neuron/src/harness/tp/tp_linear.rs @@ -0,0 +1,134 @@ +//! Tensor-parallel linear layers built on candle's `ShardedVarBuilder` +//! and `Shard` sharding hints. +//! +//! candle reads only the rank's slice of each weight tensor from +//! safetensors via `view.slice(start..stop)` — no full-tensor host +//! materialisation. That's a memory-efficiency win over hand-rolled +//! "load full + narrow" sharding (which the earlier +//! `sharded_linear.rs` exploration demonstrated but didn't pay for). +//! +//! Two layer types: +//! +//! - [`ColumnParallelLinear`] — output-sharded; forward is a plain +//! local matmul. The downstream consumer either accepts a sharded +//! activation (next layer is also column-parallel) or all-gathers. +//! - [`RowParallelLinear`] — input-sharded; forward = local matmul +//! then `AllReduce` `CustomOp1` to sum partials across ranks. +//! +//! Both assume **no bias** — every Qwen3-family weight layout we +//! actually target (Qwen3, Qwen3-Coder, Qwen3.6 base, etc.) sets +//! `attention_bias=false` and the MLP layers are no-bias. Adding bias +//! support is mechanical when a future model needs it; the design +//! choice would be: column-parallel shards the bias along dim 0; +//! row-parallel holds the bias only on rank 0 so the post-`AllReduce` +//! sum carries it exactly once. + +use anyhow::{Context, Result}; +use candle_core::{Module, Tensor}; +use candle_nn::Linear; +use candle_nn::var_builder::{Shard, ShardedVarBuilder}; + +#[cfg(feature = "cuda")] +use super::all_reduce::AllReduce; + +/// Helper to build a [`Shard`] hint for a given dimension. +pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard { + Shard { + dim, + rank: rank as usize, + world_size: world_size as usize, + } +} + +/// Output-dim sharded linear (column-parallel). Holds a standard +/// `candle_nn::Linear` whose `weight` is the rank's slice of the full +/// `[out_features, in_features]` tensor along dim 0. +pub struct ColumnParallelLinear { + inner: Linear, +} + +impl ColumnParallelLinear { + /// Load this rank's column-parallel slice from a + /// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed + /// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`). + pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result { + let weight = vb + .get_with_hints((), "weight", shard(0, rank, world_size)) + .with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?; + Ok(Self { + inner: Linear::new(weight, None), + }) + } +} + +impl Module for ColumnParallelLinear { + fn forward(&self, x: &Tensor) -> candle_core::Result { + self.inner.forward(x) + } +} + +/// Input-dim sharded linear (row-parallel). +/// +/// Holds a sharded `Linear` plus an `AllReduce` op the forward chains +/// after the local matmul to recover the full activation. +pub struct RowParallelLinear { + inner: Linear, + #[cfg(feature = "cuda")] + all_reduce: AllReduce, + /// Whether the AllReduce should run. Column-parallel ↔ row-parallel + /// is a pair: the column output is sharded, the row input is + /// sharded, and the AllReduce gives back the full output. For + /// `world_size = 1` the AllReduce is a no-op so we skip it. + needs_reduce: bool, +} + +impl RowParallelLinear { + /// Load this rank's row-parallel slice from a `ShardedVarBuilder`. + /// + /// Under `cuda`, `comm` is the NCCL communicator the row-parallel + /// `AllReduce` runs against. On CPU builds the parameter is + /// elided — forward returns the partial sum, which is the *wrong* + /// answer for inference but lets us compile-check the model. + #[cfg(feature = "cuda")] + pub fn load( + vb: &ShardedVarBuilder, + rank: u32, + world_size: u32, + comm: std::sync::Arc, + ) -> Result { + let weight = vb + .get_with_hints((), "weight", shard(1, rank, world_size)) + .with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?; + Ok(Self { + inner: Linear::new(weight, None), + all_reduce: AllReduce::new(comm), + needs_reduce: world_size > 1, + }) + } + + #[cfg(not(feature = "cuda"))] + pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result { + let weight = vb + .get_with_hints((), "weight", shard(1, rank, world_size)) + .with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?; + Ok(Self { + inner: Linear::new(weight, None), + needs_reduce: world_size > 1, + }) + } +} + +impl Module for RowParallelLinear { + /// Local matmul followed by an `AllReduce` (when `cuda` and + /// `world_size > 1`). On CPU or single-rank, returns the partial + /// output directly — which is *only* correct for `world_size == 1`. + fn forward(&self, x: &Tensor) -> candle_core::Result { + let local = self.inner.forward(x)?; + #[cfg(feature = "cuda")] + if self.needs_reduce { + return local.apply_op1_no_bwd(&self.all_reduce); + } + let _ = self.needs_reduce; + Ok(local) + } +}