feat(stage-8b): Llama + Qwen3 MoE families on the candle harness
All checks were successful
CI / Format (push) Successful in 31s
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Clippy (push) Successful in 2m6s
build-prerelease / Build neuron-blackwell (push) Successful in 3m50s
build-prerelease / Build cortex binary (push) Successful in 4m54s
CI / Test (push) Successful in 4m58s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 4m43s
build-prerelease / Build neuron-ada (push) Successful in 5m8s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s

Broadens the single-GPU dense and quantized paths to cover three
non-Qwen3 architectures already shipped by candle-transformers. TP for
these is a separate stage (each family would need its own tp_*.rs
mirroring tp_qwen3.rs).

`ModelArch` gains four variants:
- LlamaDense (boxed — wraps Llama + an inline Cache + the config it
  takes to rebuild the cache, since candle::llama::Cache has no reset)
- LlamaQuantized (candle_transformers::models::quantized_llama)
- Qwen3MoeDense (candle::models::qwen3_moe::ModelForCausalLM)
- Qwen3MoeQuantized (candle::models::quantized_qwen3_moe::GGUFQWenMoE
  — takes an explicit compute dtype; F16 by default for best
  consumer-GPU throughput)

The dispatch is method-based now:
- `ModelArch::forward(&mut self, input, offset) -> Result<Tensor>`
  with a shared `squeeze_to_vocab` normalising shape differences
  (qwen3 returns [B,1,V]; quantized_qwen3 returns [B,V]; new families
  may differ again — the helper handles all of them).
- `ModelArch::clear_kv_cache(&mut self) -> Result<()>`. Llama needs
  a Cache rebuild because its Cache has no in-place reset; the new
  `LlamaDense` wrapper holds the bits needed to do it.

`run_inference` / `run_inference_streaming` collapse to a single
dispatch path: no more per-variant match arms in the hot loop, and
new architectures pick up streaming + non-streaming for free with
zero changes outside `ModelArch`.

DENSE_SUPPORTED_MODEL_TYPES is now ["llama", "qwen3", "qwen3_moe"].
GGUF arch switch grows "qwen3moe" + "llama" branches (qwen3moe with
no underscore matches llama.cpp's general.architecture convention).
Stage 8a's diagnostic auto-reports the new supported set.

The `LlamaDense` variant is boxed because the wrapper's inline Cache
+ Config makes it 544 bytes vs ~300 for everything else
(clippy::large_enum_variant).

Verified: cargo test --workspace passes 66 tests; cargo clippy CPU
and `--features cuda` both clean (the cuda check ran inside the
locally-built `neuron-build-local` container with the math_functions.h
patch applied).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-20 08:36:22 +03:00
parent 9e31d8deca
commit c6022aa6b9

View File

@@ -14,8 +14,12 @@ use candle_core::quantized::gguf_file;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama as llama_dense;
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaWeights;
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE;
use candle_transformers::models::qwen3 as qwen3_dense;
use candle_transformers::models::qwen3_moe as qwen3_moe_dense;
use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec};
use cortex_core::openai::{
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse,
@@ -100,21 +104,109 @@ pub struct TpLoadedModel {
pub leader_model: Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
}
/// Architecture-specific weights.
/// Architecture-specific weights. Each variant covers one (family,
/// source-format) pair; the dense variants take the safetensors path
/// and the `Quantized*` variants take the GGUF path.
///
/// - `Qwen3Quantized` — GGUF source, pre-quantized. Single-GPU only;
/// TP sharding pre-quantized super-blocks is intractable. Stays the
/// default for small models loaded via `Qwen/Qwen3-*-GGUF` and
/// `unsloth/Qwen3-*-GGUF` repos.
/// - `Qwen3Dense` — bf16 safetensors source. The path that supports
/// TP (Stage 7b-ii+) because slicing dense weights by row/column
/// under safetensors is mechanical. Used when `ModelSpec.quant` is
/// None; intended target for Qwen3.6-27B etc.
///
/// Stage 8 broadens this to additional families.
/// TP currently only works through `Qwen3Dense` (see `tp_qwen3.rs`);
/// every other variant is single-GPU. Quantized variants can't shard
/// across GPUs at all — slicing GGUF super-blocks is intractable —
/// and the new dense families (Llama, Qwen3 MoE) lack their own
/// TP-aware modules yet.
pub enum ModelArch {
// Qwen3 family
Qwen3Quantized(QuantizedQwen3Weights),
Qwen3Dense(qwen3_dense::ModelForCausalLM),
Qwen3MoeQuantized(GGUFQWenMoE),
Qwen3MoeDense(qwen3_moe_dense::ModelForCausalLM),
// Llama family (covers Llama 1/2/3/3.1/3.3). Boxed because the
// wrapper carries an inline Cache + Config — without indirection
// the enum's `LlamaDense` variant is several hundred bytes larger
// than the others (clippy::large_enum_variant).
LlamaQuantized(QuantizedLlamaWeights),
LlamaDense(Box<LlamaDense>),
}
impl ModelArch {
/// One forward step on this arch with the rank-1 vocab logits
/// extracted. Hides per-family shape differences (some return
/// `[B, V]`, others `[B, 1, V]`) — every caller gets a `[V]`
/// tensor ready for sampling.
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
let raw = match self {
ModelArch::Qwen3Quantized(m) => m.forward(input, offset)?,
ModelArch::Qwen3Dense(m) => m.forward(input, offset)?,
ModelArch::Qwen3MoeQuantized(m) => m.forward(input, offset)?,
ModelArch::Qwen3MoeDense(m) => m.forward(input, offset)?,
ModelArch::LlamaQuantized(m) => m.forward(input, offset)?,
ModelArch::LlamaDense(m) => m.forward(input, offset)?,
};
squeeze_to_vocab(&raw)
}
/// Reset the KV cache before each new request so we don't attend
/// over a previous request's tokens. Some architectures have an
/// in-place reset; Llama needs a Cache rebuild (held inline in
/// the wrapper).
pub fn clear_kv_cache(&mut self) -> Result<()> {
match self {
ModelArch::Qwen3Quantized(_) => Ok(()), /* keeps cache by design;
* forward() handles offset */
ModelArch::Qwen3Dense(m) => {
m.clear_kv_cache();
Ok(())
}
ModelArch::Qwen3MoeQuantized(_) => Ok(()),
ModelArch::Qwen3MoeDense(m) => {
m.clear_kv_cache();
Ok(())
}
ModelArch::LlamaQuantized(_) => Ok(()),
ModelArch::LlamaDense(m) => m.clear_kv_cache(),
}
}
}
/// Squeeze any leading singleton dims off the logits tensor so the
/// caller gets a rank-1 `[vocab_size]` slice ready for sampling. Bails
/// on a non-singleton leading dim (would mean a batched forward, which
/// no caller emits today).
fn squeeze_to_vocab(t: &Tensor) -> Result<Tensor> {
let mut t = t.clone();
while t.dims().len() > 1 {
if t.dims()[0] != 1 {
anyhow::bail!(
"logits expected to start with a singleton dim, got shape {:?}",
t.dims()
);
}
t = t.squeeze(0)?;
}
Ok(t)
}
/// Llama dense wrapper. Bundles candle's `Llama` model with its
/// externally-managed `Cache` plus enough config to rebuild the
/// cache on `clear_kv_cache` (Llama's Cache doesn't expose a reset).
pub struct LlamaDense {
model: llama_dense::Llama,
cache: llama_dense::Cache,
config: llama_dense::Config,
dtype: DType,
device: Device,
}
impl LlamaDense {
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
Ok(self.model.forward(input, offset, &mut self.cache)?)
}
pub fn clear_kv_cache(&mut self) -> Result<()> {
self.cache = llama_dense::Cache::new(true, self.dtype, &self.config, &self.device)
.context("rebuild Llama Cache for new request")?;
Ok(())
}
}
/// Repetition penalty applied to recently-generated tokens before
@@ -130,11 +222,10 @@ const REPEAT_LAST_N: usize = 64;
/// Architectures the dense safetensors path can construct. Keep
/// alphabetical; one entry per supported `config.json#/model_type`
/// value. New entries land alongside a new `ModelArch` variant + a new
/// dispatch branch in `load_arch_dense` / `run_inference` /
/// `run_inference_streaming` (plus, for TP, a parallel pattern in
/// `tp_qwen3.rs`).
const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["qwen3"];
/// value. New entries land alongside a new `ModelArch` variant + a
/// dispatch branch in `load_arch_dense` (plus, for TP, a parallel
/// pattern in `tp_qwen3.rs`).
const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["llama", "qwen3", "qwen3_moe"];
/// Pre-flight check the operator's `config.json` against the set of
/// architectures the dense path actually knows how to build. Surfaces
@@ -362,6 +453,9 @@ impl CandleHarness {
.unwrap_or_default();
tracing::info!(architecture = %architecture, "GGUF architecture");
// The `general.architecture` GGUF metadata key follows
// llama.cpp conventions (lowercase, no underscores in some
// cases) — `qwen3moe`, not `qwen3_moe`.
match architecture.as_str() {
"qwen3" => {
let weights =
@@ -369,8 +463,25 @@ impl CandleHarness {
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
Ok(ModelArch::Qwen3Quantized(weights))
}
"qwen3moe" => {
// GGUFQWenMoE takes an explicit compute dtype
// alongside the device — F16 matches the GGUF
// weights' typical accumulation precision and
// gives the best tokens/sec on consumer cards.
let weights =
GGUFQWenMoE::from_gguf(content, &mut file, &device_for_load, DType::F16)
.map_err(|e| anyhow::anyhow!("from_gguf qwen3_moe: {e}"))?;
Ok(ModelArch::Qwen3MoeQuantized(weights))
}
"llama" => {
let weights =
QuantizedLlamaWeights::from_gguf(content, &mut file, &device_for_load)
.map_err(|e| anyhow::anyhow!("from_gguf llama: {e}"))?;
Ok(ModelArch::LlamaQuantized(weights))
}
other => anyhow::bail!(
"unsupported GGUF architecture '{other}'; quantized path only supports qwen3"
"unsupported GGUF architecture '{other}'; quantized path supports \
qwen3, qwen3moe, llama"
),
}
})
@@ -396,32 +507,84 @@ impl CandleHarness {
let model_id_for_log = spec.model_id.clone();
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
tracing::info!(
model = %model_id_for_log,
shards = safetensors_paths.len(),
"loading dense Qwen3 from safetensors"
);
let cfg_text = std::fs::read_to_string(&config_path).context("read config.json")?;
check_dense_config_supported(&cfg_text, &model_id_for_log)?;
let cfg: qwen3_dense::Config =
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
// Peek at model_type to choose the family before the
// typed deserialize — each family has its own Config.
let model_type = serde_json::from_str::<serde_json::Value>(&cfg_text)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
tracing::info!(
model = %model_id_for_log,
model_type = %model_type,
shards = safetensors_paths.len(),
"loading dense model from safetensors"
);
// bf16 is the canonical Qwen3 distribution dtype. CUDA
// devices on Ada+ support it; Ampere also supports bf16
// natively. CPU candle handles bf16 via emulation.
// bf16 is the canonical distribution dtype for Qwen3 /
// Llama 3 / Qwen3 MoE. CUDA on Ada+ has hardware bf16;
// Ampere has it too. CPU emulates.
let dtype = DType::BF16;
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
// mutation of the underlying files by another process while
// we hold the mapping is UB. We trust that nothing else on
// the host modifies the HF cache files during a model's
// lifetime (hf-hub itself is immutable-by-design).
// mutation by another process while we hold the mapping is
// UB. We trust the HF cache is immutable-by-design.
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&safetensors_paths, dtype, &device_for_load)
.context("build VarBuilder over safetensors")?
};
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
Ok(ModelArch::Qwen3Dense(model))
match model_type.as_str() {
"qwen3" => {
let cfg: qwen3_dense::Config =
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
Ok(ModelArch::Qwen3Dense(model))
}
"qwen3_moe" => {
let cfg: qwen3_moe_dense::Config =
serde_json::from_str(&cfg_text).context("parse Qwen3 MoE config.json")?;
let model = qwen3_moe_dense::ModelForCausalLM::new(&cfg, vb)
.map_err(|e| anyhow::anyhow!("build Qwen3 MoE dense model: {e}"))?;
Ok(ModelArch::Qwen3MoeDense(model))
}
"llama" => {
let cfg: llama_dense::LlamaConfig =
serde_json::from_str(&cfg_text).context("parse Llama config.json")?;
// Llama has multiple sub-variants (Llama 1 has no
// GQA; Llama 3 does). `LlamaConfig::into_config`
// resolves the right shape; the `use_flash_attn`
// arg defaults to false — the flash kernel is a
// separate feature flag and uses extra VRAM.
let config = cfg.into_config(false);
let cache = llama_dense::Cache::new(true, dtype, &config, &device_for_load)
.context("build Llama Cache")?;
let model = llama_dense::Llama::load(vb, &config)
.map_err(|e| anyhow::anyhow!("build Llama dense model: {e}"))?;
Ok(ModelArch::LlamaDense(Box::new(LlamaDense {
model,
cache,
config,
dtype,
device: device_for_load,
})))
}
other => {
// Defensive: `check_dense_config_supported` already
// gated on the supported set, so this branch is
// unreachable unless that list and the match here
// drift apart.
anyhow::bail!(
"unrouted supported model_type '{other}' — \
DENSE_SUPPORTED_MODEL_TYPES and load_arch_dense \
must stay in sync"
)
}
}
})
.await
.context("blocking dense load task panicked")??;
@@ -1375,25 +1538,10 @@ fn run_inference(
let mut generated: Vec<u32> = Vec::new();
let mut next_token = match arch {
ModelArch::Qwen3Quantized(model) => {
model.clear_kv_cache();
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
sample_with_penalty(&logits, &generated, &mut logits_processor)?
}
ModelArch::Qwen3Dense(model) => {
model.clear_kv_cache();
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
// qwen3::ModelForCausalLM::forward returns [B, 1, V] —
// no final squeeze on the dense path, unlike the quantized
// variant which returns [B, V]. Strip both batch and seq
// dims to get the rank-1 logits LogitsProcessor expects.
let logits = model.forward(&input, 0)?.squeeze(0)?.squeeze(0)?;
sample_with_penalty(&logits, &generated, &mut logits_processor)?
}
};
arch.clear_kv_cache()?;
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
let logits = arch.forward(&input, 0)?;
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
@@ -1401,23 +1549,9 @@ fn run_inference(
generated.push(next_token);
for index in 0..max_new.saturating_sub(1) {
next_token = match arch {
ModelArch::Qwen3Quantized(model) => {
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
sample_with_penalty(&logits, &generated, &mut logits_processor)?
}
ModelArch::Qwen3Dense(model) => {
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
// Dense returns [B, 1, V]; strip both leading dims.
let logits = model
.forward(&input, prompt_tokens.len() + index)?
.squeeze(0)?
.squeeze(0)?;
sample_with_penalty(&logits, &generated, &mut logits_processor)?
}
};
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
@@ -1465,22 +1599,10 @@ fn run_inference_streaming(
let mut decoded_prefix = String::new();
let mut finish_reason = "length".to_string();
let mut next_token = match arch {
ModelArch::Qwen3Quantized(model) => {
model.clear_kv_cache();
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
}
ModelArch::Qwen3Dense(model) => {
model.clear_kv_cache();
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
}
};
arch.clear_kv_cache()?;
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
let logits = arch.forward(&input, 0)?;
let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result<bool> {
let full = tokenizer
@@ -1521,23 +1643,9 @@ fn run_inference_streaming(
}
for index in 0..max_new.saturating_sub(1) {
next_token = match arch {
ModelArch::Qwen3Quantized(model) => {
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
}
ModelArch::Qwen3Dense(model) => {
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
// Dense returns [B, 1, V]; strip both leading dims.
let logits = model
.forward(&input, prompt_tokens.len() + index)?
.squeeze(0)?
.squeeze(0)?;
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
}
};
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;