diff --git a/crates/neuron/src/harness/arch/qwen3_5/decoder.rs b/crates/neuron/src/harness/arch/qwen3_5/decoder.rs new file mode 100644 index 0000000..df10bb4 --- /dev/null +++ b/crates/neuron/src/harness/arch/qwen3_5/decoder.rs @@ -0,0 +1,117 @@ +//! Qwen3-Next decoder layer. +//! +//! Standard pre-norm transformer block (LN → attention → residual → +//! LN → MLP → residual) where the attention slot dispatches on the +//! per-layer `layer_types[i]` value in the config: +//! +//! - `"full_attention"` → [`Qwen3_5Attention`] (GQA causal + output +//! gate + RoPE + KV cache). +//! - `"linear_attention"` → [`GatedDeltaNet`] (recurrent delta rule + +//! causal conv + per-head state). +//! +//! In Qwen3.6-27B every 4th layer is full_attention; the rest are +//! linear_attention. `full_attention_interval` in the config is a +//! hint; `layer_types` is authoritative. + +use anyhow::Result; +use candle_core::{Module, Tensor}; +use candle_nn::var_builder::ShardedVarBuilder; +use std::sync::Arc; + +use super::TextConfig; +use super::full_attn::Qwen3_5Attention; +use super::linear_attn::GatedDeltaNet; +use super::mlp::Qwen3_5MLP; +use super::rmsnorm::Qwen3_5RmsNorm; +use super::rope::RotaryEmbedding; + +/// One of the two attention flavours sitting in a decoder layer's +/// attention slot. Full-attention layers need the rotary table and +/// take an attention mask; linear-attention layers carry their own +/// recurrent state and ignore the mask. +enum AttentionKind { + Full(Qwen3_5Attention), + Linear(GatedDeltaNet), +} + +pub struct Qwen3_5DecoderLayer { + input_layernorm: Qwen3_5RmsNorm, + post_attention_layernorm: Qwen3_5RmsNorm, + mlp: Qwen3_5MLP, + attention: AttentionKind, +} + +impl Qwen3_5DecoderLayer { + pub fn load( + cfg: &TextConfig, + rotary: Arc, + layer_idx: usize, + vb: &ShardedVarBuilder, + ) -> Result { + let layer_type = cfg + .layer_types + .get(layer_idx) + .map(String::as_str) + .ok_or_else(|| { + anyhow::anyhow!( + "layer_types[{layer_idx}] missing (have {} entries)", + cfg.layer_types.len() + ) + })?; + + let attention = match layer_type { + "full_attention" => { + AttentionKind::Full(Qwen3_5Attention::load(cfg, rotary, &vb.pp("self_attn"))?) + } + "linear_attention" => { + AttentionKind::Linear(GatedDeltaNet::load(cfg, &vb.pp("linear_attn"))?) + } + other => anyhow::bail!( + "unknown layer_type '{other}' for layer {layer_idx} (expected \ + 'full_attention' or 'linear_attention')" + ), + }; + + let mlp = Qwen3_5MLP::load(cfg, &vb.pp("mlp"))?; + let input_layernorm = + Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?; + let post_attention_layernorm = Qwen3_5RmsNorm::load( + &vb.pp("post_attention_layernorm"), + cfg.hidden_size, + cfg.rms_norm_eps, + )?; + + Ok(Self { + input_layernorm, + post_attention_layernorm, + mlp, + attention, + }) + } + + pub fn forward( + &mut self, + x: &Tensor, + attn_mask: Option<&Tensor>, + offset: usize, + ) -> candle_core::Result { + let h = self.input_layernorm.forward(x)?; + let attn_out = match &mut self.attention { + AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?, + // Linear attention ignores attn_mask + offset; its causal + // structure is baked into the recurrent state lifecycle. + AttentionKind::Linear(net) => net.forward(&h)?, + }; + let x = (x + attn_out)?; + let h2 = self.post_attention_layernorm.forward(&x)?; + let h2 = self.mlp.forward(&h2)?; + x + h2 + } + + pub fn clear_kv_cache(&mut self) { + match &mut self.attention { + AttentionKind::Full(attn) => attn.clear_kv_cache(), + AttentionKind::Linear(net) => net.clear_kv_cache(), + } + } +} diff --git a/crates/neuron/src/harness/arch/qwen3_5/full_attn.rs b/crates/neuron/src/harness/arch/qwen3_5/full_attn.rs new file mode 100644 index 0000000..28b378b --- /dev/null +++ b/crates/neuron/src/harness/arch/qwen3_5/full_attn.rs @@ -0,0 +1,179 @@ +//! Qwen3-Next's `full_attention` layer. +//! +//! Standard GQA causal attention with two Qwen3-Next-specific quirks: +//! +//! 1. **Output gate (`attn_output_gate=True`).** `q_proj` is widened +//! to `num_heads * head_dim * 2`. The second half is reshaped to +//! `(B, L, num_heads * head_dim)` and fed through a sigmoid; the +//! attention output is pointwise-multiplied by this gate before +//! `o_proj`. Effectively a per-head per-position attenuation on +//! the attention output. +//! +//! 2. **`(1 + w) * x` RmsNorm** on q and k (see `rmsnorm::Qwen3_5RmsNorm`). +//! candle_nn's RmsNorm applies `w * x`; the upstream Qwen3-Next +//! checkpoints expect the `(1 + w)` form. +//! +//! Otherwise: GQA with `num_attention_heads / num_key_value_heads` +//! repeat, q_norm + k_norm on the head dim, GLM-style rotary (see +//! `rope::RotaryEmbedding`), and the usual causal mask. + +use anyhow::{Context, Result}; +use candle_core::{Module, Tensor}; +use candle_nn::Linear; +use candle_nn::kv_cache::ConcatKvCache; +use candle_nn::var_builder::ShardedVarBuilder; +use candle_transformers::utils::repeat_kv; +use std::sync::Arc; + +use super::TextConfig; +use super::rmsnorm::Qwen3_5RmsNorm; +use super::rope::RotaryEmbedding; + +pub struct Qwen3_5Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + q_norm: Qwen3_5RmsNorm, + k_norm: Qwen3_5RmsNorm, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + rotary: Arc, + kv_cache: ConcatKvCache, +} + +impl Qwen3_5Attention { + pub fn load( + cfg: &TextConfig, + rotary: Arc, + vb: &ShardedVarBuilder, + ) -> Result { + let head_dim = cfg.head_dim; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + if num_kv_heads == 0 || !num_heads.is_multiple_of(num_kv_heads) { + anyhow::bail!( + "num_attention_heads ({num_heads}) must be a positive multiple of \ + num_key_value_heads ({num_kv_heads})" + ); + } + let num_kv_groups = num_heads / num_kv_heads; + + // q_proj is 2x wide: the extra `num_heads * head_dim` slice is + // the gate (see attn_output_gate notes above). + let q_proj = load_linear_no_bias(vb, "q_proj", cfg.hidden_size, num_heads * head_dim * 2)?; + let k_proj = load_linear_no_bias(vb, "k_proj", cfg.hidden_size, num_kv_heads * head_dim)?; + let v_proj = load_linear_no_bias(vb, "v_proj", cfg.hidden_size, num_kv_heads * head_dim)?; + let o_proj = load_linear_no_bias(vb, "o_proj", num_heads * head_dim, cfg.hidden_size)?; + + let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?; + let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?; + + let hidden_size = head_dim * num_heads; + let kv_cache = ConcatKvCache::new(2); + + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + q_norm, + k_norm, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size, + rotary, + kv_cache, + }) + } + + pub fn forward( + &mut self, + x: &Tensor, + attn_mask: Option<&Tensor>, + offset: usize, + ) -> candle_core::Result { + let (b, l, _) = x.dims3()?; + + // 1. q_proj — widened output, split into (query, gate). + let q_raw = self + .q_proj + .forward(x)? + .reshape((b, l, self.num_heads, self.head_dim * 2))?; + let q = q_raw.narrow(3, 0, self.head_dim)?; + let gate = q_raw.narrow(3, self.head_dim, self.head_dim)?; + // Flatten the gate's head dim back into hidden_size for the + // post-attention pointwise multiply. + let gate = gate + .contiguous()? + .reshape((b, l, self.num_heads * self.head_dim))?; + + // 2. q_norm + k_norm + reshape to (B, H, L, D). + let q = self.q_norm.forward(&q.contiguous()?)?; + let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D) + + let k = self + .k_proj + .forward(x)? + .reshape((b, l, self.num_kv_heads, self.head_dim))?; + let k = self.k_norm.forward(&k.contiguous()?)?; + let k = k.transpose(1, 2)?.contiguous()?; + + let v = self + .v_proj + .forward(x)? + .reshape((b, l, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + + // 3. RoPE on q, k. + let (q, k) = self.rotary.apply(&q, &k, offset)?; + + // 4. KV cache. + let (k, v) = self.kv_cache.append(&k, &v)?; + + // 5. GQA repeat (cheap shape op). + let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; + let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?; + + // 6. Scaled dot-product + causal mask. + let scale = 1.0_f64 / (self.head_dim as f64).sqrt(); + let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + if let Some(m) = attn_mask { + scores = scores.broadcast_add(m)?; + } + let probs = candle_nn::ops::softmax_last_dim(&scores)?; + let ctx = probs.matmul(&v)?; // (B, H, L, D) + + // 7. Reshape back, apply the output gate, project. + let ctx = ctx + .transpose(1, 2)? + .contiguous()? + .reshape((b, l, self.hidden_size))?; + let gate_sig = candle_nn::ops::sigmoid(&gate)?; + let gated = (ctx * gate_sig)?; + self.o_proj.forward(&gated) + } + + pub fn clear_kv_cache(&mut self) { + self.kv_cache.reset(); + } +} + +fn load_linear_no_bias( + vb: &ShardedVarBuilder, + name: &str, + in_dim: usize, + out_dim: usize, +) -> Result { + let weight = vb + .pp(name) + .get((out_dim, in_dim), "weight") + .with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?; + Ok(Linear::new(weight, None)) +} diff --git a/crates/neuron/src/harness/arch/qwen3_5/mlp.rs b/crates/neuron/src/harness/arch/qwen3_5/mlp.rs new file mode 100644 index 0000000..64f5376 --- /dev/null +++ b/crates/neuron/src/harness/arch/qwen3_5/mlp.rs @@ -0,0 +1,53 @@ +//! SwiGLU MLP block for Qwen3-Next. +//! +//! Identical to plain Qwen3's MLP: `down(silu(gate(x)) * up(x))` with +//! no bias on any of the three projections. + +use anyhow::{Context, Result}; +use candle_core::{Module, Tensor}; +use candle_nn::Linear; +use candle_nn::var_builder::ShardedVarBuilder; + +use super::TextConfig; + +pub struct Qwen3_5MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, +} + +impl Qwen3_5MLP { + pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result { + let h = cfg.hidden_size; + let i = cfg.intermediate_size; + let gate_proj = load_linear_no_bias(vb, "gate_proj", h, i)?; + let up_proj = load_linear_no_bias(vb, "up_proj", h, i)?; + let down_proj = load_linear_no_bias(vb, "down_proj", i, h)?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + }) + } +} + +impl Module for Qwen3_5MLP { + fn forward(&self, x: &Tensor) -> candle_core::Result { + let lhs = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?; + let rhs = self.up_proj.forward(x)?; + self.down_proj.forward(&(lhs * rhs)?) + } +} + +fn load_linear_no_bias( + vb: &ShardedVarBuilder, + name: &str, + in_dim: usize, + out_dim: usize, +) -> Result { + let weight = vb + .pp(name) + .get((out_dim, in_dim), "weight") + .with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?; + Ok(Linear::new(weight, None)) +} diff --git a/crates/neuron/src/harness/arch/qwen3_5/mod.rs b/crates/neuron/src/harness/arch/qwen3_5/mod.rs index d931836..383d998 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/mod.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/mod.rs @@ -11,66 +11,77 @@ //! //! ## Status //! -//! **Scaffold only.** `Config` deserialisation is real (so the dispatch -//! in `candle.rs::load_arch_dense` can route based on `model_type` -//! and the operator's diagnostic surfaces "qwen3_5" in the supported -//! set); the actual forward pass is `unimplemented!()`. Filling this -//! in is the substantive Stage 8c work. +//! **Single-GPU dense path is real**. Both attention flavours +//! (`full_attention` with the output-gated GQA causal attention and +//! `linear_attention` with the Gated DeltaNet recurrent block) are +//! implemented. The model loads from upstream safetensors via the +//! existing `load_arch_dense` dispatch and runs forward end to end. //! -//! ## What the architecture needs (open work) +//! Numerical correctness vs the reference Python is **not yet +//! validated** — the structural code path is right, weight tensor +//! names match the upstream layout, shapes flow through cleanly, but +//! the Tbilisi probe (and any other downstream test) is the next +//! step. Likely places a bug would surface: +//! - Per-rank vs per-token-position offsets in the recurrent delta +//! rule (`linear_attn.rs`). +//! - Off-by-one in the conv state continuation across decode steps. +//! - RoPE phase mismatch from MRoPE simplification (we treat the +//! three position grids as collapsed, which is correct only for +//! text-only inference). //! -//! Confirmed from `Qwen/Qwen3.6-27B/config.json`: -//! - Real hyperparams nested under `text_config: {...}`. The -//! architecture is text-side; the multimodal vision tower is -//! separate (`image_token_id`, `language_model_only=false`). -//! - `hidden_size: 5120`, `head_dim: 256`, `intermediate_size: 17408`, -//! `num_attention_heads`, `num_key_value_heads`, etc. — bigger -//! head_dim than plain Qwen3. -//! - `attn_output_gate: true` — a sigmoid gate multiplied into the -//! attention output before the projection. ~10 LoC addition vs the -//! plain Qwen3 attention. -//! - `layer_types: ["linear_attention", "linear_attention", -//! "linear_attention", "full_attention", ...]` with -//! `full_attention_interval: 4` — every 4th layer is full -//! attention, the rest are linear-attention. The full-attention -//! layers shape like a Qwen3 attention; the linear-attention -//! layers are the hard part. +//! ## Submodules //! -//! ## Linear-attention layer +//! - [`rmsnorm`] — `Qwen3_5RmsNorm` (`(1+w)*x` variant), the +//! `Qwen3_5RmsNormGated` used after the delta rule, and the +//! `l2norm` helper. +//! - [`rope`] — text-side rotary embedding (mrope simplified, GLM +//! rotate-half). +//! - [`mlp`] — SwiGLU MLP (gate/up/down, no bias). +//! - [`full_attn`] — `Qwen3_5Attention` with the output-gate +//! widening on `q_proj`. +//! - [`linear_attn`] — `GatedDeltaNet` recurrent delta-rule block +//! (causal depthwise Conv1d → silu → split → L2norm → per-token +//! delta rule → RMSNormGated → out_proj). +//! - [`decoder`] — `Qwen3_5DecoderLayer` dispatching to one of the +//! two attention flavours per layer index. //! -//! Candle has nothing we can reuse — has to be written against the -//! reference Python in the Qwen3-Next HF repo. Likely Lightning -//! Attention-2 (state-space-ish recurrence) given the -//! `linear_attention` tag and Qwen3's prior `qwen3-omni` work. Needs: -//! - A persistent recurrent state per layer (replaces the explicit -//! KV cache for full attention). -//! - Per-token update + readout primitives, fused if possible. -//! - Numerical-correctness validation against the Python reference -//! on a fixed prompt before trusting any output downstream. +//! ## Open work //! -//! ## TP-2 (the immediate motivator) +//! - **TP variant.** `harness/tp/tp_qwen3_5.rs` is the next step. +//! Sharding strategy diverges by layer type: +//! - Full-attention layers: column-parallel q/k/v (including the +//! gate half of `q_proj`) + row-parallel `o_proj`, mirroring +//! `tp_qwen3.rs`. +//! - Linear-attention layers: the recurrent state is per-V-head, so +//! V-head-dimension sharding works cleanly — split `num_v_heads` +//! across ranks (`num_v_heads / world_size` per rank), shard +//! `in_proj_qkv` / `in_proj_z` / `in_proj_b` / `in_proj_a` along +//! the V-head dim, and row-parallel `out_proj`. The `A_log` / +//! `dt_bias` per-head params shard with the heads. //! -//! Beast's 2x RTX 5090 needs tensor-parallel to fit Qwen3.6-27B. -//! TP-aware analogue lives at `harness/tp/tp_qwen3_5.rs` (not yet -//! created — added alongside the dense impl). Sharding strategy -//! diverges by layer type: -//! - Full-attention layers: column-parallel q/k/v + row-parallel o, -//! same as `tp_qwen3.rs`. With `attn_output_gate`, the gate weight -//! is also column-parallel (one gate scalar per head). -//! - Linear-attention layers: the recurrent state is per-token, not -//! per-head, so head-dim sharding doesn't apply. Options are -//! (a) replicate the linear-attention layers across ranks (cheap -//! but wastes ~half the per-rank VRAM since 3 of every 4 layers -//! replicate), or (b) shard along the recurrent-state dimension -//! if the formulation allows. Decision deferred until the linear -//! attention is actually implemented and profiled. +//! - **Chunked delta-rule prefill.** `linear_attn.rs` runs the +//! per-token recurrent path for prefill too — correct but O(L). +//! Porting `torch_chunk_gated_delta_rule` (chunk_size=64) speeds +//! prefill substantially with no surface change. -use anyhow::Result; -use candle_core::Tensor; +use anyhow::{Context, Result}; +use candle_core::{DType, Device, IndexOp, Module, Tensor}; +use candle_nn::Embedding; +use candle_nn::Linear; +use candle_nn::var_builder::ShardedVarBuilder; use serde::Deserialize; +use std::sync::Arc; +pub mod decoder; +pub mod full_attn; pub mod linear_attn; +pub mod mlp; pub mod rmsnorm; +pub mod rope; + +use decoder::Qwen3_5DecoderLayer; +use rmsnorm::Qwen3_5RmsNorm; +use rope::RotaryEmbedding; /// `model_type` we deserialise from `config.json`. Const so the /// dispatch in `candle.rs::load_arch_dense` can pattern-match without @@ -159,41 +170,131 @@ fn default_hidden_act() -> String { "silu".into() } -/// Stub model. Fields are intentionally empty — filling in the -/// concrete architecture is the substantive Stage 8c work. The struct -/// exists so the `ModelArch::Qwen3_5Dense(_)` variant has a payload -/// and dispatch wiring compiles end-to-end. -/// -/// To extend: add embed_tokens, decoder layers, final norm, and -/// lm_head fields here; implement `new`, `forward`, `clear_kv_cache` -/// in terms of them. Mirror the layout of `qwen3_dense::ModelForCausalLM` -/// (in candle-transformers) as a starting point. -pub struct Qwen3_5ForCausalLM { - #[allow(dead_code)] - config: Config, +/// Qwen3-Next base transformer (embedding + decoder stack + final +/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can +/// also build on it later — for now only `Qwen3_5ForCausalLM` is the +/// loaded handle. +pub struct Qwen3_5Model { + embed_tokens: Embedding, + layers: Vec, + norm: Qwen3_5RmsNorm, + device: Device, + dtype: DType, } -impl Qwen3_5ForCausalLM { - pub fn new(config: Config, _vb: candle_nn::VarBuilder) -> Result { - // TODO(stage-8c): build embed_tokens, decoder layers (dispatching - // on layer_types), final RmsNorm, lm_head from the VarBuilder. - // For now we accept the construction so the load path can be - // exercised end-to-end (config parse + safetensors mmap), and - // bail at forward time with a clear marker. - Ok(Self { config }) +impl Qwen3_5Model { + pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result { + let dtype = vb.dtype(); + let device = vb.device().clone(); + + let embed_vb = vb.pp("model.embed_tokens"); + let embed_weight = embed_vb + .get((cfg.vocab_size, cfg.hidden_size), "weight") + .with_context(|| format!("load '{}/weight'", embed_vb.prefix()))?; + let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size); + + let rotary = Arc::new(RotaryEmbedding::new(dtype, cfg, &device)?); + + if cfg.layer_types.len() != cfg.num_hidden_layers { + anyhow::bail!( + "config.text_config.layer_types must have num_hidden_layers ({}) entries; \ + got {}", + cfg.num_hidden_layers, + cfg.layer_types.len() + ); + } + + let vb_l = vb.pp("model.layers"); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + for i in 0..cfg.num_hidden_layers { + layers.push(Qwen3_5DecoderLayer::load( + cfg, + rotary.clone(), + i, + &vb_l.pp(i), + )?); + } + + let norm = Qwen3_5RmsNorm::load(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?; + + Ok(Self { + embed_tokens, + layers, + norm, + device, + dtype, + }) } - pub fn forward(&mut self, _input: &Tensor, _offset: usize) -> Result { - anyhow::bail!( - "Qwen3-Next ({}) forward not implemented yet (Stage 8c, TP-2 motivator)", - self.config.model_type - ) + pub fn embed_weight(&self) -> &Tensor { + self.embed_tokens.embeddings() } pub fn clear_kv_cache(&mut self) { - // No-op for the stub. The real impl needs a `clear_kv_cache` - // that resets the per-layer KV cache (full-attention layers) - // and the per-layer recurrent state (linear-attention layers). + for l in &mut self.layers { + l.clear_kv_cache(); + } + } + + fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result { + let minf = f32::NEG_INFINITY; + let mask: Vec<_> = (0..tgt) + .flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf })) + .collect(); + Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result { + let (b, l) = input.dims2()?; + let mut h = self.embed_tokens.forward(input)?; + // Causal mask only needed for L > 1 prefill; full-attention + // layers consume it via broadcast_add. Linear-attention layers + // ignore the mask. + let causal = if l == 1 { + None + } else { + Some(self.causal_mask(b, l, offset)?) + }; + for layer in &mut self.layers { + h = layer.forward(&h, causal.as_ref(), offset)?; + } + self.norm.forward(&h) + } +} + +pub struct Qwen3_5ForCausalLM { + base: Qwen3_5Model, + lm_head: Linear, +} + +impl Qwen3_5ForCausalLM { + pub fn new(config: Config, vb: ShardedVarBuilder) -> Result { + let cfg = &config.text_config; + let base = Qwen3_5Model::load(cfg, &vb)?; + let lm_head = if cfg.tie_word_embeddings { + Linear::new(base.embed_weight().clone(), None) + } else { + let weight = vb + .pp("lm_head") + .get((cfg.vocab_size, cfg.hidden_size), "weight") + .with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?; + Linear::new(weight, None) + }; + Ok(Self { base, lm_head }) + } + + /// `input`: token-id tensor of shape `(B, L)`. Returns logits at + /// the last position, shape `(B, 1, vocab_size)` — same contract + /// as `qwen3::ModelForCausalLM::forward` so the harness's + /// `squeeze_to_vocab` helper handles both uniformly. + pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result { + let (_, l) = input.dims2()?; + let hidden = self.base.forward(input, offset)?; + hidden.i((.., l - 1.., ..))?.apply(&self.lm_head) + } + + pub fn clear_kv_cache(&mut self) { + self.base.clear_kv_cache(); } } diff --git a/crates/neuron/src/harness/arch/qwen3_5/rope.rs b/crates/neuron/src/harness/arch/qwen3_5/rope.rs new file mode 100644 index 0000000..ed72b12 --- /dev/null +++ b/crates/neuron/src/harness/arch/qwen3_5/rope.rs @@ -0,0 +1,67 @@ +//! Rotary position embedding for Qwen3-Next's full-attention layers. +//! +//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the +//! reference Python — three position grids interleaved per +//! `mrope_section`. For text-only inference all three grids carry the +//! same position ids and the interleave is a no-op, so this module +//! implements the plain (non-mrope) flavour: the standard inv_freq +//! cosine/sine tables driven by `rope_theta` and `head_dim`. +//! +//! Rotation flavour: **GLM-style** rotate-half (the second half of the +//! head dim is negated and swapped into the first). The reference +//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's +//! `rope_slow` is the matching helper. + +use anyhow::Result; +use candle_core::{DType, Device, Tensor}; + +use super::TextConfig; + +#[derive(Debug, Clone)] +pub struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result { + let dim = cfg.head_dim; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let n = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?.to_dtype(dtype)?, + cos: freqs.cos()?.to_dtype(dtype)?, + }) + } + + /// Apply RoPE to q, k. + /// + /// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index + /// into the cached cos/sin table — the position of the first token + /// in the current step. `candle_nn::rotary_emb::rope_slow` does + /// the GLM-style `x*cos + rotate_half(x)*sin` rotation and + /// internally `cat`s cos/sin with themselves along the last dim, + /// so we hand it the `(seq_len, head_dim/2)` slice it expects. + pub fn apply( + &self, + q: &Tensor, + k: &Tensor, + offset: usize, + ) -> candle_core::Result<(Tensor, Tensor)> { + let (_, _, seq_len, _) = q.dims4()?; + let cos = self.cos.narrow(0, offset, seq_len)?; + let sin = self.sin.narrow(0, offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 25285f1..342df69 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -617,12 +617,22 @@ impl CandleHarness { }))) } "qwen3_5" => { - // Stage 8c scaffold: config parses, model - // constructs, but forward bails. See - // `arch/qwen3_5.rs` for the open architecture work. + // Qwen3-Next needs a ShardedVarBuilder because its + // load functions use the sharded backend (so they + // can be reused unchanged by the future TP variant). + // With world_size=1 the backend falls through to + // the unsharded path, so there is no per-load cost. let cfg: super::arch::qwen3_5::Config = serde_json::from_str(&cfg_text) .context("parse Qwen3-Next (qwen3_5) config.json")?; - let model = super::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, vb) + let sharded_vb = unsafe { + candle_nn::var_builder::ShardedSafeTensors::var_builder( + &safetensors_paths, + dtype, + &device_for_load, + ) + .context("build ShardedVarBuilder for Qwen3-Next")? + }; + let model = super::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, sharded_vb) .context("build Qwen3-Next dense model")?; Ok(ModelArch::Qwen3_5Dense(model)) }