feat(stage-8c): TP-aware Qwen3-Next (tp_qwen3_5)
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Format (push) Successful in 39s
CI / Clippy (push) Successful in 2m13s
build-prerelease / Build neuron-blackwell (push) Successful in 3m37s
CI / Test (push) Successful in 4m49s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m26s
build-prerelease / Build neuron-ampere (push) Successful in 5m18s
build-prerelease / Package cortex RPM (push) Successful in 7m6s
build-prerelease / Build neuron-ada (push) Successful in 5m13s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m2s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m55s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 5m39s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s

Adds `harness/tp/tp_qwen3_5.rs` — the tensor-parallel variant of the
Qwen3-Next architecture — plus the dispatch wiring needed to route a
load through it on both the leader and the workers.

Architecture pieces (all per-rank, follow `tp_qwen3.rs` patterns for
the full-attention layers + a new pattern for linear-attention):

- TpQwen3_5GatedDeltaNet: V-head-dim sharded. `num_v_heads / world_size`
  V-heads per rank, `num_k_heads / world_size` K-heads. `in_proj_z`,
  `in_proj_b`, `in_proj_a`, `A_log`, `dt_bias` shard uniformly along
  the V-head dim. `out_proj` is row-parallel + AllReduce (the only
  collective inside the block). The recurrent state shards 1:1 with
  V-heads — no cross-rank sync inside the delta-rule loop.

  `in_proj_qkv` and `conv1d.weight` are FUSED tensors with three
  regions along dim 0 (`[first key_dim, second key_dim, value_dim]`).
  Standard uniform-slicing doesn't align with the head boundaries —
  rank 0 would end up with `[first half of K_0, full K_1, first half
  of V]`. New `load_fused_qkv_slice_{2d,3d}` helpers load the full
  tensor, narrow per-region per-rank, and `Tensor::cat` the three
  slices into a per-rank fused weight. Transient peak of one full
  tensor per layer during construction; net memory is properly per-
  rank after the full drops.

- TpQwen3_5Attention: column-parallel `q_proj` (the widened
  `2 * num_heads * head_dim` output, including the gate half — shards
  along the head axis so both query AND gate halves stay consistent
  per rank), `k_proj`, `v_proj`; row-parallel `o_proj` with AllReduce.
  Otherwise mirrors `tp_qwen3.rs`'s attention.

- TpQwen3_5MLP, TpQwen3_5DecoderLayer (dispatches on layer_types),
  TpQwen3_5Model (with `model.language_model.*` prefix), and
  TpQwen3_5ForCausalLM (with tied or separate `lm_head` at top level).

Dispatch wiring:

- New `tp::TpLeaderModel` enum holds either Qwen3 or Qwen3_5 variant.
  `WorkerPool::load_dense_shard` now dispatches on `model_type` from
  the config JSON and returns `Arc<Mutex<TpLeaderModel>>`. The two
  downstream methods (`generate_step`, `clear_kv_cache`) thread this
  enum through — the inner forward+clear_kv_cache dispatch happens
  via the enum's pub methods. Adding another TP architecture later is
  one more enum variant + match arms.

- Worker side gets a parallel `WorkerModel` enum + dispatch in
  `handle_load_dense_shard`, branching on the same `model_type`.

- Harness gate `TP_SUPPORTED_MODEL_TYPES` now `["qwen3", "qwen3_5"]`.
  `TpLoadedModel.leader_model` retyped to the enum.

Helpers in `arch/qwen3_5/linear_attn.rs`:
- `softplus` and `repeat_interleave` made `pub(crate)` so the TP
  module reuses them rather than duplicating.

Reuses unchanged: `Qwen3_5RmsNorm` (replicated weight), the gated
`Qwen3_5RmsNormGated` tail, `l2norm`, the `RotaryEmbedding` (partial
RoPE with `partial_rotary_factor` already correct).

CPU build + clippy + 32 lib tests pass; `cargo clippy --features cuda`
also clean inside the patched runner container.

Single inflight risk to call out: tensor names. For full-attention
layers the per-layer prefix is `model.language_model.layers.<i>.self_attn.*`
and for linear-attention layers `model.language_model.layers.<i>.linear_attn.*`
— the same as the single-GPU path. lm_head sits at the top level (not
under `language_model`) — consistent with the single-GPU path that
validated against Qwen3.5-0.8B.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-20 22:02:42 +03:00
parent 495d3f7c05
commit 95dc8745eb
5 changed files with 1305 additions and 63 deletions

View File

@@ -22,6 +22,7 @@ pub mod nccl_state;
pub mod rpc;
pub mod tp_linear;
pub mod tp_qwen3;
pub mod tp_qwen3_5;
pub mod worker;
use anyhow::{Context, Result};
@@ -32,6 +33,49 @@ use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use rpc::{WorkerRequest, WorkerResponse};
/// Leader-side handle for any TP-loaded model. The pool's
/// `load_dense_shard` dispatches on `config.json#/model_type` to build
/// the right variant; downstream callers (the harness's
/// `chat_completion_tp` path, `generate_step`, `clear_kv_cache`,
/// `unload_model`) all hold this enum and let the variant dispatch
/// determine the concrete forward.
///
/// Variants gated on `cuda` because the underlying TP models hold
/// `Arc<cudarc::nccl::Comm>` references — irrelevant on CPU builds.
#[cfg(feature = "cuda")]
pub enum TpLeaderModel {
Qwen3(tp_qwen3::TpQwen3ForCausalLM),
Qwen3_5(tp_qwen3_5::TpQwen3_5ForCausalLM),
}
#[cfg(feature = "cuda")]
impl TpLeaderModel {
pub fn forward(
&mut self,
input: &candle_core::Tensor,
offset: usize,
) -> candle_core::Result<candle_core::Tensor> {
match self {
TpLeaderModel::Qwen3(m) => m.forward(input, offset),
TpLeaderModel::Qwen3_5(m) => m.forward(input, offset),
}
}
pub fn clear_kv_cache(&mut self) {
match self {
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
TpLeaderModel::Qwen3_5(m) => m.clear_kv_cache(),
}
}
pub fn device(&self) -> &candle_core::Device {
match self {
TpLeaderModel::Qwen3(m) => m.device(),
TpLeaderModel::Qwen3_5(m) => m.device(),
}
}
}
/// One worker subprocess plus its bidirectional stdio handles.
struct Worker {
rank: u32,
@@ -363,7 +407,7 @@ impl WorkerPool {
safetensors_paths: &[std::path::PathBuf],
leader_device: &candle_core::Device,
dtype: candle_core::DType,
) -> Result<std::sync::Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>> {
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> {
use candle_nn::var_builder::ShardedSafeTensors;
use std::sync::Arc;
use tokio::sync::Mutex;
@@ -396,36 +440,56 @@ impl WorkerPool {
.await?;
}
// 2. Build rank 0's shard on the leader. ShardedVarBuilder reads
// only the rank's slice from safetensors — no full-tensor
// materialisation. Runs in spawn_blocking because the
// file-mmap + slice + copy-to-device work is synchronous.
let cfg: super::tp::tp_qwen3::Config =
serde_json::from_str(config_json).context("parse Qwen3 Config JSON for leader load")?;
// 2. Build rank 0's shard on the leader. Dispatch on model_type
// — for `qwen3` we build a `TpQwen3ForCausalLM`, for
// `qwen3_5` (Qwen3-Next, Qwen3.6's architecture) we build
// `TpQwen3_5ForCausalLM`. Both end up wrapped in the
// `TpLeaderModel` enum so downstream callers don't care.
let model_type = serde_json::from_str::<serde_json::Value>(config_json)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let paths_for_leader: Vec<std::path::PathBuf> = safetensors_paths.to_vec();
let device_for_leader = leader_device.clone();
let comm_for_leader = leader_comm;
let model_id_for_log = model_id.to_string();
let leader_model = tokio::task::spawn_blocking(
move || -> Result<super::tp::tp_qwen3::TpQwen3ForCausalLM> {
// SAFETY: same invariant as the single-GPU dense path —
// the HF cache files are treated as immutable while the
// mmap is held.
let vb = unsafe {
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
.context("build ShardedVarBuilder over safetensors")?
};
let model = super::tp::tp_qwen3::TpQwen3ForCausalLM::load(
&cfg,
&vb,
0,
world_size,
comm_for_leader.into_inner(),
)?;
tracing::info!(rank = 0, model = %model_id_for_log, "loaded TP shard (leader)");
Ok(model)
},
)
let config_json_for_leader = config_json.to_string();
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
// SAFETY: same invariant as the single-GPU dense path —
// the HF cache files are treated as immutable while the
// mmap is held.
let vb = unsafe {
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
.context("build ShardedVarBuilder over safetensors")?
};
let comm = comm_for_leader.into_inner();
let loaded = match model_type.as_str() {
"qwen3" => {
let cfg: super::tp::tp_qwen3::Config = serde_json::from_str(&config_json_for_leader)
.context("parse Qwen3 Config JSON for leader load")?;
TpLeaderModel::Qwen3(super::tp::tp_qwen3::TpQwen3ForCausalLM::load(
&cfg, &vb, 0, world_size, comm,
)?)
}
"qwen3_5" => {
let cfg: super::tp::tp_qwen3_5::Config =
serde_json::from_str(&config_json_for_leader)
.context("parse Qwen3-Next Config JSON for leader load")?;
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
cfg, &vb, 0, world_size, comm,
)?)
}
other => anyhow::bail!(
"TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)"
),
};
tracing::info!(rank = 0, model = %model_id_for_log, model_type = %model_type, "loaded TP shard (leader)");
Ok(loaded)
})
.await
.context("leader load task panicked")??;
@@ -463,7 +527,7 @@ impl WorkerPool {
pub async fn generate_step(
&mut self,
model_id: &str,
leader_model: std::sync::Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
tokens: Vec<u32>,
offset: usize,
) -> Result<candle_core::Tensor> {
@@ -516,9 +580,7 @@ impl WorkerPool {
pub async fn clear_kv_cache(
&mut self,
model_id: &str,
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<
tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>,
>,
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
) -> Result<()> {
for w in &mut self.workers {
w.send_only(&WorkerRequest::ClearKvCache {