feat(stage-8b): Llama + Qwen3 MoE families on the candle harness
All checks were successful
CI / Format (push) Successful in 31s
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Clippy (push) Successful in 2m6s
build-prerelease / Build neuron-blackwell (push) Successful in 3m50s
build-prerelease / Build cortex binary (push) Successful in 4m54s
CI / Test (push) Successful in 4m58s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 4m43s
build-prerelease / Build neuron-ada (push) Successful in 5m8s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
All checks were successful
CI / Format (push) Successful in 31s
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Clippy (push) Successful in 2m6s
build-prerelease / Build neuron-blackwell (push) Successful in 3m50s
build-prerelease / Build cortex binary (push) Successful in 4m54s
CI / Test (push) Successful in 4m58s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 4m43s
build-prerelease / Build neuron-ada (push) Successful in 5m8s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
Broadens the single-GPU dense and quantized paths to cover three non-Qwen3 architectures already shipped by candle-transformers. TP for these is a separate stage (each family would need its own tp_*.rs mirroring tp_qwen3.rs). `ModelArch` gains four variants: - LlamaDense (boxed — wraps Llama + an inline Cache + the config it takes to rebuild the cache, since candle::llama::Cache has no reset) - LlamaQuantized (candle_transformers::models::quantized_llama) - Qwen3MoeDense (candle::models::qwen3_moe::ModelForCausalLM) - Qwen3MoeQuantized (candle::models::quantized_qwen3_moe::GGUFQWenMoE — takes an explicit compute dtype; F16 by default for best consumer-GPU throughput) The dispatch is method-based now: - `ModelArch::forward(&mut self, input, offset) -> Result<Tensor>` with a shared `squeeze_to_vocab` normalising shape differences (qwen3 returns [B,1,V]; quantized_qwen3 returns [B,V]; new families may differ again — the helper handles all of them). - `ModelArch::clear_kv_cache(&mut self) -> Result<()>`. Llama needs a Cache rebuild because its Cache has no in-place reset; the new `LlamaDense` wrapper holds the bits needed to do it. `run_inference` / `run_inference_streaming` collapse to a single dispatch path: no more per-variant match arms in the hot loop, and new architectures pick up streaming + non-streaming for free with zero changes outside `ModelArch`. DENSE_SUPPORTED_MODEL_TYPES is now ["llama", "qwen3", "qwen3_moe"]. GGUF arch switch grows "qwen3moe" + "llama" branches (qwen3moe with no underscore matches llama.cpp's general.architecture convention). Stage 8a's diagnostic auto-reports the new supported set. The `LlamaDense` variant is boxed because the wrapper's inline Cache + Config makes it 544 bytes vs ~300 for everything else (clippy::large_enum_variant). Verified: cargo test --workspace passes 66 tests; cargo clippy CPU and `--features cuda` both clean (the cuda check ran inside the locally-built `neuron-build-local` container with the math_functions.h patch applied). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -14,8 +14,12 @@ use candle_core::quantized::gguf_file;
|
|||||||
use candle_core::{DType, Device, Tensor};
|
use candle_core::{DType, Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
use candle_transformers::models::llama as llama_dense;
|
||||||
|
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaWeights;
|
||||||
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
|
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
|
||||||
|
use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE;
|
||||||
use candle_transformers::models::qwen3 as qwen3_dense;
|
use candle_transformers::models::qwen3 as qwen3_dense;
|
||||||
|
use candle_transformers::models::qwen3_moe as qwen3_moe_dense;
|
||||||
use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec};
|
use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec};
|
||||||
use cortex_core::openai::{
|
use cortex_core::openai::{
|
||||||
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse,
|
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse,
|
||||||
@@ -100,21 +104,109 @@ pub struct TpLoadedModel {
|
|||||||
pub leader_model: Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
|
pub leader_model: Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Architecture-specific weights.
|
/// Architecture-specific weights. Each variant covers one (family,
|
||||||
|
/// source-format) pair; the dense variants take the safetensors path
|
||||||
|
/// and the `Quantized*` variants take the GGUF path.
|
||||||
///
|
///
|
||||||
/// - `Qwen3Quantized` — GGUF source, pre-quantized. Single-GPU only;
|
/// TP currently only works through `Qwen3Dense` (see `tp_qwen3.rs`);
|
||||||
/// TP sharding pre-quantized super-blocks is intractable. Stays the
|
/// every other variant is single-GPU. Quantized variants can't shard
|
||||||
/// default for small models loaded via `Qwen/Qwen3-*-GGUF` and
|
/// across GPUs at all — slicing GGUF super-blocks is intractable —
|
||||||
/// `unsloth/Qwen3-*-GGUF` repos.
|
/// and the new dense families (Llama, Qwen3 MoE) lack their own
|
||||||
/// - `Qwen3Dense` — bf16 safetensors source. The path that supports
|
/// TP-aware modules yet.
|
||||||
/// TP (Stage 7b-ii+) because slicing dense weights by row/column
|
|
||||||
/// under safetensors is mechanical. Used when `ModelSpec.quant` is
|
|
||||||
/// None; intended target for Qwen3.6-27B etc.
|
|
||||||
///
|
|
||||||
/// Stage 8 broadens this to additional families.
|
|
||||||
pub enum ModelArch {
|
pub enum ModelArch {
|
||||||
|
// Qwen3 family
|
||||||
Qwen3Quantized(QuantizedQwen3Weights),
|
Qwen3Quantized(QuantizedQwen3Weights),
|
||||||
Qwen3Dense(qwen3_dense::ModelForCausalLM),
|
Qwen3Dense(qwen3_dense::ModelForCausalLM),
|
||||||
|
Qwen3MoeQuantized(GGUFQWenMoE),
|
||||||
|
Qwen3MoeDense(qwen3_moe_dense::ModelForCausalLM),
|
||||||
|
|
||||||
|
// Llama family (covers Llama 1/2/3/3.1/3.3). Boxed because the
|
||||||
|
// wrapper carries an inline Cache + Config — without indirection
|
||||||
|
// the enum's `LlamaDense` variant is several hundred bytes larger
|
||||||
|
// than the others (clippy::large_enum_variant).
|
||||||
|
LlamaQuantized(QuantizedLlamaWeights),
|
||||||
|
LlamaDense(Box<LlamaDense>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelArch {
|
||||||
|
/// One forward step on this arch with the rank-1 vocab logits
|
||||||
|
/// extracted. Hides per-family shape differences (some return
|
||||||
|
/// `[B, V]`, others `[B, 1, V]`) — every caller gets a `[V]`
|
||||||
|
/// tensor ready for sampling.
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||||
|
let raw = match self {
|
||||||
|
ModelArch::Qwen3Quantized(m) => m.forward(input, offset)?,
|
||||||
|
ModelArch::Qwen3Dense(m) => m.forward(input, offset)?,
|
||||||
|
ModelArch::Qwen3MoeQuantized(m) => m.forward(input, offset)?,
|
||||||
|
ModelArch::Qwen3MoeDense(m) => m.forward(input, offset)?,
|
||||||
|
ModelArch::LlamaQuantized(m) => m.forward(input, offset)?,
|
||||||
|
ModelArch::LlamaDense(m) => m.forward(input, offset)?,
|
||||||
|
};
|
||||||
|
squeeze_to_vocab(&raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset the KV cache before each new request so we don't attend
|
||||||
|
/// over a previous request's tokens. Some architectures have an
|
||||||
|
/// in-place reset; Llama needs a Cache rebuild (held inline in
|
||||||
|
/// the wrapper).
|
||||||
|
pub fn clear_kv_cache(&mut self) -> Result<()> {
|
||||||
|
match self {
|
||||||
|
ModelArch::Qwen3Quantized(_) => Ok(()), /* keeps cache by design;
|
||||||
|
* forward() handles offset */
|
||||||
|
ModelArch::Qwen3Dense(m) => {
|
||||||
|
m.clear_kv_cache();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
ModelArch::Qwen3MoeQuantized(_) => Ok(()),
|
||||||
|
ModelArch::Qwen3MoeDense(m) => {
|
||||||
|
m.clear_kv_cache();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
ModelArch::LlamaQuantized(_) => Ok(()),
|
||||||
|
ModelArch::LlamaDense(m) => m.clear_kv_cache(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Squeeze any leading singleton dims off the logits tensor so the
|
||||||
|
/// caller gets a rank-1 `[vocab_size]` slice ready for sampling. Bails
|
||||||
|
/// on a non-singleton leading dim (would mean a batched forward, which
|
||||||
|
/// no caller emits today).
|
||||||
|
fn squeeze_to_vocab(t: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut t = t.clone();
|
||||||
|
while t.dims().len() > 1 {
|
||||||
|
if t.dims()[0] != 1 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"logits expected to start with a singleton dim, got shape {:?}",
|
||||||
|
t.dims()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
t = t.squeeze(0)?;
|
||||||
|
}
|
||||||
|
Ok(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Llama dense wrapper. Bundles candle's `Llama` model with its
|
||||||
|
/// externally-managed `Cache` plus enough config to rebuild the
|
||||||
|
/// cache on `clear_kv_cache` (Llama's Cache doesn't expose a reset).
|
||||||
|
pub struct LlamaDense {
|
||||||
|
model: llama_dense::Llama,
|
||||||
|
cache: llama_dense::Cache,
|
||||||
|
config: llama_dense::Config,
|
||||||
|
dtype: DType,
|
||||||
|
device: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlamaDense {
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||||
|
Ok(self.model.forward(input, offset, &mut self.cache)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) -> Result<()> {
|
||||||
|
self.cache = llama_dense::Cache::new(true, self.dtype, &self.config, &self.device)
|
||||||
|
.context("rebuild Llama Cache for new request")?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Repetition penalty applied to recently-generated tokens before
|
/// Repetition penalty applied to recently-generated tokens before
|
||||||
@@ -130,11 +222,10 @@ const REPEAT_LAST_N: usize = 64;
|
|||||||
|
|
||||||
/// Architectures the dense safetensors path can construct. Keep
|
/// Architectures the dense safetensors path can construct. Keep
|
||||||
/// alphabetical; one entry per supported `config.json#/model_type`
|
/// alphabetical; one entry per supported `config.json#/model_type`
|
||||||
/// value. New entries land alongside a new `ModelArch` variant + a new
|
/// value. New entries land alongside a new `ModelArch` variant + a
|
||||||
/// dispatch branch in `load_arch_dense` / `run_inference` /
|
/// dispatch branch in `load_arch_dense` (plus, for TP, a parallel
|
||||||
/// `run_inference_streaming` (plus, for TP, a parallel pattern in
|
/// pattern in `tp_qwen3.rs`).
|
||||||
/// `tp_qwen3.rs`).
|
const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["llama", "qwen3", "qwen3_moe"];
|
||||||
const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["qwen3"];
|
|
||||||
|
|
||||||
/// Pre-flight check the operator's `config.json` against the set of
|
/// Pre-flight check the operator's `config.json` against the set of
|
||||||
/// architectures the dense path actually knows how to build. Surfaces
|
/// architectures the dense path actually knows how to build. Surfaces
|
||||||
@@ -362,6 +453,9 @@ impl CandleHarness {
|
|||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
tracing::info!(architecture = %architecture, "GGUF architecture");
|
tracing::info!(architecture = %architecture, "GGUF architecture");
|
||||||
|
|
||||||
|
// The `general.architecture` GGUF metadata key follows
|
||||||
|
// llama.cpp conventions (lowercase, no underscores in some
|
||||||
|
// cases) — `qwen3moe`, not `qwen3_moe`.
|
||||||
match architecture.as_str() {
|
match architecture.as_str() {
|
||||||
"qwen3" => {
|
"qwen3" => {
|
||||||
let weights =
|
let weights =
|
||||||
@@ -369,8 +463,25 @@ impl CandleHarness {
|
|||||||
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
|
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
|
||||||
Ok(ModelArch::Qwen3Quantized(weights))
|
Ok(ModelArch::Qwen3Quantized(weights))
|
||||||
}
|
}
|
||||||
|
"qwen3moe" => {
|
||||||
|
// GGUFQWenMoE takes an explicit compute dtype
|
||||||
|
// alongside the device — F16 matches the GGUF
|
||||||
|
// weights' typical accumulation precision and
|
||||||
|
// gives the best tokens/sec on consumer cards.
|
||||||
|
let weights =
|
||||||
|
GGUFQWenMoE::from_gguf(content, &mut file, &device_for_load, DType::F16)
|
||||||
|
.map_err(|e| anyhow::anyhow!("from_gguf qwen3_moe: {e}"))?;
|
||||||
|
Ok(ModelArch::Qwen3MoeQuantized(weights))
|
||||||
|
}
|
||||||
|
"llama" => {
|
||||||
|
let weights =
|
||||||
|
QuantizedLlamaWeights::from_gguf(content, &mut file, &device_for_load)
|
||||||
|
.map_err(|e| anyhow::anyhow!("from_gguf llama: {e}"))?;
|
||||||
|
Ok(ModelArch::LlamaQuantized(weights))
|
||||||
|
}
|
||||||
other => anyhow::bail!(
|
other => anyhow::bail!(
|
||||||
"unsupported GGUF architecture '{other}'; quantized path only supports qwen3"
|
"unsupported GGUF architecture '{other}'; quantized path supports \
|
||||||
|
qwen3, qwen3moe, llama"
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -396,32 +507,84 @@ impl CandleHarness {
|
|||||||
let model_id_for_log = spec.model_id.clone();
|
let model_id_for_log = spec.model_id.clone();
|
||||||
|
|
||||||
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
|
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
|
||||||
tracing::info!(
|
|
||||||
model = %model_id_for_log,
|
|
||||||
shards = safetensors_paths.len(),
|
|
||||||
"loading dense Qwen3 from safetensors"
|
|
||||||
);
|
|
||||||
let cfg_text = std::fs::read_to_string(&config_path).context("read config.json")?;
|
let cfg_text = std::fs::read_to_string(&config_path).context("read config.json")?;
|
||||||
check_dense_config_supported(&cfg_text, &model_id_for_log)?;
|
check_dense_config_supported(&cfg_text, &model_id_for_log)?;
|
||||||
let cfg: qwen3_dense::Config =
|
// Peek at model_type to choose the family before the
|
||||||
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
|
// typed deserialize — each family has its own Config.
|
||||||
|
let model_type = serde_json::from_str::<serde_json::Value>(&cfg_text)
|
||||||
|
.ok()
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|v| v.get("model_type"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
tracing::info!(
|
||||||
|
model = %model_id_for_log,
|
||||||
|
model_type = %model_type,
|
||||||
|
shards = safetensors_paths.len(),
|
||||||
|
"loading dense model from safetensors"
|
||||||
|
);
|
||||||
|
|
||||||
// bf16 is the canonical Qwen3 distribution dtype. CUDA
|
// bf16 is the canonical distribution dtype for Qwen3 /
|
||||||
// devices on Ada+ support it; Ampere also supports bf16
|
// Llama 3 / Qwen3 MoE. CUDA on Ada+ has hardware bf16;
|
||||||
// natively. CPU candle handles bf16 via emulation.
|
// Ampere has it too. CPU emulates.
|
||||||
let dtype = DType::BF16;
|
let dtype = DType::BF16;
|
||||||
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
|
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
|
||||||
// mutation of the underlying files by another process while
|
// mutation by another process while we hold the mapping is
|
||||||
// we hold the mapping is UB. We trust that nothing else on
|
// UB. We trust the HF cache is immutable-by-design.
|
||||||
// the host modifies the HF cache files during a model's
|
|
||||||
// lifetime (hf-hub itself is immutable-by-design).
|
|
||||||
let vb = unsafe {
|
let vb = unsafe {
|
||||||
VarBuilder::from_mmaped_safetensors(&safetensors_paths, dtype, &device_for_load)
|
VarBuilder::from_mmaped_safetensors(&safetensors_paths, dtype, &device_for_load)
|
||||||
.context("build VarBuilder over safetensors")?
|
.context("build VarBuilder over safetensors")?
|
||||||
};
|
};
|
||||||
|
|
||||||
|
match model_type.as_str() {
|
||||||
|
"qwen3" => {
|
||||||
|
let cfg: qwen3_dense::Config =
|
||||||
|
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
|
||||||
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
|
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
|
||||||
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
|
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
|
||||||
Ok(ModelArch::Qwen3Dense(model))
|
Ok(ModelArch::Qwen3Dense(model))
|
||||||
|
}
|
||||||
|
"qwen3_moe" => {
|
||||||
|
let cfg: qwen3_moe_dense::Config =
|
||||||
|
serde_json::from_str(&cfg_text).context("parse Qwen3 MoE config.json")?;
|
||||||
|
let model = qwen3_moe_dense::ModelForCausalLM::new(&cfg, vb)
|
||||||
|
.map_err(|e| anyhow::anyhow!("build Qwen3 MoE dense model: {e}"))?;
|
||||||
|
Ok(ModelArch::Qwen3MoeDense(model))
|
||||||
|
}
|
||||||
|
"llama" => {
|
||||||
|
let cfg: llama_dense::LlamaConfig =
|
||||||
|
serde_json::from_str(&cfg_text).context("parse Llama config.json")?;
|
||||||
|
// Llama has multiple sub-variants (Llama 1 has no
|
||||||
|
// GQA; Llama 3 does). `LlamaConfig::into_config`
|
||||||
|
// resolves the right shape; the `use_flash_attn`
|
||||||
|
// arg defaults to false — the flash kernel is a
|
||||||
|
// separate feature flag and uses extra VRAM.
|
||||||
|
let config = cfg.into_config(false);
|
||||||
|
let cache = llama_dense::Cache::new(true, dtype, &config, &device_for_load)
|
||||||
|
.context("build Llama Cache")?;
|
||||||
|
let model = llama_dense::Llama::load(vb, &config)
|
||||||
|
.map_err(|e| anyhow::anyhow!("build Llama dense model: {e}"))?;
|
||||||
|
Ok(ModelArch::LlamaDense(Box::new(LlamaDense {
|
||||||
|
model,
|
||||||
|
cache,
|
||||||
|
config,
|
||||||
|
dtype,
|
||||||
|
device: device_for_load,
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
// Defensive: `check_dense_config_supported` already
|
||||||
|
// gated on the supported set, so this branch is
|
||||||
|
// unreachable unless that list and the match here
|
||||||
|
// drift apart.
|
||||||
|
anyhow::bail!(
|
||||||
|
"unrouted supported model_type '{other}' — \
|
||||||
|
DENSE_SUPPORTED_MODEL_TYPES and load_arch_dense \
|
||||||
|
must stay in sync"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.context("blocking dense load task panicked")??;
|
.context("blocking dense load task panicked")??;
|
||||||
@@ -1375,25 +1538,10 @@ fn run_inference(
|
|||||||
|
|
||||||
let mut generated: Vec<u32> = Vec::new();
|
let mut generated: Vec<u32> = Vec::new();
|
||||||
|
|
||||||
let mut next_token = match arch {
|
arch.clear_kv_cache()?;
|
||||||
ModelArch::Qwen3Quantized(model) => {
|
|
||||||
model.clear_kv_cache();
|
|
||||||
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, 0)?;
|
let logits = arch.forward(&input, 0)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
|
||||||
sample_with_penalty(&logits, &generated, &mut logits_processor)?
|
|
||||||
}
|
|
||||||
ModelArch::Qwen3Dense(model) => {
|
|
||||||
model.clear_kv_cache();
|
|
||||||
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
|
||||||
// qwen3::ModelForCausalLM::forward returns [B, 1, V] —
|
|
||||||
// no final squeeze on the dense path, unlike the quantized
|
|
||||||
// variant which returns [B, V]. Strip both batch and seq
|
|
||||||
// dims to get the rank-1 logits LogitsProcessor expects.
|
|
||||||
let logits = model.forward(&input, 0)?.squeeze(0)?.squeeze(0)?;
|
|
||||||
sample_with_penalty(&logits, &generated, &mut logits_processor)?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if Some(next_token) == eos_id {
|
if Some(next_token) == eos_id {
|
||||||
return Ok((generated, "stop".into()));
|
return Ok((generated, "stop".into()));
|
||||||
@@ -1401,23 +1549,9 @@ fn run_inference(
|
|||||||
generated.push(next_token);
|
generated.push(next_token);
|
||||||
|
|
||||||
for index in 0..max_new.saturating_sub(1) {
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
next_token = match arch {
|
|
||||||
ModelArch::Qwen3Quantized(model) => {
|
|
||||||
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
|
||||||
let logits = logits.squeeze(0)?;
|
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
|
||||||
sample_with_penalty(&logits, &generated, &mut logits_processor)?
|
|
||||||
}
|
|
||||||
ModelArch::Qwen3Dense(model) => {
|
|
||||||
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
|
||||||
// Dense returns [B, 1, V]; strip both leading dims.
|
|
||||||
let logits = model
|
|
||||||
.forward(&input, prompt_tokens.len() + index)?
|
|
||||||
.squeeze(0)?
|
|
||||||
.squeeze(0)?;
|
|
||||||
sample_with_penalty(&logits, &generated, &mut logits_processor)?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if Some(next_token) == eos_id {
|
if Some(next_token) == eos_id {
|
||||||
return Ok((generated, "stop".into()));
|
return Ok((generated, "stop".into()));
|
||||||
}
|
}
|
||||||
@@ -1465,22 +1599,10 @@ fn run_inference_streaming(
|
|||||||
let mut decoded_prefix = String::new();
|
let mut decoded_prefix = String::new();
|
||||||
let mut finish_reason = "length".to_string();
|
let mut finish_reason = "length".to_string();
|
||||||
|
|
||||||
let mut next_token = match arch {
|
arch.clear_kv_cache()?;
|
||||||
ModelArch::Qwen3Quantized(model) => {
|
|
||||||
model.clear_kv_cache();
|
|
||||||
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, 0)?;
|
let logits = arch.forward(&input, 0)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
|
||||||
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
|
|
||||||
}
|
|
||||||
ModelArch::Qwen3Dense(model) => {
|
|
||||||
model.clear_kv_cache();
|
|
||||||
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
|
||||||
let logits = model.forward(&input, 0)?;
|
|
||||||
let logits = logits.squeeze(0)?;
|
|
||||||
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result<bool> {
|
let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result<bool> {
|
||||||
let full = tokenizer
|
let full = tokenizer
|
||||||
@@ -1521,23 +1643,9 @@ fn run_inference_streaming(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for index in 0..max_new.saturating_sub(1) {
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
next_token = match arch {
|
|
||||||
ModelArch::Qwen3Quantized(model) => {
|
|
||||||
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
|
||||||
let logits = logits.squeeze(0)?;
|
next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
|
||||||
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
|
|
||||||
}
|
|
||||||
ModelArch::Qwen3Dense(model) => {
|
|
||||||
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
|
||||||
// Dense returns [B, 1, V]; strip both leading dims.
|
|
||||||
let logits = model
|
|
||||||
.forward(&input, prompt_tokens.len() + index)?
|
|
||||||
.squeeze(0)?
|
|
||||||
.squeeze(0)?;
|
|
||||||
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if Some(next_token) == eos_id {
|
if Some(next_token) == eos_id {
|
||||||
finish_reason = "stop".into();
|
finish_reason = "stop".into();
|
||||||
break;
|
break;
|
||||||
|
|||||||
Reference in New Issue
Block a user