feat(tp): TP-aware Qwen3 dense model (Stage 7b-iii 2/2)

Mirrors candle_transformers::models::qwen3 structurally with column-
parallel q/k/v + gate/up projections, row-parallel o + down projections,
and replicated embedding/norms/lm_head. Per-rank head counts come from
dividing num_attention_heads / num_key_value_heads by world_size at load
time; intermediate_size split likewise. Load bails on any non-divisible
shape — the safetensors slice would lose data otherwise.

KV cache holds the rank-local slice since K/V come out of column-parallel
projections; no cache resharding across ranks. Causal mask is computed
on rank 0 shape and broadcasts over the head dim so per-rank H differs
without rework.

Replicated tensors (embedding, all RmsNorms, untied lm_head) load via
vb.get(shape, name), which uses the default Shard { world_size: 1 } and
falls through to the unsharded backend path on ShardedSafeTensors.

The cuda / non-cuda load splits track the existing tp_linear pattern:
RowParallelLinear takes an Arc<Comm> only under cuda, and the higher-
level composers (TpQwen3MLP, TpQwen3Attention, TpDecoderLayer,
TpQwen3Model, TpQwen3ForCausalLM) thread it through accordingly.

7b-iv wires RPC + dispatch in CandleHarness::load_model.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-19 18:24:20 +03:00
parent 8d3194f992
commit 46527d7804
2 changed files with 606 additions and 0 deletions

View File

@@ -21,6 +21,7 @@ pub mod all_reduce;
pub mod nccl_state;
pub mod rpc;
pub mod tp_linear;
pub mod tp_qwen3;
pub mod worker;
use anyhow::{Context, Result};