From 07c44d5db1f4d180934c0818e6998e35f6827abc Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Wed, 20 May 2026 16:18:52 +0300 Subject: [PATCH] fix(qwen3_5): nested rope_parameters + partial_rotary_factor=0.25 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two interlocked bugs surfaced trying to load Qwen/Qwen3.5-0.8B (and the same applies to Qwen/Qwen3.6-27B): 1. Qwen3-Next config.json does NOT have a top-level `rope_theta`. It lives inside `rope_parameters: { rope_theta, partial_rotary_factor, rope_type, mrope_section, mrope_interleaved }`. Our TextConfig declared `rope_theta` as a non-optional top-level field, so the deserializer bailed with the misleading "missing field `rope_theta` at line 74 col 5". Replaced with a nested `RopeParameters` struct that mirrors the upstream shape. Defaults are conservative (rope_theta=10000, partial_rotary_factor=1.0) so a missing or partial block degrades to standard full-rotation RoPE rather than failing. 2. `partial_rotary_factor: 0.25` means only `head_dim * 0.25 = 64` of the 256 head_dim values get RoPE applied — the rest pass through unchanged. Our RotaryEmbedding was building the inv_freq table for the full head_dim and rotating everything. Silently wrong for every full-attention layer. `RotaryEmbedding` now derives `rotary_dim` from `head_dim * partial_rotary_factor`, builds its cos/sin tables at that smaller size, and in `apply()` splits q/k into (rotate, pass) on the last dim, only `rope_slow`-rotates the rotate half, and re-concatenates. Mirrors the reference Python's `apply_rotary_pos_emb` exactly for the non-trivial `partial_rotary_factor` case. Tests updated: config-deserialise fixture uses the real `rope_parameters` shape (matching the Qwen3.6-27B and Qwen3.5-0.8B configs). The linear-attention forward-smoke test was already using full rotation which still works; just shifted to the nested struct. After this, the load that previously failed at "parse Qwen3-Next (qwen3_5) config.json: missing field rope_theta" should reach the actual safetensors materialisation step. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/harness/arch/qwen3_5/linear_attn.rs | 8 ++- crates/neuron/src/harness/arch/qwen3_5/mod.rs | 52 +++++++++++++- .../neuron/src/harness/arch/qwen3_5/rope.rs | 69 ++++++++++++++++--- 3 files changed, 114 insertions(+), 15 deletions(-) diff --git a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs index ea9417b..3ce654d 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs @@ -60,6 +60,8 @@ use candle_core::{IndexOp, Module, Tensor}; use candle_nn::Linear; use candle_nn::var_builder::ShardedVarBuilder; +#[cfg(test)] +use super::RopeParameters; use super::TextConfig; use super::rmsnorm::{Qwen3_5RmsNormGated, l2norm}; @@ -475,7 +477,11 @@ mod tests { num_key_value_heads: 1, head_dim: 4, max_position_embeddings: 32, - rope_theta: 10000.0, + rope_parameters: RopeParameters { + rope_theta: 10000.0, + partial_rotary_factor: 1.0, + rope_type: None, + }, rms_norm_eps: 1e-6, tie_word_embeddings: false, attn_output_gate: true, diff --git a/crates/neuron/src/harness/arch/qwen3_5/mod.rs b/crates/neuron/src/harness/arch/qwen3_5/mod.rs index 383d998..467b7c0 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/mod.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/mod.rs @@ -114,7 +114,12 @@ pub struct TextConfig { pub num_key_value_heads: usize, pub head_dim: usize, pub max_position_embeddings: usize, - pub rope_theta: f64, + /// Nested RoPE settings. Qwen3-Next puts `rope_theta` and + /// `partial_rotary_factor` inside this block rather than at the + /// top level — important because the partial rotary means only + /// `head_dim * partial_rotary_factor` dims get RoPE applied (the + /// rest pass through unchanged). + pub rope_parameters: RopeParameters, pub rms_norm_eps: f64, #[serde(default)] pub tie_word_embeddings: bool, @@ -170,6 +175,37 @@ fn default_hidden_act() -> String { "silu".into() } +/// Nested `rope_parameters` block from a Qwen3-Next `config.json`. +/// `mrope_section` and `mrope_interleaved` are accepted via the +/// `#[serde(default)]` flatten-tolerance below but ignored — we treat +/// MRoPE as plain RoPE for text-only inference (the three position +/// grids carry identical ids when there's no vision input, so the +/// interleaving is a no-op). +#[derive(Debug, Clone, Deserialize)] +pub struct RopeParameters { + /// Base for the inverse-frequency computation. Qwen3.6: 10_000_000. + #[serde(default = "default_rope_theta")] + pub rope_theta: f64, + /// Fraction of `head_dim` that gets the rotation applied. The + /// remaining `head_dim * (1 - partial_rotary_factor)` dims pass + /// through unchanged. Qwen3.6 / Qwen3.5: 0.25. + #[serde(default = "default_partial_rotary_factor")] + pub partial_rotary_factor: f32, + /// `"default"` for the standard inv_freq RoPE; other values (e.g. + /// `"linear"`, `"dynamic"`) are upstream-supported but not yet + /// implemented here. + #[serde(default)] + pub rope_type: Option, +} + +fn default_rope_theta() -> f64 { + 10_000.0 +} + +fn default_partial_rotary_factor() -> f32 { + 1.0 +} + /// Qwen3-Next base transformer (embedding + decoder stack + final /// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can /// also build on it later — for now only `Qwen3_5ForCausalLM` is the @@ -304,7 +340,9 @@ mod tests { /// Confirms we can deserialise the real upstream config shape. /// Sample taken from `Qwen/Qwen3.6-27B/config.json`, trimmed to - /// the fields the architecture cares about. + /// the fields the architecture cares about. Note `rope_theta` and + /// `partial_rotary_factor` are nested under `rope_parameters` — + /// Qwen3-Next does NOT have a top-level `rope_theta`. #[test] fn config_deserialises_the_real_qwen3_6_shape() { let raw = r#"{ @@ -321,7 +359,13 @@ mod tests { "num_key_value_heads": 8, "head_dim": 256, "max_position_embeddings": 32768, - "rope_theta": 5000000.0, + "rope_parameters": { + "mrope_interleaved": true, + "mrope_section": [11, 11, 10], + "partial_rotary_factor": 0.25, + "rope_theta": 10000000, + "rope_type": "default" + }, "rms_norm_eps": 1e-6, "tie_word_embeddings": false, "attn_output_gate": true, @@ -339,5 +383,7 @@ mod tests { assert!(cfg.text_config.attn_output_gate); assert_eq!(cfg.text_config.full_attention_interval, Some(4)); assert_eq!(cfg.text_config.layer_types.len(), 4); + assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0); + assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6); } } diff --git a/crates/neuron/src/harness/arch/qwen3_5/rope.rs b/crates/neuron/src/harness/arch/qwen3_5/rope.rs index ed72b12..3bc9cfc 100644 --- a/crates/neuron/src/harness/arch/qwen3_5/rope.rs +++ b/crates/neuron/src/harness/arch/qwen3_5/rope.rs @@ -21,15 +21,36 @@ use super::TextConfig; pub struct RotaryEmbedding { sin: Tensor, cos: Tensor, + /// Number of dims at the head's leading edge that the rotation + /// covers. The remaining `head_dim - rotary_dim` dims pass through + /// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so + /// for `head_dim = 256` only 64 dims rotate. + rotary_dim: usize, + head_dim: usize, } impl RotaryEmbedding { pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result { - let dim = cfg.head_dim; + let head_dim = cfg.head_dim; + let rope = &cfg.rope_parameters; + let rotary_dim = (head_dim as f32 * rope.partial_rotary_factor) as usize; + if !rotary_dim.is_multiple_of(2) { + anyhow::bail!( + "rotary_dim = head_dim * partial_rotary_factor = {head_dim} * {} = {rotary_dim} \ + must be even (cos/sin are paired)", + rope.partial_rotary_factor + ); + } + if rotary_dim == 0 { + anyhow::bail!( + "rotary_dim = 0 (partial_rotary_factor = {} too small)", + rope.partial_rotary_factor + ); + } let max_seq_len = cfg.max_position_embeddings; - let inv_freq: Vec = (0..dim) + let inv_freq: Vec = (0..rotary_dim) .step_by(2) - .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32) .collect(); let n = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?; @@ -40,6 +61,8 @@ impl RotaryEmbedding { Ok(Self { sin: freqs.sin()?.to_dtype(dtype)?, cos: freqs.cos()?.to_dtype(dtype)?, + rotary_dim, + head_dim, }) } @@ -47,21 +70,45 @@ impl RotaryEmbedding { /// /// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index /// into the cached cos/sin table — the position of the first token - /// in the current step. `candle_nn::rotary_emb::rope_slow` does - /// the GLM-style `x*cos + rotate_half(x)*sin` rotation and - /// internally `cat`s cos/sin with themselves along the last dim, - /// so we hand it the `(seq_len, head_dim/2)` slice it expects. + /// in the current step. + /// + /// When `rotary_dim < head_dim` the rotation is applied only to the + /// first `rotary_dim` dims of each head; the tail passes through + /// unchanged (matches the reference Python's + /// `apply_rotary_pos_emb` with non-trivial `partial_rotary_factor`). pub fn apply( &self, q: &Tensor, k: &Tensor, offset: usize, ) -> candle_core::Result<(Tensor, Tensor)> { - let (_, _, seq_len, _) = q.dims4()?; + let (_, _, seq_len, head_dim_in) = q.dims4()?; + debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch"); let cos = self.cos.narrow(0, offset, seq_len)?; let sin = self.sin.narrow(0, offset, seq_len)?; - let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?; - let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?; - Ok((q_embed, k_embed)) + if self.rotary_dim == self.head_dim { + // Full rotation. + let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } else { + // Partial rotation: narrow → rotate → cat the untouched tail. + let tail = self.head_dim - self.rotary_dim; + let q_rot = q + .narrow(candle_core::D::Minus1, 0, self.rotary_dim)? + .contiguous()?; + let q_pass = q.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?; + let k_rot = k + .narrow(candle_core::D::Minus1, 0, self.rotary_dim)? + .contiguous()?; + let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?; + let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, &cos, &sin)?; + let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, &cos, &sin)?; + let q_embed = + Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?; + let k_embed = + Tensor::cat(&[&k_rotated, &k_pass.contiguous()?], candle_core::D::Minus1)?; + Ok((q_embed, k_embed)) + } } }