feat(stage-8c): full-attention layer + decoder + Model + ForCausalLM for qwen3_5
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 39s
CI / Clippy (push) Successful in 2m19s
CI / Test (push) Successful in 4m50s
build-prerelease / Build cortex binary (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m41s
build-prerelease / Package cortex RPM (push) Successful in 1m27s
build-prerelease / Build neuron-ampere (push) Successful in 4m58s
build-prerelease / Build neuron-ada (push) Successful in 5m8s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 58s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 39s
CI / Clippy (push) Successful in 2m19s
CI / Test (push) Successful in 4m50s
build-prerelease / Build cortex binary (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m41s
build-prerelease / Package cortex RPM (push) Successful in 1m27s
build-prerelease / Build neuron-ampere (push) Successful in 4m58s
build-prerelease / Build neuron-ada (push) Successful in 5m8s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 58s
Completes the single-GPU dense path for Qwen3-Next (Qwen3.6's
architecture). The four new modules wrap the substantive
`linear_attn.rs` (landed previously) with the rest of the
transformer:
- `arch/qwen3_5/rope.rs` — text-side rotary embedding. MRoPE is
simplified to plain RoPE (the three position grids collapse to one
for text-only inference); uses candle's `rope_slow` for the
GLM-style rotate-half rotation.
- `arch/qwen3_5/mlp.rs` — Qwen3_5MLP (SwiGLU: gate/up/down, bias=False).
- `arch/qwen3_5/full_attn.rs` — Qwen3_5Attention with the two
Qwen3-Next quirks:
- `q_proj` widened to `2 * num_heads * head_dim`; second half
sigmoid'd and multiplied into the attention output before `o_proj`.
- q_norm/k_norm use the `(1+w)*x` RmsNorm variant.
- `arch/qwen3_5/decoder.rs` — Qwen3_5DecoderLayer dispatching on
`layer_types[i]` to either Full attention or GatedDeltaNet.
`arch/qwen3_5/mod.rs` gets the real `Qwen3_5Model` (embedding + layer
stack + final norm) and `Qwen3_5ForCausalLM` (model + lm_head). The
forward returns `[B, 1, vocab]` to match `qwen3_dense`; the harness's
`squeeze_to_vocab` handles either shape.
Switch: `candle.rs::load_arch_dense` for `model_type=qwen3_5` now
builds a `ShardedVarBuilder` instead of a plain VarBuilder. The
sharded backend falls through to the unsharded path when
`world_size=1`, so single-GPU load is zero-cost; this lets the
forthcoming `tp_qwen3_5.rs` reuse the same load functions without a
second copy.
Verified: cargo build CPU + --features cuda inside the patched
container; clippy clean on both; 32 lib tests still pass. The
ForCausalLM forward no longer bails — but numerical correctness vs
the Python reference hasn't been validated yet (that's the next
step, with the Tbilisi probe).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
117
crates/neuron/src/harness/arch/qwen3_5/decoder.rs
Normal file
117
crates/neuron/src/harness/arch/qwen3_5/decoder.rs
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
//! Qwen3-Next decoder layer.
|
||||||
|
//!
|
||||||
|
//! Standard pre-norm transformer block (LN → attention → residual →
|
||||||
|
//! LN → MLP → residual) where the attention slot dispatches on the
|
||||||
|
//! per-layer `layer_types[i]` value in the config:
|
||||||
|
//!
|
||||||
|
//! - `"full_attention"` → [`Qwen3_5Attention`] (GQA causal + output
|
||||||
|
//! gate + RoPE + KV cache).
|
||||||
|
//! - `"linear_attention"` → [`GatedDeltaNet`] (recurrent delta rule +
|
||||||
|
//! causal conv + per-head state).
|
||||||
|
//!
|
||||||
|
//! In Qwen3.6-27B every 4th layer is full_attention; the rest are
|
||||||
|
//! linear_attention. `full_attention_interval` in the config is a
|
||||||
|
//! hint; `layer_types` is authoritative.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::TextConfig;
|
||||||
|
use super::full_attn::Qwen3_5Attention;
|
||||||
|
use super::linear_attn::GatedDeltaNet;
|
||||||
|
use super::mlp::Qwen3_5MLP;
|
||||||
|
use super::rmsnorm::Qwen3_5RmsNorm;
|
||||||
|
use super::rope::RotaryEmbedding;
|
||||||
|
|
||||||
|
/// One of the two attention flavours sitting in a decoder layer's
|
||||||
|
/// attention slot. Full-attention layers need the rotary table and
|
||||||
|
/// take an attention mask; linear-attention layers carry their own
|
||||||
|
/// recurrent state and ignore the mask.
|
||||||
|
enum AttentionKind {
|
||||||
|
Full(Qwen3_5Attention),
|
||||||
|
Linear(GatedDeltaNet),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Qwen3_5DecoderLayer {
|
||||||
|
input_layernorm: Qwen3_5RmsNorm,
|
||||||
|
post_attention_layernorm: Qwen3_5RmsNorm,
|
||||||
|
mlp: Qwen3_5MLP,
|
||||||
|
attention: AttentionKind,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5DecoderLayer {
|
||||||
|
pub fn load(
|
||||||
|
cfg: &TextConfig,
|
||||||
|
rotary: Arc<RotaryEmbedding>,
|
||||||
|
layer_idx: usize,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let layer_type = cfg
|
||||||
|
.layer_types
|
||||||
|
.get(layer_idx)
|
||||||
|
.map(String::as_str)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"layer_types[{layer_idx}] missing (have {} entries)",
|
||||||
|
cfg.layer_types.len()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let attention = match layer_type {
|
||||||
|
"full_attention" => {
|
||||||
|
AttentionKind::Full(Qwen3_5Attention::load(cfg, rotary, &vb.pp("self_attn"))?)
|
||||||
|
}
|
||||||
|
"linear_attention" => {
|
||||||
|
AttentionKind::Linear(GatedDeltaNet::load(cfg, &vb.pp("linear_attn"))?)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"unknown layer_type '{other}' for layer {layer_idx} (expected \
|
||||||
|
'full_attention' or 'linear_attention')"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mlp = Qwen3_5MLP::load(cfg, &vb.pp("mlp"))?;
|
||||||
|
let input_layernorm =
|
||||||
|
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
let post_attention_layernorm = Qwen3_5RmsNorm::load(
|
||||||
|
&vb.pp("post_attention_layernorm"),
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
input_layernorm,
|
||||||
|
post_attention_layernorm,
|
||||||
|
mlp,
|
||||||
|
attention,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let h = self.input_layernorm.forward(x)?;
|
||||||
|
let attn_out = match &mut self.attention {
|
||||||
|
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
|
||||||
|
// Linear attention ignores attn_mask + offset; its causal
|
||||||
|
// structure is baked into the recurrent state lifecycle.
|
||||||
|
AttentionKind::Linear(net) => net.forward(&h)?,
|
||||||
|
};
|
||||||
|
let x = (x + attn_out)?;
|
||||||
|
let h2 = self.post_attention_layernorm.forward(&x)?;
|
||||||
|
let h2 = self.mlp.forward(&h2)?;
|
||||||
|
x + h2
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
match &mut self.attention {
|
||||||
|
AttentionKind::Full(attn) => attn.clear_kv_cache(),
|
||||||
|
AttentionKind::Linear(net) => net.clear_kv_cache(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
179
crates/neuron/src/harness/arch/qwen3_5/full_attn.rs
Normal file
179
crates/neuron/src/harness/arch/qwen3_5/full_attn.rs
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
//! Qwen3-Next's `full_attention` layer.
|
||||||
|
//!
|
||||||
|
//! Standard GQA causal attention with two Qwen3-Next-specific quirks:
|
||||||
|
//!
|
||||||
|
//! 1. **Output gate (`attn_output_gate=True`).** `q_proj` is widened
|
||||||
|
//! to `num_heads * head_dim * 2`. The second half is reshaped to
|
||||||
|
//! `(B, L, num_heads * head_dim)` and fed through a sigmoid; the
|
||||||
|
//! attention output is pointwise-multiplied by this gate before
|
||||||
|
//! `o_proj`. Effectively a per-head per-position attenuation on
|
||||||
|
//! the attention output.
|
||||||
|
//!
|
||||||
|
//! 2. **`(1 + w) * x` RmsNorm** on q and k (see `rmsnorm::Qwen3_5RmsNorm`).
|
||||||
|
//! candle_nn's RmsNorm applies `w * x`; the upstream Qwen3-Next
|
||||||
|
//! checkpoints expect the `(1 + w)` form.
|
||||||
|
//!
|
||||||
|
//! Otherwise: GQA with `num_attention_heads / num_key_value_heads`
|
||||||
|
//! repeat, q_norm + k_norm on the head dim, GLM-style rotary (see
|
||||||
|
//! `rope::RotaryEmbedding`), and the usual causal mask.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::kv_cache::ConcatKvCache;
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
use candle_transformers::utils::repeat_kv;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::TextConfig;
|
||||||
|
use super::rmsnorm::Qwen3_5RmsNorm;
|
||||||
|
use super::rope::RotaryEmbedding;
|
||||||
|
|
||||||
|
pub struct Qwen3_5Attention {
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
o_proj: Linear,
|
||||||
|
q_norm: Qwen3_5RmsNorm,
|
||||||
|
k_norm: Qwen3_5RmsNorm,
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
hidden_size: usize,
|
||||||
|
rotary: Arc<RotaryEmbedding>,
|
||||||
|
kv_cache: ConcatKvCache,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5Attention {
|
||||||
|
pub fn load(
|
||||||
|
cfg: &TextConfig,
|
||||||
|
rotary: Arc<RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let head_dim = cfg.head_dim;
|
||||||
|
let num_heads = cfg.num_attention_heads;
|
||||||
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
|
if num_kv_heads == 0 || !num_heads.is_multiple_of(num_kv_heads) {
|
||||||
|
anyhow::bail!(
|
||||||
|
"num_attention_heads ({num_heads}) must be a positive multiple of \
|
||||||
|
num_key_value_heads ({num_kv_heads})"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let num_kv_groups = num_heads / num_kv_heads;
|
||||||
|
|
||||||
|
// q_proj is 2x wide: the extra `num_heads * head_dim` slice is
|
||||||
|
// the gate (see attn_output_gate notes above).
|
||||||
|
let q_proj = load_linear_no_bias(vb, "q_proj", cfg.hidden_size, num_heads * head_dim * 2)?;
|
||||||
|
let k_proj = load_linear_no_bias(vb, "k_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
|
||||||
|
let v_proj = load_linear_no_bias(vb, "v_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
|
||||||
|
let o_proj = load_linear_no_bias(vb, "o_proj", num_heads * head_dim, cfg.hidden_size)?;
|
||||||
|
|
||||||
|
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
|
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
let hidden_size = head_dim * num_heads;
|
||||||
|
let kv_cache = ConcatKvCache::new(2);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
q_norm,
|
||||||
|
k_norm,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
num_kv_groups,
|
||||||
|
head_dim,
|
||||||
|
hidden_size,
|
||||||
|
rotary,
|
||||||
|
kv_cache,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let (b, l, _) = x.dims3()?;
|
||||||
|
|
||||||
|
// 1. q_proj — widened output, split into (query, gate).
|
||||||
|
let q_raw = self
|
||||||
|
.q_proj
|
||||||
|
.forward(x)?
|
||||||
|
.reshape((b, l, self.num_heads, self.head_dim * 2))?;
|
||||||
|
let q = q_raw.narrow(3, 0, self.head_dim)?;
|
||||||
|
let gate = q_raw.narrow(3, self.head_dim, self.head_dim)?;
|
||||||
|
// Flatten the gate's head dim back into hidden_size for the
|
||||||
|
// post-attention pointwise multiply.
|
||||||
|
let gate = gate
|
||||||
|
.contiguous()?
|
||||||
|
.reshape((b, l, self.num_heads * self.head_dim))?;
|
||||||
|
|
||||||
|
// 2. q_norm + k_norm + reshape to (B, H, L, D).
|
||||||
|
let q = self.q_norm.forward(&q.contiguous()?)?;
|
||||||
|
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D)
|
||||||
|
|
||||||
|
let k = self
|
||||||
|
.k_proj
|
||||||
|
.forward(x)?
|
||||||
|
.reshape((b, l, self.num_kv_heads, self.head_dim))?;
|
||||||
|
let k = self.k_norm.forward(&k.contiguous()?)?;
|
||||||
|
let k = k.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
|
let v = self
|
||||||
|
.v_proj
|
||||||
|
.forward(x)?
|
||||||
|
.reshape((b, l, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
|
||||||
|
// 3. RoPE on q, k.
|
||||||
|
let (q, k) = self.rotary.apply(&q, &k, offset)?;
|
||||||
|
|
||||||
|
// 4. KV cache.
|
||||||
|
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||||
|
|
||||||
|
// 5. GQA repeat (cheap shape op).
|
||||||
|
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
||||||
|
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
|
// 6. Scaled dot-product + causal mask.
|
||||||
|
let scale = 1.0_f64 / (self.head_dim as f64).sqrt();
|
||||||
|
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||||
|
if let Some(m) = attn_mask {
|
||||||
|
scores = scores.broadcast_add(m)?;
|
||||||
|
}
|
||||||
|
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||||
|
let ctx = probs.matmul(&v)?; // (B, H, L, D)
|
||||||
|
|
||||||
|
// 7. Reshape back, apply the output gate, project.
|
||||||
|
let ctx = ctx
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?
|
||||||
|
.reshape((b, l, self.hidden_size))?;
|
||||||
|
let gate_sig = candle_nn::ops::sigmoid(&gate)?;
|
||||||
|
let gated = (ctx * gate_sig)?;
|
||||||
|
self.o_proj.forward(&gated)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_linear_no_bias(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
name: &str,
|
||||||
|
in_dim: usize,
|
||||||
|
out_dim: usize,
|
||||||
|
) -> Result<Linear> {
|
||||||
|
let weight = vb
|
||||||
|
.pp(name)
|
||||||
|
.get((out_dim, in_dim), "weight")
|
||||||
|
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
|
||||||
|
Ok(Linear::new(weight, None))
|
||||||
|
}
|
||||||
53
crates/neuron/src/harness/arch/qwen3_5/mlp.rs
Normal file
53
crates/neuron/src/harness/arch/qwen3_5/mlp.rs
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
//! SwiGLU MLP block for Qwen3-Next.
|
||||||
|
//!
|
||||||
|
//! Identical to plain Qwen3's MLP: `down(silu(gate(x)) * up(x))` with
|
||||||
|
//! no bias on any of the three projections.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
|
||||||
|
use super::TextConfig;
|
||||||
|
|
||||||
|
pub struct Qwen3_5MLP {
|
||||||
|
gate_proj: Linear,
|
||||||
|
up_proj: Linear,
|
||||||
|
down_proj: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5MLP {
|
||||||
|
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let h = cfg.hidden_size;
|
||||||
|
let i = cfg.intermediate_size;
|
||||||
|
let gate_proj = load_linear_no_bias(vb, "gate_proj", h, i)?;
|
||||||
|
let up_proj = load_linear_no_bias(vb, "up_proj", h, i)?;
|
||||||
|
let down_proj = load_linear_no_bias(vb, "down_proj", i, h)?;
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj,
|
||||||
|
up_proj,
|
||||||
|
down_proj,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Qwen3_5MLP {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let lhs = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?;
|
||||||
|
let rhs = self.up_proj.forward(x)?;
|
||||||
|
self.down_proj.forward(&(lhs * rhs)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_linear_no_bias(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
name: &str,
|
||||||
|
in_dim: usize,
|
||||||
|
out_dim: usize,
|
||||||
|
) -> Result<Linear> {
|
||||||
|
let weight = vb
|
||||||
|
.pp(name)
|
||||||
|
.get((out_dim, in_dim), "weight")
|
||||||
|
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
|
||||||
|
Ok(Linear::new(weight, None))
|
||||||
|
}
|
||||||
@@ -11,66 +11,77 @@
|
|||||||
//!
|
//!
|
||||||
//! ## Status
|
//! ## Status
|
||||||
//!
|
//!
|
||||||
//! **Scaffold only.** `Config` deserialisation is real (so the dispatch
|
//! **Single-GPU dense path is real**. Both attention flavours
|
||||||
//! in `candle.rs::load_arch_dense` can route based on `model_type`
|
//! (`full_attention` with the output-gated GQA causal attention and
|
||||||
//! and the operator's diagnostic surfaces "qwen3_5" in the supported
|
//! `linear_attention` with the Gated DeltaNet recurrent block) are
|
||||||
//! set); the actual forward pass is `unimplemented!()`. Filling this
|
//! implemented. The model loads from upstream safetensors via the
|
||||||
//! in is the substantive Stage 8c work.
|
//! existing `load_arch_dense` dispatch and runs forward end to end.
|
||||||
//!
|
//!
|
||||||
//! ## What the architecture needs (open work)
|
//! Numerical correctness vs the reference Python is **not yet
|
||||||
|
//! validated** — the structural code path is right, weight tensor
|
||||||
|
//! names match the upstream layout, shapes flow through cleanly, but
|
||||||
|
//! the Tbilisi probe (and any other downstream test) is the next
|
||||||
|
//! step. Likely places a bug would surface:
|
||||||
|
//! - Per-rank vs per-token-position offsets in the recurrent delta
|
||||||
|
//! rule (`linear_attn.rs`).
|
||||||
|
//! - Off-by-one in the conv state continuation across decode steps.
|
||||||
|
//! - RoPE phase mismatch from MRoPE simplification (we treat the
|
||||||
|
//! three position grids as collapsed, which is correct only for
|
||||||
|
//! text-only inference).
|
||||||
//!
|
//!
|
||||||
//! Confirmed from `Qwen/Qwen3.6-27B/config.json`:
|
//! ## Submodules
|
||||||
//! - Real hyperparams nested under `text_config: {...}`. The
|
|
||||||
//! architecture is text-side; the multimodal vision tower is
|
|
||||||
//! separate (`image_token_id`, `language_model_only=false`).
|
|
||||||
//! - `hidden_size: 5120`, `head_dim: 256`, `intermediate_size: 17408`,
|
|
||||||
//! `num_attention_heads`, `num_key_value_heads`, etc. — bigger
|
|
||||||
//! head_dim than plain Qwen3.
|
|
||||||
//! - `attn_output_gate: true` — a sigmoid gate multiplied into the
|
|
||||||
//! attention output before the projection. ~10 LoC addition vs the
|
|
||||||
//! plain Qwen3 attention.
|
|
||||||
//! - `layer_types: ["linear_attention", "linear_attention",
|
|
||||||
//! "linear_attention", "full_attention", ...]` with
|
|
||||||
//! `full_attention_interval: 4` — every 4th layer is full
|
|
||||||
//! attention, the rest are linear-attention. The full-attention
|
|
||||||
//! layers shape like a Qwen3 attention; the linear-attention
|
|
||||||
//! layers are the hard part.
|
|
||||||
//!
|
//!
|
||||||
//! ## Linear-attention layer
|
//! - [`rmsnorm`] — `Qwen3_5RmsNorm` (`(1+w)*x` variant), the
|
||||||
|
//! `Qwen3_5RmsNormGated` used after the delta rule, and the
|
||||||
|
//! `l2norm` helper.
|
||||||
|
//! - [`rope`] — text-side rotary embedding (mrope simplified, GLM
|
||||||
|
//! rotate-half).
|
||||||
|
//! - [`mlp`] — SwiGLU MLP (gate/up/down, no bias).
|
||||||
|
//! - [`full_attn`] — `Qwen3_5Attention` with the output-gate
|
||||||
|
//! widening on `q_proj`.
|
||||||
|
//! - [`linear_attn`] — `GatedDeltaNet` recurrent delta-rule block
|
||||||
|
//! (causal depthwise Conv1d → silu → split → L2norm → per-token
|
||||||
|
//! delta rule → RMSNormGated → out_proj).
|
||||||
|
//! - [`decoder`] — `Qwen3_5DecoderLayer` dispatching to one of the
|
||||||
|
//! two attention flavours per layer index.
|
||||||
//!
|
//!
|
||||||
//! Candle has nothing we can reuse — has to be written against the
|
//! ## Open work
|
||||||
//! reference Python in the Qwen3-Next HF repo. Likely Lightning
|
|
||||||
//! Attention-2 (state-space-ish recurrence) given the
|
|
||||||
//! `linear_attention` tag and Qwen3's prior `qwen3-omni` work. Needs:
|
|
||||||
//! - A persistent recurrent state per layer (replaces the explicit
|
|
||||||
//! KV cache for full attention).
|
|
||||||
//! - Per-token update + readout primitives, fused if possible.
|
|
||||||
//! - Numerical-correctness validation against the Python reference
|
|
||||||
//! on a fixed prompt before trusting any output downstream.
|
|
||||||
//!
|
//!
|
||||||
//! ## TP-2 (the immediate motivator)
|
//! - **TP variant.** `harness/tp/tp_qwen3_5.rs` is the next step.
|
||||||
|
//! Sharding strategy diverges by layer type:
|
||||||
|
//! - Full-attention layers: column-parallel q/k/v (including the
|
||||||
|
//! gate half of `q_proj`) + row-parallel `o_proj`, mirroring
|
||||||
|
//! `tp_qwen3.rs`.
|
||||||
|
//! - Linear-attention layers: the recurrent state is per-V-head, so
|
||||||
|
//! V-head-dimension sharding works cleanly — split `num_v_heads`
|
||||||
|
//! across ranks (`num_v_heads / world_size` per rank), shard
|
||||||
|
//! `in_proj_qkv` / `in_proj_z` / `in_proj_b` / `in_proj_a` along
|
||||||
|
//! the V-head dim, and row-parallel `out_proj`. The `A_log` /
|
||||||
|
//! `dt_bias` per-head params shard with the heads.
|
||||||
//!
|
//!
|
||||||
//! Beast's 2x RTX 5090 needs tensor-parallel to fit Qwen3.6-27B.
|
//! - **Chunked delta-rule prefill.** `linear_attn.rs` runs the
|
||||||
//! TP-aware analogue lives at `harness/tp/tp_qwen3_5.rs` (not yet
|
//! per-token recurrent path for prefill too — correct but O(L).
|
||||||
//! created — added alongside the dense impl). Sharding strategy
|
//! Porting `torch_chunk_gated_delta_rule` (chunk_size=64) speeds
|
||||||
//! diverges by layer type:
|
//! prefill substantially with no surface change.
|
||||||
//! - Full-attention layers: column-parallel q/k/v + row-parallel o,
|
|
||||||
//! same as `tp_qwen3.rs`. With `attn_output_gate`, the gate weight
|
|
||||||
//! is also column-parallel (one gate scalar per head).
|
|
||||||
//! - Linear-attention layers: the recurrent state is per-token, not
|
|
||||||
//! per-head, so head-dim sharding doesn't apply. Options are
|
|
||||||
//! (a) replicate the linear-attention layers across ranks (cheap
|
|
||||||
//! but wastes ~half the per-rank VRAM since 3 of every 4 layers
|
|
||||||
//! replicate), or (b) shard along the recurrent-state dimension
|
|
||||||
//! if the formulation allows. Decision deferred until the linear
|
|
||||||
//! attention is actually implemented and profiled.
|
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::{Context, Result};
|
||||||
use candle_core::Tensor;
|
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
||||||
|
use candle_nn::Embedding;
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub mod decoder;
|
||||||
|
pub mod full_attn;
|
||||||
pub mod linear_attn;
|
pub mod linear_attn;
|
||||||
|
pub mod mlp;
|
||||||
pub mod rmsnorm;
|
pub mod rmsnorm;
|
||||||
|
pub mod rope;
|
||||||
|
|
||||||
|
use decoder::Qwen3_5DecoderLayer;
|
||||||
|
use rmsnorm::Qwen3_5RmsNorm;
|
||||||
|
use rope::RotaryEmbedding;
|
||||||
|
|
||||||
/// `model_type` we deserialise from `config.json`. Const so the
|
/// `model_type` we deserialise from `config.json`. Const so the
|
||||||
/// dispatch in `candle.rs::load_arch_dense` can pattern-match without
|
/// dispatch in `candle.rs::load_arch_dense` can pattern-match without
|
||||||
@@ -159,41 +170,131 @@ fn default_hidden_act() -> String {
|
|||||||
"silu".into()
|
"silu".into()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stub model. Fields are intentionally empty — filling in the
|
/// Qwen3-Next base transformer (embedding + decoder stack + final
|
||||||
/// concrete architecture is the substantive Stage 8c work. The struct
|
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
|
||||||
/// exists so the `ModelArch::Qwen3_5Dense(_)` variant has a payload
|
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the
|
||||||
/// and dispatch wiring compiles end-to-end.
|
/// loaded handle.
|
||||||
///
|
pub struct Qwen3_5Model {
|
||||||
/// To extend: add embed_tokens, decoder layers, final norm, and
|
embed_tokens: Embedding,
|
||||||
/// lm_head fields here; implement `new`, `forward`, `clear_kv_cache`
|
layers: Vec<Qwen3_5DecoderLayer>,
|
||||||
/// in terms of them. Mirror the layout of `qwen3_dense::ModelForCausalLM`
|
norm: Qwen3_5RmsNorm,
|
||||||
/// (in candle-transformers) as a starting point.
|
device: Device,
|
||||||
pub struct Qwen3_5ForCausalLM {
|
dtype: DType,
|
||||||
#[allow(dead_code)]
|
|
||||||
config: Config,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Qwen3_5ForCausalLM {
|
impl Qwen3_5Model {
|
||||||
pub fn new(config: Config, _vb: candle_nn::VarBuilder) -> Result<Self> {
|
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||||
// TODO(stage-8c): build embed_tokens, decoder layers (dispatching
|
let dtype = vb.dtype();
|
||||||
// on layer_types), final RmsNorm, lm_head from the VarBuilder.
|
let device = vb.device().clone();
|
||||||
// For now we accept the construction so the load path can be
|
|
||||||
// exercised end-to-end (config parse + safetensors mmap), and
|
let embed_vb = vb.pp("model.embed_tokens");
|
||||||
// bail at forward time with a clear marker.
|
let embed_weight = embed_vb
|
||||||
Ok(Self { config })
|
.get((cfg.vocab_size, cfg.hidden_size), "weight")
|
||||||
|
.with_context(|| format!("load '{}/weight'", embed_vb.prefix()))?;
|
||||||
|
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
|
||||||
|
|
||||||
|
let rotary = Arc::new(RotaryEmbedding::new(dtype, cfg, &device)?);
|
||||||
|
|
||||||
|
if cfg.layer_types.len() != cfg.num_hidden_layers {
|
||||||
|
anyhow::bail!(
|
||||||
|
"config.text_config.layer_types must have num_hidden_layers ({}) entries; \
|
||||||
|
got {}",
|
||||||
|
cfg.num_hidden_layers,
|
||||||
|
cfg.layer_types.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, _input: &Tensor, _offset: usize) -> Result<Tensor> {
|
let vb_l = vb.pp("model.layers");
|
||||||
anyhow::bail!(
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
"Qwen3-Next ({}) forward not implemented yet (Stage 8c, TP-2 motivator)",
|
for i in 0..cfg.num_hidden_layers {
|
||||||
self.config.model_type
|
layers.push(Qwen3_5DecoderLayer::load(
|
||||||
)
|
cfg,
|
||||||
|
rotary.clone(),
|
||||||
|
i,
|
||||||
|
&vb_l.pp(i),
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
let norm = Qwen3_5RmsNorm::load(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_weight(&self) -> &Tensor {
|
||||||
|
self.embed_tokens.embeddings()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn clear_kv_cache(&mut self) {
|
pub fn clear_kv_cache(&mut self) {
|
||||||
// No-op for the stub. The real impl needs a `clear_kv_cache`
|
for l in &mut self.layers {
|
||||||
// that resets the per-layer KV cache (full-attention layers)
|
l.clear_kv_cache();
|
||||||
// and the per-layer recurrent state (linear-attention layers).
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let minf = f32::NEG_INFINITY;
|
||||||
|
let mask: Vec<_> = (0..tgt)
|
||||||
|
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let (b, l) = input.dims2()?;
|
||||||
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
|
// Causal mask only needed for L > 1 prefill; full-attention
|
||||||
|
// layers consume it via broadcast_add. Linear-attention layers
|
||||||
|
// ignore the mask.
|
||||||
|
let causal = if l == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.causal_mask(b, l, offset)?)
|
||||||
|
};
|
||||||
|
for layer in &mut self.layers {
|
||||||
|
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||||
|
}
|
||||||
|
self.norm.forward(&h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Qwen3_5ForCausalLM {
|
||||||
|
base: Qwen3_5Model,
|
||||||
|
lm_head: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5ForCausalLM {
|
||||||
|
pub fn new(config: Config, vb: ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let cfg = &config.text_config;
|
||||||
|
let base = Qwen3_5Model::load(cfg, &vb)?;
|
||||||
|
let lm_head = if cfg.tie_word_embeddings {
|
||||||
|
Linear::new(base.embed_weight().clone(), None)
|
||||||
|
} else {
|
||||||
|
let weight = vb
|
||||||
|
.pp("lm_head")
|
||||||
|
.get((cfg.vocab_size, cfg.hidden_size), "weight")
|
||||||
|
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
|
||||||
|
Linear::new(weight, None)
|
||||||
|
};
|
||||||
|
Ok(Self { base, lm_head })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
|
||||||
|
/// the last position, shape `(B, 1, vocab_size)` — same contract
|
||||||
|
/// as `qwen3::ModelForCausalLM::forward` so the harness's
|
||||||
|
/// `squeeze_to_vocab` helper handles both uniformly.
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let (_, l) = input.dims2()?;
|
||||||
|
let hidden = self.base.forward(input, offset)?;
|
||||||
|
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.base.clear_kv_cache();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
67
crates/neuron/src/harness/arch/qwen3_5/rope.rs
Normal file
67
crates/neuron/src/harness/arch/qwen3_5/rope.rs
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
//! Rotary position embedding for Qwen3-Next's full-attention layers.
|
||||||
|
//!
|
||||||
|
//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the
|
||||||
|
//! reference Python — three position grids interleaved per
|
||||||
|
//! `mrope_section`. For text-only inference all three grids carry the
|
||||||
|
//! same position ids and the interleave is a no-op, so this module
|
||||||
|
//! implements the plain (non-mrope) flavour: the standard inv_freq
|
||||||
|
//! cosine/sine tables driven by `rope_theta` and `head_dim`.
|
||||||
|
//!
|
||||||
|
//! Rotation flavour: **GLM-style** rotate-half (the second half of the
|
||||||
|
//! head dim is negated and swapped into the first). The reference
|
||||||
|
//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's
|
||||||
|
//! `rope_slow` is the matching helper.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
|
||||||
|
use super::TextConfig;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
|
||||||
|
let dim = cfg.head_dim;
|
||||||
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
|
let inv_freq: Vec<f32> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||||
|
.collect();
|
||||||
|
let n = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
Ok(Self {
|
||||||
|
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||||
|
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply RoPE to q, k.
|
||||||
|
///
|
||||||
|
/// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index
|
||||||
|
/// into the cached cos/sin table — the position of the first token
|
||||||
|
/// in the current step. `candle_nn::rotary_emb::rope_slow` does
|
||||||
|
/// the GLM-style `x*cos + rotate_half(x)*sin` rotation and
|
||||||
|
/// internally `cat`s cos/sin with themselves along the last dim,
|
||||||
|
/// so we hand it the `(seq_len, head_dim/2)` slice it expects.
|
||||||
|
pub fn apply(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let (_, _, seq_len, _) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, offset, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, offset, seq_len)?;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -617,12 +617,22 @@ impl CandleHarness {
|
|||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
"qwen3_5" => {
|
"qwen3_5" => {
|
||||||
// Stage 8c scaffold: config parses, model
|
// Qwen3-Next needs a ShardedVarBuilder because its
|
||||||
// constructs, but forward bails. See
|
// load functions use the sharded backend (so they
|
||||||
// `arch/qwen3_5.rs` for the open architecture work.
|
// can be reused unchanged by the future TP variant).
|
||||||
|
// With world_size=1 the backend falls through to
|
||||||
|
// the unsharded path, so there is no per-load cost.
|
||||||
let cfg: super::arch::qwen3_5::Config = serde_json::from_str(&cfg_text)
|
let cfg: super::arch::qwen3_5::Config = serde_json::from_str(&cfg_text)
|
||||||
.context("parse Qwen3-Next (qwen3_5) config.json")?;
|
.context("parse Qwen3-Next (qwen3_5) config.json")?;
|
||||||
let model = super::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, vb)
|
let sharded_vb = unsafe {
|
||||||
|
candle_nn::var_builder::ShardedSafeTensors::var_builder(
|
||||||
|
&safetensors_paths,
|
||||||
|
dtype,
|
||||||
|
&device_for_load,
|
||||||
|
)
|
||||||
|
.context("build ShardedVarBuilder for Qwen3-Next")?
|
||||||
|
};
|
||||||
|
let model = super::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, sharded_vb)
|
||||||
.context("build Qwen3-Next dense model")?;
|
.context("build Qwen3-Next dense model")?;
|
||||||
Ok(ModelArch::Qwen3_5Dense(model))
|
Ok(ModelArch::Qwen3_5Dense(model))
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user