feat(stage-8b): Llama + Qwen3 MoE families on the candle harness
All checks were successful
CI / Format (push) Successful in 31s
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Clippy (push) Successful in 2m6s
build-prerelease / Build neuron-blackwell (push) Successful in 3m50s
build-prerelease / Build cortex binary (push) Successful in 4m54s
CI / Test (push) Successful in 4m58s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 4m43s
build-prerelease / Build neuron-ada (push) Successful in 5m8s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
All checks were successful
CI / Format (push) Successful in 31s
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Clippy (push) Successful in 2m6s
build-prerelease / Build neuron-blackwell (push) Successful in 3m50s
build-prerelease / Build cortex binary (push) Successful in 4m54s
CI / Test (push) Successful in 4m58s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 4m43s
build-prerelease / Build neuron-ada (push) Successful in 5m8s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
Broadens the single-GPU dense and quantized paths to cover three non-Qwen3 architectures already shipped by candle-transformers. TP for these is a separate stage (each family would need its own tp_*.rs mirroring tp_qwen3.rs). `ModelArch` gains four variants: - LlamaDense (boxed — wraps Llama + an inline Cache + the config it takes to rebuild the cache, since candle::llama::Cache has no reset) - LlamaQuantized (candle_transformers::models::quantized_llama) - Qwen3MoeDense (candle::models::qwen3_moe::ModelForCausalLM) - Qwen3MoeQuantized (candle::models::quantized_qwen3_moe::GGUFQWenMoE — takes an explicit compute dtype; F16 by default for best consumer-GPU throughput) The dispatch is method-based now: - `ModelArch::forward(&mut self, input, offset) -> Result<Tensor>` with a shared `squeeze_to_vocab` normalising shape differences (qwen3 returns [B,1,V]; quantized_qwen3 returns [B,V]; new families may differ again — the helper handles all of them). - `ModelArch::clear_kv_cache(&mut self) -> Result<()>`. Llama needs a Cache rebuild because its Cache has no in-place reset; the new `LlamaDense` wrapper holds the bits needed to do it. `run_inference` / `run_inference_streaming` collapse to a single dispatch path: no more per-variant match arms in the hot loop, and new architectures pick up streaming + non-streaming for free with zero changes outside `ModelArch`. DENSE_SUPPORTED_MODEL_TYPES is now ["llama", "qwen3", "qwen3_moe"]. GGUF arch switch grows "qwen3moe" + "llama" branches (qwen3moe with no underscore matches llama.cpp's general.architecture convention). Stage 8a's diagnostic auto-reports the new supported set. The `LlamaDense` variant is boxed because the wrapper's inline Cache + Config makes it 544 bytes vs ~300 for everything else (clippy::large_enum_variant). Verified: cargo test --workspace passes 66 tests; cargo clippy CPU and `--features cuda` both clean (the cuda check ran inside the locally-built `neuron-build-local` container with the math_functions.h patch applied). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -14,8 +14,12 @@ use candle_core::quantized::gguf_file;
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use candle_transformers::models::llama as llama_dense;
|
||||
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaWeights;
|
||||
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
|
||||
use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE;
|
||||
use candle_transformers::models::qwen3 as qwen3_dense;
|
||||
use candle_transformers::models::qwen3_moe as qwen3_moe_dense;
|
||||
use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec};
|
||||
use cortex_core::openai::{
|
||||
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse,
|
||||
@@ -100,21 +104,109 @@ pub struct TpLoadedModel {
|
||||
pub leader_model: Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
|
||||
}
|
||||
|
||||
/// Architecture-specific weights.
|
||||
/// Architecture-specific weights. Each variant covers one (family,
|
||||
/// source-format) pair; the dense variants take the safetensors path
|
||||
/// and the `Quantized*` variants take the GGUF path.
|
||||
///
|
||||
/// - `Qwen3Quantized` — GGUF source, pre-quantized. Single-GPU only;
|
||||
/// TP sharding pre-quantized super-blocks is intractable. Stays the
|
||||
/// default for small models loaded via `Qwen/Qwen3-*-GGUF` and
|
||||
/// `unsloth/Qwen3-*-GGUF` repos.
|
||||
/// - `Qwen3Dense` — bf16 safetensors source. The path that supports
|
||||
/// TP (Stage 7b-ii+) because slicing dense weights by row/column
|
||||
/// under safetensors is mechanical. Used when `ModelSpec.quant` is
|
||||
/// None; intended target for Qwen3.6-27B etc.
|
||||
///
|
||||
/// Stage 8 broadens this to additional families.
|
||||
/// TP currently only works through `Qwen3Dense` (see `tp_qwen3.rs`);
|
||||
/// every other variant is single-GPU. Quantized variants can't shard
|
||||
/// across GPUs at all — slicing GGUF super-blocks is intractable —
|
||||
/// and the new dense families (Llama, Qwen3 MoE) lack their own
|
||||
/// TP-aware modules yet.
|
||||
pub enum ModelArch {
|
||||
// Qwen3 family
|
||||
Qwen3Quantized(QuantizedQwen3Weights),
|
||||
Qwen3Dense(qwen3_dense::ModelForCausalLM),
|
||||
Qwen3MoeQuantized(GGUFQWenMoE),
|
||||
Qwen3MoeDense(qwen3_moe_dense::ModelForCausalLM),
|
||||
|
||||
// Llama family (covers Llama 1/2/3/3.1/3.3). Boxed because the
|
||||
// wrapper carries an inline Cache + Config — without indirection
|
||||
// the enum's `LlamaDense` variant is several hundred bytes larger
|
||||
// than the others (clippy::large_enum_variant).
|
||||
LlamaQuantized(QuantizedLlamaWeights),
|
||||
LlamaDense(Box<LlamaDense>),
|
||||
}
|
||||
|
||||
impl ModelArch {
|
||||
/// One forward step on this arch with the rank-1 vocab logits
|
||||
/// extracted. Hides per-family shape differences (some return
|
||||
/// `[B, V]`, others `[B, 1, V]`) — every caller gets a `[V]`
|
||||
/// tensor ready for sampling.
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||
let raw = match self {
|
||||
ModelArch::Qwen3Quantized(m) => m.forward(input, offset)?,
|
||||
ModelArch::Qwen3Dense(m) => m.forward(input, offset)?,
|
||||
ModelArch::Qwen3MoeQuantized(m) => m.forward(input, offset)?,
|
||||
ModelArch::Qwen3MoeDense(m) => m.forward(input, offset)?,
|
||||
ModelArch::LlamaQuantized(m) => m.forward(input, offset)?,
|
||||
ModelArch::LlamaDense(m) => m.forward(input, offset)?,
|
||||
};
|
||||
squeeze_to_vocab(&raw)
|
||||
}
|
||||
|
||||
/// Reset the KV cache before each new request so we don't attend
|
||||
/// over a previous request's tokens. Some architectures have an
|
||||
/// in-place reset; Llama needs a Cache rebuild (held inline in
|
||||
/// the wrapper).
|
||||
pub fn clear_kv_cache(&mut self) -> Result<()> {
|
||||
match self {
|
||||
ModelArch::Qwen3Quantized(_) => Ok(()), /* keeps cache by design;
|
||||
* forward() handles offset */
|
||||
ModelArch::Qwen3Dense(m) => {
|
||||
m.clear_kv_cache();
|
||||
Ok(())
|
||||
}
|
||||
ModelArch::Qwen3MoeQuantized(_) => Ok(()),
|
||||
ModelArch::Qwen3MoeDense(m) => {
|
||||
m.clear_kv_cache();
|
||||
Ok(())
|
||||
}
|
||||
ModelArch::LlamaQuantized(_) => Ok(()),
|
||||
ModelArch::LlamaDense(m) => m.clear_kv_cache(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Squeeze any leading singleton dims off the logits tensor so the
|
||||
/// caller gets a rank-1 `[vocab_size]` slice ready for sampling. Bails
|
||||
/// on a non-singleton leading dim (would mean a batched forward, which
|
||||
/// no caller emits today).
|
||||
fn squeeze_to_vocab(t: &Tensor) -> Result<Tensor> {
|
||||
let mut t = t.clone();
|
||||
while t.dims().len() > 1 {
|
||||
if t.dims()[0] != 1 {
|
||||
anyhow::bail!(
|
||||
"logits expected to start with a singleton dim, got shape {:?}",
|
||||
t.dims()
|
||||
);
|
||||
}
|
||||
t = t.squeeze(0)?;
|
||||
}
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
/// Llama dense wrapper. Bundles candle's `Llama` model with its
|
||||
/// externally-managed `Cache` plus enough config to rebuild the
|
||||
/// cache on `clear_kv_cache` (Llama's Cache doesn't expose a reset).
|
||||
pub struct LlamaDense {
|
||||
model: llama_dense::Llama,
|
||||
cache: llama_dense::Cache,
|
||||
config: llama_dense::Config,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl LlamaDense {
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||
Ok(self.model.forward(input, offset, &mut self.cache)?)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) -> Result<()> {
|
||||
self.cache = llama_dense::Cache::new(true, self.dtype, &self.config, &self.device)
|
||||
.context("rebuild Llama Cache for new request")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Repetition penalty applied to recently-generated tokens before
|
||||
@@ -130,11 +222,10 @@ const REPEAT_LAST_N: usize = 64;
|
||||
|
||||
/// Architectures the dense safetensors path can construct. Keep
|
||||
/// alphabetical; one entry per supported `config.json#/model_type`
|
||||
/// value. New entries land alongside a new `ModelArch` variant + a new
|
||||
/// dispatch branch in `load_arch_dense` / `run_inference` /
|
||||
/// `run_inference_streaming` (plus, for TP, a parallel pattern in
|
||||
/// `tp_qwen3.rs`).
|
||||
const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["qwen3"];
|
||||
/// value. New entries land alongside a new `ModelArch` variant + a
|
||||
/// dispatch branch in `load_arch_dense` (plus, for TP, a parallel
|
||||
/// pattern in `tp_qwen3.rs`).
|
||||
const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["llama", "qwen3", "qwen3_moe"];
|
||||
|
||||
/// Pre-flight check the operator's `config.json` against the set of
|
||||
/// architectures the dense path actually knows how to build. Surfaces
|
||||
@@ -362,6 +453,9 @@ impl CandleHarness {
|
||||
.unwrap_or_default();
|
||||
tracing::info!(architecture = %architecture, "GGUF architecture");
|
||||
|
||||
// The `general.architecture` GGUF metadata key follows
|
||||
// llama.cpp conventions (lowercase, no underscores in some
|
||||
// cases) — `qwen3moe`, not `qwen3_moe`.
|
||||
match architecture.as_str() {
|
||||
"qwen3" => {
|
||||
let weights =
|
||||
@@ -369,8 +463,25 @@ impl CandleHarness {
|
||||
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
|
||||
Ok(ModelArch::Qwen3Quantized(weights))
|
||||
}
|
||||
"qwen3moe" => {
|
||||
// GGUFQWenMoE takes an explicit compute dtype
|
||||
// alongside the device — F16 matches the GGUF
|
||||
// weights' typical accumulation precision and
|
||||
// gives the best tokens/sec on consumer cards.
|
||||
let weights =
|
||||
GGUFQWenMoE::from_gguf(content, &mut file, &device_for_load, DType::F16)
|
||||
.map_err(|e| anyhow::anyhow!("from_gguf qwen3_moe: {e}"))?;
|
||||
Ok(ModelArch::Qwen3MoeQuantized(weights))
|
||||
}
|
||||
"llama" => {
|
||||
let weights =
|
||||
QuantizedLlamaWeights::from_gguf(content, &mut file, &device_for_load)
|
||||
.map_err(|e| anyhow::anyhow!("from_gguf llama: {e}"))?;
|
||||
Ok(ModelArch::LlamaQuantized(weights))
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"unsupported GGUF architecture '{other}'; quantized path only supports qwen3"
|
||||
"unsupported GGUF architecture '{other}'; quantized path supports \
|
||||
qwen3, qwen3moe, llama"
|
||||
),
|
||||
}
|
||||
})
|
||||
@@ -396,32 +507,84 @@ impl CandleHarness {
|
||||
let model_id_for_log = spec.model_id.clone();
|
||||
|
||||
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
|
||||
tracing::info!(
|
||||
model = %model_id_for_log,
|
||||
shards = safetensors_paths.len(),
|
||||
"loading dense Qwen3 from safetensors"
|
||||
);
|
||||
let cfg_text = std::fs::read_to_string(&config_path).context("read config.json")?;
|
||||
check_dense_config_supported(&cfg_text, &model_id_for_log)?;
|
||||
let cfg: qwen3_dense::Config =
|
||||
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
|
||||
// Peek at model_type to choose the family before the
|
||||
// typed deserialize — each family has its own Config.
|
||||
let model_type = serde_json::from_str::<serde_json::Value>(&cfg_text)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("model_type"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
tracing::info!(
|
||||
model = %model_id_for_log,
|
||||
model_type = %model_type,
|
||||
shards = safetensors_paths.len(),
|
||||
"loading dense model from safetensors"
|
||||
);
|
||||
|
||||
// bf16 is the canonical Qwen3 distribution dtype. CUDA
|
||||
// devices on Ada+ support it; Ampere also supports bf16
|
||||
// natively. CPU candle handles bf16 via emulation.
|
||||
// bf16 is the canonical distribution dtype for Qwen3 /
|
||||
// Llama 3 / Qwen3 MoE. CUDA on Ada+ has hardware bf16;
|
||||
// Ampere has it too. CPU emulates.
|
||||
let dtype = DType::BF16;
|
||||
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
|
||||
// mutation of the underlying files by another process while
|
||||
// we hold the mapping is UB. We trust that nothing else on
|
||||
// the host modifies the HF cache files during a model's
|
||||
// lifetime (hf-hub itself is immutable-by-design).
|
||||
// mutation by another process while we hold the mapping is
|
||||
// UB. We trust the HF cache is immutable-by-design.
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&safetensors_paths, dtype, &device_for_load)
|
||||
.context("build VarBuilder over safetensors")?
|
||||
};
|
||||
|
||||
match model_type.as_str() {
|
||||
"qwen3" => {
|
||||
let cfg: qwen3_dense::Config =
|
||||
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
|
||||
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
|
||||
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
|
||||
Ok(ModelArch::Qwen3Dense(model))
|
||||
}
|
||||
"qwen3_moe" => {
|
||||
let cfg: qwen3_moe_dense::Config =
|
||||
serde_json::from_str(&cfg_text).context("parse Qwen3 MoE config.json")?;
|
||||
let model = qwen3_moe_dense::ModelForCausalLM::new(&cfg, vb)
|
||||
.map_err(|e| anyhow::anyhow!("build Qwen3 MoE dense model: {e}"))?;
|
||||
Ok(ModelArch::Qwen3MoeDense(model))
|
||||
}
|
||||
"llama" => {
|
||||
let cfg: llama_dense::LlamaConfig =
|
||||
serde_json::from_str(&cfg_text).context("parse Llama config.json")?;
|
||||
// Llama has multiple sub-variants (Llama 1 has no
|
||||
// GQA; Llama 3 does). `LlamaConfig::into_config`
|
||||
// resolves the right shape; the `use_flash_attn`
|
||||
// arg defaults to false — the flash kernel is a
|
||||
// separate feature flag and uses extra VRAM.
|
||||
let config = cfg.into_config(false);
|
||||
let cache = llama_dense::Cache::new(true, dtype, &config, &device_for_load)
|
||||
.context("build Llama Cache")?;
|
||||
let model = llama_dense::Llama::load(vb, &config)
|
||||
.map_err(|e| anyhow::anyhow!("build Llama dense model: {e}"))?;
|
||||
Ok(ModelArch::LlamaDense(Box::new(LlamaDense {
|
||||
model,
|
||||
cache,
|
||||
config,
|
||||
dtype,
|
||||
device: device_for_load,
|
||||
})))
|
||||
}
|
||||
other => {
|
||||
// Defensive: `check_dense_config_supported` already
|
||||
// gated on the supported set, so this branch is
|
||||
// unreachable unless that list and the match here
|
||||
// drift apart.
|
||||
anyhow::bail!(
|
||||
"unrouted supported model_type '{other}' — \
|
||||
DENSE_SUPPORTED_MODEL_TYPES and load_arch_dense \
|
||||
must stay in sync"
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
.context("blocking dense load task panicked")??;
|
||||
@@ -1375,25 +1538,10 @@ fn run_inference(
|
||||
|
||||
let mut generated: Vec<u32> = Vec::new();
|
||||
|
||||
let mut next_token = match arch {
|
||||
ModelArch::Qwen3Quantized(model) => {
|
||||
model.clear_kv_cache();
|
||||
arch.clear_kv_cache()?;
|
||||
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
sample_with_penalty(&logits, &generated, &mut logits_processor)?
|
||||
}
|
||||
ModelArch::Qwen3Dense(model) => {
|
||||
model.clear_kv_cache();
|
||||
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
||||
// qwen3::ModelForCausalLM::forward returns [B, 1, V] —
|
||||
// no final squeeze on the dense path, unlike the quantized
|
||||
// variant which returns [B, V]. Strip both batch and seq
|
||||
// dims to get the rank-1 logits LogitsProcessor expects.
|
||||
let logits = model.forward(&input, 0)?.squeeze(0)?.squeeze(0)?;
|
||||
sample_with_penalty(&logits, &generated, &mut logits_processor)?
|
||||
}
|
||||
};
|
||||
let logits = arch.forward(&input, 0)?;
|
||||
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
|
||||
|
||||
if Some(next_token) == eos_id {
|
||||
return Ok((generated, "stop".into()));
|
||||
@@ -1401,23 +1549,9 @@ fn run_inference(
|
||||
generated.push(next_token);
|
||||
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
next_token = match arch {
|
||||
ModelArch::Qwen3Quantized(model) => {
|
||||
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
sample_with_penalty(&logits, &generated, &mut logits_processor)?
|
||||
}
|
||||
ModelArch::Qwen3Dense(model) => {
|
||||
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||
// Dense returns [B, 1, V]; strip both leading dims.
|
||||
let logits = model
|
||||
.forward(&input, prompt_tokens.len() + index)?
|
||||
.squeeze(0)?
|
||||
.squeeze(0)?;
|
||||
sample_with_penalty(&logits, &generated, &mut logits_processor)?
|
||||
}
|
||||
};
|
||||
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
|
||||
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
|
||||
if Some(next_token) == eos_id {
|
||||
return Ok((generated, "stop".into()));
|
||||
}
|
||||
@@ -1465,22 +1599,10 @@ fn run_inference_streaming(
|
||||
let mut decoded_prefix = String::new();
|
||||
let mut finish_reason = "length".to_string();
|
||||
|
||||
let mut next_token = match arch {
|
||||
ModelArch::Qwen3Quantized(model) => {
|
||||
model.clear_kv_cache();
|
||||
arch.clear_kv_cache()?;
|
||||
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
|
||||
}
|
||||
ModelArch::Qwen3Dense(model) => {
|
||||
model.clear_kv_cache();
|
||||
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
|
||||
}
|
||||
};
|
||||
let logits = arch.forward(&input, 0)?;
|
||||
let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
|
||||
|
||||
let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result<bool> {
|
||||
let full = tokenizer
|
||||
@@ -1521,23 +1643,9 @@ fn run_inference_streaming(
|
||||
}
|
||||
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
next_token = match arch {
|
||||
ModelArch::Qwen3Quantized(model) => {
|
||||
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
|
||||
}
|
||||
ModelArch::Qwen3Dense(model) => {
|
||||
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||
// Dense returns [B, 1, V]; strip both leading dims.
|
||||
let logits = model
|
||||
.forward(&input, prompt_tokens.len() + index)?
|
||||
.squeeze(0)?
|
||||
.squeeze(0)?;
|
||||
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
|
||||
}
|
||||
};
|
||||
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
|
||||
next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = "stop".into();
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user