Stage 7b-iii (1/2): AllReduce CustomOp + ShardedVarBuilder-backed TP linears
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 38s
CI / Clippy (push) Successful in 2m16s
build-prerelease / Build neuron-blackwell (push) Failing after 3m19s
CI / Test (push) Successful in 4m26s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m22s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Failing after 4m58s
build-prerelease / Build neuron-ada (push) Failing after 4m53s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 38s
CI / Clippy (push) Successful in 2m16s
build-prerelease / Build neuron-blackwell (push) Failing after 3m19s
CI / Test (push) Successful in 4m26s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m22s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Failing after 4m58s
build-prerelease / Build neuron-ada (push) Failing after 4m53s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Ports the canonical
candle-examples/examples/llama_multiprocess/model.rs pattern into
the harness. Two new files, one deletion:
- harness/tp/all_reduce.rs — AllReduce wraps Arc<cudarc::nccl::Comm>
and implements candle's CustomOp1 trait. cuda_fwd extracts the
rank's CudaSlice<dtype> from a CudaStorage, asserts the input is
contiguous (a strided activation hitting all_reduce is almost
always a model construction bug), allocates an output CudaSlice
on the same device, calls Comm::all_reduce(Sum), and wraps the
result back as a CudaStorage. Handles BF16, F16, F32. NcclError
surfaces via {e:?} (no Display impl in cudarc 0.19.x). Send/Sync
hand-impl'd with the same NCCL-thread-safety caveat candle's
example documents.
- harness/tp/tp_linear.rs — ColumnParallelLinear and
RowParallelLinear, both built on candle's ShardedVarBuilder +
Shard hints. `vb.get_with_hints((), "weight", shard(dim, rank, ws))`
reads JUST the rank's slice from the safetensors view; no full-
tensor host materialisation. ColumnParallel.forward is a plain
local matmul (output is naturally sharded). RowParallel.forward =
local matmul + apply_op1_no_bwd(&self.all_reduce). On CPU /
world_size == 1, the AllReduce is skipped and the partial output
is returned as-is. Both layers are no-bias — every Qwen3-family
target sets attention_bias=false; bias-aware sharding is a
future-model concern.
- Deletes harness/tp/sharded_linear.rs from 7b-ii. That commit's
hand-rolled "load full + narrow" approach was useful exploration
but candle's ShardedVarBuilder does the same work without
materialising the full tensor on host. The 5 unit tests there
verified the slicing math against an unsharded reference; that
math now lives inside candle and is covered by candle's own tests.
Next (7b-iii 2/2): TpQwen3Attention + TpQwen3MLP composing the
column/row pair, then a TpQwen3Model that runs the full forward.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
121
crates/neuron/src/harness/tp/all_reduce.rs
Normal file
121
crates/neuron/src/harness/tp/all_reduce.rs
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
//! `AllReduce` as a candle `CustomOp1` — the bridge between candle's
|
||||||
|
//! `Tensor` graph and `cudarc::nccl::Comm::all_reduce`.
|
||||||
|
//!
|
||||||
|
//! Ported from the canonical
|
||||||
|
//! `candle-examples/examples/llama_multiprocess/model.rs` pattern.
|
||||||
|
//! Row-parallel layers apply this op after their local matmul to sum
|
||||||
|
//! partial outputs across NCCL ranks.
|
||||||
|
//!
|
||||||
|
//! Available only under `--features cuda`; on CPU builds this module
|
||||||
|
//! is empty and row-parallel layers degenerate to local matmul only
|
||||||
|
//! (useful for compile-checking the model code; correctness requires
|
||||||
|
//! cuda).
|
||||||
|
//!
|
||||||
|
//! Thread-safety caveat: NCCL communicators are technically only
|
||||||
|
//! safe to use from a single thread at a time
|
||||||
|
//! (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html).
|
||||||
|
//! We hold the `AllReduce` behind an `Arc<Comm>` and only issue ops
|
||||||
|
//! against it from the dedicated `spawn_blocking` thread the inference
|
||||||
|
//! pipeline already uses for candle's forward passes.
|
||||||
|
|
||||||
|
#![cfg(feature = "cuda")]
|
||||||
|
|
||||||
|
use candle_core::cuda_backend::WrapErr;
|
||||||
|
use candle_core::{CpuStorage, CudaStorage, CustomOp1, DType, Layout, Result, Shape};
|
||||||
|
use cudarc::nccl::{Comm, ReduceOp};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Wraps an NCCL `Comm` so it can be plugged into a candle forward
|
||||||
|
/// graph as a custom op. Each row-parallel layer holds one of these.
|
||||||
|
pub struct AllReduce {
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAFETY: `Comm` contains a raw `ncclComm_t` pointer; NCCL's docs note
|
||||||
|
// that issuing ops against one comm from multiple threads concurrently
|
||||||
|
// is unsafe. We serialise via the single spawn_blocking thread that
|
||||||
|
// drives the model's forward pass. The Send/Sync impl is necessary
|
||||||
|
// because candle's CustomOp1 trait bounds require it; the correctness
|
||||||
|
// invariant is enforced at the call site, not the type level.
|
||||||
|
unsafe impl Send for AllReduce {}
|
||||||
|
unsafe impl Sync for AllReduce {}
|
||||||
|
|
||||||
|
impl AllReduce {
|
||||||
|
pub fn new(comm: Arc<Comm>) -> Self {
|
||||||
|
Self { comm }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn comm(&self) -> &Arc<Comm> {
|
||||||
|
&self.comm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CustomOp1 for AllReduce {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"neuron.tp.all_reduce"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||||
|
candle_core::bail!("AllReduce custom-op invoked on CPU storage; TP requires CUDA")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cuda_fwd(&self, s: &CudaStorage, l: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||||
|
use cudarc::driver::DeviceSlice;
|
||||||
|
|
||||||
|
// Reject non-contiguous inputs explicitly — copying them
|
||||||
|
// server-side would mask shape bugs (a TP layer feeding a
|
||||||
|
// strided activation into all_reduce is almost certainly a
|
||||||
|
// model construction error).
|
||||||
|
fn require_contiguous<T: cudarc::driver::DeviceRepr>(
|
||||||
|
slice: &cudarc::driver::CudaSlice<T>,
|
||||||
|
l: &Layout,
|
||||||
|
) -> Result<()> {
|
||||||
|
match l.contiguous_offsets() {
|
||||||
|
Some((0, n)) if n == slice.len() => Ok(()),
|
||||||
|
_ => candle_core::bail!(
|
||||||
|
"AllReduce input is non-contiguous: layout={:?}, slice_len={}",
|
||||||
|
l,
|
||||||
|
slice.len()
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let elem_count = l.shape().elem_count();
|
||||||
|
let dev = s.device().clone();
|
||||||
|
|
||||||
|
let out = match s.dtype() {
|
||||||
|
DType::BF16 => {
|
||||||
|
let src = s.as_cuda_slice::<bf16>()?;
|
||||||
|
require_contiguous(src, l)?;
|
||||||
|
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce bf16: {e:?}")))?;
|
||||||
|
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let src = s.as_cuda_slice::<f16>()?;
|
||||||
|
require_contiguous(src, l)?;
|
||||||
|
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f16: {e:?}")))?;
|
||||||
|
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let src = s.as_cuda_slice::<f32>()?;
|
||||||
|
require_contiguous(src, l)?;
|
||||||
|
let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f32: {e:?}")))?;
|
||||||
|
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
dtype => candle_core::bail!(
|
||||||
|
"AllReduce: unsupported dtype {dtype:?}; TP path expects bf16/f16/f32"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
Ok((out, l.shape().clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -17,9 +17,10 @@
|
|||||||
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
|
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
|
||||||
//! - **7c:** crash detection, streaming SSE, graceful unload.
|
//! - **7c:** crash detection, streaming SSE, graceful unload.
|
||||||
|
|
||||||
|
pub mod all_reduce;
|
||||||
pub mod nccl_state;
|
pub mod nccl_state;
|
||||||
pub mod rpc;
|
pub mod rpc;
|
||||||
pub mod sharded_linear;
|
pub mod tp_linear;
|
||||||
pub mod worker;
|
pub mod worker;
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
|
|||||||
@@ -1,405 +0,0 @@
|
|||||||
//! Tensor-parallel linear layers over `candle_nn::Linear`.
|
|
||||||
//!
|
|
||||||
//! Two sharding strategies, both following the Megatron-LM convention
|
|
||||||
//! that's also what mistral.rs uses for vanilla Qwen3:
|
|
||||||
//!
|
|
||||||
//! - [`ColumnParallelLinear`] — splits the **output** dimension. Each
|
|
||||||
//! rank holds `out_features / world_size` rows of the weight matrix.
|
|
||||||
//! The forward pass is a plain local matmul; the output is *sharded*
|
|
||||||
//! (each rank produces a slice of the output vector). Used for
|
|
||||||
//! `q_proj` / `k_proj` / `v_proj` (sharding by head) and the FFN's
|
|
||||||
//! `gate_proj` / `up_proj`.
|
|
||||||
//!
|
|
||||||
//! - [`RowParallelLinear`] — splits the **input** dimension. Each
|
|
||||||
//! rank holds `in_features / world_size` columns of the weight
|
|
||||||
//! matrix and consumes a sharded input from upstream. Each rank's
|
|
||||||
//! local matmul produces a *partial* output; an `all_reduce(Sum)`
|
|
||||||
//! across ranks recovers the full activation. Used for `o_proj`
|
|
||||||
//! (after attention) and `down_proj` (after the FFN).
|
|
||||||
//!
|
|
||||||
//! Stage 7b-ii (this commit): the layers, sharded loading, local
|
|
||||||
//! forward. The `all_reduce` collective lives in `forward_with_comm`
|
|
||||||
//! and is wired up in 7b-iii when the full TP-aware Qwen3 model is
|
|
||||||
//! assembled with an NCCL Comm in scope. Tests here exercise only
|
|
||||||
//! the local (no-NCCL) math against an unsharded reference.
|
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
|
||||||
use candle_core::{Module, Tensor};
|
|
||||||
use candle_nn::{Linear, VarBuilder};
|
|
||||||
|
|
||||||
/// Direction of the parallelism split — selects which axis of the
|
|
||||||
/// weight matrix the rank's local slice is taken from.
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub enum ShardKind {
|
|
||||||
/// Split the output dimension: rank `r` holds rows
|
|
||||||
/// `[r * out/N .. (r+1) * out/N]` of the weight matrix. The
|
|
||||||
/// downstream consumer either accepts a sharded activation
|
|
||||||
/// (the next layer is also column-parallel) or merges via
|
|
||||||
/// all-gather.
|
|
||||||
Column,
|
|
||||||
/// Split the input dimension: rank `r` holds columns
|
|
||||||
/// `[r * in/N .. (r+1) * in/N]`. The forward pass produces a
|
|
||||||
/// partial output; an `all_reduce(Sum)` across ranks yields the
|
|
||||||
/// full activation.
|
|
||||||
Row,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A linear layer whose weights have been sharded across NCCL ranks.
|
|
||||||
///
|
|
||||||
/// Holds a standard `candle_nn::Linear` constructed from the local
|
|
||||||
/// slice. The collective op (only meaningful for `Row`) is invoked
|
|
||||||
/// by [`forward_with_comm`] — the trait `Module::forward` does just
|
|
||||||
/// the local matmul, so callers that want correct semantics on a
|
|
||||||
/// Row-parallel layer must drive the collective themselves.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ShardedLinear {
|
|
||||||
inner: Linear,
|
|
||||||
kind: ShardKind,
|
|
||||||
rank: u32,
|
|
||||||
world_size: u32,
|
|
||||||
/// Captured for diagnostics ("rank 3 layer says X but should say Y").
|
|
||||||
/// `out_features` reflects the **logical** size (pre-shard) so the
|
|
||||||
/// caller can validate against the model config without doing the
|
|
||||||
/// arithmetic itself.
|
|
||||||
logical_out_features: usize,
|
|
||||||
logical_in_features: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ShardedLinear {
|
|
||||||
/// Load a column-parallel slice from a `VarBuilder`. Reads the
|
|
||||||
/// full weight (and bias, if any) from the safetensors and
|
|
||||||
/// narrows on dim 0 to the rank's slice. The bias is sharded the
|
|
||||||
/// same way (each rank holds its own bias slice).
|
|
||||||
///
|
|
||||||
/// Bails if `out_features` is not divisible by `world_size` — the
|
|
||||||
/// same divisibility precondition mistral.rs's PR #2054-era code
|
|
||||||
/// added an explicit guard for after the first TP shard attempt.
|
|
||||||
pub fn load_column(
|
|
||||||
vb: &VarBuilder,
|
|
||||||
in_features: usize,
|
|
||||||
out_features: usize,
|
|
||||||
has_bias: bool,
|
|
||||||
rank: u32,
|
|
||||||
world_size: u32,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let path = vb.prefix();
|
|
||||||
if !out_features.is_multiple_of(world_size as usize) {
|
|
||||||
anyhow::bail!(
|
|
||||||
"column-parallel '{path}': out_features={out_features} \
|
|
||||||
not divisible by world_size={world_size}"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let shard = out_features / world_size as usize;
|
|
||||||
let start = rank as usize * shard;
|
|
||||||
|
|
||||||
let full_w = vb
|
|
||||||
.get((out_features, in_features), "weight")
|
|
||||||
.with_context(|| format!("load weight for column-parallel '{path}'"))?;
|
|
||||||
let weight = full_w
|
|
||||||
.narrow(0, start, shard)
|
|
||||||
.with_context(|| format!("narrow weight rows for column-parallel '{path}'"))?
|
|
||||||
.contiguous()
|
|
||||||
.with_context(|| format!("contiguous weight for column-parallel '{path}'"))?;
|
|
||||||
// Drop the full tensor as soon as we have the shard so peak
|
|
||||||
// host RAM during load tracks shard-size, not full-size, once
|
|
||||||
// all narrows complete (Rust's drop semantics handle this
|
|
||||||
// because `full_w` goes out of scope here).
|
|
||||||
drop(full_w);
|
|
||||||
|
|
||||||
let bias = if has_bias {
|
|
||||||
let full_b = vb
|
|
||||||
.get(out_features, "bias")
|
|
||||||
.with_context(|| format!("load bias for column-parallel '{path}'"))?;
|
|
||||||
let b = full_b
|
|
||||||
.narrow(0, start, shard)
|
|
||||||
.with_context(|| format!("narrow bias for column-parallel '{path}'"))?
|
|
||||||
.contiguous()
|
|
||||||
.with_context(|| format!("contiguous bias for column-parallel '{path}'"))?;
|
|
||||||
Some(b)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
inner: Linear::new(weight, bias),
|
|
||||||
kind: ShardKind::Column,
|
|
||||||
rank,
|
|
||||||
world_size,
|
|
||||||
logical_out_features: out_features,
|
|
||||||
logical_in_features: in_features,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Load a row-parallel slice from a `VarBuilder`. Reads the full
|
|
||||||
/// weight and narrows on dim 1 to the rank's column slice. The
|
|
||||||
/// bias, if any, lives **only on rank 0** — every other rank
|
|
||||||
/// holds `None`. This keeps the post-`all_reduce` semantics
|
|
||||||
/// correct: each rank contributes its partial sum without the
|
|
||||||
/// bias, then rank 0's bias (added in `forward_with_comm`) lands
|
|
||||||
/// on the result exactly once.
|
|
||||||
pub fn load_row(
|
|
||||||
vb: &VarBuilder,
|
|
||||||
in_features: usize,
|
|
||||||
out_features: usize,
|
|
||||||
has_bias: bool,
|
|
||||||
rank: u32,
|
|
||||||
world_size: u32,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let path = vb.prefix();
|
|
||||||
if !in_features.is_multiple_of(world_size as usize) {
|
|
||||||
anyhow::bail!(
|
|
||||||
"row-parallel '{path}': in_features={in_features} \
|
|
||||||
not divisible by world_size={world_size}"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let shard = in_features / world_size as usize;
|
|
||||||
let start = rank as usize * shard;
|
|
||||||
|
|
||||||
let full_w = vb
|
|
||||||
.get((out_features, in_features), "weight")
|
|
||||||
.with_context(|| format!("load weight for row-parallel '{path}'"))?;
|
|
||||||
let weight = full_w
|
|
||||||
.narrow(1, start, shard)
|
|
||||||
.with_context(|| format!("narrow weight cols for row-parallel '{path}'"))?
|
|
||||||
.contiguous()
|
|
||||||
.with_context(|| format!("contiguous weight for row-parallel '{path}'"))?;
|
|
||||||
drop(full_w);
|
|
||||||
|
|
||||||
let bias = if has_bias && rank == 0 {
|
|
||||||
let b = vb
|
|
||||||
.get(out_features, "bias")
|
|
||||||
.with_context(|| format!("load bias for row-parallel '{path}'"))?;
|
|
||||||
Some(b)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
inner: Linear::new(weight, bias),
|
|
||||||
kind: ShardKind::Row,
|
|
||||||
rank,
|
|
||||||
world_size,
|
|
||||||
logical_out_features: out_features,
|
|
||||||
logical_in_features: in_features,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn kind(&self) -> ShardKind {
|
|
||||||
self.kind
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn rank(&self) -> u32 {
|
|
||||||
self.rank
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn world_size(&self) -> u32 {
|
|
||||||
self.world_size
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn logical_in_features(&self) -> usize {
|
|
||||||
self.logical_in_features
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn logical_out_features(&self) -> usize {
|
|
||||||
self.logical_out_features
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for ShardedLinear {
|
|
||||||
/// Local matmul only. For `Row`-parallel layers, the output is a
|
|
||||||
/// *partial sum* — call [`Self::forward_with_comm`] to get the
|
|
||||||
/// reduced result. Implementing `Module` lets a `ShardedLinear`
|
|
||||||
/// be drop-in for any `Module`-shaped consumer that doesn't need
|
|
||||||
/// the reduce step (column-parallel layers; tests).
|
|
||||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
|
||||||
self.inner.forward(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
impl ShardedLinear {
|
|
||||||
/// Forward pass that issues an `all_reduce(Sum)` for row-parallel
|
|
||||||
/// layers. Column-parallel layers just delegate to the local
|
|
||||||
/// matmul (their output is naturally sharded; the next consumer
|
|
||||||
/// will either gather or accept the shard).
|
|
||||||
pub fn forward_with_comm(&self, x: &Tensor, comm: &cudarc::nccl::Comm) -> Result<Tensor> {
|
|
||||||
let local = self
|
|
||||||
.inner
|
|
||||||
.forward(x)
|
|
||||||
.map_err(|e| anyhow::anyhow!("local matmul: {e}"))?;
|
|
||||||
match self.kind {
|
|
||||||
ShardKind::Column => Ok(local),
|
|
||||||
ShardKind::Row => {
|
|
||||||
// TODO Stage 7b-iii: wrap `local`'s CudaSlice with a
|
|
||||||
// matching output buffer, call comm.all_reduce(Sum),
|
|
||||||
// return the result. The cudarc::nccl all_reduce
|
|
||||||
// signature takes `&S: DevicePtr<T>` + `&mut R: DevicePtrMut<T>`,
|
|
||||||
// both backed by `CudaSlice<T>`. candle stores its
|
|
||||||
// Tensor data behind its own slab — extracting the
|
|
||||||
// underlying CudaSlice safely is a separate piece of
|
|
||||||
// plumbing best landed alongside the model assembly,
|
|
||||||
// so this body is a placeholder.
|
|
||||||
let _ = comm;
|
|
||||||
anyhow::bail!(
|
|
||||||
"ShardedLinear::forward_with_comm row-parallel reduce \
|
|
||||||
lands in Stage 7b-iii alongside the model assembly; \
|
|
||||||
7b-ii ships only the local matmul"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use candle_nn::var_builder::VarBuilderArgs;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
/// Build a VarBuilder over an in-memory map of tensors. Used by
|
|
||||||
/// the tests to fake a safetensors source without touching disk.
|
|
||||||
fn vb_from_map(tensors: HashMap<String, Tensor>, device: &Device) -> VarBuilder<'static> {
|
|
||||||
VarBuilderArgs::from_tensors(tensors, DType::F32, device)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// World_size=2 column-parallel split of a 4x3 weight. Each rank's
|
|
||||||
/// local matmul on the same input should be 2 rows of the
|
|
||||||
/// reference (full) matmul.
|
|
||||||
#[test]
|
|
||||||
fn column_parallel_shards_output_correctly() {
|
|
||||||
let device = Device::Cpu;
|
|
||||||
// weight (out=4, in=3): rows are easy to identify by value.
|
|
||||||
let w = Tensor::from_slice(
|
|
||||||
&[
|
|
||||||
1f32, 2., 3., // row 0
|
|
||||||
4., 5., 6., // row 1
|
|
||||||
7., 8., 9., // row 2
|
|
||||||
10., 11., 12., // row 3
|
|
||||||
],
|
|
||||||
(4, 3),
|
|
||||||
&device,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let mut tensors = HashMap::new();
|
|
||||||
tensors.insert("foo.weight".into(), w.clone());
|
|
||||||
let vb_root = vb_from_map(tensors, &device);
|
|
||||||
let vb_foo = vb_root.pp("foo");
|
|
||||||
|
|
||||||
// rank 0 of world_size 2 gets rows 0..2.
|
|
||||||
let r0 = ShardedLinear::load_column(&vb_foo, 3, 4, false, 0, 2).unwrap();
|
|
||||||
// rank 1 gets rows 2..4.
|
|
||||||
let r1 = ShardedLinear::load_column(&vb_foo, 3, 4, false, 1, 2).unwrap();
|
|
||||||
|
|
||||||
let x = Tensor::from_slice(&[1f32, 0., 0.], (1, 3), &device).unwrap();
|
|
||||||
let y0 = r0.forward(&x).unwrap().to_vec2::<f32>().unwrap();
|
|
||||||
let y1 = r1.forward(&x).unwrap().to_vec2::<f32>().unwrap();
|
|
||||||
// Full reference: x @ w.T → [1, 4, 7, 10]. Rank 0 owns [1, 4],
|
|
||||||
// rank 1 owns [7, 10].
|
|
||||||
assert_eq!(y0, vec![vec![1.0, 4.0]]);
|
|
||||||
assert_eq!(y1, vec![vec![7.0, 10.0]]);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// World_size=2 row-parallel split of a 4x4 weight. Each rank's
|
|
||||||
/// local matmul on its half of the input should be a partial sum;
|
|
||||||
/// summing the two partials should equal the unsharded reference.
|
|
||||||
#[test]
|
|
||||||
fn row_parallel_partials_sum_to_full() {
|
|
||||||
let device = Device::Cpu;
|
|
||||||
// weight (out=4, in=4): use distinct values per column so the
|
|
||||||
// partial sums are obviously different.
|
|
||||||
let w = Tensor::from_slice(
|
|
||||||
&[
|
|
||||||
1f32, 2., 3., 4., // row 0
|
|
||||||
5., 6., 7., 8., // row 1
|
|
||||||
9., 10., 11., 12., // row 2
|
|
||||||
13., 14., 15., 16., // row 3
|
|
||||||
],
|
|
||||||
(4, 4),
|
|
||||||
&device,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let mut tensors = HashMap::new();
|
|
||||||
tensors.insert("bar.weight".into(), w.clone());
|
|
||||||
let vb_root = vb_from_map(tensors, &device);
|
|
||||||
let vb_bar = vb_root.pp("bar");
|
|
||||||
|
|
||||||
let r0 = ShardedLinear::load_row(&vb_bar, 4, 4, false, 0, 2).unwrap();
|
|
||||||
let r1 = ShardedLinear::load_row(&vb_bar, 4, 4, false, 1, 2).unwrap();
|
|
||||||
|
|
||||||
// x split: rank 0 takes x[..2], rank 1 takes x[2..].
|
|
||||||
let x_full = Tensor::from_slice(&[1f32, 1., 1., 1.], (1, 4), &device).unwrap();
|
|
||||||
let x0 = x_full.narrow(1, 0, 2).unwrap();
|
|
||||||
let x1 = x_full.narrow(1, 2, 2).unwrap();
|
|
||||||
|
|
||||||
let y0 = r0.forward(&x0).unwrap();
|
|
||||||
let y1 = r1.forward(&x1).unwrap();
|
|
||||||
let summed = (y0 + y1).unwrap().to_vec2::<f32>().unwrap();
|
|
||||||
|
|
||||||
// Reference: x_full @ w.T = [1+2+3+4, 5+6+7+8, 9+10+11+12, 13+14+15+16]
|
|
||||||
// = [10, 26, 42, 58].
|
|
||||||
assert_eq!(summed, vec![vec![10.0, 26.0, 42.0, 58.0]]);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Row-parallel bias lives only on rank 0; other ranks have None.
|
|
||||||
/// (Verifies the rank-0-only bias contract.)
|
|
||||||
#[test]
|
|
||||||
fn row_parallel_bias_only_on_rank_zero() {
|
|
||||||
let device = Device::Cpu;
|
|
||||||
let w = Tensor::zeros((4, 4), DType::F32, &device).unwrap();
|
|
||||||
let b = Tensor::from_slice(&[1f32, 1., 1., 1.], 4, &device).unwrap();
|
|
||||||
let mut tensors = HashMap::new();
|
|
||||||
tensors.insert("baz.weight".into(), w);
|
|
||||||
tensors.insert("baz.bias".into(), b);
|
|
||||||
let vb_root = vb_from_map(tensors, &device);
|
|
||||||
let vb_baz = vb_root.pp("baz");
|
|
||||||
|
|
||||||
let r0 = ShardedLinear::load_row(&vb_baz, 4, 4, true, 0, 2).unwrap();
|
|
||||||
let r1 = ShardedLinear::load_row(&vb_baz, 4, 4, true, 1, 2).unwrap();
|
|
||||||
|
|
||||||
// We can't introspect the Linear's bias from the public API,
|
|
||||||
// but we can run forward of zero-weight rank 1 and confirm
|
|
||||||
// the output is zero (no bias added on non-zero ranks).
|
|
||||||
let x = Tensor::ones((1, 2), DType::F32, &device).unwrap();
|
|
||||||
let y1 = r1.forward(&x).unwrap().to_vec2::<f32>().unwrap();
|
|
||||||
assert_eq!(y1, vec![vec![0.0, 0.0, 0.0, 0.0]]);
|
|
||||||
|
|
||||||
let y0 = r0.forward(&x).unwrap().to_vec2::<f32>().unwrap();
|
|
||||||
// Rank 0 weight is zero but bias is [1,1,1,1] → output should be [1,1,1,1].
|
|
||||||
assert_eq!(y0, vec![vec![1.0, 1.0, 1.0, 1.0]]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn column_parallel_rejects_non_divisible_out_features() {
|
|
||||||
let device = Device::Cpu;
|
|
||||||
let w = Tensor::zeros((5, 3), DType::F32, &device).unwrap();
|
|
||||||
let mut tensors = HashMap::new();
|
|
||||||
tensors.insert("nope.weight".into(), w);
|
|
||||||
let vb_root = vb_from_map(tensors, &device);
|
|
||||||
let vb_nope = vb_root.pp("nope");
|
|
||||||
|
|
||||||
let err = ShardedLinear::load_column(&vb_nope, 3, 5, false, 0, 2).unwrap_err();
|
|
||||||
let msg = format!("{err:#}");
|
|
||||||
assert!(
|
|
||||||
msg.contains("not divisible by world_size"),
|
|
||||||
"expected divisibility error, got: {msg}"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn row_parallel_rejects_non_divisible_in_features() {
|
|
||||||
let device = Device::Cpu;
|
|
||||||
let w = Tensor::zeros((4, 5), DType::F32, &device).unwrap();
|
|
||||||
let mut tensors = HashMap::new();
|
|
||||||
tensors.insert("nope.weight".into(), w);
|
|
||||||
let vb_root = vb_from_map(tensors, &device);
|
|
||||||
let vb_nope = vb_root.pp("nope");
|
|
||||||
|
|
||||||
let err = ShardedLinear::load_row(&vb_nope, 5, 4, false, 0, 2).unwrap_err();
|
|
||||||
let msg = format!("{err:#}");
|
|
||||||
assert!(
|
|
||||||
msg.contains("not divisible by world_size"),
|
|
||||||
"expected divisibility error, got: {msg}"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
134
crates/neuron/src/harness/tp/tp_linear.rs
Normal file
134
crates/neuron/src/harness/tp/tp_linear.rs
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
//! Tensor-parallel linear layers built on candle's `ShardedVarBuilder`
|
||||||
|
//! and `Shard` sharding hints.
|
||||||
|
//!
|
||||||
|
//! candle reads only the rank's slice of each weight tensor from
|
||||||
|
//! safetensors via `view.slice(start..stop)` — no full-tensor host
|
||||||
|
//! materialisation. That's a memory-efficiency win over hand-rolled
|
||||||
|
//! "load full + narrow" sharding (which the earlier
|
||||||
|
//! `sharded_linear.rs` exploration demonstrated but didn't pay for).
|
||||||
|
//!
|
||||||
|
//! Two layer types:
|
||||||
|
//!
|
||||||
|
//! - [`ColumnParallelLinear`] — output-sharded; forward is a plain
|
||||||
|
//! local matmul. The downstream consumer either accepts a sharded
|
||||||
|
//! activation (next layer is also column-parallel) or all-gathers.
|
||||||
|
//! - [`RowParallelLinear`] — input-sharded; forward = local matmul
|
||||||
|
//! then `AllReduce` `CustomOp1` to sum partials across ranks.
|
||||||
|
//!
|
||||||
|
//! Both assume **no bias** — every Qwen3-family weight layout we
|
||||||
|
//! actually target (Qwen3, Qwen3-Coder, Qwen3.6 base, etc.) sets
|
||||||
|
//! `attention_bias=false` and the MLP layers are no-bias. Adding bias
|
||||||
|
//! support is mechanical when a future model needs it; the design
|
||||||
|
//! choice would be: column-parallel shards the bias along dim 0;
|
||||||
|
//! row-parallel holds the bias only on rank 0 so the post-`AllReduce`
|
||||||
|
//! sum carries it exactly once.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::var_builder::{Shard, ShardedVarBuilder};
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use super::all_reduce::AllReduce;
|
||||||
|
|
||||||
|
/// Helper to build a [`Shard`] hint for a given dimension.
|
||||||
|
pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
|
||||||
|
Shard {
|
||||||
|
dim,
|
||||||
|
rank: rank as usize,
|
||||||
|
world_size: world_size as usize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Output-dim sharded linear (column-parallel). Holds a standard
|
||||||
|
/// `candle_nn::Linear` whose `weight` is the rank's slice of the full
|
||||||
|
/// `[out_features, in_features]` tensor along dim 0.
|
||||||
|
pub struct ColumnParallelLinear {
|
||||||
|
inner: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ColumnParallelLinear {
|
||||||
|
/// Load this rank's column-parallel slice from a
|
||||||
|
/// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed
|
||||||
|
/// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`).
|
||||||
|
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get_with_hints((), "weight", shard(0, rank, world_size))
|
||||||
|
.with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
inner: Linear::new(weight, None),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ColumnParallelLinear {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
self.inner.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Input-dim sharded linear (row-parallel).
|
||||||
|
///
|
||||||
|
/// Holds a sharded `Linear` plus an `AllReduce` op the forward chains
|
||||||
|
/// after the local matmul to recover the full activation.
|
||||||
|
pub struct RowParallelLinear {
|
||||||
|
inner: Linear,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
all_reduce: AllReduce,
|
||||||
|
/// Whether the AllReduce should run. Column-parallel ↔ row-parallel
|
||||||
|
/// is a pair: the column output is sharded, the row input is
|
||||||
|
/// sharded, and the AllReduce gives back the full output. For
|
||||||
|
/// `world_size = 1` the AllReduce is a no-op so we skip it.
|
||||||
|
needs_reduce: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RowParallelLinear {
|
||||||
|
/// Load this rank's row-parallel slice from a `ShardedVarBuilder`.
|
||||||
|
///
|
||||||
|
/// Under `cuda`, `comm` is the NCCL communicator the row-parallel
|
||||||
|
/// `AllReduce` runs against. On CPU builds the parameter is
|
||||||
|
/// elided — forward returns the partial sum, which is the *wrong*
|
||||||
|
/// answer for inference but lets us compile-check the model.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: std::sync::Arc<cudarc::nccl::Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||||
|
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
inner: Linear::new(weight, None),
|
||||||
|
all_reduce: AllReduce::new(comm),
|
||||||
|
needs_reduce: world_size > 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||||
|
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
inner: Linear::new(weight, None),
|
||||||
|
needs_reduce: world_size > 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for RowParallelLinear {
|
||||||
|
/// Local matmul followed by an `AllReduce` (when `cuda` and
|
||||||
|
/// `world_size > 1`). On CPU or single-rank, returns the partial
|
||||||
|
/// output directly — which is *only* correct for `world_size == 1`.
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let local = self.inner.forward(x)?;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
if self.needs_reduce {
|
||||||
|
return local.apply_op1_no_bwd(&self.all_reduce);
|
||||||
|
}
|
||||||
|
let _ = self.needs_reduce;
|
||||||
|
Ok(local)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user