From c6022aa6b969a583bc5a8215fa3c72ad669ace87 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Wed, 20 May 2026 08:36:22 +0300 Subject: [PATCH] feat(stage-8b): Llama + Qwen3 MoE families on the candle harness MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Broadens the single-GPU dense and quantized paths to cover three non-Qwen3 architectures already shipped by candle-transformers. TP for these is a separate stage (each family would need its own tp_*.rs mirroring tp_qwen3.rs). `ModelArch` gains four variants: - LlamaDense (boxed — wraps Llama + an inline Cache + the config it takes to rebuild the cache, since candle::llama::Cache has no reset) - LlamaQuantized (candle_transformers::models::quantized_llama) - Qwen3MoeDense (candle::models::qwen3_moe::ModelForCausalLM) - Qwen3MoeQuantized (candle::models::quantized_qwen3_moe::GGUFQWenMoE — takes an explicit compute dtype; F16 by default for best consumer-GPU throughput) The dispatch is method-based now: - `ModelArch::forward(&mut self, input, offset) -> Result` with a shared `squeeze_to_vocab` normalising shape differences (qwen3 returns [B,1,V]; quantized_qwen3 returns [B,V]; new families may differ again — the helper handles all of them). - `ModelArch::clear_kv_cache(&mut self) -> Result<()>`. Llama needs a Cache rebuild because its Cache has no in-place reset; the new `LlamaDense` wrapper holds the bits needed to do it. `run_inference` / `run_inference_streaming` collapse to a single dispatch path: no more per-variant match arms in the hot loop, and new architectures pick up streaming + non-streaming for free with zero changes outside `ModelArch`. DENSE_SUPPORTED_MODEL_TYPES is now ["llama", "qwen3", "qwen3_moe"]. GGUF arch switch grows "qwen3moe" + "llama" branches (qwen3moe with no underscore matches llama.cpp's general.architecture convention). Stage 8a's diagnostic auto-reports the new supported set. The `LlamaDense` variant is boxed because the wrapper's inline Cache + Config makes it 544 bytes vs ~300 for everything else (clippy::large_enum_variant). Verified: cargo test --workspace passes 66 tests; cargo clippy CPU and `--features cuda` both clean (the cuda check ran inside the locally-built `neuron-build-local` container with the math_functions.h patch applied). Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/neuron/src/harness/candle.rs | 314 +++++++++++++++++++--------- 1 file changed, 211 insertions(+), 103 deletions(-) diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index ea860ac..e8cf95d 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -14,8 +14,12 @@ use candle_core::quantized::gguf_file; use candle_core::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::{LogitsProcessor, Sampling}; +use candle_transformers::models::llama as llama_dense; +use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaWeights; use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights; +use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE; use candle_transformers::models::qwen3 as qwen3_dense; +use candle_transformers::models::qwen3_moe as qwen3_moe_dense; use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec}; use cortex_core::openai::{ ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, @@ -100,21 +104,109 @@ pub struct TpLoadedModel { pub leader_model: Arc>, } -/// Architecture-specific weights. +/// Architecture-specific weights. Each variant covers one (family, +/// source-format) pair; the dense variants take the safetensors path +/// and the `Quantized*` variants take the GGUF path. /// -/// - `Qwen3Quantized` — GGUF source, pre-quantized. Single-GPU only; -/// TP sharding pre-quantized super-blocks is intractable. Stays the -/// default for small models loaded via `Qwen/Qwen3-*-GGUF` and -/// `unsloth/Qwen3-*-GGUF` repos. -/// - `Qwen3Dense` — bf16 safetensors source. The path that supports -/// TP (Stage 7b-ii+) because slicing dense weights by row/column -/// under safetensors is mechanical. Used when `ModelSpec.quant` is -/// None; intended target for Qwen3.6-27B etc. -/// -/// Stage 8 broadens this to additional families. +/// TP currently only works through `Qwen3Dense` (see `tp_qwen3.rs`); +/// every other variant is single-GPU. Quantized variants can't shard +/// across GPUs at all — slicing GGUF super-blocks is intractable — +/// and the new dense families (Llama, Qwen3 MoE) lack their own +/// TP-aware modules yet. pub enum ModelArch { + // Qwen3 family Qwen3Quantized(QuantizedQwen3Weights), Qwen3Dense(qwen3_dense::ModelForCausalLM), + Qwen3MoeQuantized(GGUFQWenMoE), + Qwen3MoeDense(qwen3_moe_dense::ModelForCausalLM), + + // Llama family (covers Llama 1/2/3/3.1/3.3). Boxed because the + // wrapper carries an inline Cache + Config — without indirection + // the enum's `LlamaDense` variant is several hundred bytes larger + // than the others (clippy::large_enum_variant). + LlamaQuantized(QuantizedLlamaWeights), + LlamaDense(Box), +} + +impl ModelArch { + /// One forward step on this arch with the rank-1 vocab logits + /// extracted. Hides per-family shape differences (some return + /// `[B, V]`, others `[B, 1, V]`) — every caller gets a `[V]` + /// tensor ready for sampling. + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let raw = match self { + ModelArch::Qwen3Quantized(m) => m.forward(input, offset)?, + ModelArch::Qwen3Dense(m) => m.forward(input, offset)?, + ModelArch::Qwen3MoeQuantized(m) => m.forward(input, offset)?, + ModelArch::Qwen3MoeDense(m) => m.forward(input, offset)?, + ModelArch::LlamaQuantized(m) => m.forward(input, offset)?, + ModelArch::LlamaDense(m) => m.forward(input, offset)?, + }; + squeeze_to_vocab(&raw) + } + + /// Reset the KV cache before each new request so we don't attend + /// over a previous request's tokens. Some architectures have an + /// in-place reset; Llama needs a Cache rebuild (held inline in + /// the wrapper). + pub fn clear_kv_cache(&mut self) -> Result<()> { + match self { + ModelArch::Qwen3Quantized(_) => Ok(()), /* keeps cache by design; + * forward() handles offset */ + ModelArch::Qwen3Dense(m) => { + m.clear_kv_cache(); + Ok(()) + } + ModelArch::Qwen3MoeQuantized(_) => Ok(()), + ModelArch::Qwen3MoeDense(m) => { + m.clear_kv_cache(); + Ok(()) + } + ModelArch::LlamaQuantized(_) => Ok(()), + ModelArch::LlamaDense(m) => m.clear_kv_cache(), + } + } +} + +/// Squeeze any leading singleton dims off the logits tensor so the +/// caller gets a rank-1 `[vocab_size]` slice ready for sampling. Bails +/// on a non-singleton leading dim (would mean a batched forward, which +/// no caller emits today). +fn squeeze_to_vocab(t: &Tensor) -> Result { + let mut t = t.clone(); + while t.dims().len() > 1 { + if t.dims()[0] != 1 { + anyhow::bail!( + "logits expected to start with a singleton dim, got shape {:?}", + t.dims() + ); + } + t = t.squeeze(0)?; + } + Ok(t) +} + +/// Llama dense wrapper. Bundles candle's `Llama` model with its +/// externally-managed `Cache` plus enough config to rebuild the +/// cache on `clear_kv_cache` (Llama's Cache doesn't expose a reset). +pub struct LlamaDense { + model: llama_dense::Llama, + cache: llama_dense::Cache, + config: llama_dense::Config, + dtype: DType, + device: Device, +} + +impl LlamaDense { + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + Ok(self.model.forward(input, offset, &mut self.cache)?) + } + + pub fn clear_kv_cache(&mut self) -> Result<()> { + self.cache = llama_dense::Cache::new(true, self.dtype, &self.config, &self.device) + .context("rebuild Llama Cache for new request")?; + Ok(()) + } } /// Repetition penalty applied to recently-generated tokens before @@ -130,11 +222,10 @@ const REPEAT_LAST_N: usize = 64; /// Architectures the dense safetensors path can construct. Keep /// alphabetical; one entry per supported `config.json#/model_type` -/// value. New entries land alongside a new `ModelArch` variant + a new -/// dispatch branch in `load_arch_dense` / `run_inference` / -/// `run_inference_streaming` (plus, for TP, a parallel pattern in -/// `tp_qwen3.rs`). -const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["qwen3"]; +/// value. New entries land alongside a new `ModelArch` variant + a +/// dispatch branch in `load_arch_dense` (plus, for TP, a parallel +/// pattern in `tp_qwen3.rs`). +const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["llama", "qwen3", "qwen3_moe"]; /// Pre-flight check the operator's `config.json` against the set of /// architectures the dense path actually knows how to build. Surfaces @@ -362,6 +453,9 @@ impl CandleHarness { .unwrap_or_default(); tracing::info!(architecture = %architecture, "GGUF architecture"); + // The `general.architecture` GGUF metadata key follows + // llama.cpp conventions (lowercase, no underscores in some + // cases) — `qwen3moe`, not `qwen3_moe`. match architecture.as_str() { "qwen3" => { let weights = @@ -369,8 +463,25 @@ impl CandleHarness { .map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?; Ok(ModelArch::Qwen3Quantized(weights)) } + "qwen3moe" => { + // GGUFQWenMoE takes an explicit compute dtype + // alongside the device — F16 matches the GGUF + // weights' typical accumulation precision and + // gives the best tokens/sec on consumer cards. + let weights = + GGUFQWenMoE::from_gguf(content, &mut file, &device_for_load, DType::F16) + .map_err(|e| anyhow::anyhow!("from_gguf qwen3_moe: {e}"))?; + Ok(ModelArch::Qwen3MoeQuantized(weights)) + } + "llama" => { + let weights = + QuantizedLlamaWeights::from_gguf(content, &mut file, &device_for_load) + .map_err(|e| anyhow::anyhow!("from_gguf llama: {e}"))?; + Ok(ModelArch::LlamaQuantized(weights)) + } other => anyhow::bail!( - "unsupported GGUF architecture '{other}'; quantized path only supports qwen3" + "unsupported GGUF architecture '{other}'; quantized path supports \ + qwen3, qwen3moe, llama" ), } }) @@ -396,32 +507,84 @@ impl CandleHarness { let model_id_for_log = spec.model_id.clone(); let arch = tokio::task::spawn_blocking(move || -> Result { - tracing::info!( - model = %model_id_for_log, - shards = safetensors_paths.len(), - "loading dense Qwen3 from safetensors" - ); let cfg_text = std::fs::read_to_string(&config_path).context("read config.json")?; check_dense_config_supported(&cfg_text, &model_id_for_log)?; - let cfg: qwen3_dense::Config = - serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?; + // Peek at model_type to choose the family before the + // typed deserialize — each family has its own Config. + let model_type = serde_json::from_str::(&cfg_text) + .ok() + .as_ref() + .and_then(|v| v.get("model_type")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + tracing::info!( + model = %model_id_for_log, + model_type = %model_type, + shards = safetensors_paths.len(), + "loading dense model from safetensors" + ); - // bf16 is the canonical Qwen3 distribution dtype. CUDA - // devices on Ada+ support it; Ampere also supports bf16 - // natively. CPU candle handles bf16 via emulation. + // bf16 is the canonical distribution dtype for Qwen3 / + // Llama 3 / Qwen3 MoE. CUDA on Ada+ has hardware bf16; + // Ampere has it too. CPU emulates. let dtype = DType::BF16; // SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files; - // mutation of the underlying files by another process while - // we hold the mapping is UB. We trust that nothing else on - // the host modifies the HF cache files during a model's - // lifetime (hf-hub itself is immutable-by-design). + // mutation by another process while we hold the mapping is + // UB. We trust the HF cache is immutable-by-design. let vb = unsafe { VarBuilder::from_mmaped_safetensors(&safetensors_paths, dtype, &device_for_load) .context("build VarBuilder over safetensors")? }; - let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb) - .map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?; - Ok(ModelArch::Qwen3Dense(model)) + + match model_type.as_str() { + "qwen3" => { + let cfg: qwen3_dense::Config = + serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?; + let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb) + .map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?; + Ok(ModelArch::Qwen3Dense(model)) + } + "qwen3_moe" => { + let cfg: qwen3_moe_dense::Config = + serde_json::from_str(&cfg_text).context("parse Qwen3 MoE config.json")?; + let model = qwen3_moe_dense::ModelForCausalLM::new(&cfg, vb) + .map_err(|e| anyhow::anyhow!("build Qwen3 MoE dense model: {e}"))?; + Ok(ModelArch::Qwen3MoeDense(model)) + } + "llama" => { + let cfg: llama_dense::LlamaConfig = + serde_json::from_str(&cfg_text).context("parse Llama config.json")?; + // Llama has multiple sub-variants (Llama 1 has no + // GQA; Llama 3 does). `LlamaConfig::into_config` + // resolves the right shape; the `use_flash_attn` + // arg defaults to false — the flash kernel is a + // separate feature flag and uses extra VRAM. + let config = cfg.into_config(false); + let cache = llama_dense::Cache::new(true, dtype, &config, &device_for_load) + .context("build Llama Cache")?; + let model = llama_dense::Llama::load(vb, &config) + .map_err(|e| anyhow::anyhow!("build Llama dense model: {e}"))?; + Ok(ModelArch::LlamaDense(Box::new(LlamaDense { + model, + cache, + config, + dtype, + device: device_for_load, + }))) + } + other => { + // Defensive: `check_dense_config_supported` already + // gated on the supported set, so this branch is + // unreachable unless that list and the match here + // drift apart. + anyhow::bail!( + "unrouted supported model_type '{other}' — \ + DENSE_SUPPORTED_MODEL_TYPES and load_arch_dense \ + must stay in sync" + ) + } + } }) .await .context("blocking dense load task panicked")??; @@ -1375,25 +1538,10 @@ fn run_inference( let mut generated: Vec = Vec::new(); - let mut next_token = match arch { - ModelArch::Qwen3Quantized(model) => { - model.clear_kv_cache(); - let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?; - let logits = model.forward(&input, 0)?; - let logits = logits.squeeze(0)?; - sample_with_penalty(&logits, &generated, &mut logits_processor)? - } - ModelArch::Qwen3Dense(model) => { - model.clear_kv_cache(); - let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?; - // qwen3::ModelForCausalLM::forward returns [B, 1, V] — - // no final squeeze on the dense path, unlike the quantized - // variant which returns [B, V]. Strip both batch and seq - // dims to get the rank-1 logits LogitsProcessor expects. - let logits = model.forward(&input, 0)?.squeeze(0)?.squeeze(0)?; - sample_with_penalty(&logits, &generated, &mut logits_processor)? - } - }; + arch.clear_kv_cache()?; + let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?; + let logits = arch.forward(&input, 0)?; + let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?; if Some(next_token) == eos_id { return Ok((generated, "stop".into())); @@ -1401,23 +1549,9 @@ fn run_inference( generated.push(next_token); for index in 0..max_new.saturating_sub(1) { - next_token = match arch { - ModelArch::Qwen3Quantized(model) => { - let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?; - let logits = model.forward(&input, prompt_tokens.len() + index)?; - let logits = logits.squeeze(0)?; - sample_with_penalty(&logits, &generated, &mut logits_processor)? - } - ModelArch::Qwen3Dense(model) => { - let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?; - // Dense returns [B, 1, V]; strip both leading dims. - let logits = model - .forward(&input, prompt_tokens.len() + index)? - .squeeze(0)? - .squeeze(0)?; - sample_with_penalty(&logits, &generated, &mut logits_processor)? - } - }; + let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?; + let logits = arch.forward(&input, prompt_tokens.len() + index)?; + next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?; if Some(next_token) == eos_id { return Ok((generated, "stop".into())); } @@ -1465,22 +1599,10 @@ fn run_inference_streaming( let mut decoded_prefix = String::new(); let mut finish_reason = "length".to_string(); - let mut next_token = match arch { - ModelArch::Qwen3Quantized(model) => { - model.clear_kv_cache(); - let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?; - let logits = model.forward(&input, 0)?; - let logits = logits.squeeze(0)?; - sample_with_penalty(&logits, &all_tokens, &mut logits_processor)? - } - ModelArch::Qwen3Dense(model) => { - model.clear_kv_cache(); - let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?; - let logits = model.forward(&input, 0)?; - let logits = logits.squeeze(0)?; - sample_with_penalty(&logits, &all_tokens, &mut logits_processor)? - } - }; + arch.clear_kv_cache()?; + let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?; + let logits = arch.forward(&input, 0)?; + let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?; let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result { let full = tokenizer @@ -1521,23 +1643,9 @@ fn run_inference_streaming( } for index in 0..max_new.saturating_sub(1) { - next_token = match arch { - ModelArch::Qwen3Quantized(model) => { - let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?; - let logits = model.forward(&input, prompt_tokens.len() + index)?; - let logits = logits.squeeze(0)?; - sample_with_penalty(&logits, &all_tokens, &mut logits_processor)? - } - ModelArch::Qwen3Dense(model) => { - let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?; - // Dense returns [B, 1, V]; strip both leading dims. - let logits = model - .forward(&input, prompt_tokens.len() + index)? - .squeeze(0)? - .squeeze(0)?; - sample_with_penalty(&logits, &all_tokens, &mut logits_processor)? - } - }; + let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?; + let logits = arch.forward(&input, prompt_tokens.len() + index)?; + next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?; if Some(next_token) == eos_id { finish_reason = "stop".into(); break;