Stage 7b-iii (1/2): AllReduce CustomOp + ShardedVarBuilder-backed TP linears
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 38s
CI / Clippy (push) Successful in 2m16s
build-prerelease / Build neuron-blackwell (push) Failing after 3m19s
CI / Test (push) Successful in 4m26s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m22s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Failing after 4m58s
build-prerelease / Build neuron-ada (push) Failing after 4m53s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 38s
CI / Clippy (push) Successful in 2m16s
build-prerelease / Build neuron-blackwell (push) Failing after 3m19s
CI / Test (push) Successful in 4m26s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m22s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Failing after 4m58s
build-prerelease / Build neuron-ada (push) Failing after 4m53s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Ports the canonical
candle-examples/examples/llama_multiprocess/model.rs pattern into
the harness. Two new files, one deletion:
- harness/tp/all_reduce.rs — AllReduce wraps Arc<cudarc::nccl::Comm>
and implements candle's CustomOp1 trait. cuda_fwd extracts the
rank's CudaSlice<dtype> from a CudaStorage, asserts the input is
contiguous (a strided activation hitting all_reduce is almost
always a model construction bug), allocates an output CudaSlice
on the same device, calls Comm::all_reduce(Sum), and wraps the
result back as a CudaStorage. Handles BF16, F16, F32. NcclError
surfaces via {e:?} (no Display impl in cudarc 0.19.x). Send/Sync
hand-impl'd with the same NCCL-thread-safety caveat candle's
example documents.
- harness/tp/tp_linear.rs — ColumnParallelLinear and
RowParallelLinear, both built on candle's ShardedVarBuilder +
Shard hints. `vb.get_with_hints((), "weight", shard(dim, rank, ws))`
reads JUST the rank's slice from the safetensors view; no full-
tensor host materialisation. ColumnParallel.forward is a plain
local matmul (output is naturally sharded). RowParallel.forward =
local matmul + apply_op1_no_bwd(&self.all_reduce). On CPU /
world_size == 1, the AllReduce is skipped and the partial output
is returned as-is. Both layers are no-bias — every Qwen3-family
target sets attention_bias=false; bias-aware sharding is a
future-model concern.
- Deletes harness/tp/sharded_linear.rs from 7b-ii. That commit's
hand-rolled "load full + narrow" approach was useful exploration
but candle's ShardedVarBuilder does the same work without
materialising the full tensor on host. The 5 unit tests there
verified the slicing math against an unsharded reference; that
math now lives inside candle and is covered by candle's own tests.
Next (7b-iii 2/2): TpQwen3Attention + TpQwen3MLP composing the
column/row pair, then a TpQwen3Model that runs the full forward.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
121
crates/neuron/src/harness/tp/all_reduce.rs
Normal file
121
crates/neuron/src/harness/tp/all_reduce.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
//! `AllReduce` as a candle `CustomOp1` — the bridge between candle's
|
||||
//! `Tensor` graph and `cudarc::nccl::Comm::all_reduce`.
|
||||
//!
|
||||
//! Ported from the canonical
|
||||
//! `candle-examples/examples/llama_multiprocess/model.rs` pattern.
|
||||
//! Row-parallel layers apply this op after their local matmul to sum
|
||||
//! partial outputs across NCCL ranks.
|
||||
//!
|
||||
//! Available only under `--features cuda`; on CPU builds this module
|
||||
//! is empty and row-parallel layers degenerate to local matmul only
|
||||
//! (useful for compile-checking the model code; correctness requires
|
||||
//! cuda).
|
||||
//!
|
||||
//! Thread-safety caveat: NCCL communicators are technically only
|
||||
//! safe to use from a single thread at a time
|
||||
//! (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html).
|
||||
//! We hold the `AllReduce` behind an `Arc<Comm>` and only issue ops
|
||||
//! against it from the dedicated `spawn_blocking` thread the inference
|
||||
//! pipeline already uses for candle's forward passes.
|
||||
|
||||
#![cfg(feature = "cuda")]
|
||||
|
||||
use candle_core::cuda_backend::WrapErr;
|
||||
use candle_core::{CpuStorage, CudaStorage, CustomOp1, DType, Layout, Result, Shape};
|
||||
use cudarc::nccl::{Comm, ReduceOp};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Wraps an NCCL `Comm` so it can be plugged into a candle forward
|
||||
/// graph as a custom op. Each row-parallel layer holds one of these.
|
||||
pub struct AllReduce {
|
||||
comm: Arc<Comm>,
|
||||
}
|
||||
|
||||
// SAFETY: `Comm` contains a raw `ncclComm_t` pointer; NCCL's docs note
|
||||
// that issuing ops against one comm from multiple threads concurrently
|
||||
// is unsafe. We serialise via the single spawn_blocking thread that
|
||||
// drives the model's forward pass. The Send/Sync impl is necessary
|
||||
// because candle's CustomOp1 trait bounds require it; the correctness
|
||||
// invariant is enforced at the call site, not the type level.
|
||||
unsafe impl Send for AllReduce {}
|
||||
unsafe impl Sync for AllReduce {}
|
||||
|
||||
impl AllReduce {
|
||||
pub fn new(comm: Arc<Comm>) -> Self {
|
||||
Self { comm }
|
||||
}
|
||||
|
||||
pub fn comm(&self) -> &Arc<Comm> {
|
||||
&self.comm
|
||||
}
|
||||
}
|
||||
|
||||
impl CustomOp1 for AllReduce {
|
||||
fn name(&self) -> &'static str {
|
||||
"neuron.tp.all_reduce"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
candle_core::bail!("AllReduce custom-op invoked on CPU storage; TP requires CUDA")
|
||||
}
|
||||
|
||||
fn cuda_fwd(&self, s: &CudaStorage, l: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||
use cudarc::driver::DeviceSlice;
|
||||
|
||||
// Reject non-contiguous inputs explicitly — copying them
|
||||
// server-side would mask shape bugs (a TP layer feeding a
|
||||
// strided activation into all_reduce is almost certainly a
|
||||
// model construction error).
|
||||
fn require_contiguous<T: cudarc::driver::DeviceRepr>(
|
||||
slice: &cudarc::driver::CudaSlice<T>,
|
||||
l: &Layout,
|
||||
) -> Result<()> {
|
||||
match l.contiguous_offsets() {
|
||||
Some((0, n)) if n == slice.len() => Ok(()),
|
||||
_ => candle_core::bail!(
|
||||
"AllReduce input is non-contiguous: layout={:?}, slice_len={}",
|
||||
l,
|
||||
slice.len()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
let elem_count = l.shape().elem_count();
|
||||
let dev = s.device().clone();
|
||||
|
||||
let out = match s.dtype() {
|
||||
DType::BF16 => {
|
||||
let src = s.as_cuda_slice::<bf16>()?;
|
||||
require_contiguous(src, l)?;
|
||||
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||
self.comm
|
||||
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce bf16: {e:?}")))?;
|
||||
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
DType::F16 => {
|
||||
let src = s.as_cuda_slice::<f16>()?;
|
||||
require_contiguous(src, l)?;
|
||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
self.comm
|
||||
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f16: {e:?}")))?;
|
||||
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
DType::F32 => {
|
||||
let src = s.as_cuda_slice::<f32>()?;
|
||||
require_contiguous(src, l)?;
|
||||
let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||
self.comm
|
||||
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f32: {e:?}")))?;
|
||||
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
dtype => candle_core::bail!(
|
||||
"AllReduce: unsupported dtype {dtype:?}; TP path expects bf16/f16/f32"
|
||||
),
|
||||
};
|
||||
Ok((out, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
@@ -17,9 +17,10 @@
|
||||
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
|
||||
//! - **7c:** crash detection, streaming SSE, graceful unload.
|
||||
|
||||
pub mod all_reduce;
|
||||
pub mod nccl_state;
|
||||
pub mod rpc;
|
||||
pub mod sharded_linear;
|
||||
pub mod tp_linear;
|
||||
pub mod worker;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
|
||||
@@ -1,405 +0,0 @@
|
||||
//! Tensor-parallel linear layers over `candle_nn::Linear`.
|
||||
//!
|
||||
//! Two sharding strategies, both following the Megatron-LM convention
|
||||
//! that's also what mistral.rs uses for vanilla Qwen3:
|
||||
//!
|
||||
//! - [`ColumnParallelLinear`] — splits the **output** dimension. Each
|
||||
//! rank holds `out_features / world_size` rows of the weight matrix.
|
||||
//! The forward pass is a plain local matmul; the output is *sharded*
|
||||
//! (each rank produces a slice of the output vector). Used for
|
||||
//! `q_proj` / `k_proj` / `v_proj` (sharding by head) and the FFN's
|
||||
//! `gate_proj` / `up_proj`.
|
||||
//!
|
||||
//! - [`RowParallelLinear`] — splits the **input** dimension. Each
|
||||
//! rank holds `in_features / world_size` columns of the weight
|
||||
//! matrix and consumes a sharded input from upstream. Each rank's
|
||||
//! local matmul produces a *partial* output; an `all_reduce(Sum)`
|
||||
//! across ranks recovers the full activation. Used for `o_proj`
|
||||
//! (after attention) and `down_proj` (after the FFN).
|
||||
//!
|
||||
//! Stage 7b-ii (this commit): the layers, sharded loading, local
|
||||
//! forward. The `all_reduce` collective lives in `forward_with_comm`
|
||||
//! and is wired up in 7b-iii when the full TP-aware Qwen3 model is
|
||||
//! assembled with an NCCL Comm in scope. Tests here exercise only
|
||||
//! the local (no-NCCL) math against an unsharded reference.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{Module, Tensor};
|
||||
use candle_nn::{Linear, VarBuilder};
|
||||
|
||||
/// Direction of the parallelism split — selects which axis of the
|
||||
/// weight matrix the rank's local slice is taken from.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ShardKind {
|
||||
/// Split the output dimension: rank `r` holds rows
|
||||
/// `[r * out/N .. (r+1) * out/N]` of the weight matrix. The
|
||||
/// downstream consumer either accepts a sharded activation
|
||||
/// (the next layer is also column-parallel) or merges via
|
||||
/// all-gather.
|
||||
Column,
|
||||
/// Split the input dimension: rank `r` holds columns
|
||||
/// `[r * in/N .. (r+1) * in/N]`. The forward pass produces a
|
||||
/// partial output; an `all_reduce(Sum)` across ranks yields the
|
||||
/// full activation.
|
||||
Row,
|
||||
}
|
||||
|
||||
/// A linear layer whose weights have been sharded across NCCL ranks.
|
||||
///
|
||||
/// Holds a standard `candle_nn::Linear` constructed from the local
|
||||
/// slice. The collective op (only meaningful for `Row`) is invoked
|
||||
/// by [`forward_with_comm`] — the trait `Module::forward` does just
|
||||
/// the local matmul, so callers that want correct semantics on a
|
||||
/// Row-parallel layer must drive the collective themselves.
|
||||
#[derive(Debug)]
|
||||
pub struct ShardedLinear {
|
||||
inner: Linear,
|
||||
kind: ShardKind,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
/// Captured for diagnostics ("rank 3 layer says X but should say Y").
|
||||
/// `out_features` reflects the **logical** size (pre-shard) so the
|
||||
/// caller can validate against the model config without doing the
|
||||
/// arithmetic itself.
|
||||
logical_out_features: usize,
|
||||
logical_in_features: usize,
|
||||
}
|
||||
|
||||
impl ShardedLinear {
|
||||
/// Load a column-parallel slice from a `VarBuilder`. Reads the
|
||||
/// full weight (and bias, if any) from the safetensors and
|
||||
/// narrows on dim 0 to the rank's slice. The bias is sharded the
|
||||
/// same way (each rank holds its own bias slice).
|
||||
///
|
||||
/// Bails if `out_features` is not divisible by `world_size` — the
|
||||
/// same divisibility precondition mistral.rs's PR #2054-era code
|
||||
/// added an explicit guard for after the first TP shard attempt.
|
||||
pub fn load_column(
|
||||
vb: &VarBuilder,
|
||||
in_features: usize,
|
||||
out_features: usize,
|
||||
has_bias: bool,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
) -> Result<Self> {
|
||||
let path = vb.prefix();
|
||||
if !out_features.is_multiple_of(world_size as usize) {
|
||||
anyhow::bail!(
|
||||
"column-parallel '{path}': out_features={out_features} \
|
||||
not divisible by world_size={world_size}"
|
||||
);
|
||||
}
|
||||
let shard = out_features / world_size as usize;
|
||||
let start = rank as usize * shard;
|
||||
|
||||
let full_w = vb
|
||||
.get((out_features, in_features), "weight")
|
||||
.with_context(|| format!("load weight for column-parallel '{path}'"))?;
|
||||
let weight = full_w
|
||||
.narrow(0, start, shard)
|
||||
.with_context(|| format!("narrow weight rows for column-parallel '{path}'"))?
|
||||
.contiguous()
|
||||
.with_context(|| format!("contiguous weight for column-parallel '{path}'"))?;
|
||||
// Drop the full tensor as soon as we have the shard so peak
|
||||
// host RAM during load tracks shard-size, not full-size, once
|
||||
// all narrows complete (Rust's drop semantics handle this
|
||||
// because `full_w` goes out of scope here).
|
||||
drop(full_w);
|
||||
|
||||
let bias = if has_bias {
|
||||
let full_b = vb
|
||||
.get(out_features, "bias")
|
||||
.with_context(|| format!("load bias for column-parallel '{path}'"))?;
|
||||
let b = full_b
|
||||
.narrow(0, start, shard)
|
||||
.with_context(|| format!("narrow bias for column-parallel '{path}'"))?
|
||||
.contiguous()
|
||||
.with_context(|| format!("contiguous bias for column-parallel '{path}'"))?;
|
||||
Some(b)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
inner: Linear::new(weight, bias),
|
||||
kind: ShardKind::Column,
|
||||
rank,
|
||||
world_size,
|
||||
logical_out_features: out_features,
|
||||
logical_in_features: in_features,
|
||||
})
|
||||
}
|
||||
|
||||
/// Load a row-parallel slice from a `VarBuilder`. Reads the full
|
||||
/// weight and narrows on dim 1 to the rank's column slice. The
|
||||
/// bias, if any, lives **only on rank 0** — every other rank
|
||||
/// holds `None`. This keeps the post-`all_reduce` semantics
|
||||
/// correct: each rank contributes its partial sum without the
|
||||
/// bias, then rank 0's bias (added in `forward_with_comm`) lands
|
||||
/// on the result exactly once.
|
||||
pub fn load_row(
|
||||
vb: &VarBuilder,
|
||||
in_features: usize,
|
||||
out_features: usize,
|
||||
has_bias: bool,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
) -> Result<Self> {
|
||||
let path = vb.prefix();
|
||||
if !in_features.is_multiple_of(world_size as usize) {
|
||||
anyhow::bail!(
|
||||
"row-parallel '{path}': in_features={in_features} \
|
||||
not divisible by world_size={world_size}"
|
||||
);
|
||||
}
|
||||
let shard = in_features / world_size as usize;
|
||||
let start = rank as usize * shard;
|
||||
|
||||
let full_w = vb
|
||||
.get((out_features, in_features), "weight")
|
||||
.with_context(|| format!("load weight for row-parallel '{path}'"))?;
|
||||
let weight = full_w
|
||||
.narrow(1, start, shard)
|
||||
.with_context(|| format!("narrow weight cols for row-parallel '{path}'"))?
|
||||
.contiguous()
|
||||
.with_context(|| format!("contiguous weight for row-parallel '{path}'"))?;
|
||||
drop(full_w);
|
||||
|
||||
let bias = if has_bias && rank == 0 {
|
||||
let b = vb
|
||||
.get(out_features, "bias")
|
||||
.with_context(|| format!("load bias for row-parallel '{path}'"))?;
|
||||
Some(b)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
inner: Linear::new(weight, bias),
|
||||
kind: ShardKind::Row,
|
||||
rank,
|
||||
world_size,
|
||||
logical_out_features: out_features,
|
||||
logical_in_features: in_features,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn kind(&self) -> ShardKind {
|
||||
self.kind
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> u32 {
|
||||
self.rank
|
||||
}
|
||||
|
||||
pub fn world_size(&self) -> u32 {
|
||||
self.world_size
|
||||
}
|
||||
|
||||
pub fn logical_in_features(&self) -> usize {
|
||||
self.logical_in_features
|
||||
}
|
||||
|
||||
pub fn logical_out_features(&self) -> usize {
|
||||
self.logical_out_features
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ShardedLinear {
|
||||
/// Local matmul only. For `Row`-parallel layers, the output is a
|
||||
/// *partial sum* — call [`Self::forward_with_comm`] to get the
|
||||
/// reduced result. Implementing `Module` lets a `ShardedLinear`
|
||||
/// be drop-in for any `Module`-shaped consumer that doesn't need
|
||||
/// the reduce step (column-parallel layers; tests).
|
||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
impl ShardedLinear {
|
||||
/// Forward pass that issues an `all_reduce(Sum)` for row-parallel
|
||||
/// layers. Column-parallel layers just delegate to the local
|
||||
/// matmul (their output is naturally sharded; the next consumer
|
||||
/// will either gather or accept the shard).
|
||||
pub fn forward_with_comm(&self, x: &Tensor, comm: &cudarc::nccl::Comm) -> Result<Tensor> {
|
||||
let local = self
|
||||
.inner
|
||||
.forward(x)
|
||||
.map_err(|e| anyhow::anyhow!("local matmul: {e}"))?;
|
||||
match self.kind {
|
||||
ShardKind::Column => Ok(local),
|
||||
ShardKind::Row => {
|
||||
// TODO Stage 7b-iii: wrap `local`'s CudaSlice with a
|
||||
// matching output buffer, call comm.all_reduce(Sum),
|
||||
// return the result. The cudarc::nccl all_reduce
|
||||
// signature takes `&S: DevicePtr<T>` + `&mut R: DevicePtrMut<T>`,
|
||||
// both backed by `CudaSlice<T>`. candle stores its
|
||||
// Tensor data behind its own slab — extracting the
|
||||
// underlying CudaSlice safely is a separate piece of
|
||||
// plumbing best landed alongside the model assembly,
|
||||
// so this body is a placeholder.
|
||||
let _ = comm;
|
||||
anyhow::bail!(
|
||||
"ShardedLinear::forward_with_comm row-parallel reduce \
|
||||
lands in Stage 7b-iii alongside the model assembly; \
|
||||
7b-ii ships only the local matmul"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use candle_nn::var_builder::VarBuilderArgs;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Build a VarBuilder over an in-memory map of tensors. Used by
|
||||
/// the tests to fake a safetensors source without touching disk.
|
||||
fn vb_from_map(tensors: HashMap<String, Tensor>, device: &Device) -> VarBuilder<'static> {
|
||||
VarBuilderArgs::from_tensors(tensors, DType::F32, device)
|
||||
}
|
||||
|
||||
/// World_size=2 column-parallel split of a 4x3 weight. Each rank's
|
||||
/// local matmul on the same input should be 2 rows of the
|
||||
/// reference (full) matmul.
|
||||
#[test]
|
||||
fn column_parallel_shards_output_correctly() {
|
||||
let device = Device::Cpu;
|
||||
// weight (out=4, in=3): rows are easy to identify by value.
|
||||
let w = Tensor::from_slice(
|
||||
&[
|
||||
1f32, 2., 3., // row 0
|
||||
4., 5., 6., // row 1
|
||||
7., 8., 9., // row 2
|
||||
10., 11., 12., // row 3
|
||||
],
|
||||
(4, 3),
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
let mut tensors = HashMap::new();
|
||||
tensors.insert("foo.weight".into(), w.clone());
|
||||
let vb_root = vb_from_map(tensors, &device);
|
||||
let vb_foo = vb_root.pp("foo");
|
||||
|
||||
// rank 0 of world_size 2 gets rows 0..2.
|
||||
let r0 = ShardedLinear::load_column(&vb_foo, 3, 4, false, 0, 2).unwrap();
|
||||
// rank 1 gets rows 2..4.
|
||||
let r1 = ShardedLinear::load_column(&vb_foo, 3, 4, false, 1, 2).unwrap();
|
||||
|
||||
let x = Tensor::from_slice(&[1f32, 0., 0.], (1, 3), &device).unwrap();
|
||||
let y0 = r0.forward(&x).unwrap().to_vec2::<f32>().unwrap();
|
||||
let y1 = r1.forward(&x).unwrap().to_vec2::<f32>().unwrap();
|
||||
// Full reference: x @ w.T → [1, 4, 7, 10]. Rank 0 owns [1, 4],
|
||||
// rank 1 owns [7, 10].
|
||||
assert_eq!(y0, vec![vec![1.0, 4.0]]);
|
||||
assert_eq!(y1, vec![vec![7.0, 10.0]]);
|
||||
}
|
||||
|
||||
/// World_size=2 row-parallel split of a 4x4 weight. Each rank's
|
||||
/// local matmul on its half of the input should be a partial sum;
|
||||
/// summing the two partials should equal the unsharded reference.
|
||||
#[test]
|
||||
fn row_parallel_partials_sum_to_full() {
|
||||
let device = Device::Cpu;
|
||||
// weight (out=4, in=4): use distinct values per column so the
|
||||
// partial sums are obviously different.
|
||||
let w = Tensor::from_slice(
|
||||
&[
|
||||
1f32, 2., 3., 4., // row 0
|
||||
5., 6., 7., 8., // row 1
|
||||
9., 10., 11., 12., // row 2
|
||||
13., 14., 15., 16., // row 3
|
||||
],
|
||||
(4, 4),
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
let mut tensors = HashMap::new();
|
||||
tensors.insert("bar.weight".into(), w.clone());
|
||||
let vb_root = vb_from_map(tensors, &device);
|
||||
let vb_bar = vb_root.pp("bar");
|
||||
|
||||
let r0 = ShardedLinear::load_row(&vb_bar, 4, 4, false, 0, 2).unwrap();
|
||||
let r1 = ShardedLinear::load_row(&vb_bar, 4, 4, false, 1, 2).unwrap();
|
||||
|
||||
// x split: rank 0 takes x[..2], rank 1 takes x[2..].
|
||||
let x_full = Tensor::from_slice(&[1f32, 1., 1., 1.], (1, 4), &device).unwrap();
|
||||
let x0 = x_full.narrow(1, 0, 2).unwrap();
|
||||
let x1 = x_full.narrow(1, 2, 2).unwrap();
|
||||
|
||||
let y0 = r0.forward(&x0).unwrap();
|
||||
let y1 = r1.forward(&x1).unwrap();
|
||||
let summed = (y0 + y1).unwrap().to_vec2::<f32>().unwrap();
|
||||
|
||||
// Reference: x_full @ w.T = [1+2+3+4, 5+6+7+8, 9+10+11+12, 13+14+15+16]
|
||||
// = [10, 26, 42, 58].
|
||||
assert_eq!(summed, vec![vec![10.0, 26.0, 42.0, 58.0]]);
|
||||
}
|
||||
|
||||
/// Row-parallel bias lives only on rank 0; other ranks have None.
|
||||
/// (Verifies the rank-0-only bias contract.)
|
||||
#[test]
|
||||
fn row_parallel_bias_only_on_rank_zero() {
|
||||
let device = Device::Cpu;
|
||||
let w = Tensor::zeros((4, 4), DType::F32, &device).unwrap();
|
||||
let b = Tensor::from_slice(&[1f32, 1., 1., 1.], 4, &device).unwrap();
|
||||
let mut tensors = HashMap::new();
|
||||
tensors.insert("baz.weight".into(), w);
|
||||
tensors.insert("baz.bias".into(), b);
|
||||
let vb_root = vb_from_map(tensors, &device);
|
||||
let vb_baz = vb_root.pp("baz");
|
||||
|
||||
let r0 = ShardedLinear::load_row(&vb_baz, 4, 4, true, 0, 2).unwrap();
|
||||
let r1 = ShardedLinear::load_row(&vb_baz, 4, 4, true, 1, 2).unwrap();
|
||||
|
||||
// We can't introspect the Linear's bias from the public API,
|
||||
// but we can run forward of zero-weight rank 1 and confirm
|
||||
// the output is zero (no bias added on non-zero ranks).
|
||||
let x = Tensor::ones((1, 2), DType::F32, &device).unwrap();
|
||||
let y1 = r1.forward(&x).unwrap().to_vec2::<f32>().unwrap();
|
||||
assert_eq!(y1, vec![vec![0.0, 0.0, 0.0, 0.0]]);
|
||||
|
||||
let y0 = r0.forward(&x).unwrap().to_vec2::<f32>().unwrap();
|
||||
// Rank 0 weight is zero but bias is [1,1,1,1] → output should be [1,1,1,1].
|
||||
assert_eq!(y0, vec![vec![1.0, 1.0, 1.0, 1.0]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn column_parallel_rejects_non_divisible_out_features() {
|
||||
let device = Device::Cpu;
|
||||
let w = Tensor::zeros((5, 3), DType::F32, &device).unwrap();
|
||||
let mut tensors = HashMap::new();
|
||||
tensors.insert("nope.weight".into(), w);
|
||||
let vb_root = vb_from_map(tensors, &device);
|
||||
let vb_nope = vb_root.pp("nope");
|
||||
|
||||
let err = ShardedLinear::load_column(&vb_nope, 3, 5, false, 0, 2).unwrap_err();
|
||||
let msg = format!("{err:#}");
|
||||
assert!(
|
||||
msg.contains("not divisible by world_size"),
|
||||
"expected divisibility error, got: {msg}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn row_parallel_rejects_non_divisible_in_features() {
|
||||
let device = Device::Cpu;
|
||||
let w = Tensor::zeros((4, 5), DType::F32, &device).unwrap();
|
||||
let mut tensors = HashMap::new();
|
||||
tensors.insert("nope.weight".into(), w);
|
||||
let vb_root = vb_from_map(tensors, &device);
|
||||
let vb_nope = vb_root.pp("nope");
|
||||
|
||||
let err = ShardedLinear::load_row(&vb_nope, 5, 4, false, 0, 2).unwrap_err();
|
||||
let msg = format!("{err:#}");
|
||||
assert!(
|
||||
msg.contains("not divisible by world_size"),
|
||||
"expected divisibility error, got: {msg}"
|
||||
);
|
||||
}
|
||||
}
|
||||
134
crates/neuron/src/harness/tp/tp_linear.rs
Normal file
134
crates/neuron/src/harness/tp/tp_linear.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
//! Tensor-parallel linear layers built on candle's `ShardedVarBuilder`
|
||||
//! and `Shard` sharding hints.
|
||||
//!
|
||||
//! candle reads only the rank's slice of each weight tensor from
|
||||
//! safetensors via `view.slice(start..stop)` — no full-tensor host
|
||||
//! materialisation. That's a memory-efficiency win over hand-rolled
|
||||
//! "load full + narrow" sharding (which the earlier
|
||||
//! `sharded_linear.rs` exploration demonstrated but didn't pay for).
|
||||
//!
|
||||
//! Two layer types:
|
||||
//!
|
||||
//! - [`ColumnParallelLinear`] — output-sharded; forward is a plain
|
||||
//! local matmul. The downstream consumer either accepts a sharded
|
||||
//! activation (next layer is also column-parallel) or all-gathers.
|
||||
//! - [`RowParallelLinear`] — input-sharded; forward = local matmul
|
||||
//! then `AllReduce` `CustomOp1` to sum partials across ranks.
|
||||
//!
|
||||
//! Both assume **no bias** — every Qwen3-family weight layout we
|
||||
//! actually target (Qwen3, Qwen3-Coder, Qwen3.6 base, etc.) sets
|
||||
//! `attention_bias=false` and the MLP layers are no-bias. Adding bias
|
||||
//! support is mechanical when a future model needs it; the design
|
||||
//! choice would be: column-parallel shards the bias along dim 0;
|
||||
//! row-parallel holds the bias only on rank 0 so the post-`AllReduce`
|
||||
//! sum carries it exactly once.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{Module, Tensor};
|
||||
use candle_nn::Linear;
|
||||
use candle_nn::var_builder::{Shard, ShardedVarBuilder};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use super::all_reduce::AllReduce;
|
||||
|
||||
/// Helper to build a [`Shard`] hint for a given dimension.
|
||||
pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
|
||||
Shard {
|
||||
dim,
|
||||
rank: rank as usize,
|
||||
world_size: world_size as usize,
|
||||
}
|
||||
}
|
||||
|
||||
/// Output-dim sharded linear (column-parallel). Holds a standard
|
||||
/// `candle_nn::Linear` whose `weight` is the rank's slice of the full
|
||||
/// `[out_features, in_features]` tensor along dim 0.
|
||||
pub struct ColumnParallelLinear {
|
||||
inner: Linear,
|
||||
}
|
||||
|
||||
impl ColumnParallelLinear {
|
||||
/// Load this rank's column-parallel slice from a
|
||||
/// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed
|
||||
/// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`).
|
||||
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||
let weight = vb
|
||||
.get_with_hints((), "weight", shard(0, rank, world_size))
|
||||
.with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?;
|
||||
Ok(Self {
|
||||
inner: Linear::new(weight, None),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ColumnParallelLinear {
|
||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
/// Input-dim sharded linear (row-parallel).
|
||||
///
|
||||
/// Holds a sharded `Linear` plus an `AllReduce` op the forward chains
|
||||
/// after the local matmul to recover the full activation.
|
||||
pub struct RowParallelLinear {
|
||||
inner: Linear,
|
||||
#[cfg(feature = "cuda")]
|
||||
all_reduce: AllReduce,
|
||||
/// Whether the AllReduce should run. Column-parallel ↔ row-parallel
|
||||
/// is a pair: the column output is sharded, the row input is
|
||||
/// sharded, and the AllReduce gives back the full output. For
|
||||
/// `world_size = 1` the AllReduce is a no-op so we skip it.
|
||||
needs_reduce: bool,
|
||||
}
|
||||
|
||||
impl RowParallelLinear {
|
||||
/// Load this rank's row-parallel slice from a `ShardedVarBuilder`.
|
||||
///
|
||||
/// Under `cuda`, `comm` is the NCCL communicator the row-parallel
|
||||
/// `AllReduce` runs against. On CPU builds the parameter is
|
||||
/// elided — forward returns the partial sum, which is the *wrong*
|
||||
/// answer for inference but lets us compile-check the model.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn load(
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: std::sync::Arc<cudarc::nccl::Comm>,
|
||||
) -> Result<Self> {
|
||||
let weight = vb
|
||||
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||
Ok(Self {
|
||||
inner: Linear::new(weight, None),
|
||||
all_reduce: AllReduce::new(comm),
|
||||
needs_reduce: world_size > 1,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||
let weight = vb
|
||||
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||
Ok(Self {
|
||||
inner: Linear::new(weight, None),
|
||||
needs_reduce: world_size > 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RowParallelLinear {
|
||||
/// Local matmul followed by an `AllReduce` (when `cuda` and
|
||||
/// `world_size > 1`). On CPU or single-rank, returns the partial
|
||||
/// output directly — which is *only* correct for `world_size == 1`.
|
||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
let local = self.inner.forward(x)?;
|
||||
#[cfg(feature = "cuda")]
|
||||
if self.needs_reduce {
|
||||
return local.apply_op1_no_bwd(&self.all_reduce);
|
||||
}
|
||||
let _ = self.needs_reduce;
|
||||
Ok(local)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user