feat(tp): TP-aware Qwen3 dense model (Stage 7b-iii 2/2)
Mirrors candle_transformers::models::qwen3 structurally with column-
parallel q/k/v + gate/up projections, row-parallel o + down projections,
and replicated embedding/norms/lm_head. Per-rank head counts come from
dividing num_attention_heads / num_key_value_heads by world_size at load
time; intermediate_size split likewise. Load bails on any non-divisible
shape — the safetensors slice would lose data otherwise.
KV cache holds the rank-local slice since K/V come out of column-parallel
projections; no cache resharding across ranks. Causal mask is computed
on rank 0 shape and broadcasts over the head dim so per-rank H differs
without rework.
Replicated tensors (embedding, all RmsNorms, untied lm_head) load via
vb.get(shape, name), which uses the default Shard { world_size: 1 } and
falls through to the unsharded backend path on ShardedSafeTensors.
The cuda / non-cuda load splits track the existing tp_linear pattern:
RowParallelLinear takes an Arc<Comm> only under cuda, and the higher-
level composers (TpQwen3MLP, TpQwen3Attention, TpDecoderLayer,
TpQwen3Model, TpQwen3ForCausalLM) thread it through accordingly.
7b-iv wires RPC + dispatch in CandleHarness::load_model.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -21,6 +21,7 @@ pub mod all_reduce;
|
||||
pub mod nccl_state;
|
||||
pub mod rpc;
|
||||
pub mod tp_linear;
|
||||
pub mod tp_qwen3;
|
||||
pub mod worker;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
|
||||
Reference in New Issue
Block a user