Stage 7b-iii (1/2): AllReduce CustomOp + ShardedVarBuilder-backed TP linears
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 38s
CI / Clippy (push) Successful in 2m16s
build-prerelease / Build neuron-blackwell (push) Failing after 3m19s
CI / Test (push) Successful in 4m26s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m22s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Failing after 4m58s
build-prerelease / Build neuron-ada (push) Failing after 4m53s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped

Ports the canonical
candle-examples/examples/llama_multiprocess/model.rs pattern into
the harness. Two new files, one deletion:

- harness/tp/all_reduce.rs — AllReduce wraps Arc<cudarc::nccl::Comm>
  and implements candle's CustomOp1 trait. cuda_fwd extracts the
  rank's CudaSlice<dtype> from a CudaStorage, asserts the input is
  contiguous (a strided activation hitting all_reduce is almost
  always a model construction bug), allocates an output CudaSlice
  on the same device, calls Comm::all_reduce(Sum), and wraps the
  result back as a CudaStorage. Handles BF16, F16, F32. NcclError
  surfaces via {e:?} (no Display impl in cudarc 0.19.x). Send/Sync
  hand-impl'd with the same NCCL-thread-safety caveat candle's
  example documents.

- harness/tp/tp_linear.rs — ColumnParallelLinear and
  RowParallelLinear, both built on candle's ShardedVarBuilder +
  Shard hints. `vb.get_with_hints((), "weight", shard(dim, rank, ws))`
  reads JUST the rank's slice from the safetensors view; no full-
  tensor host materialisation. ColumnParallel.forward is a plain
  local matmul (output is naturally sharded). RowParallel.forward =
  local matmul + apply_op1_no_bwd(&self.all_reduce). On CPU /
  world_size == 1, the AllReduce is skipped and the partial output
  is returned as-is. Both layers are no-bias — every Qwen3-family
  target sets attention_bias=false; bias-aware sharding is a
  future-model concern.

- Deletes harness/tp/sharded_linear.rs from 7b-ii. That commit's
  hand-rolled "load full + narrow" approach was useful exploration
  but candle's ShardedVarBuilder does the same work without
  materialising the full tensor on host. The 5 unit tests there
  verified the slicing math against an unsharded reference; that
  math now lives inside candle and is covered by candle's own tests.

Next (7b-iii 2/2): TpQwen3Attention + TpQwen3MLP composing the
column/row pair, then a TpQwen3Model that runs the full forward.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-19 18:14:54 +03:00
parent 5436af9c73
commit 8d3194f992
4 changed files with 257 additions and 406 deletions

View File

@@ -0,0 +1,121 @@
//! `AllReduce` as a candle `CustomOp1` — the bridge between candle's
//! `Tensor` graph and `cudarc::nccl::Comm::all_reduce`.
//!
//! Ported from the canonical
//! `candle-examples/examples/llama_multiprocess/model.rs` pattern.
//! Row-parallel layers apply this op after their local matmul to sum
//! partial outputs across NCCL ranks.
//!
//! Available only under `--features cuda`; on CPU builds this module
//! is empty and row-parallel layers degenerate to local matmul only
//! (useful for compile-checking the model code; correctness requires
//! cuda).
//!
//! Thread-safety caveat: NCCL communicators are technically only
//! safe to use from a single thread at a time
//! (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html).
//! We hold the `AllReduce` behind an `Arc<Comm>` and only issue ops
//! against it from the dedicated `spawn_blocking` thread the inference
//! pipeline already uses for candle's forward passes.
#![cfg(feature = "cuda")]
use candle_core::cuda_backend::WrapErr;
use candle_core::{CpuStorage, CudaStorage, CustomOp1, DType, Layout, Result, Shape};
use cudarc::nccl::{Comm, ReduceOp};
use half::{bf16, f16};
use std::sync::Arc;
/// Wraps an NCCL `Comm` so it can be plugged into a candle forward
/// graph as a custom op. Each row-parallel layer holds one of these.
pub struct AllReduce {
comm: Arc<Comm>,
}
// SAFETY: `Comm` contains a raw `ncclComm_t` pointer; NCCL's docs note
// that issuing ops against one comm from multiple threads concurrently
// is unsafe. We serialise via the single spawn_blocking thread that
// drives the model's forward pass. The Send/Sync impl is necessary
// because candle's CustomOp1 trait bounds require it; the correctness
// invariant is enforced at the call site, not the type level.
unsafe impl Send for AllReduce {}
unsafe impl Sync for AllReduce {}
impl AllReduce {
pub fn new(comm: Arc<Comm>) -> Self {
Self { comm }
}
pub fn comm(&self) -> &Arc<Comm> {
&self.comm
}
}
impl CustomOp1 for AllReduce {
fn name(&self) -> &'static str {
"neuron.tp.all_reduce"
}
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
candle_core::bail!("AllReduce custom-op invoked on CPU storage; TP requires CUDA")
}
fn cuda_fwd(&self, s: &CudaStorage, l: &Layout) -> Result<(CudaStorage, Shape)> {
use cudarc::driver::DeviceSlice;
// Reject non-contiguous inputs explicitly — copying them
// server-side would mask shape bugs (a TP layer feeding a
// strided activation into all_reduce is almost certainly a
// model construction error).
fn require_contiguous<T: cudarc::driver::DeviceRepr>(
slice: &cudarc::driver::CudaSlice<T>,
l: &Layout,
) -> Result<()> {
match l.contiguous_offsets() {
Some((0, n)) if n == slice.len() => Ok(()),
_ => candle_core::bail!(
"AllReduce input is non-contiguous: layout={:?}, slice_len={}",
l,
slice.len()
),
}
}
let elem_count = l.shape().elem_count();
let dev = s.device().clone();
let out = match s.dtype() {
DType::BF16 => {
let src = s.as_cuda_slice::<bf16>()?;
require_contiguous(src, l)?;
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
self.comm
.all_reduce(src, &mut dst, &ReduceOp::Sum)
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce bf16: {e:?}")))?;
CudaStorage::wrap_cuda_slice(dst, dev)
}
DType::F16 => {
let src = s.as_cuda_slice::<f16>()?;
require_contiguous(src, l)?;
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
self.comm
.all_reduce(src, &mut dst, &ReduceOp::Sum)
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f16: {e:?}")))?;
CudaStorage::wrap_cuda_slice(dst, dev)
}
DType::F32 => {
let src = s.as_cuda_slice::<f32>()?;
require_contiguous(src, l)?;
let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
self.comm
.all_reduce(src, &mut dst, &ReduceOp::Sum)
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f32: {e:?}")))?;
CudaStorage::wrap_cuda_slice(dst, dev)
}
dtype => candle_core::bail!(
"AllReduce: unsupported dtype {dtype:?}; TP path expects bf16/f16/f32"
),
};
Ok((out, l.shape().clone()))
}
}

View File

@@ -17,9 +17,10 @@
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
//! - **7c:** crash detection, streaming SSE, graceful unload.
pub mod all_reduce;
pub mod nccl_state;
pub mod rpc;
pub mod sharded_linear;
pub mod tp_linear;
pub mod worker;
use anyhow::{Context, Result};

View File

@@ -1,405 +0,0 @@
//! Tensor-parallel linear layers over `candle_nn::Linear`.
//!
//! Two sharding strategies, both following the Megatron-LM convention
//! that's also what mistral.rs uses for vanilla Qwen3:
//!
//! - [`ColumnParallelLinear`] — splits the **output** dimension. Each
//! rank holds `out_features / world_size` rows of the weight matrix.
//! The forward pass is a plain local matmul; the output is *sharded*
//! (each rank produces a slice of the output vector). Used for
//! `q_proj` / `k_proj` / `v_proj` (sharding by head) and the FFN's
//! `gate_proj` / `up_proj`.
//!
//! - [`RowParallelLinear`] — splits the **input** dimension. Each
//! rank holds `in_features / world_size` columns of the weight
//! matrix and consumes a sharded input from upstream. Each rank's
//! local matmul produces a *partial* output; an `all_reduce(Sum)`
//! across ranks recovers the full activation. Used for `o_proj`
//! (after attention) and `down_proj` (after the FFN).
//!
//! Stage 7b-ii (this commit): the layers, sharded loading, local
//! forward. The `all_reduce` collective lives in `forward_with_comm`
//! and is wired up in 7b-iii when the full TP-aware Qwen3 model is
//! assembled with an NCCL Comm in scope. Tests here exercise only
//! the local (no-NCCL) math against an unsharded reference.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::{Linear, VarBuilder};
/// Direction of the parallelism split — selects which axis of the
/// weight matrix the rank's local slice is taken from.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShardKind {
/// Split the output dimension: rank `r` holds rows
/// `[r * out/N .. (r+1) * out/N]` of the weight matrix. The
/// downstream consumer either accepts a sharded activation
/// (the next layer is also column-parallel) or merges via
/// all-gather.
Column,
/// Split the input dimension: rank `r` holds columns
/// `[r * in/N .. (r+1) * in/N]`. The forward pass produces a
/// partial output; an `all_reduce(Sum)` across ranks yields the
/// full activation.
Row,
}
/// A linear layer whose weights have been sharded across NCCL ranks.
///
/// Holds a standard `candle_nn::Linear` constructed from the local
/// slice. The collective op (only meaningful for `Row`) is invoked
/// by [`forward_with_comm`] — the trait `Module::forward` does just
/// the local matmul, so callers that want correct semantics on a
/// Row-parallel layer must drive the collective themselves.
#[derive(Debug)]
pub struct ShardedLinear {
inner: Linear,
kind: ShardKind,
rank: u32,
world_size: u32,
/// Captured for diagnostics ("rank 3 layer says X but should say Y").
/// `out_features` reflects the **logical** size (pre-shard) so the
/// caller can validate against the model config without doing the
/// arithmetic itself.
logical_out_features: usize,
logical_in_features: usize,
}
impl ShardedLinear {
/// Load a column-parallel slice from a `VarBuilder`. Reads the
/// full weight (and bias, if any) from the safetensors and
/// narrows on dim 0 to the rank's slice. The bias is sharded the
/// same way (each rank holds its own bias slice).
///
/// Bails if `out_features` is not divisible by `world_size` — the
/// same divisibility precondition mistral.rs's PR #2054-era code
/// added an explicit guard for after the first TP shard attempt.
pub fn load_column(
vb: &VarBuilder,
in_features: usize,
out_features: usize,
has_bias: bool,
rank: u32,
world_size: u32,
) -> Result<Self> {
let path = vb.prefix();
if !out_features.is_multiple_of(world_size as usize) {
anyhow::bail!(
"column-parallel '{path}': out_features={out_features} \
not divisible by world_size={world_size}"
);
}
let shard = out_features / world_size as usize;
let start = rank as usize * shard;
let full_w = vb
.get((out_features, in_features), "weight")
.with_context(|| format!("load weight for column-parallel '{path}'"))?;
let weight = full_w
.narrow(0, start, shard)
.with_context(|| format!("narrow weight rows for column-parallel '{path}'"))?
.contiguous()
.with_context(|| format!("contiguous weight for column-parallel '{path}'"))?;
// Drop the full tensor as soon as we have the shard so peak
// host RAM during load tracks shard-size, not full-size, once
// all narrows complete (Rust's drop semantics handle this
// because `full_w` goes out of scope here).
drop(full_w);
let bias = if has_bias {
let full_b = vb
.get(out_features, "bias")
.with_context(|| format!("load bias for column-parallel '{path}'"))?;
let b = full_b
.narrow(0, start, shard)
.with_context(|| format!("narrow bias for column-parallel '{path}'"))?
.contiguous()
.with_context(|| format!("contiguous bias for column-parallel '{path}'"))?;
Some(b)
} else {
None
};
Ok(Self {
inner: Linear::new(weight, bias),
kind: ShardKind::Column,
rank,
world_size,
logical_out_features: out_features,
logical_in_features: in_features,
})
}
/// Load a row-parallel slice from a `VarBuilder`. Reads the full
/// weight and narrows on dim 1 to the rank's column slice. The
/// bias, if any, lives **only on rank 0** — every other rank
/// holds `None`. This keeps the post-`all_reduce` semantics
/// correct: each rank contributes its partial sum without the
/// bias, then rank 0's bias (added in `forward_with_comm`) lands
/// on the result exactly once.
pub fn load_row(
vb: &VarBuilder,
in_features: usize,
out_features: usize,
has_bias: bool,
rank: u32,
world_size: u32,
) -> Result<Self> {
let path = vb.prefix();
if !in_features.is_multiple_of(world_size as usize) {
anyhow::bail!(
"row-parallel '{path}': in_features={in_features} \
not divisible by world_size={world_size}"
);
}
let shard = in_features / world_size as usize;
let start = rank as usize * shard;
let full_w = vb
.get((out_features, in_features), "weight")
.with_context(|| format!("load weight for row-parallel '{path}'"))?;
let weight = full_w
.narrow(1, start, shard)
.with_context(|| format!("narrow weight cols for row-parallel '{path}'"))?
.contiguous()
.with_context(|| format!("contiguous weight for row-parallel '{path}'"))?;
drop(full_w);
let bias = if has_bias && rank == 0 {
let b = vb
.get(out_features, "bias")
.with_context(|| format!("load bias for row-parallel '{path}'"))?;
Some(b)
} else {
None
};
Ok(Self {
inner: Linear::new(weight, bias),
kind: ShardKind::Row,
rank,
world_size,
logical_out_features: out_features,
logical_in_features: in_features,
})
}
pub fn kind(&self) -> ShardKind {
self.kind
}
pub fn rank(&self) -> u32 {
self.rank
}
pub fn world_size(&self) -> u32 {
self.world_size
}
pub fn logical_in_features(&self) -> usize {
self.logical_in_features
}
pub fn logical_out_features(&self) -> usize {
self.logical_out_features
}
}
impl Module for ShardedLinear {
/// Local matmul only. For `Row`-parallel layers, the output is a
/// *partial sum* — call [`Self::forward_with_comm`] to get the
/// reduced result. Implementing `Module` lets a `ShardedLinear`
/// be drop-in for any `Module`-shaped consumer that doesn't need
/// the reduce step (column-parallel layers; tests).
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
self.inner.forward(x)
}
}
#[cfg(feature = "cuda")]
impl ShardedLinear {
/// Forward pass that issues an `all_reduce(Sum)` for row-parallel
/// layers. Column-parallel layers just delegate to the local
/// matmul (their output is naturally sharded; the next consumer
/// will either gather or accept the shard).
pub fn forward_with_comm(&self, x: &Tensor, comm: &cudarc::nccl::Comm) -> Result<Tensor> {
let local = self
.inner
.forward(x)
.map_err(|e| anyhow::anyhow!("local matmul: {e}"))?;
match self.kind {
ShardKind::Column => Ok(local),
ShardKind::Row => {
// TODO Stage 7b-iii: wrap `local`'s CudaSlice with a
// matching output buffer, call comm.all_reduce(Sum),
// return the result. The cudarc::nccl all_reduce
// signature takes `&S: DevicePtr<T>` + `&mut R: DevicePtrMut<T>`,
// both backed by `CudaSlice<T>`. candle stores its
// Tensor data behind its own slab — extracting the
// underlying CudaSlice safely is a separate piece of
// plumbing best landed alongside the model assembly,
// so this body is a placeholder.
let _ = comm;
anyhow::bail!(
"ShardedLinear::forward_with_comm row-parallel reduce \
lands in Stage 7b-iii alongside the model assembly; \
7b-ii ships only the local matmul"
);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device, Tensor};
use candle_nn::var_builder::VarBuilderArgs;
use std::collections::HashMap;
/// Build a VarBuilder over an in-memory map of tensors. Used by
/// the tests to fake a safetensors source without touching disk.
fn vb_from_map(tensors: HashMap<String, Tensor>, device: &Device) -> VarBuilder<'static> {
VarBuilderArgs::from_tensors(tensors, DType::F32, device)
}
/// World_size=2 column-parallel split of a 4x3 weight. Each rank's
/// local matmul on the same input should be 2 rows of the
/// reference (full) matmul.
#[test]
fn column_parallel_shards_output_correctly() {
let device = Device::Cpu;
// weight (out=4, in=3): rows are easy to identify by value.
let w = Tensor::from_slice(
&[
1f32, 2., 3., // row 0
4., 5., 6., // row 1
7., 8., 9., // row 2
10., 11., 12., // row 3
],
(4, 3),
&device,
)
.unwrap();
let mut tensors = HashMap::new();
tensors.insert("foo.weight".into(), w.clone());
let vb_root = vb_from_map(tensors, &device);
let vb_foo = vb_root.pp("foo");
// rank 0 of world_size 2 gets rows 0..2.
let r0 = ShardedLinear::load_column(&vb_foo, 3, 4, false, 0, 2).unwrap();
// rank 1 gets rows 2..4.
let r1 = ShardedLinear::load_column(&vb_foo, 3, 4, false, 1, 2).unwrap();
let x = Tensor::from_slice(&[1f32, 0., 0.], (1, 3), &device).unwrap();
let y0 = r0.forward(&x).unwrap().to_vec2::<f32>().unwrap();
let y1 = r1.forward(&x).unwrap().to_vec2::<f32>().unwrap();
// Full reference: x @ w.T → [1, 4, 7, 10]. Rank 0 owns [1, 4],
// rank 1 owns [7, 10].
assert_eq!(y0, vec![vec![1.0, 4.0]]);
assert_eq!(y1, vec![vec![7.0, 10.0]]);
}
/// World_size=2 row-parallel split of a 4x4 weight. Each rank's
/// local matmul on its half of the input should be a partial sum;
/// summing the two partials should equal the unsharded reference.
#[test]
fn row_parallel_partials_sum_to_full() {
let device = Device::Cpu;
// weight (out=4, in=4): use distinct values per column so the
// partial sums are obviously different.
let w = Tensor::from_slice(
&[
1f32, 2., 3., 4., // row 0
5., 6., 7., 8., // row 1
9., 10., 11., 12., // row 2
13., 14., 15., 16., // row 3
],
(4, 4),
&device,
)
.unwrap();
let mut tensors = HashMap::new();
tensors.insert("bar.weight".into(), w.clone());
let vb_root = vb_from_map(tensors, &device);
let vb_bar = vb_root.pp("bar");
let r0 = ShardedLinear::load_row(&vb_bar, 4, 4, false, 0, 2).unwrap();
let r1 = ShardedLinear::load_row(&vb_bar, 4, 4, false, 1, 2).unwrap();
// x split: rank 0 takes x[..2], rank 1 takes x[2..].
let x_full = Tensor::from_slice(&[1f32, 1., 1., 1.], (1, 4), &device).unwrap();
let x0 = x_full.narrow(1, 0, 2).unwrap();
let x1 = x_full.narrow(1, 2, 2).unwrap();
let y0 = r0.forward(&x0).unwrap();
let y1 = r1.forward(&x1).unwrap();
let summed = (y0 + y1).unwrap().to_vec2::<f32>().unwrap();
// Reference: x_full @ w.T = [1+2+3+4, 5+6+7+8, 9+10+11+12, 13+14+15+16]
// = [10, 26, 42, 58].
assert_eq!(summed, vec![vec![10.0, 26.0, 42.0, 58.0]]);
}
/// Row-parallel bias lives only on rank 0; other ranks have None.
/// (Verifies the rank-0-only bias contract.)
#[test]
fn row_parallel_bias_only_on_rank_zero() {
let device = Device::Cpu;
let w = Tensor::zeros((4, 4), DType::F32, &device).unwrap();
let b = Tensor::from_slice(&[1f32, 1., 1., 1.], 4, &device).unwrap();
let mut tensors = HashMap::new();
tensors.insert("baz.weight".into(), w);
tensors.insert("baz.bias".into(), b);
let vb_root = vb_from_map(tensors, &device);
let vb_baz = vb_root.pp("baz");
let r0 = ShardedLinear::load_row(&vb_baz, 4, 4, true, 0, 2).unwrap();
let r1 = ShardedLinear::load_row(&vb_baz, 4, 4, true, 1, 2).unwrap();
// We can't introspect the Linear's bias from the public API,
// but we can run forward of zero-weight rank 1 and confirm
// the output is zero (no bias added on non-zero ranks).
let x = Tensor::ones((1, 2), DType::F32, &device).unwrap();
let y1 = r1.forward(&x).unwrap().to_vec2::<f32>().unwrap();
assert_eq!(y1, vec![vec![0.0, 0.0, 0.0, 0.0]]);
let y0 = r0.forward(&x).unwrap().to_vec2::<f32>().unwrap();
// Rank 0 weight is zero but bias is [1,1,1,1] → output should be [1,1,1,1].
assert_eq!(y0, vec![vec![1.0, 1.0, 1.0, 1.0]]);
}
#[test]
fn column_parallel_rejects_non_divisible_out_features() {
let device = Device::Cpu;
let w = Tensor::zeros((5, 3), DType::F32, &device).unwrap();
let mut tensors = HashMap::new();
tensors.insert("nope.weight".into(), w);
let vb_root = vb_from_map(tensors, &device);
let vb_nope = vb_root.pp("nope");
let err = ShardedLinear::load_column(&vb_nope, 3, 5, false, 0, 2).unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("not divisible by world_size"),
"expected divisibility error, got: {msg}"
);
}
#[test]
fn row_parallel_rejects_non_divisible_in_features() {
let device = Device::Cpu;
let w = Tensor::zeros((4, 5), DType::F32, &device).unwrap();
let mut tensors = HashMap::new();
tensors.insert("nope.weight".into(), w);
let vb_root = vb_from_map(tensors, &device);
let vb_nope = vb_root.pp("nope");
let err = ShardedLinear::load_row(&vb_nope, 5, 4, false, 0, 2).unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("not divisible by world_size"),
"expected divisibility error, got: {msg}"
);
}
}

View File

@@ -0,0 +1,134 @@
//! Tensor-parallel linear layers built on candle's `ShardedVarBuilder`
//! and `Shard` sharding hints.
//!
//! candle reads only the rank's slice of each weight tensor from
//! safetensors via `view.slice(start..stop)` — no full-tensor host
//! materialisation. That's a memory-efficiency win over hand-rolled
//! "load full + narrow" sharding (which the earlier
//! `sharded_linear.rs` exploration demonstrated but didn't pay for).
//!
//! Two layer types:
//!
//! - [`ColumnParallelLinear`] — output-sharded; forward is a plain
//! local matmul. The downstream consumer either accepts a sharded
//! activation (next layer is also column-parallel) or all-gathers.
//! - [`RowParallelLinear`] — input-sharded; forward = local matmul
//! then `AllReduce` `CustomOp1` to sum partials across ranks.
//!
//! Both assume **no bias** — every Qwen3-family weight layout we
//! actually target (Qwen3, Qwen3-Coder, Qwen3.6 base, etc.) sets
//! `attention_bias=false` and the MLP layers are no-bias. Adding bias
//! support is mechanical when a future model needs it; the design
//! choice would be: column-parallel shards the bias along dim 0;
//! row-parallel holds the bias only on rank 0 so the post-`AllReduce`
//! sum carries it exactly once.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::{Shard, ShardedVarBuilder};
#[cfg(feature = "cuda")]
use super::all_reduce::AllReduce;
/// Helper to build a [`Shard`] hint for a given dimension.
pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
Shard {
dim,
rank: rank as usize,
world_size: world_size as usize,
}
}
/// Output-dim sharded linear (column-parallel). Holds a standard
/// `candle_nn::Linear` whose `weight` is the rank's slice of the full
/// `[out_features, in_features]` tensor along dim 0.
pub struct ColumnParallelLinear {
inner: Linear,
}
impl ColumnParallelLinear {
/// Load this rank's column-parallel slice from a
/// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed
/// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`).
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
let weight = vb
.get_with_hints((), "weight", shard(0, rank, world_size))
.with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?;
Ok(Self {
inner: Linear::new(weight, None),
})
}
}
impl Module for ColumnParallelLinear {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
self.inner.forward(x)
}
}
/// Input-dim sharded linear (row-parallel).
///
/// Holds a sharded `Linear` plus an `AllReduce` op the forward chains
/// after the local matmul to recover the full activation.
pub struct RowParallelLinear {
inner: Linear,
#[cfg(feature = "cuda")]
all_reduce: AllReduce,
/// Whether the AllReduce should run. Column-parallel ↔ row-parallel
/// is a pair: the column output is sharded, the row input is
/// sharded, and the AllReduce gives back the full output. For
/// `world_size = 1` the AllReduce is a no-op so we skip it.
needs_reduce: bool,
}
impl RowParallelLinear {
/// Load this rank's row-parallel slice from a `ShardedVarBuilder`.
///
/// Under `cuda`, `comm` is the NCCL communicator the row-parallel
/// `AllReduce` runs against. On CPU builds the parameter is
/// elided — forward returns the partial sum, which is the *wrong*
/// answer for inference but lets us compile-check the model.
#[cfg(feature = "cuda")]
pub fn load(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: std::sync::Arc<cudarc::nccl::Comm>,
) -> Result<Self> {
let weight = vb
.get_with_hints((), "weight", shard(1, rank, world_size))
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
Ok(Self {
inner: Linear::new(weight, None),
all_reduce: AllReduce::new(comm),
needs_reduce: world_size > 1,
})
}
#[cfg(not(feature = "cuda"))]
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
let weight = vb
.get_with_hints((), "weight", shard(1, rank, world_size))
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
Ok(Self {
inner: Linear::new(weight, None),
needs_reduce: world_size > 1,
})
}
}
impl Module for RowParallelLinear {
/// Local matmul followed by an `AllReduce` (when `cuda` and
/// `world_size > 1`). On CPU or single-rank, returns the partial
/// output directly — which is *only* correct for `world_size == 1`.
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let local = self.inner.forward(x)?;
#[cfg(feature = "cuda")]
if self.needs_reduce {
return local.apply_op1_no_bwd(&self.all_reduce);
}
let _ = self.needs_reduce;
Ok(local)
}
}