feat(stage-8c): TP-aware Qwen3-Next (tp_qwen3_5)
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Format (push) Successful in 39s
CI / Clippy (push) Successful in 2m13s
build-prerelease / Build neuron-blackwell (push) Successful in 3m37s
CI / Test (push) Successful in 4m49s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m26s
build-prerelease / Build neuron-ampere (push) Successful in 5m18s
build-prerelease / Package cortex RPM (push) Successful in 7m6s
build-prerelease / Build neuron-ada (push) Successful in 5m13s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m2s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m55s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 5m39s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Format (push) Successful in 39s
CI / Clippy (push) Successful in 2m13s
build-prerelease / Build neuron-blackwell (push) Successful in 3m37s
CI / Test (push) Successful in 4m49s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m26s
build-prerelease / Build neuron-ampere (push) Successful in 5m18s
build-prerelease / Package cortex RPM (push) Successful in 7m6s
build-prerelease / Build neuron-ada (push) Successful in 5m13s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m2s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m55s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 5m39s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
Adds `harness/tp/tp_qwen3_5.rs` — the tensor-parallel variant of the
Qwen3-Next architecture — plus the dispatch wiring needed to route a
load through it on both the leader and the workers.
Architecture pieces (all per-rank, follow `tp_qwen3.rs` patterns for
the full-attention layers + a new pattern for linear-attention):
- TpQwen3_5GatedDeltaNet: V-head-dim sharded. `num_v_heads / world_size`
V-heads per rank, `num_k_heads / world_size` K-heads. `in_proj_z`,
`in_proj_b`, `in_proj_a`, `A_log`, `dt_bias` shard uniformly along
the V-head dim. `out_proj` is row-parallel + AllReduce (the only
collective inside the block). The recurrent state shards 1:1 with
V-heads — no cross-rank sync inside the delta-rule loop.
`in_proj_qkv` and `conv1d.weight` are FUSED tensors with three
regions along dim 0 (`[first key_dim, second key_dim, value_dim]`).
Standard uniform-slicing doesn't align with the head boundaries —
rank 0 would end up with `[first half of K_0, full K_1, first half
of V]`. New `load_fused_qkv_slice_{2d,3d}` helpers load the full
tensor, narrow per-region per-rank, and `Tensor::cat` the three
slices into a per-rank fused weight. Transient peak of one full
tensor per layer during construction; net memory is properly per-
rank after the full drops.
- TpQwen3_5Attention: column-parallel `q_proj` (the widened
`2 * num_heads * head_dim` output, including the gate half — shards
along the head axis so both query AND gate halves stay consistent
per rank), `k_proj`, `v_proj`; row-parallel `o_proj` with AllReduce.
Otherwise mirrors `tp_qwen3.rs`'s attention.
- TpQwen3_5MLP, TpQwen3_5DecoderLayer (dispatches on layer_types),
TpQwen3_5Model (with `model.language_model.*` prefix), and
TpQwen3_5ForCausalLM (with tied or separate `lm_head` at top level).
Dispatch wiring:
- New `tp::TpLeaderModel` enum holds either Qwen3 or Qwen3_5 variant.
`WorkerPool::load_dense_shard` now dispatches on `model_type` from
the config JSON and returns `Arc<Mutex<TpLeaderModel>>`. The two
downstream methods (`generate_step`, `clear_kv_cache`) thread this
enum through — the inner forward+clear_kv_cache dispatch happens
via the enum's pub methods. Adding another TP architecture later is
one more enum variant + match arms.
- Worker side gets a parallel `WorkerModel` enum + dispatch in
`handle_load_dense_shard`, branching on the same `model_type`.
- Harness gate `TP_SUPPORTED_MODEL_TYPES` now `["qwen3", "qwen3_5"]`.
`TpLoadedModel.leader_model` retyped to the enum.
Helpers in `arch/qwen3_5/linear_attn.rs`:
- `softplus` and `repeat_interleave` made `pub(crate)` so the TP
module reuses them rather than duplicating.
Reuses unchanged: `Qwen3_5RmsNorm` (replicated weight), the gated
`Qwen3_5RmsNormGated` tail, `l2norm`, the `RotaryEmbedding` (partial
RoPE with `partial_rotary_factor` already correct).
CPU build + clippy + 32 lib tests pass; `cargo clippy --features cuda`
also clean inside the patched runner container.
Single inflight risk to call out: tensor names. For full-attention
layers the per-layer prefix is `model.language_model.layers.<i>.self_attn.*`
and for linear-attention layers `model.language_model.layers.<i>.linear_attn.*`
— the same as the single-GPU path. lm_head sits at the top level (not
under `language_model`) — consistent with the single-GPU path that
validated against Qwen3.5-0.8B.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -22,6 +22,7 @@ pub mod nccl_state;
|
||||
pub mod rpc;
|
||||
pub mod tp_linear;
|
||||
pub mod tp_qwen3;
|
||||
pub mod tp_qwen3_5;
|
||||
pub mod worker;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
@@ -32,6 +33,49 @@ use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||
|
||||
use rpc::{WorkerRequest, WorkerResponse};
|
||||
|
||||
/// Leader-side handle for any TP-loaded model. The pool's
|
||||
/// `load_dense_shard` dispatches on `config.json#/model_type` to build
|
||||
/// the right variant; downstream callers (the harness's
|
||||
/// `chat_completion_tp` path, `generate_step`, `clear_kv_cache`,
|
||||
/// `unload_model`) all hold this enum and let the variant dispatch
|
||||
/// determine the concrete forward.
|
||||
///
|
||||
/// Variants gated on `cuda` because the underlying TP models hold
|
||||
/// `Arc<cudarc::nccl::Comm>` references — irrelevant on CPU builds.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub enum TpLeaderModel {
|
||||
Qwen3(tp_qwen3::TpQwen3ForCausalLM),
|
||||
Qwen3_5(tp_qwen3_5::TpQwen3_5ForCausalLM),
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
impl TpLeaderModel {
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
input: &candle_core::Tensor,
|
||||
offset: usize,
|
||||
) -> candle_core::Result<candle_core::Tensor> {
|
||||
match self {
|
||||
TpLeaderModel::Qwen3(m) => m.forward(input, offset),
|
||||
TpLeaderModel::Qwen3_5(m) => m.forward(input, offset),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
match self {
|
||||
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
|
||||
TpLeaderModel::Qwen3_5(m) => m.clear_kv_cache(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &candle_core::Device {
|
||||
match self {
|
||||
TpLeaderModel::Qwen3(m) => m.device(),
|
||||
TpLeaderModel::Qwen3_5(m) => m.device(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// One worker subprocess plus its bidirectional stdio handles.
|
||||
struct Worker {
|
||||
rank: u32,
|
||||
@@ -363,7 +407,7 @@ impl WorkerPool {
|
||||
safetensors_paths: &[std::path::PathBuf],
|
||||
leader_device: &candle_core::Device,
|
||||
dtype: candle_core::DType,
|
||||
) -> Result<std::sync::Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>> {
|
||||
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> {
|
||||
use candle_nn::var_builder::ShardedSafeTensors;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -396,36 +440,56 @@ impl WorkerPool {
|
||||
.await?;
|
||||
}
|
||||
|
||||
// 2. Build rank 0's shard on the leader. ShardedVarBuilder reads
|
||||
// only the rank's slice from safetensors — no full-tensor
|
||||
// materialisation. Runs in spawn_blocking because the
|
||||
// file-mmap + slice + copy-to-device work is synchronous.
|
||||
let cfg: super::tp::tp_qwen3::Config =
|
||||
serde_json::from_str(config_json).context("parse Qwen3 Config JSON for leader load")?;
|
||||
// 2. Build rank 0's shard on the leader. Dispatch on model_type
|
||||
// — for `qwen3` we build a `TpQwen3ForCausalLM`, for
|
||||
// `qwen3_5` (Qwen3-Next, Qwen3.6's architecture) we build
|
||||
// `TpQwen3_5ForCausalLM`. Both end up wrapped in the
|
||||
// `TpLeaderModel` enum so downstream callers don't care.
|
||||
let model_type = serde_json::from_str::<serde_json::Value>(config_json)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("model_type"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let paths_for_leader: Vec<std::path::PathBuf> = safetensors_paths.to_vec();
|
||||
let device_for_leader = leader_device.clone();
|
||||
let comm_for_leader = leader_comm;
|
||||
let model_id_for_log = model_id.to_string();
|
||||
let leader_model = tokio::task::spawn_blocking(
|
||||
move || -> Result<super::tp::tp_qwen3::TpQwen3ForCausalLM> {
|
||||
// SAFETY: same invariant as the single-GPU dense path —
|
||||
// the HF cache files are treated as immutable while the
|
||||
// mmap is held.
|
||||
let vb = unsafe {
|
||||
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
|
||||
.context("build ShardedVarBuilder over safetensors")?
|
||||
};
|
||||
let model = super::tp::tp_qwen3::TpQwen3ForCausalLM::load(
|
||||
&cfg,
|
||||
&vb,
|
||||
0,
|
||||
world_size,
|
||||
comm_for_leader.into_inner(),
|
||||
)?;
|
||||
tracing::info!(rank = 0, model = %model_id_for_log, "loaded TP shard (leader)");
|
||||
Ok(model)
|
||||
},
|
||||
)
|
||||
let config_json_for_leader = config_json.to_string();
|
||||
|
||||
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
|
||||
// SAFETY: same invariant as the single-GPU dense path —
|
||||
// the HF cache files are treated as immutable while the
|
||||
// mmap is held.
|
||||
let vb = unsafe {
|
||||
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
|
||||
.context("build ShardedVarBuilder over safetensors")?
|
||||
};
|
||||
let comm = comm_for_leader.into_inner();
|
||||
let loaded = match model_type.as_str() {
|
||||
"qwen3" => {
|
||||
let cfg: super::tp::tp_qwen3::Config = serde_json::from_str(&config_json_for_leader)
|
||||
.context("parse Qwen3 Config JSON for leader load")?;
|
||||
TpLeaderModel::Qwen3(super::tp::tp_qwen3::TpQwen3ForCausalLM::load(
|
||||
&cfg, &vb, 0, world_size, comm,
|
||||
)?)
|
||||
}
|
||||
"qwen3_5" => {
|
||||
let cfg: super::tp::tp_qwen3_5::Config =
|
||||
serde_json::from_str(&config_json_for_leader)
|
||||
.context("parse Qwen3-Next Config JSON for leader load")?;
|
||||
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
|
||||
cfg, &vb, 0, world_size, comm,
|
||||
)?)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)"
|
||||
),
|
||||
};
|
||||
tracing::info!(rank = 0, model = %model_id_for_log, model_type = %model_type, "loaded TP shard (leader)");
|
||||
Ok(loaded)
|
||||
})
|
||||
.await
|
||||
.context("leader load task panicked")??;
|
||||
|
||||
@@ -463,7 +527,7 @@ impl WorkerPool {
|
||||
pub async fn generate_step(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
leader_model: std::sync::Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
|
||||
leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
) -> Result<candle_core::Tensor> {
|
||||
@@ -516,9 +580,7 @@ impl WorkerPool {
|
||||
pub async fn clear_kv_cache(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<
|
||||
tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>,
|
||||
>,
|
||||
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
|
||||
) -> Result<()> {
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::ClearKvCache {
|
||||
|
||||
Reference in New Issue
Block a user