feat(stage-8c): full-attention layer + decoder + Model + ForCausalLM for qwen3_5
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 39s
CI / Clippy (push) Successful in 2m19s
CI / Test (push) Successful in 4m50s
build-prerelease / Build cortex binary (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m41s
build-prerelease / Package cortex RPM (push) Successful in 1m27s
build-prerelease / Build neuron-ampere (push) Successful in 4m58s
build-prerelease / Build neuron-ada (push) Successful in 5m8s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 58s

Completes the single-GPU dense path for Qwen3-Next (Qwen3.6's
architecture). The four new modules wrap the substantive
`linear_attn.rs` (landed previously) with the rest of the
transformer:

- `arch/qwen3_5/rope.rs` — text-side rotary embedding. MRoPE is
  simplified to plain RoPE (the three position grids collapse to one
  for text-only inference); uses candle's `rope_slow` for the
  GLM-style rotate-half rotation.
- `arch/qwen3_5/mlp.rs` — Qwen3_5MLP (SwiGLU: gate/up/down, bias=False).
- `arch/qwen3_5/full_attn.rs` — Qwen3_5Attention with the two
  Qwen3-Next quirks:
  - `q_proj` widened to `2 * num_heads * head_dim`; second half
    sigmoid'd and multiplied into the attention output before `o_proj`.
  - q_norm/k_norm use the `(1+w)*x` RmsNorm variant.
- `arch/qwen3_5/decoder.rs` — Qwen3_5DecoderLayer dispatching on
  `layer_types[i]` to either Full attention or GatedDeltaNet.

`arch/qwen3_5/mod.rs` gets the real `Qwen3_5Model` (embedding + layer
stack + final norm) and `Qwen3_5ForCausalLM` (model + lm_head). The
forward returns `[B, 1, vocab]` to match `qwen3_dense`; the harness's
`squeeze_to_vocab` handles either shape.

Switch: `candle.rs::load_arch_dense` for `model_type=qwen3_5` now
builds a `ShardedVarBuilder` instead of a plain VarBuilder. The
sharded backend falls through to the unsharded path when
`world_size=1`, so single-GPU load is zero-cost; this lets the
forthcoming `tp_qwen3_5.rs` reuse the same load functions without a
second copy.

Verified: cargo build CPU + --features cuda inside the patched
container; clippy clean on both; 32 lib tests still pass. The
ForCausalLM forward no longer bails — but numerical correctness vs
the Python reference hasn't been validated yet (that's the next
step, with the Tbilisi probe).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-20 15:52:33 +03:00
parent 180274548d
commit e7eb3dab6a
6 changed files with 608 additions and 81 deletions

View File

@@ -0,0 +1,117 @@
//! Qwen3-Next decoder layer.
//!
//! Standard pre-norm transformer block (LN → attention → residual →
//! LN → MLP → residual) where the attention slot dispatches on the
//! per-layer `layer_types[i]` value in the config:
//!
//! - `"full_attention"` → [`Qwen3_5Attention`] (GQA causal + output
//! gate + RoPE + KV cache).
//! - `"linear_attention"` → [`GatedDeltaNet`] (recurrent delta rule +
//! causal conv + per-head state).
//!
//! In Qwen3.6-27B every 4th layer is full_attention; the rest are
//! linear_attention. `full_attention_interval` in the config is a
//! hint; `layer_types` is authoritative.
use anyhow::Result;
use candle_core::{Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
use std::sync::Arc;
use super::TextConfig;
use super::full_attn::Qwen3_5Attention;
use super::linear_attn::GatedDeltaNet;
use super::mlp::Qwen3_5MLP;
use super::rmsnorm::Qwen3_5RmsNorm;
use super::rope::RotaryEmbedding;
/// One of the two attention flavours sitting in a decoder layer's
/// attention slot. Full-attention layers need the rotary table and
/// take an attention mask; linear-attention layers carry their own
/// recurrent state and ignore the mask.
enum AttentionKind {
Full(Qwen3_5Attention),
Linear(GatedDeltaNet),
}
pub struct Qwen3_5DecoderLayer {
input_layernorm: Qwen3_5RmsNorm,
post_attention_layernorm: Qwen3_5RmsNorm,
mlp: Qwen3_5MLP,
attention: AttentionKind,
}
impl Qwen3_5DecoderLayer {
pub fn load(
cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>,
layer_idx: usize,
vb: &ShardedVarBuilder,
) -> Result<Self> {
let layer_type = cfg
.layer_types
.get(layer_idx)
.map(String::as_str)
.ok_or_else(|| {
anyhow::anyhow!(
"layer_types[{layer_idx}] missing (have {} entries)",
cfg.layer_types.len()
)
})?;
let attention = match layer_type {
"full_attention" => {
AttentionKind::Full(Qwen3_5Attention::load(cfg, rotary, &vb.pp("self_attn"))?)
}
"linear_attention" => {
AttentionKind::Linear(GatedDeltaNet::load(cfg, &vb.pp("linear_attn"))?)
}
other => anyhow::bail!(
"unknown layer_type '{other}' for layer {layer_idx} (expected \
'full_attention' or 'linear_attention')"
),
};
let mlp = Qwen3_5MLP::load(cfg, &vb.pp("mlp"))?;
let input_layernorm =
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
let post_attention_layernorm = Qwen3_5RmsNorm::load(
&vb.pp("post_attention_layernorm"),
cfg.hidden_size,
cfg.rms_norm_eps,
)?;
Ok(Self {
input_layernorm,
post_attention_layernorm,
mlp,
attention,
})
}
pub fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let h = self.input_layernorm.forward(x)?;
let attn_out = match &mut self.attention {
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
// Linear attention ignores attn_mask + offset; its causal
// structure is baked into the recurrent state lifecycle.
AttentionKind::Linear(net) => net.forward(&h)?,
};
let x = (x + attn_out)?;
let h2 = self.post_attention_layernorm.forward(&x)?;
let h2 = self.mlp.forward(&h2)?;
x + h2
}
pub fn clear_kv_cache(&mut self) {
match &mut self.attention {
AttentionKind::Full(attn) => attn.clear_kv_cache(),
AttentionKind::Linear(net) => net.clear_kv_cache(),
}
}
}

View File

@@ -0,0 +1,179 @@
//! Qwen3-Next's `full_attention` layer.
//!
//! Standard GQA causal attention with two Qwen3-Next-specific quirks:
//!
//! 1. **Output gate (`attn_output_gate=True`).** `q_proj` is widened
//! to `num_heads * head_dim * 2`. The second half is reshaped to
//! `(B, L, num_heads * head_dim)` and fed through a sigmoid; the
//! attention output is pointwise-multiplied by this gate before
//! `o_proj`. Effectively a per-head per-position attenuation on
//! the attention output.
//!
//! 2. **`(1 + w) * x` RmsNorm** on q and k (see `rmsnorm::Qwen3_5RmsNorm`).
//! candle_nn's RmsNorm applies `w * x`; the upstream Qwen3-Next
//! checkpoints expect the `(1 + w)` form.
//!
//! Otherwise: GQA with `num_attention_heads / num_key_value_heads`
//! repeat, q_norm + k_norm on the head dim, GLM-style rotary (see
//! `rope::RotaryEmbedding`), and the usual causal mask.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::kv_cache::ConcatKvCache;
use candle_nn::var_builder::ShardedVarBuilder;
use candle_transformers::utils::repeat_kv;
use std::sync::Arc;
use super::TextConfig;
use super::rmsnorm::Qwen3_5RmsNorm;
use super::rope::RotaryEmbedding;
pub struct Qwen3_5Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
q_norm: Qwen3_5RmsNorm,
k_norm: Qwen3_5RmsNorm,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary: Arc<RotaryEmbedding>,
kv_cache: ConcatKvCache,
}
impl Qwen3_5Attention {
pub fn load(
cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>,
vb: &ShardedVarBuilder,
) -> Result<Self> {
let head_dim = cfg.head_dim;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
if num_kv_heads == 0 || !num_heads.is_multiple_of(num_kv_heads) {
anyhow::bail!(
"num_attention_heads ({num_heads}) must be a positive multiple of \
num_key_value_heads ({num_kv_heads})"
);
}
let num_kv_groups = num_heads / num_kv_heads;
// q_proj is 2x wide: the extra `num_heads * head_dim` slice is
// the gate (see attn_output_gate notes above).
let q_proj = load_linear_no_bias(vb, "q_proj", cfg.hidden_size, num_heads * head_dim * 2)?;
let k_proj = load_linear_no_bias(vb, "k_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
let v_proj = load_linear_no_bias(vb, "v_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
let o_proj = load_linear_no_bias(vb, "o_proj", num_heads * head_dim, cfg.hidden_size)?;
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
let hidden_size = head_dim * num_heads;
let kv_cache = ConcatKvCache::new(2);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size,
rotary,
kv_cache,
})
}
pub fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let (b, l, _) = x.dims3()?;
// 1. q_proj — widened output, split into (query, gate).
let q_raw = self
.q_proj
.forward(x)?
.reshape((b, l, self.num_heads, self.head_dim * 2))?;
let q = q_raw.narrow(3, 0, self.head_dim)?;
let gate = q_raw.narrow(3, self.head_dim, self.head_dim)?;
// Flatten the gate's head dim back into hidden_size for the
// post-attention pointwise multiply.
let gate = gate
.contiguous()?
.reshape((b, l, self.num_heads * self.head_dim))?;
// 2. q_norm + k_norm + reshape to (B, H, L, D).
let q = self.q_norm.forward(&q.contiguous()?)?;
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D)
let k = self
.k_proj
.forward(x)?
.reshape((b, l, self.num_kv_heads, self.head_dim))?;
let k = self.k_norm.forward(&k.contiguous()?)?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = self
.v_proj
.forward(x)?
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
// 3. RoPE on q, k.
let (q, k) = self.rotary.apply(&q, &k, offset)?;
// 4. KV cache.
let (k, v) = self.kv_cache.append(&k, &v)?;
// 5. GQA repeat (cheap shape op).
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
// 6. Scaled dot-product + causal mask.
let scale = 1.0_f64 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(m) = attn_mask {
scores = scores.broadcast_add(m)?;
}
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
let ctx = probs.matmul(&v)?; // (B, H, L, D)
// 7. Reshape back, apply the output gate, project.
let ctx = ctx
.transpose(1, 2)?
.contiguous()?
.reshape((b, l, self.hidden_size))?;
let gate_sig = candle_nn::ops::sigmoid(&gate)?;
let gated = (ctx * gate_sig)?;
self.o_proj.forward(&gated)
}
pub fn clear_kv_cache(&mut self) {
self.kv_cache.reset();
}
}
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}

View File

@@ -0,0 +1,53 @@
//! SwiGLU MLP block for Qwen3-Next.
//!
//! Identical to plain Qwen3's MLP: `down(silu(gate(x)) * up(x))` with
//! no bias on any of the three projections.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
use super::TextConfig;
pub struct Qwen3_5MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
}
impl Qwen3_5MLP {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let i = cfg.intermediate_size;
let gate_proj = load_linear_no_bias(vb, "gate_proj", h, i)?;
let up_proj = load_linear_no_bias(vb, "up_proj", h, i)?;
let down_proj = load_linear_no_bias(vb, "down_proj", i, h)?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
})
}
}
impl Module for Qwen3_5MLP {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let lhs = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?;
let rhs = self.up_proj.forward(x)?;
self.down_proj.forward(&(lhs * rhs)?)
}
}
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}

View File

@@ -11,66 +11,77 @@
//!
//! ## Status
//!
//! **Scaffold only.** `Config` deserialisation is real (so the dispatch
//! in `candle.rs::load_arch_dense` can route based on `model_type`
//! and the operator's diagnostic surfaces "qwen3_5" in the supported
//! set); the actual forward pass is `unimplemented!()`. Filling this
//! in is the substantive Stage 8c work.
//! **Single-GPU dense path is real**. Both attention flavours
//! (`full_attention` with the output-gated GQA causal attention and
//! `linear_attention` with the Gated DeltaNet recurrent block) are
//! implemented. The model loads from upstream safetensors via the
//! existing `load_arch_dense` dispatch and runs forward end to end.
//!
//! ## What the architecture needs (open work)
//! Numerical correctness vs the reference Python is **not yet
//! validated** — the structural code path is right, weight tensor
//! names match the upstream layout, shapes flow through cleanly, but
//! the Tbilisi probe (and any other downstream test) is the next
//! step. Likely places a bug would surface:
//! - Per-rank vs per-token-position offsets in the recurrent delta
//! rule (`linear_attn.rs`).
//! - Off-by-one in the conv state continuation across decode steps.
//! - RoPE phase mismatch from MRoPE simplification (we treat the
//! three position grids as collapsed, which is correct only for
//! text-only inference).
//!
//! Confirmed from `Qwen/Qwen3.6-27B/config.json`:
//! - Real hyperparams nested under `text_config: {...}`. The
//! architecture is text-side; the multimodal vision tower is
//! separate (`image_token_id`, `language_model_only=false`).
//! - `hidden_size: 5120`, `head_dim: 256`, `intermediate_size: 17408`,
//! `num_attention_heads`, `num_key_value_heads`, etc. — bigger
//! head_dim than plain Qwen3.
//! - `attn_output_gate: true` — a sigmoid gate multiplied into the
//! attention output before the projection. ~10 LoC addition vs the
//! plain Qwen3 attention.
//! - `layer_types: ["linear_attention", "linear_attention",
//! "linear_attention", "full_attention", ...]` with
//! `full_attention_interval: 4` — every 4th layer is full
//! attention, the rest are linear-attention. The full-attention
//! layers shape like a Qwen3 attention; the linear-attention
//! layers are the hard part.
//! ## Submodules
//!
//! ## Linear-attention layer
//! - [`rmsnorm`] — `Qwen3_5RmsNorm` (`(1+w)*x` variant), the
//! `Qwen3_5RmsNormGated` used after the delta rule, and the
//! `l2norm` helper.
//! - [`rope`] — text-side rotary embedding (mrope simplified, GLM
//! rotate-half).
//! - [`mlp`] — SwiGLU MLP (gate/up/down, no bias).
//! - [`full_attn`] — `Qwen3_5Attention` with the output-gate
//! widening on `q_proj`.
//! - [`linear_attn`] — `GatedDeltaNet` recurrent delta-rule block
//! (causal depthwise Conv1d → silu → split → L2norm → per-token
//! delta rule → RMSNormGated → out_proj).
//! - [`decoder`] — `Qwen3_5DecoderLayer` dispatching to one of the
//! two attention flavours per layer index.
//!
//! Candle has nothing we can reuse — has to be written against the
//! reference Python in the Qwen3-Next HF repo. Likely Lightning
//! Attention-2 (state-space-ish recurrence) given the
//! `linear_attention` tag and Qwen3's prior `qwen3-omni` work. Needs:
//! - A persistent recurrent state per layer (replaces the explicit
//! KV cache for full attention).
//! - Per-token update + readout primitives, fused if possible.
//! - Numerical-correctness validation against the Python reference
//! on a fixed prompt before trusting any output downstream.
//! ## Open work
//!
//! ## TP-2 (the immediate motivator)
//! - **TP variant.** `harness/tp/tp_qwen3_5.rs` is the next step.
//! Sharding strategy diverges by layer type:
//! - Full-attention layers: column-parallel q/k/v (including the
//! gate half of `q_proj`) + row-parallel `o_proj`, mirroring
//! `tp_qwen3.rs`.
//! - Linear-attention layers: the recurrent state is per-V-head, so
//! V-head-dimension sharding works cleanly — split `num_v_heads`
//! across ranks (`num_v_heads / world_size` per rank), shard
//! `in_proj_qkv` / `in_proj_z` / `in_proj_b` / `in_proj_a` along
//! the V-head dim, and row-parallel `out_proj`. The `A_log` /
//! `dt_bias` per-head params shard with the heads.
//!
//! Beast's 2x RTX 5090 needs tensor-parallel to fit Qwen3.6-27B.
//! TP-aware analogue lives at `harness/tp/tp_qwen3_5.rs` (not yet
//! created — added alongside the dense impl). Sharding strategy
//! diverges by layer type:
//! - Full-attention layers: column-parallel q/k/v + row-parallel o,
//! same as `tp_qwen3.rs`. With `attn_output_gate`, the gate weight
//! is also column-parallel (one gate scalar per head).
//! - Linear-attention layers: the recurrent state is per-token, not
//! per-head, so head-dim sharding doesn't apply. Options are
//! (a) replicate the linear-attention layers across ranks (cheap
//! but wastes ~half the per-rank VRAM since 3 of every 4 layers
//! replicate), or (b) shard along the recurrent-state dimension
//! if the formulation allows. Decision deferred until the linear
//! attention is actually implemented and profiled.
//! - **Chunked delta-rule prefill.** `linear_attn.rs` runs the
//! per-token recurrent path for prefill too — correct but O(L).
//! Porting `torch_chunk_gated_delta_rule` (chunk_size=64) speeds
//! prefill substantially with no surface change.
use anyhow::Result;
use candle_core::Tensor;
use anyhow::{Context, Result};
use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::Embedding;
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
use serde::Deserialize;
use std::sync::Arc;
pub mod decoder;
pub mod full_attn;
pub mod linear_attn;
pub mod mlp;
pub mod rmsnorm;
pub mod rope;
use decoder::Qwen3_5DecoderLayer;
use rmsnorm::Qwen3_5RmsNorm;
use rope::RotaryEmbedding;
/// `model_type` we deserialise from `config.json`. Const so the
/// dispatch in `candle.rs::load_arch_dense` can pattern-match without
@@ -159,41 +170,131 @@ fn default_hidden_act() -> String {
"silu".into()
}
/// Stub model. Fields are intentionally empty — filling in the
/// concrete architecture is the substantive Stage 8c work. The struct
/// exists so the `ModelArch::Qwen3_5Dense(_)` variant has a payload
/// and dispatch wiring compiles end-to-end.
///
/// To extend: add embed_tokens, decoder layers, final norm, and
/// lm_head fields here; implement `new`, `forward`, `clear_kv_cache`
/// in terms of them. Mirror the layout of `qwen3_dense::ModelForCausalLM`
/// (in candle-transformers) as a starting point.
pub struct Qwen3_5ForCausalLM {
#[allow(dead_code)]
config: Config,
/// Qwen3-Next base transformer (embedding + decoder stack + final
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the
/// loaded handle.
pub struct Qwen3_5Model {
embed_tokens: Embedding,
layers: Vec<Qwen3_5DecoderLayer>,
norm: Qwen3_5RmsNorm,
device: Device,
dtype: DType,
}
impl Qwen3_5ForCausalLM {
pub fn new(config: Config, _vb: candle_nn::VarBuilder) -> Result<Self> {
// TODO(stage-8c): build embed_tokens, decoder layers (dispatching
// on layer_types), final RmsNorm, lm_head from the VarBuilder.
// For now we accept the construction so the load path can be
// exercised end-to-end (config parse + safetensors mmap), and
// bail at forward time with a clear marker.
Ok(Self { config })
}
impl Qwen3_5Model {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let dtype = vb.dtype();
let device = vb.device().clone();
pub fn forward(&mut self, _input: &Tensor, _offset: usize) -> Result<Tensor> {
let embed_vb = vb.pp("model.embed_tokens");
let embed_weight = embed_vb
.get((cfg.vocab_size, cfg.hidden_size), "weight")
.with_context(|| format!("load '{}/weight'", embed_vb.prefix()))?;
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
let rotary = Arc::new(RotaryEmbedding::new(dtype, cfg, &device)?);
if cfg.layer_types.len() != cfg.num_hidden_layers {
anyhow::bail!(
"Qwen3-Next ({}) forward not implemented yet (Stage 8c, TP-2 motivator)",
self.config.model_type
)
"config.text_config.layer_types must have num_hidden_layers ({}) entries; \
got {}",
cfg.num_hidden_layers,
cfg.layer_types.len()
);
}
let vb_l = vb.pp("model.layers");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
layers.push(Qwen3_5DecoderLayer::load(
cfg,
rotary.clone(),
i,
&vb_l.pp(i),
)?);
}
let norm = Qwen3_5RmsNorm::load(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
Ok(Self {
embed_tokens,
layers,
norm,
device,
dtype,
})
}
pub fn embed_weight(&self) -> &Tensor {
self.embed_tokens.embeddings()
}
pub fn clear_kv_cache(&mut self) {
// No-op for the stub. The real impl needs a `clear_kv_cache`
// that resets the per-layer KV cache (full-attention layers)
// and the per-layer recurrent state (linear-attention layers).
for l in &mut self.layers {
l.clear_kv_cache();
}
}
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
// Causal mask only needed for L > 1 prefill; full-attention
// layers consume it via broadcast_add. Linear-attention layers
// ignore the mask.
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
}
self.norm.forward(&h)
}
}
pub struct Qwen3_5ForCausalLM {
base: Qwen3_5Model,
lm_head: Linear,
}
impl Qwen3_5ForCausalLM {
pub fn new(config: Config, vb: ShardedVarBuilder) -> Result<Self> {
let cfg = &config.text_config;
let base = Qwen3_5Model::load(cfg, &vb)?;
let lm_head = if cfg.tie_word_embeddings {
Linear::new(base.embed_weight().clone(), None)
} else {
let weight = vb
.pp("lm_head")
.get((cfg.vocab_size, cfg.hidden_size), "weight")
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
Linear::new(weight, None)
};
Ok(Self { base, lm_head })
}
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
/// the last position, shape `(B, 1, vocab_size)` — same contract
/// as `qwen3::ModelForCausalLM::forward` so the harness's
/// `squeeze_to_vocab` helper handles both uniformly.
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?;
let hidden = self.base.forward(input, offset)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache();
}
}

View File

@@ -0,0 +1,67 @@
//! Rotary position embedding for Qwen3-Next's full-attention layers.
//!
//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the
//! reference Python — three position grids interleaved per
//! `mrope_section`. For text-only inference all three grids carry the
//! same position ids and the interleave is a no-op, so this module
//! implements the plain (non-mrope) flavour: the standard inv_freq
//! cosine/sine tables driven by `rope_theta` and `head_dim`.
//!
//! Rotation flavour: **GLM-style** rotate-half (the second half of the
//! head dim is negated and swapped into the first). The reference
//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's
//! `rope_slow` is the matching helper.
use anyhow::Result;
use candle_core::{DType, Device, Tensor};
use super::TextConfig;
#[derive(Debug, Clone)]
pub struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<f32> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let n = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?.to_dtype(dtype)?,
cos: freqs.cos()?.to_dtype(dtype)?,
})
}
/// Apply RoPE to q, k.
///
/// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index
/// into the cached cos/sin table — the position of the first token
/// in the current step. `candle_nn::rotary_emb::rope_slow` does
/// the GLM-style `x*cos + rotate_half(x)*sin` rotation and
/// internally `cat`s cos/sin with themselves along the last dim,
/// so we hand it the `(seq_len, head_dim/2)` slice it expects.
pub fn apply(
&self,
q: &Tensor,
k: &Tensor,
offset: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let (_, _, seq_len, _) = q.dims4()?;
let cos = self.cos.narrow(0, offset, seq_len)?;
let sin = self.sin.narrow(0, offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}

View File

@@ -617,12 +617,22 @@ impl CandleHarness {
})))
}
"qwen3_5" => {
// Stage 8c scaffold: config parses, model
// constructs, but forward bails. See
// `arch/qwen3_5.rs` for the open architecture work.
// Qwen3-Next needs a ShardedVarBuilder because its
// load functions use the sharded backend (so they
// can be reused unchanged by the future TP variant).
// With world_size=1 the backend falls through to
// the unsharded path, so there is no per-load cost.
let cfg: super::arch::qwen3_5::Config = serde_json::from_str(&cfg_text)
.context("parse Qwen3-Next (qwen3_5) config.json")?;
let model = super::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, vb)
let sharded_vb = unsafe {
candle_nn::var_builder::ShardedSafeTensors::var_builder(
&safetensors_paths,
dtype,
&device_for_load,
)
.context("build ShardedVarBuilder for Qwen3-Next")?
};
let model = super::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, sharded_vb)
.context("build Qwen3-Next dense model")?;
Ok(ModelArch::Qwen3_5Dense(model))
}