Stage 7b-ii: ColumnParallel + RowParallel sharded linear primitives
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / Format (push) Successful in 31s
CI / Clippy (push) Failing after 49s
build-prerelease / Build neuron-blackwell (push) Failing after 3m29s
build-prerelease / Build cortex binary (push) Successful in 4m41s
CI / Test (push) Successful in 5m6s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m20s
build-prerelease / Build neuron-ampere (push) Failing after 5m1s
build-prerelease / Build neuron-ada (push) Failing after 4m53s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped

Adds harness/tp/sharded_linear.rs with ShardedLinear — a Megatron-LM
style sharded wrapper over candle_nn::Linear. Two constructors:

- load_column: splits the output dimension. Each rank holds rows
  [r*out/N .. (r+1)*out/N] of the weight, plus its slice of the bias.
  Forward = local matmul; output is naturally sharded; downstream
  consumer either accepts the shard (next layer is column-parallel)
  or merges via all-gather later.
- load_row: splits the input dimension. Each rank holds cols
  [r*in/N .. (r+1)*in/N] of the weight; bias lives only on rank 0
  so the post-all_reduce sum carries it exactly once. Forward
  produces a partial output that the caller reduces via NCCL.

Both constructors bail with a clear error when divisibility doesn't
hold — the precondition mistral.rs's first qwen3-next-tp commit
made explicit. The path included in the error is the VarBuilder
prefix, so the operator sees exactly which projection failed
("column-parallel 'model.layers.0.self_attn.q_proj': out_features=...").

5 unit tests on CPU verify the math against an unsharded reference:
- column shard produces the expected slice of the full matmul
- row partials sum to the unsharded result
- row bias appears only on rank 0
- divisibility violations bail (column + row)

forward_with_comm() is stubbed for row-parallel (CUDA-only) — wiring
the actual cudarc::nccl all_reduce against candle's Tensor lands in
7b-iii alongside the model assembly, where the model holds the Comm
in scope. ColumnParallel's forward_with_comm just delegates to the
local matmul (no collective needed).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-19 17:07:19 +03:00
parent 05e15f3597
commit 93421f48e2
2 changed files with 406 additions and 0 deletions

View File

@@ -19,6 +19,7 @@
pub mod nccl_state;
pub mod rpc;
pub mod sharded_linear;
pub mod worker;
use anyhow::{Context, Result};

View File

@@ -0,0 +1,405 @@
//! Tensor-parallel linear layers over `candle_nn::Linear`.
//!
//! Two sharding strategies, both following the Megatron-LM convention
//! that's also what mistral.rs uses for vanilla Qwen3:
//!
//! - [`ColumnParallelLinear`] — splits the **output** dimension. Each
//! rank holds `out_features / world_size` rows of the weight matrix.
//! The forward pass is a plain local matmul; the output is *sharded*
//! (each rank produces a slice of the output vector). Used for
//! `q_proj` / `k_proj` / `v_proj` (sharding by head) and the FFN's
//! `gate_proj` / `up_proj`.
//!
//! - [`RowParallelLinear`] — splits the **input** dimension. Each
//! rank holds `in_features / world_size` columns of the weight
//! matrix and consumes a sharded input from upstream. Each rank's
//! local matmul produces a *partial* output; an `all_reduce(Sum)`
//! across ranks recovers the full activation. Used for `o_proj`
//! (after attention) and `down_proj` (after the FFN).
//!
//! Stage 7b-ii (this commit): the layers, sharded loading, local
//! forward. The `all_reduce` collective lives in `forward_with_comm`
//! and is wired up in 7b-iii when the full TP-aware Qwen3 model is
//! assembled with an NCCL Comm in scope. Tests here exercise only
//! the local (no-NCCL) math against an unsharded reference.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::{Linear, VarBuilder};
/// Direction of the parallelism split — selects which axis of the
/// weight matrix the rank's local slice is taken from.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShardKind {
/// Split the output dimension: rank `r` holds rows
/// `[r * out/N .. (r+1) * out/N]` of the weight matrix. The
/// downstream consumer either accepts a sharded activation
/// (the next layer is also column-parallel) or merges via
/// all-gather.
Column,
/// Split the input dimension: rank `r` holds columns
/// `[r * in/N .. (r+1) * in/N]`. The forward pass produces a
/// partial output; an `all_reduce(Sum)` across ranks yields the
/// full activation.
Row,
}
/// A linear layer whose weights have been sharded across NCCL ranks.
///
/// Holds a standard `candle_nn::Linear` constructed from the local
/// slice. The collective op (only meaningful for `Row`) is invoked
/// by [`forward_with_comm`] — the trait `Module::forward` does just
/// the local matmul, so callers that want correct semantics on a
/// Row-parallel layer must drive the collective themselves.
#[derive(Debug)]
pub struct ShardedLinear {
inner: Linear,
kind: ShardKind,
rank: u32,
world_size: u32,
/// Captured for diagnostics ("rank 3 layer says X but should say Y").
/// `out_features` reflects the **logical** size (pre-shard) so the
/// caller can validate against the model config without doing the
/// arithmetic itself.
logical_out_features: usize,
logical_in_features: usize,
}
impl ShardedLinear {
/// Load a column-parallel slice from a `VarBuilder`. Reads the
/// full weight (and bias, if any) from the safetensors and
/// narrows on dim 0 to the rank's slice. The bias is sharded the
/// same way (each rank holds its own bias slice).
///
/// Bails if `out_features` is not divisible by `world_size` — the
/// same divisibility precondition mistral.rs's PR #2054-era code
/// added an explicit guard for after the first TP shard attempt.
pub fn load_column(
vb: &VarBuilder,
in_features: usize,
out_features: usize,
has_bias: bool,
rank: u32,
world_size: u32,
) -> Result<Self> {
let path = vb.prefix();
if !out_features.is_multiple_of(world_size as usize) {
anyhow::bail!(
"column-parallel '{path}': out_features={out_features} \
not divisible by world_size={world_size}"
);
}
let shard = out_features / world_size as usize;
let start = rank as usize * shard;
let full_w = vb
.get((out_features, in_features), "weight")
.with_context(|| format!("load weight for column-parallel '{path}'"))?;
let weight = full_w
.narrow(0, start, shard)
.with_context(|| format!("narrow weight rows for column-parallel '{path}'"))?
.contiguous()
.with_context(|| format!("contiguous weight for column-parallel '{path}'"))?;
// Drop the full tensor as soon as we have the shard so peak
// host RAM during load tracks shard-size, not full-size, once
// all narrows complete (Rust's drop semantics handle this
// because `full_w` goes out of scope here).
drop(full_w);
let bias = if has_bias {
let full_b = vb
.get(out_features, "bias")
.with_context(|| format!("load bias for column-parallel '{path}'"))?;
let b = full_b
.narrow(0, start, shard)
.with_context(|| format!("narrow bias for column-parallel '{path}'"))?
.contiguous()
.with_context(|| format!("contiguous bias for column-parallel '{path}'"))?;
Some(b)
} else {
None
};
Ok(Self {
inner: Linear::new(weight, bias),
kind: ShardKind::Column,
rank,
world_size,
logical_out_features: out_features,
logical_in_features: in_features,
})
}
/// Load a row-parallel slice from a `VarBuilder`. Reads the full
/// weight and narrows on dim 1 to the rank's column slice. The
/// bias, if any, lives **only on rank 0** — every other rank
/// holds `None`. This keeps the post-`all_reduce` semantics
/// correct: each rank contributes its partial sum without the
/// bias, then rank 0's bias (added in `forward_with_comm`) lands
/// on the result exactly once.
pub fn load_row(
vb: &VarBuilder,
in_features: usize,
out_features: usize,
has_bias: bool,
rank: u32,
world_size: u32,
) -> Result<Self> {
let path = vb.prefix();
if !in_features.is_multiple_of(world_size as usize) {
anyhow::bail!(
"row-parallel '{path}': in_features={in_features} \
not divisible by world_size={world_size}"
);
}
let shard = in_features / world_size as usize;
let start = rank as usize * shard;
let full_w = vb
.get((out_features, in_features), "weight")
.with_context(|| format!("load weight for row-parallel '{path}'"))?;
let weight = full_w
.narrow(1, start, shard)
.with_context(|| format!("narrow weight cols for row-parallel '{path}'"))?
.contiguous()
.with_context(|| format!("contiguous weight for row-parallel '{path}'"))?;
drop(full_w);
let bias = if has_bias && rank == 0 {
let b = vb
.get(out_features, "bias")
.with_context(|| format!("load bias for row-parallel '{path}'"))?;
Some(b)
} else {
None
};
Ok(Self {
inner: Linear::new(weight, bias),
kind: ShardKind::Row,
rank,
world_size,
logical_out_features: out_features,
logical_in_features: in_features,
})
}
pub fn kind(&self) -> ShardKind {
self.kind
}
pub fn rank(&self) -> u32 {
self.rank
}
pub fn world_size(&self) -> u32 {
self.world_size
}
pub fn logical_in_features(&self) -> usize {
self.logical_in_features
}
pub fn logical_out_features(&self) -> usize {
self.logical_out_features
}
}
impl Module for ShardedLinear {
/// Local matmul only. For `Row`-parallel layers, the output is a
/// *partial sum* — call [`Self::forward_with_comm`] to get the
/// reduced result. Implementing `Module` lets a `ShardedLinear`
/// be drop-in for any `Module`-shaped consumer that doesn't need
/// the reduce step (column-parallel layers; tests).
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
self.inner.forward(x)
}
}
#[cfg(feature = "cuda")]
impl ShardedLinear {
/// Forward pass that issues an `all_reduce(Sum)` for row-parallel
/// layers. Column-parallel layers just delegate to the local
/// matmul (their output is naturally sharded; the next consumer
/// will either gather or accept the shard).
pub fn forward_with_comm(&self, x: &Tensor, comm: &cudarc::nccl::Comm) -> Result<Tensor> {
let local = self
.inner
.forward(x)
.map_err(|e| anyhow::anyhow!("local matmul: {e}"))?;
match self.kind {
ShardKind::Column => Ok(local),
ShardKind::Row => {
// TODO Stage 7b-iii: wrap `local`'s CudaSlice with a
// matching output buffer, call comm.all_reduce(Sum),
// return the result. The cudarc::nccl all_reduce
// signature takes `&S: DevicePtr<T>` + `&mut R: DevicePtrMut<T>`,
// both backed by `CudaSlice<T>`. candle stores its
// Tensor data behind its own slab — extracting the
// underlying CudaSlice safely is a separate piece of
// plumbing best landed alongside the model assembly,
// so this body is a placeholder.
let _ = comm;
anyhow::bail!(
"ShardedLinear::forward_with_comm row-parallel reduce \
lands in Stage 7b-iii alongside the model assembly; \
7b-ii ships only the local matmul"
);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device, Tensor};
use candle_nn::var_builder::VarBuilderArgs;
use std::collections::HashMap;
/// Build a VarBuilder over an in-memory map of tensors. Used by
/// the tests to fake a safetensors source without touching disk.
fn vb_from_map(tensors: HashMap<String, Tensor>, device: &Device) -> VarBuilder<'static> {
VarBuilderArgs::from_tensors(tensors, DType::F32, device)
}
/// World_size=2 column-parallel split of a 4x3 weight. Each rank's
/// local matmul on the same input should be 2 rows of the
/// reference (full) matmul.
#[test]
fn column_parallel_shards_output_correctly() {
let device = Device::Cpu;
// weight (out=4, in=3): rows are easy to identify by value.
let w = Tensor::from_slice(
&[
1f32, 2., 3., // row 0
4., 5., 6., // row 1
7., 8., 9., // row 2
10., 11., 12., // row 3
],
(4, 3),
&device,
)
.unwrap();
let mut tensors = HashMap::new();
tensors.insert("foo.weight".into(), w.clone());
let vb_root = vb_from_map(tensors, &device);
let vb_foo = vb_root.pp("foo");
// rank 0 of world_size 2 gets rows 0..2.
let r0 = ShardedLinear::load_column(&vb_foo, 3, 4, false, 0, 2).unwrap();
// rank 1 gets rows 2..4.
let r1 = ShardedLinear::load_column(&vb_foo, 3, 4, false, 1, 2).unwrap();
let x = Tensor::from_slice(&[1f32, 0., 0.], (1, 3), &device).unwrap();
let y0 = r0.forward(&x).unwrap().to_vec2::<f32>().unwrap();
let y1 = r1.forward(&x).unwrap().to_vec2::<f32>().unwrap();
// Full reference: x @ w.T → [1, 4, 7, 10]. Rank 0 owns [1, 4],
// rank 1 owns [7, 10].
assert_eq!(y0, vec![vec![1.0, 4.0]]);
assert_eq!(y1, vec![vec![7.0, 10.0]]);
}
/// World_size=2 row-parallel split of a 4x4 weight. Each rank's
/// local matmul on its half of the input should be a partial sum;
/// summing the two partials should equal the unsharded reference.
#[test]
fn row_parallel_partials_sum_to_full() {
let device = Device::Cpu;
// weight (out=4, in=4): use distinct values per column so the
// partial sums are obviously different.
let w = Tensor::from_slice(
&[
1f32, 2., 3., 4., // row 0
5., 6., 7., 8., // row 1
9., 10., 11., 12., // row 2
13., 14., 15., 16., // row 3
],
(4, 4),
&device,
)
.unwrap();
let mut tensors = HashMap::new();
tensors.insert("bar.weight".into(), w.clone());
let vb_root = vb_from_map(tensors, &device);
let vb_bar = vb_root.pp("bar");
let r0 = ShardedLinear::load_row(&vb_bar, 4, 4, false, 0, 2).unwrap();
let r1 = ShardedLinear::load_row(&vb_bar, 4, 4, false, 1, 2).unwrap();
// x split: rank 0 takes x[..2], rank 1 takes x[2..].
let x_full = Tensor::from_slice(&[1f32, 1., 1., 1.], (1, 4), &device).unwrap();
let x0 = x_full.narrow(1, 0, 2).unwrap();
let x1 = x_full.narrow(1, 2, 2).unwrap();
let y0 = r0.forward(&x0).unwrap();
let y1 = r1.forward(&x1).unwrap();
let summed = (y0 + y1).unwrap().to_vec2::<f32>().unwrap();
// Reference: x_full @ w.T = [1+2+3+4, 5+6+7+8, 9+10+11+12, 13+14+15+16]
// = [10, 26, 42, 58].
assert_eq!(summed, vec![vec![10.0, 26.0, 42.0, 58.0]]);
}
/// Row-parallel bias lives only on rank 0; other ranks have None.
/// (Verifies the rank-0-only bias contract.)
#[test]
fn row_parallel_bias_only_on_rank_zero() {
let device = Device::Cpu;
let w = Tensor::zeros((4, 4), DType::F32, &device).unwrap();
let b = Tensor::from_slice(&[1f32, 1., 1., 1.], 4, &device).unwrap();
let mut tensors = HashMap::new();
tensors.insert("baz.weight".into(), w);
tensors.insert("baz.bias".into(), b);
let vb_root = vb_from_map(tensors, &device);
let vb_baz = vb_root.pp("baz");
let r0 = ShardedLinear::load_row(&vb_baz, 4, 4, true, 0, 2).unwrap();
let r1 = ShardedLinear::load_row(&vb_baz, 4, 4, true, 1, 2).unwrap();
// We can't introspect the Linear's bias from the public API,
// but we can run forward of zero-weight rank 1 and confirm
// the output is zero (no bias added on non-zero ranks).
let x = Tensor::ones((1, 2), DType::F32, &device).unwrap();
let y1 = r1.forward(&x).unwrap().to_vec2::<f32>().unwrap();
assert_eq!(y1, vec![vec![0.0, 0.0, 0.0, 0.0]]);
let y0 = r0.forward(&x).unwrap().to_vec2::<f32>().unwrap();
// Rank 0 weight is zero but bias is [1,1,1,1] → output should be [1,1,1,1].
assert_eq!(y0, vec![vec![1.0, 1.0, 1.0, 1.0]]);
}
#[test]
fn column_parallel_rejects_non_divisible_out_features() {
let device = Device::Cpu;
let w = Tensor::zeros((5, 3), DType::F32, &device).unwrap();
let mut tensors = HashMap::new();
tensors.insert("nope.weight".into(), w);
let vb_root = vb_from_map(tensors, &device);
let vb_nope = vb_root.pp("nope");
let err = ShardedLinear::load_column(&vb_nope, 3, 5, false, 0, 2).unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("not divisible by world_size"),
"expected divisibility error, got: {msg}"
);
}
#[test]
fn row_parallel_rejects_non_divisible_in_features() {
let device = Device::Cpu;
let w = Tensor::zeros((4, 5), DType::F32, &device).unwrap();
let mut tensors = HashMap::new();
tensors.insert("nope.weight".into(), w);
let vb_root = vb_from_map(tensors, &device);
let vb_nope = vb_root.pp("nope");
let err = ShardedLinear::load_row(&vb_nope, 5, 4, false, 0, 2).unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("not divisible by world_size"),
"expected divisibility error, got: {msg}"
);
}
}