feat(stage-8c): scaffold qwen3_5 (Qwen3.6) — dispatch + stubs + TP gate
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / Format (push) Successful in 38s
CI / Clippy (push) Successful in 2m14s
CI / Test (push) Successful in 4m29s
build-prerelease / Build neuron-blackwell (push) Successful in 3m39s
build-prerelease / Build cortex binary (push) Successful in 4m17s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m31s
build-prerelease / Build neuron-ampere (push) Successful in 5m13s
build-prerelease / Build neuron-ada (push) Successful in 5m1s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m14s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / Format (push) Successful in 38s
CI / Clippy (push) Successful in 2m14s
CI / Test (push) Successful in 4m29s
build-prerelease / Build neuron-blackwell (push) Successful in 3m39s
build-prerelease / Build cortex binary (push) Successful in 4m17s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m31s
build-prerelease / Build neuron-ampere (push) Successful in 5m13s
build-prerelease / Build neuron-ada (push) Successful in 5m1s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m14s
Lays the wiring for the top-priority TP-2 target without doing the substantive architecture work yet. After this commit, attempting to load a Qwen3.6 (`model_type = "qwen3_5"`) model: - Passes config.json parse — the real upstream shape (text_config wrapper, layer_types, attn_output_gate, head_dim=256, etc.) round- trips through a typed Config (unit test included). - Constructs a placeholder Qwen3_5ForCausalLM, attaches it to a ModelArch::Qwen3_5Dense variant, registers it in the loaded set. - Fails on the first inference forward with a clear "Qwen3-Next forward not implemented yet (Stage 8c, TP-2 motivator)" — the point where the real architecture work begins. New layout: - `harness/arch/` for custom architectures candle-transformers doesn't ship. Each architecture is one module: Config + ForCausalLM + impl. - `harness/arch/qwen3_5.rs` — the scaffold. Heavy doc comments on the open work: layer_types dispatch (full_attention vs linear_attention, the latter being the hard part with no candle precedent), attn_output_gate, text_config nesting, recurrent state lifecycle. - DENSE_SUPPORTED_MODEL_TYPES adds "qwen3_5"; load_arch_dense gains a branch that constructs the stub. TP-side gate: - New `check_tp_arch_supported`: even though Llama / Qwen3 MoE pass the single-GPU dense check (DENSE_SUPPORTED_MODEL_TYPES), the worker pool's `load_dense_shard` reconstructs the config as Qwen3 on every rank — silently misrouting a non-Qwen3 dense load through it would surface as a cryptic per-rank deserialise error. - TP_SUPPORTED_MODEL_TYPES = ["qwen3"] (cuda-gated). Anything else bails *before* the worker pool spawns and NCCL handshake costs are paid, with a marker pointing at the `tp_<family>.rs` module a contributor would need to add. qwen3_5 specifically lands here until its architecture is real. The naming choice: keep "qwen3_5" from the model's own config.json rather than mistralrs's "qwen3_next" — the latter ages poorly the moment Qwen ship another architecture revision. Unit tests: 2 new for qwen3_5 (config deserialise + dispatch gate); the previously-rejecting test for qwen3_5 swapped to a fictional arch so it stays meaningful as the supported set grows. 26 lib tests pass; cargo clippy CPU + --features cuda both clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
23
crates/neuron/src/harness/arch/mod.rs
Normal file
23
crates/neuron/src/harness/arch/mod.rs
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
//! Custom architecture implementations.
|
||||||
|
//!
|
||||||
|
//! When candle-transformers ships a model family unchanged
|
||||||
|
//! (`models::llama`, `models::qwen3`, `models::qwen3_moe`, etc.), the
|
||||||
|
//! handler in `harness/candle.rs` just wraps the upstream type in a
|
||||||
|
//! `ModelArch` variant.
|
||||||
|
//!
|
||||||
|
//! When candle has nothing for the architecture and we have to write
|
||||||
|
//! it from scratch — Qwen3-Next / Qwen3.6 (`qwen3_5`) being the
|
||||||
|
//! motivating example — the implementation lands here, one file per
|
||||||
|
//! architecture.
|
||||||
|
//!
|
||||||
|
//! Each architecture module is expected to expose:
|
||||||
|
//! - A `Config` type deserialised from the model's `config.json`
|
||||||
|
//! (some architectures nest the real hyperparams under `text_config`,
|
||||||
|
//! in which case the module owns the unwrapping).
|
||||||
|
//! - A `ForCausalLM` struct with `new`, `forward(&mut self, x, offset)
|
||||||
|
//! -> Result<Tensor>`, and `clear_kv_cache(&mut self)`.
|
||||||
|
//!
|
||||||
|
//! TP-aware analogues live in `harness/tp/tp_<family>.rs` and follow
|
||||||
|
//! the pattern set by `tp_qwen3.rs`.
|
||||||
|
|
||||||
|
pub mod qwen3_5;
|
||||||
207
crates/neuron/src/harness/arch/qwen3_5.rs
Normal file
207
crates/neuron/src/harness/arch/qwen3_5.rs
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
//! Qwen3-Next (`model_type = "qwen3_5"`) architecture — Qwen3.6's
|
||||||
|
//! upstream architecture revision.
|
||||||
|
//!
|
||||||
|
//! ## Naming
|
||||||
|
//!
|
||||||
|
//! The model release this targets is `Qwen/Qwen3.6-*` but the
|
||||||
|
//! architecture name in HuggingFace's `config.json` is `qwen3_5`.
|
||||||
|
//! mistralrs calls the same architecture `qwen3_next`; that label
|
||||||
|
//! ages poorly the next time Qwen ship a new arch, so we key on the
|
||||||
|
//! canonical `qwen3_5` from the model's own config.
|
||||||
|
//!
|
||||||
|
//! ## Status
|
||||||
|
//!
|
||||||
|
//! **Scaffold only.** `Config` deserialisation is real (so the dispatch
|
||||||
|
//! in `candle.rs::load_arch_dense` can route based on `model_type`
|
||||||
|
//! and the operator's diagnostic surfaces "qwen3_5" in the supported
|
||||||
|
//! set); the actual forward pass is `unimplemented!()`. Filling this
|
||||||
|
//! in is the substantive Stage 8c work.
|
||||||
|
//!
|
||||||
|
//! ## What the architecture needs (open work)
|
||||||
|
//!
|
||||||
|
//! Confirmed from `Qwen/Qwen3.6-27B/config.json`:
|
||||||
|
//! - Real hyperparams nested under `text_config: {...}`. The
|
||||||
|
//! architecture is text-side; the multimodal vision tower is
|
||||||
|
//! separate (`image_token_id`, `language_model_only=false`).
|
||||||
|
//! - `hidden_size: 5120`, `head_dim: 256`, `intermediate_size: 17408`,
|
||||||
|
//! `num_attention_heads`, `num_key_value_heads`, etc. — bigger
|
||||||
|
//! head_dim than plain Qwen3.
|
||||||
|
//! - `attn_output_gate: true` — a sigmoid gate multiplied into the
|
||||||
|
//! attention output before the projection. ~10 LoC addition vs the
|
||||||
|
//! plain Qwen3 attention.
|
||||||
|
//! - `layer_types: ["linear_attention", "linear_attention",
|
||||||
|
//! "linear_attention", "full_attention", ...]` with
|
||||||
|
//! `full_attention_interval: 4` — every 4th layer is full
|
||||||
|
//! attention, the rest are linear-attention. The full-attention
|
||||||
|
//! layers shape like a Qwen3 attention; the linear-attention
|
||||||
|
//! layers are the hard part.
|
||||||
|
//!
|
||||||
|
//! ## Linear-attention layer
|
||||||
|
//!
|
||||||
|
//! Candle has nothing we can reuse — has to be written against the
|
||||||
|
//! reference Python in the Qwen3-Next HF repo. Likely Lightning
|
||||||
|
//! Attention-2 (state-space-ish recurrence) given the
|
||||||
|
//! `linear_attention` tag and Qwen3's prior `qwen3-omni` work. Needs:
|
||||||
|
//! - A persistent recurrent state per layer (replaces the explicit
|
||||||
|
//! KV cache for full attention).
|
||||||
|
//! - Per-token update + readout primitives, fused if possible.
|
||||||
|
//! - Numerical-correctness validation against the Python reference
|
||||||
|
//! on a fixed prompt before trusting any output downstream.
|
||||||
|
//!
|
||||||
|
//! ## TP-2 (the immediate motivator)
|
||||||
|
//!
|
||||||
|
//! Beast's 2x RTX 5090 needs tensor-parallel to fit Qwen3.6-27B.
|
||||||
|
//! TP-aware analogue lives at `harness/tp/tp_qwen3_5.rs` (not yet
|
||||||
|
//! created — added alongside the dense impl). Sharding strategy
|
||||||
|
//! diverges by layer type:
|
||||||
|
//! - Full-attention layers: column-parallel q/k/v + row-parallel o,
|
||||||
|
//! same as `tp_qwen3.rs`. With `attn_output_gate`, the gate weight
|
||||||
|
//! is also column-parallel (one gate scalar per head).
|
||||||
|
//! - Linear-attention layers: the recurrent state is per-token, not
|
||||||
|
//! per-head, so head-dim sharding doesn't apply. Options are
|
||||||
|
//! (a) replicate the linear-attention layers across ranks (cheap
|
||||||
|
//! but wastes ~half the per-rank VRAM since 3 of every 4 layers
|
||||||
|
//! replicate), or (b) shard along the recurrent-state dimension
|
||||||
|
//! if the formulation allows. Decision deferred until the linear
|
||||||
|
//! attention is actually implemented and profiled.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle_core::Tensor;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
/// `model_type` we deserialise from `config.json`. Const so the
|
||||||
|
/// dispatch in `candle.rs::load_arch_dense` can pattern-match without
|
||||||
|
/// magic strings.
|
||||||
|
pub const MODEL_TYPE: &str = "qwen3_5";
|
||||||
|
|
||||||
|
/// Top-level shape of Qwen3-Next's `config.json`. The real
|
||||||
|
/// hyperparameters live in `text_config`; the rest is multimodal /
|
||||||
|
/// tokeniser glue we don't need for the language-model forward.
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
/// Always `"qwen3_5"` for this architecture. Kept on the struct
|
||||||
|
/// so the (eventual) dispatch / logging code can show it without
|
||||||
|
/// re-parsing the JSON.
|
||||||
|
pub model_type: String,
|
||||||
|
/// The text-side hyperparameters. Everything we actually need.
|
||||||
|
pub text_config: TextConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
|
||||||
|
/// but with the extras Qwen3-Next adds (`attn_output_gate`,
|
||||||
|
/// `layer_types`, `full_attention_interval`, larger `head_dim`).
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct TextConfig {
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub num_key_value_heads: usize,
|
||||||
|
pub head_dim: usize,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
|
pub rope_theta: f64,
|
||||||
|
pub rms_norm_eps: f64,
|
||||||
|
#[serde(default)]
|
||||||
|
pub tie_word_embeddings: bool,
|
||||||
|
|
||||||
|
/// New in Qwen3-Next: a sigmoid gate multiplied into the attention
|
||||||
|
/// output before the o_proj. The Python reference applies it
|
||||||
|
/// pointwise after softmax+matmul.
|
||||||
|
#[serde(default)]
|
||||||
|
pub attn_output_gate: bool,
|
||||||
|
|
||||||
|
/// One entry per decoder layer; values are `"full_attention"` or
|
||||||
|
/// `"linear_attention"`. Length must equal `num_hidden_layers`.
|
||||||
|
/// `full_attention_interval` is a derived hint (every 4th layer
|
||||||
|
/// by default) — `layer_types` is authoritative.
|
||||||
|
#[serde(default)]
|
||||||
|
pub layer_types: Vec<String>,
|
||||||
|
|
||||||
|
/// Hint for the layer-type pattern (defaults to 4). Kept for
|
||||||
|
/// logging / validation; the forward dispatches on `layer_types`.
|
||||||
|
#[serde(default)]
|
||||||
|
pub full_attention_interval: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stub model. Fields are intentionally empty — filling in the
|
||||||
|
/// concrete architecture is the substantive Stage 8c work. The struct
|
||||||
|
/// exists so the `ModelArch::Qwen3_5Dense(_)` variant has a payload
|
||||||
|
/// and dispatch wiring compiles end-to-end.
|
||||||
|
///
|
||||||
|
/// To extend: add embed_tokens, decoder layers, final norm, and
|
||||||
|
/// lm_head fields here; implement `new`, `forward`, `clear_kv_cache`
|
||||||
|
/// in terms of them. Mirror the layout of `qwen3_dense::ModelForCausalLM`
|
||||||
|
/// (in candle-transformers) as a starting point.
|
||||||
|
pub struct Qwen3_5ForCausalLM {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
config: Config,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5ForCausalLM {
|
||||||
|
pub fn new(config: Config, _vb: candle_nn::VarBuilder) -> Result<Self> {
|
||||||
|
// TODO(stage-8c): build embed_tokens, decoder layers (dispatching
|
||||||
|
// on layer_types), final RmsNorm, lm_head from the VarBuilder.
|
||||||
|
// For now we accept the construction so the load path can be
|
||||||
|
// exercised end-to-end (config parse + safetensors mmap), and
|
||||||
|
// bail at forward time with a clear marker.
|
||||||
|
Ok(Self { config })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, _input: &Tensor, _offset: usize) -> Result<Tensor> {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Qwen3-Next ({}) forward not implemented yet (Stage 8c, TP-2 motivator)",
|
||||||
|
self.config.model_type
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
// No-op for the stub. The real impl needs a `clear_kv_cache`
|
||||||
|
// that resets the per-layer KV cache (full-attention layers)
|
||||||
|
// and the per-layer recurrent state (linear-attention layers).
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
/// Confirms we can deserialise the real upstream config shape.
|
||||||
|
/// Sample taken from `Qwen/Qwen3.6-27B/config.json`, trimmed to
|
||||||
|
/// the fields the architecture cares about.
|
||||||
|
#[test]
|
||||||
|
fn config_deserialises_the_real_qwen3_6_shape() {
|
||||||
|
let raw = r#"{
|
||||||
|
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
||||||
|
"model_type": "qwen3_5",
|
||||||
|
"image_token_id": 248056,
|
||||||
|
"language_model_only": false,
|
||||||
|
"text_config": {
|
||||||
|
"vocab_size": 248064,
|
||||||
|
"hidden_size": 5120,
|
||||||
|
"intermediate_size": 17408,
|
||||||
|
"num_hidden_layers": 64,
|
||||||
|
"num_attention_heads": 64,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"head_dim": 256,
|
||||||
|
"max_position_embeddings": 32768,
|
||||||
|
"rope_theta": 5000000.0,
|
||||||
|
"rms_norm_eps": 1e-6,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"attn_output_gate": true,
|
||||||
|
"full_attention_interval": 4,
|
||||||
|
"layer_types": [
|
||||||
|
"linear_attention", "linear_attention",
|
||||||
|
"linear_attention", "full_attention"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}"#;
|
||||||
|
let cfg: Config = serde_json::from_str(raw).expect("parse Qwen3.6 config");
|
||||||
|
assert_eq!(cfg.model_type, "qwen3_5");
|
||||||
|
assert_eq!(cfg.text_config.hidden_size, 5120);
|
||||||
|
assert_eq!(cfg.text_config.head_dim, 256);
|
||||||
|
assert!(cfg.text_config.attn_output_gate);
|
||||||
|
assert_eq!(cfg.text_config.full_attention_interval, Some(4));
|
||||||
|
assert_eq!(cfg.text_config.layer_types.len(), 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -126,6 +126,12 @@ pub enum ModelArch {
|
|||||||
// than the others (clippy::large_enum_variant).
|
// than the others (clippy::large_enum_variant).
|
||||||
LlamaQuantized(QuantizedLlamaWeights),
|
LlamaQuantized(QuantizedLlamaWeights),
|
||||||
LlamaDense(Box<LlamaDense>),
|
LlamaDense(Box<LlamaDense>),
|
||||||
|
|
||||||
|
// Qwen3-Next family (model_type "qwen3_5") — Qwen3.6's
|
||||||
|
// architecture. Stage 8c scaffolding only: dispatch + config parse
|
||||||
|
// are real; forward bails "not implemented yet". See
|
||||||
|
// `arch/qwen3_5.rs` for the open architecture work.
|
||||||
|
Qwen3_5Dense(super::arch::qwen3_5::Qwen3_5ForCausalLM),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelArch {
|
impl ModelArch {
|
||||||
@@ -141,6 +147,7 @@ impl ModelArch {
|
|||||||
ModelArch::Qwen3MoeDense(m) => m.forward(input, offset)?,
|
ModelArch::Qwen3MoeDense(m) => m.forward(input, offset)?,
|
||||||
ModelArch::LlamaQuantized(m) => m.forward(input, offset)?,
|
ModelArch::LlamaQuantized(m) => m.forward(input, offset)?,
|
||||||
ModelArch::LlamaDense(m) => m.forward(input, offset)?,
|
ModelArch::LlamaDense(m) => m.forward(input, offset)?,
|
||||||
|
ModelArch::Qwen3_5Dense(m) => m.forward(input, offset)?,
|
||||||
};
|
};
|
||||||
squeeze_to_vocab(&raw)
|
squeeze_to_vocab(&raw)
|
||||||
}
|
}
|
||||||
@@ -164,6 +171,10 @@ impl ModelArch {
|
|||||||
}
|
}
|
||||||
ModelArch::LlamaQuantized(_) => Ok(()),
|
ModelArch::LlamaQuantized(_) => Ok(()),
|
||||||
ModelArch::LlamaDense(m) => m.clear_kv_cache(),
|
ModelArch::LlamaDense(m) => m.clear_kv_cache(),
|
||||||
|
ModelArch::Qwen3_5Dense(m) => {
|
||||||
|
m.clear_kv_cache();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -225,7 +236,7 @@ const REPEAT_LAST_N: usize = 64;
|
|||||||
/// value. New entries land alongside a new `ModelArch` variant + a
|
/// value. New entries land alongside a new `ModelArch` variant + a
|
||||||
/// dispatch branch in `load_arch_dense` (plus, for TP, a parallel
|
/// dispatch branch in `load_arch_dense` (plus, for TP, a parallel
|
||||||
/// pattern in `tp_qwen3.rs`).
|
/// pattern in `tp_qwen3.rs`).
|
||||||
const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["llama", "qwen3", "qwen3_moe"];
|
const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["llama", "qwen3", "qwen3_5", "qwen3_moe"];
|
||||||
|
|
||||||
/// Pre-flight check the operator's `config.json` against the set of
|
/// Pre-flight check the operator's `config.json` against the set of
|
||||||
/// architectures the dense path actually knows how to build. Surfaces
|
/// architectures the dense path actually knows how to build. Surfaces
|
||||||
@@ -275,6 +286,38 @@ fn check_dense_config_supported(config_json: &str, model_id: &str) -> Result<()>
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Architectures the TP path can actually load and run. A subset of
|
||||||
|
/// `DENSE_SUPPORTED_MODEL_TYPES` — the single-GPU path supports more
|
||||||
|
/// families than the TP path because each TP-aware module is a real
|
||||||
|
/// chunk of work (`tp_qwen3.rs` is the only one shipped today).
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
const TP_SUPPORTED_MODEL_TYPES: &[&str] = &["qwen3"];
|
||||||
|
|
||||||
|
/// TP-side counterpart to `check_dense_config_supported`. Gates the
|
||||||
|
/// `load_tp` path on a narrower architecture set: even though the
|
||||||
|
/// single-GPU dense path knows how to build a Llama model, the worker
|
||||||
|
/// pool's `load_dense_shard` reconstructs the config as Qwen3 — there
|
||||||
|
/// is no `tp_llama.rs` yet. Surfacing this as a config-time error
|
||||||
|
/// (before we spawn workers and burn NCCL handshake cost) is much
|
||||||
|
/// kinder than the inevitable per-rank deserialise failure.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn check_tp_arch_supported(config_json: &str, model_id: &str) -> Result<()> {
|
||||||
|
let v: serde_json::Value = serde_json::from_str(config_json)
|
||||||
|
.with_context(|| format!("parse config.json for '{model_id}' as JSON"))?;
|
||||||
|
let model_type = v.get("model_type").and_then(|x| x.as_str()).unwrap_or("");
|
||||||
|
if TP_SUPPORTED_MODEL_TYPES.contains(&model_type) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
anyhow::bail!(
|
||||||
|
"tensor_parallel requested for '{model_id}' (model_type='{model_type}') but \
|
||||||
|
the TP path supports only {TP_SUPPORTED_MODEL_TYPES:?}. Adding a new \
|
||||||
|
TP-aware architecture needs a `harness/tp/tp_<family>.rs` module mirroring \
|
||||||
|
`tp_qwen3.rs` (sharded linears, AllReduce, per-rank head counts) and a \
|
||||||
|
dispatch in `WorkerPool::load_dense_shard`. For models that fit on one \
|
||||||
|
GPU, drop `tensor_parallel` to use the single-GPU dense path."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/// Resolve the effective HuggingFace cache directory for the candle
|
/// Resolve the effective HuggingFace cache directory for the candle
|
||||||
/// harness. Precedence (first hit wins):
|
/// harness. Precedence (first hit wins):
|
||||||
///
|
///
|
||||||
@@ -573,6 +616,16 @@ impl CandleHarness {
|
|||||||
device: device_for_load,
|
device: device_for_load,
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
"qwen3_5" => {
|
||||||
|
// Stage 8c scaffold: config parses, model
|
||||||
|
// constructs, but forward bails. See
|
||||||
|
// `arch/qwen3_5.rs` for the open architecture work.
|
||||||
|
let cfg: super::arch::qwen3_5::Config = serde_json::from_str(&cfg_text)
|
||||||
|
.context("parse Qwen3-Next (qwen3_5) config.json")?;
|
||||||
|
let model = super::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, vb)
|
||||||
|
.context("build Qwen3-Next dense model")?;
|
||||||
|
Ok(ModelArch::Qwen3_5Dense(model))
|
||||||
|
}
|
||||||
other => {
|
other => {
|
||||||
// Defensive: `check_dense_config_supported` already
|
// Defensive: `check_dense_config_supported` already
|
||||||
// gated on the supported set, so this branch is
|
// gated on the supported set, so this branch is
|
||||||
@@ -1045,6 +1098,16 @@ impl CandleHarness {
|
|||||||
// lifecycle on a load that's guaranteed to fail at deserialise
|
// lifecycle on a load that's guaranteed to fail at deserialise
|
||||||
// time inside every rank.
|
// time inside every rank.
|
||||||
check_dense_config_supported(&config_json, &spec.model_id)?;
|
check_dense_config_supported(&config_json, &spec.model_id)?;
|
||||||
|
// The TP path knows how to ship and reconstruct a Qwen3 dense
|
||||||
|
// shard (`tp_qwen3.rs`). Other architectures may pass the
|
||||||
|
// single-GPU `check_dense_config_supported` check above but
|
||||||
|
// have no TP-aware module — bail with a clear marker pointing
|
||||||
|
// at the file the implementer needs to add. This keeps an
|
||||||
|
// operator who sets `tensor_parallel=2` on a Llama model from
|
||||||
|
// silently routing through `pool.load_dense_shard` (which
|
||||||
|
// assumes Qwen3 config shape on the worker side) and producing
|
||||||
|
// a confusing config-parse failure inside every rank.
|
||||||
|
check_tp_arch_supported(&config_json, &spec.model_id)?;
|
||||||
|
|
||||||
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
|
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
|
||||||
// 1..tp_size are subprocesses, one per device after the
|
// 1..tp_size are subprocesses, one per device after the
|
||||||
@@ -1704,22 +1767,24 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn check_dense_config_rejects_qwen3_5_with_clear_message() {
|
fn check_dense_config_rejects_unsupported_arch_with_clear_message() {
|
||||||
|
// Use a deliberately-fake model_type so this test stays
|
||||||
|
// meaningful as the supported set grows. (qwen3_5 was the
|
||||||
|
// motivating real example but now lives in the supported set
|
||||||
|
// as a Stage 8c scaffold.)
|
||||||
let cfg = r#"{
|
let cfg = r#"{
|
||||||
"model_type": "qwen3_5",
|
"model_type": "fictional_arch_99",
|
||||||
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
"architectures": ["FictionalArch99ForCausalLM"]
|
||||||
"image_token_id": 248056,
|
|
||||||
"text_config": {"hidden_size": 5120}
|
|
||||||
}"#;
|
}"#;
|
||||||
let err = check_dense_config_supported(cfg, "Qwen/Qwen3.6-27B")
|
let err = check_dense_config_supported(cfg, "Fake/Model-99")
|
||||||
.expect_err("qwen3_5 should be rejected");
|
.expect_err("fictional_arch_99 should be rejected");
|
||||||
let msg = format!("{err}");
|
let msg = format!("{err}");
|
||||||
assert!(
|
assert!(
|
||||||
msg.contains("unsupported model_type 'qwen3_5'"),
|
msg.contains("unsupported model_type 'fictional_arch_99'"),
|
||||||
"message should name the rejected type: {msg}"
|
"message should name the rejected type: {msg}"
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
msg.contains("Qwen/Qwen3.6-27B"),
|
msg.contains("Fake/Model-99"),
|
||||||
"message should echo the model id: {msg}"
|
"message should echo the model id: {msg}"
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
@@ -1728,6 +1793,21 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn check_dense_config_accepts_qwen3_5() {
|
||||||
|
// Sanity: Stage 8c scaffold means qwen3_5 deserialises into the
|
||||||
|
// supported set. Forward still bails (covered by tests on the
|
||||||
|
// architecture module itself), but the dispatch gate must let
|
||||||
|
// it through.
|
||||||
|
let cfg = r#"{
|
||||||
|
"model_type": "qwen3_5",
|
||||||
|
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
||||||
|
"text_config": {"hidden_size": 5120}
|
||||||
|
}"#;
|
||||||
|
check_dense_config_supported(cfg, "Qwen/Qwen3.6-27B")
|
||||||
|
.expect("qwen3_5 should be in the supported set as of Stage 8c scaffold");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn check_dense_config_rejects_missing_model_type() {
|
fn check_dense_config_rejects_missing_model_type() {
|
||||||
let cfg = r#"{ "vocab_size": 1234 }"#;
|
let cfg = r#"{ "vocab_size": 1234 }"#;
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
//! Harness registry — maps harness names to trait implementations.
|
//! Harness registry — maps harness names to trait implementations.
|
||||||
|
|
||||||
|
pub mod arch;
|
||||||
pub mod candle;
|
pub mod candle;
|
||||||
pub mod tp;
|
pub mod tp;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user