All checks were successful
CI / Format (push) Successful in 31s
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Clippy (push) Successful in 2m6s
build-prerelease / Build neuron-blackwell (push) Successful in 3m50s
build-prerelease / Build cortex binary (push) Successful in 4m54s
CI / Test (push) Successful in 4m58s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 4m43s
build-prerelease / Build neuron-ada (push) Successful in 5m8s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
Broadens the single-GPU dense and quantized paths to cover three non-Qwen3 architectures already shipped by candle-transformers. TP for these is a separate stage (each family would need its own tp_*.rs mirroring tp_qwen3.rs). `ModelArch` gains four variants: - LlamaDense (boxed — wraps Llama + an inline Cache + the config it takes to rebuild the cache, since candle::llama::Cache has no reset) - LlamaQuantized (candle_transformers::models::quantized_llama) - Qwen3MoeDense (candle::models::qwen3_moe::ModelForCausalLM) - Qwen3MoeQuantized (candle::models::quantized_qwen3_moe::GGUFQWenMoE — takes an explicit compute dtype; F16 by default for best consumer-GPU throughput) The dispatch is method-based now: - `ModelArch::forward(&mut self, input, offset) -> Result<Tensor>` with a shared `squeeze_to_vocab` normalising shape differences (qwen3 returns [B,1,V]; quantized_qwen3 returns [B,V]; new families may differ again — the helper handles all of them). - `ModelArch::clear_kv_cache(&mut self) -> Result<()>`. Llama needs a Cache rebuild because its Cache has no in-place reset; the new `LlamaDense` wrapper holds the bits needed to do it. `run_inference` / `run_inference_streaming` collapse to a single dispatch path: no more per-variant match arms in the hot loop, and new architectures pick up streaming + non-streaming for free with zero changes outside `ModelArch`. DENSE_SUPPORTED_MODEL_TYPES is now ["llama", "qwen3", "qwen3_moe"]. GGUF arch switch grows "qwen3moe" + "llama" branches (qwen3moe with no underscore matches llama.cpp's general.architecture convention). Stage 8a's diagnostic auto-reports the new supported set. The `LlamaDense` variant is boxed because the wrapper's inline Cache + Config makes it 544 bytes vs ~300 for everything else (clippy::large_enum_variant). Verified: cargo test --workspace passes 66 tests; cargo clippy CPU and `--features cuda` both clean (the cuda check ran inside the locally-built `neuron-build-local` container with the math_functions.h patch applied). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1752 lines
67 KiB
Rust
1752 lines
67 KiB
Rust
//! Candle harness — in-process inference using huggingface/candle.
|
|
//!
|
|
//! This is the sole `Harness` implementation. Inference runs inside
|
|
//! the neuron process; there is no external subprocess.
|
|
//!
|
|
//! - Stage 2 wired GGUF (Qwen3 only) load/unload via `quantized_qwen3`.
|
|
//! - Stage 3 (this) adds `chat_completion` — a non-streaming OpenAI
|
|
//! compatible chat completion routed to the loaded model's forward
|
|
//! pass on a per-model serialised generation loop.
|
|
|
|
use anyhow::{Context, Result};
|
|
use async_trait::async_trait;
|
|
use candle_core::quantized::gguf_file;
|
|
use candle_core::{DType, Device, Tensor};
|
|
use candle_nn::VarBuilder;
|
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
|
use candle_transformers::models::llama as llama_dense;
|
|
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaWeights;
|
|
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
|
|
use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE;
|
|
use candle_transformers::models::qwen3 as qwen3_dense;
|
|
use candle_transformers::models::qwen3_moe as qwen3_moe_dense;
|
|
use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec};
|
|
use cortex_core::openai::{
|
|
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse,
|
|
ChatMessage, ChunkChoice, MessageContent, Usage,
|
|
};
|
|
use serde_json::json;
|
|
use std::collections::HashMap;
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
|
use tokenizers::Tokenizer;
|
|
use tokio::sync::{Mutex, RwLock, mpsc};
|
|
|
|
/// In-process candle harness. Owns the loaded model registry.
|
|
pub struct CandleHarness {
|
|
models: Arc<RwLock<HashMap<String, LoadedHandle>>>,
|
|
hf_cache: Option<PathBuf>,
|
|
bind_url: String,
|
|
}
|
|
|
|
/// One entry in the harness's loaded-model registry. Single-GPU loads
|
|
/// land in `Single`; loads with `tensor_parallel > 1` land in `Tp`.
|
|
/// The two variants share the same `model_id` key in the map, so
|
|
/// `list_models`, `unload_model`, and `inference_endpoint` can walk
|
|
/// them uniformly without branching the storage layout.
|
|
///
|
|
/// `Clone` is cheap: both variants hold `Arc<_>` and cloning just bumps
|
|
/// the refcount.
|
|
#[derive(Clone)]
|
|
pub enum LoadedHandle {
|
|
Single(Arc<LoadedModel>),
|
|
#[cfg(feature = "cuda")]
|
|
Tp(Arc<TpLoadedModel>),
|
|
}
|
|
|
|
impl LoadedHandle {
|
|
pub fn model_id(&self) -> &str {
|
|
match self {
|
|
LoadedHandle::Single(m) => &m.model_id,
|
|
#[cfg(feature = "cuda")]
|
|
LoadedHandle::Tp(m) => &m.model_id,
|
|
}
|
|
}
|
|
|
|
pub fn devices(&self) -> Vec<u32> {
|
|
match self {
|
|
LoadedHandle::Single(m) => m.devices.clone(),
|
|
#[cfg(feature = "cuda")]
|
|
LoadedHandle::Tp(m) => m.devices.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A loaded model with its tokenizer, device placement, and architecture-
|
|
/// specific weights. The `arch` is `Arc<Mutex<>>` so the lock guard can be
|
|
/// moved into `spawn_blocking` for synchronous candle forward passes.
|
|
pub struct LoadedModel {
|
|
pub model_id: String,
|
|
pub arch: Arc<Mutex<ModelArch>>,
|
|
pub tokenizer: Tokenizer,
|
|
pub device: Device,
|
|
pub quant: Option<String>,
|
|
pub devices: Vec<u32>,
|
|
}
|
|
|
|
/// Tensor-parallel loaded model. Holds the leader's rank-0 shard
|
|
/// (which the inference loop drives via spawn_blocking) and the
|
|
/// `WorkerPool` (which drives every non-zero rank over the RPC
|
|
/// channel). Both are behind tokio Mutexes so concurrent inference
|
|
/// requests against the same model are serialised; concurrent loads
|
|
/// for *different* models would each have their own pool.
|
|
#[cfg(feature = "cuda")]
|
|
pub struct TpLoadedModel {
|
|
pub model_id: String,
|
|
pub tokenizer: Tokenizer,
|
|
pub devices: Vec<u32>,
|
|
/// One end-to-end gate: the pool's RPC stream isn't safe to use
|
|
/// concurrently and the leader shard's KV cache mutates with every
|
|
/// step. The same Mutex covers both for the simplest correctness
|
|
/// story.
|
|
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
|
|
pub leader_model: Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
|
|
}
|
|
|
|
/// Architecture-specific weights. Each variant covers one (family,
|
|
/// source-format) pair; the dense variants take the safetensors path
|
|
/// and the `Quantized*` variants take the GGUF path.
|
|
///
|
|
/// TP currently only works through `Qwen3Dense` (see `tp_qwen3.rs`);
|
|
/// every other variant is single-GPU. Quantized variants can't shard
|
|
/// across GPUs at all — slicing GGUF super-blocks is intractable —
|
|
/// and the new dense families (Llama, Qwen3 MoE) lack their own
|
|
/// TP-aware modules yet.
|
|
pub enum ModelArch {
|
|
// Qwen3 family
|
|
Qwen3Quantized(QuantizedQwen3Weights),
|
|
Qwen3Dense(qwen3_dense::ModelForCausalLM),
|
|
Qwen3MoeQuantized(GGUFQWenMoE),
|
|
Qwen3MoeDense(qwen3_moe_dense::ModelForCausalLM),
|
|
|
|
// Llama family (covers Llama 1/2/3/3.1/3.3). Boxed because the
|
|
// wrapper carries an inline Cache + Config — without indirection
|
|
// the enum's `LlamaDense` variant is several hundred bytes larger
|
|
// than the others (clippy::large_enum_variant).
|
|
LlamaQuantized(QuantizedLlamaWeights),
|
|
LlamaDense(Box<LlamaDense>),
|
|
}
|
|
|
|
impl ModelArch {
|
|
/// One forward step on this arch with the rank-1 vocab logits
|
|
/// extracted. Hides per-family shape differences (some return
|
|
/// `[B, V]`, others `[B, 1, V]`) — every caller gets a `[V]`
|
|
/// tensor ready for sampling.
|
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
|
let raw = match self {
|
|
ModelArch::Qwen3Quantized(m) => m.forward(input, offset)?,
|
|
ModelArch::Qwen3Dense(m) => m.forward(input, offset)?,
|
|
ModelArch::Qwen3MoeQuantized(m) => m.forward(input, offset)?,
|
|
ModelArch::Qwen3MoeDense(m) => m.forward(input, offset)?,
|
|
ModelArch::LlamaQuantized(m) => m.forward(input, offset)?,
|
|
ModelArch::LlamaDense(m) => m.forward(input, offset)?,
|
|
};
|
|
squeeze_to_vocab(&raw)
|
|
}
|
|
|
|
/// Reset the KV cache before each new request so we don't attend
|
|
/// over a previous request's tokens. Some architectures have an
|
|
/// in-place reset; Llama needs a Cache rebuild (held inline in
|
|
/// the wrapper).
|
|
pub fn clear_kv_cache(&mut self) -> Result<()> {
|
|
match self {
|
|
ModelArch::Qwen3Quantized(_) => Ok(()), /* keeps cache by design;
|
|
* forward() handles offset */
|
|
ModelArch::Qwen3Dense(m) => {
|
|
m.clear_kv_cache();
|
|
Ok(())
|
|
}
|
|
ModelArch::Qwen3MoeQuantized(_) => Ok(()),
|
|
ModelArch::Qwen3MoeDense(m) => {
|
|
m.clear_kv_cache();
|
|
Ok(())
|
|
}
|
|
ModelArch::LlamaQuantized(_) => Ok(()),
|
|
ModelArch::LlamaDense(m) => m.clear_kv_cache(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Squeeze any leading singleton dims off the logits tensor so the
|
|
/// caller gets a rank-1 `[vocab_size]` slice ready for sampling. Bails
|
|
/// on a non-singleton leading dim (would mean a batched forward, which
|
|
/// no caller emits today).
|
|
fn squeeze_to_vocab(t: &Tensor) -> Result<Tensor> {
|
|
let mut t = t.clone();
|
|
while t.dims().len() > 1 {
|
|
if t.dims()[0] != 1 {
|
|
anyhow::bail!(
|
|
"logits expected to start with a singleton dim, got shape {:?}",
|
|
t.dims()
|
|
);
|
|
}
|
|
t = t.squeeze(0)?;
|
|
}
|
|
Ok(t)
|
|
}
|
|
|
|
/// Llama dense wrapper. Bundles candle's `Llama` model with its
|
|
/// externally-managed `Cache` plus enough config to rebuild the
|
|
/// cache on `clear_kv_cache` (Llama's Cache doesn't expose a reset).
|
|
pub struct LlamaDense {
|
|
model: llama_dense::Llama,
|
|
cache: llama_dense::Cache,
|
|
config: llama_dense::Config,
|
|
dtype: DType,
|
|
device: Device,
|
|
}
|
|
|
|
impl LlamaDense {
|
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
|
Ok(self.model.forward(input, offset, &mut self.cache)?)
|
|
}
|
|
|
|
pub fn clear_kv_cache(&mut self) -> Result<()> {
|
|
self.cache = llama_dense::Cache::new(true, self.dtype, &self.config, &self.device)
|
|
.context("rebuild Llama Cache for new request")?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// Repetition penalty applied to recently-generated tokens before
|
|
/// sampling. 1.0 disables it; >1.0 makes recently-emitted tokens less
|
|
/// likely. mistral.rs and llama.cpp default to 1.1, which is enough to
|
|
/// stop small quantized models from degenerating into "Wait, no, no..."
|
|
/// loops without distorting normal output.
|
|
const REPEAT_PENALTY: f32 = 1.1;
|
|
|
|
/// Number of recently-generated tokens to feed into the repetition
|
|
/// penalty. Matches the candle quantized-qwen3 example default.
|
|
const REPEAT_LAST_N: usize = 64;
|
|
|
|
/// Architectures the dense safetensors path can construct. Keep
|
|
/// alphabetical; one entry per supported `config.json#/model_type`
|
|
/// value. New entries land alongside a new `ModelArch` variant + a
|
|
/// dispatch branch in `load_arch_dense` (plus, for TP, a parallel
|
|
/// pattern in `tp_qwen3.rs`).
|
|
const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["llama", "qwen3", "qwen3_moe"];
|
|
|
|
/// Pre-flight check the operator's `config.json` against the set of
|
|
/// architectures the dense path actually knows how to build. Surfaces
|
|
/// architecture mismatches as a single clean error before the serde
|
|
/// deserializer trips on missing fields — the latter happens because
|
|
/// every architecture has different hyperparameter names, so when the
|
|
/// JSON is e.g. Qwen3.6 wrapped under `text_config: {...}`, candle's
|
|
/// `qwen3::Config` finds none of its expected top-level fields and
|
|
/// fails with a cryptic `missing field 'vocab_size' at line N col 1`.
|
|
///
|
|
/// The result message names the model_type we saw, the supported set,
|
|
/// and points at the files an operator (or future contributor) needs
|
|
/// to touch to grow the supported set.
|
|
fn check_dense_config_supported(config_json: &str, model_id: &str) -> Result<()> {
|
|
let v: serde_json::Value = serde_json::from_str(config_json)
|
|
.with_context(|| format!("parse config.json for '{model_id}' as JSON"))?;
|
|
let model_type = v.get("model_type").and_then(|x| x.as_str()).unwrap_or("");
|
|
if model_type.is_empty() {
|
|
anyhow::bail!(
|
|
"config.json for '{model_id}' is missing `model_type`; the dense \
|
|
path needs it to gate architecture support (supported: {:?})",
|
|
DENSE_SUPPORTED_MODEL_TYPES
|
|
);
|
|
}
|
|
if DENSE_SUPPORTED_MODEL_TYPES.contains(&model_type) {
|
|
return Ok(());
|
|
}
|
|
// Bonus context: the model usually also lists architectures, which
|
|
// is what `transformers` keys on. Including it makes the error
|
|
// self-contained.
|
|
let architectures = v
|
|
.get("architectures")
|
|
.and_then(|x| x.as_array())
|
|
.map(|a| {
|
|
a.iter()
|
|
.filter_map(|v| v.as_str().map(String::from))
|
|
.collect::<Vec<_>>()
|
|
})
|
|
.unwrap_or_default();
|
|
anyhow::bail!(
|
|
"unsupported model_type '{model_type}' for '{model_id}' \
|
|
(architectures={architectures:?}); the dense path supports {:?}. \
|
|
Add a `ModelArch` variant + load/forward branches in \
|
|
crates/neuron/src/harness/candle.rs (and the TP analogue in \
|
|
tp_qwen3.rs) to extend coverage.",
|
|
DENSE_SUPPORTED_MODEL_TYPES
|
|
);
|
|
}
|
|
|
|
/// Resolve the effective HuggingFace cache directory for the candle
|
|
/// harness. Precedence (first hit wins):
|
|
///
|
|
/// 1. Explicit `hf_cache` from `[harness.candle]` in `neuron.toml`.
|
|
/// Operator's wishes always win.
|
|
/// 2. `HF_HUB_CACHE` env var. The Python `huggingface_hub` library
|
|
/// points at the cache root directly with this var; the Rust
|
|
/// `hf-hub` crate doesn't read it natively, so we bridge here.
|
|
/// Honouring it lets a neuron host share a cache directory with
|
|
/// Python tooling and other harnesses without per-tool config.
|
|
/// 3. `HF_HOME` env var. Canonical HuggingFace base directory; the
|
|
/// cache lives at `$HF_HOME/hub`. Hf-hub respects this on its own,
|
|
/// but we resolve it here too so the resulting path shows up in
|
|
/// logs alongside the explicit/HF_HUB_CACHE cases.
|
|
/// 4. `None`. Falls through to `hf-hub`'s default
|
|
/// (`~/.cache/huggingface/hub`).
|
|
fn resolve_hf_cache(explicit: Option<PathBuf>) -> Option<PathBuf> {
|
|
if let Some(p) = explicit {
|
|
return Some(p);
|
|
}
|
|
if let Ok(v) = std::env::var("HF_HUB_CACHE")
|
|
&& !v.is_empty()
|
|
{
|
|
return Some(PathBuf::from(v));
|
|
}
|
|
if let Ok(v) = std::env::var("HF_HOME")
|
|
&& !v.is_empty()
|
|
{
|
|
return Some(PathBuf::from(v).join("hub"));
|
|
}
|
|
None
|
|
}
|
|
|
|
/// Apply the repetition penalty (if any) to the prediction logits and
|
|
/// then sample. Centralises the prefill / generation-loop call sites
|
|
/// so they share identical sampling behaviour.
|
|
fn sample_with_penalty(
|
|
logits: &Tensor,
|
|
history: &[u32],
|
|
logits_processor: &mut LogitsProcessor,
|
|
) -> Result<u32> {
|
|
let penalised = if (REPEAT_PENALTY - 1.0).abs() < f32::EPSILON || history.is_empty() {
|
|
logits.clone()
|
|
} else {
|
|
let start = history.len().saturating_sub(REPEAT_LAST_N);
|
|
candle_transformers::utils::apply_repeat_penalty(logits, REPEAT_PENALTY, &history[start..])?
|
|
};
|
|
Ok(logits_processor.sample(&penalised)?)
|
|
}
|
|
|
|
impl CandleHarness {
|
|
pub fn new(bind_url: String, hf_cache: Option<PathBuf>) -> Self {
|
|
let hf_cache = resolve_hf_cache(hf_cache);
|
|
if let Some(p) = &hf_cache {
|
|
tracing::info!(path = %p.display(), "candle harness using HuggingFace cache");
|
|
}
|
|
Self {
|
|
models: Arc::new(RwLock::new(HashMap::new())),
|
|
hf_cache,
|
|
bind_url,
|
|
}
|
|
}
|
|
|
|
/// Pick a candle `Device` for the requested indices. Without the
|
|
/// `cuda` feature, or if CUDA initialisation fails, falls back to CPU.
|
|
fn pick_device(devices: &[u32]) -> Result<Device> {
|
|
let _idx = devices.first().copied().unwrap_or(0) as usize;
|
|
#[cfg(feature = "cuda")]
|
|
{
|
|
match Device::new_cuda(_idx) {
|
|
Ok(d) => return Ok(d),
|
|
Err(e) => tracing::warn!(
|
|
device = _idx,
|
|
error = %e,
|
|
"CUDA device unavailable, falling back to CPU"
|
|
),
|
|
}
|
|
}
|
|
Ok(Device::Cpu)
|
|
}
|
|
|
|
/// Build an hf-hub API client pre-configured with the harness's
|
|
/// `hf_cache` (when one is set).
|
|
fn hf_api(&self) -> Result<hf_hub::api::tokio::Api> {
|
|
let mut builder = hf_hub::api::tokio::ApiBuilder::new();
|
|
if let Some(cache) = &self.hf_cache {
|
|
builder = builder.with_cache_dir(cache.clone());
|
|
}
|
|
builder.build().context("build hf-hub API")
|
|
}
|
|
|
|
/// Resolve a dense (bf16/fp16 safetensors) model to its local file
|
|
/// paths.
|
|
///
|
|
/// Handles both sharded repos (`model.safetensors.index.json` plus
|
|
/// several `model-*.safetensors`) and the single-file layout
|
|
/// (`model.safetensors`). Returns the safetensors paths in
|
|
/// arbitrary order — `VarBuilder` unifies them into one tensor view.
|
|
async fn resolve_dense_files(
|
|
&self,
|
|
spec: &ModelSpec,
|
|
) -> Result<(PathBuf, PathBuf, Vec<PathBuf>)> {
|
|
let api = self.hf_api()?;
|
|
let repo = api.model(spec.model_id.clone());
|
|
|
|
let config_path = repo
|
|
.get("config.json")
|
|
.await
|
|
.with_context(|| format!("fetch config.json from {}", spec.model_id))?;
|
|
let tokenizer_path = repo
|
|
.get("tokenizer.json")
|
|
.await
|
|
.with_context(|| format!("fetch tokenizer.json from {}", spec.model_id))?;
|
|
|
|
// Prefer the sharded layout (most HF dense models > 5B ship it).
|
|
let safetensors_paths = match repo.get("model.safetensors.index.json").await {
|
|
Ok(index_path) => {
|
|
let index_text = std::fs::read_to_string(&index_path)
|
|
.context("read model.safetensors.index.json")?;
|
|
let index: serde_json::Value = serde_json::from_str(&index_text)
|
|
.context("parse model.safetensors.index.json")?;
|
|
let weight_map = index
|
|
.get("weight_map")
|
|
.and_then(|v| v.as_object())
|
|
.ok_or_else(|| {
|
|
anyhow::anyhow!("safetensors index missing weight_map object")
|
|
})?;
|
|
let unique: std::collections::BTreeSet<String> = weight_map
|
|
.values()
|
|
.filter_map(|v| v.as_str().map(String::from))
|
|
.collect();
|
|
let mut paths = Vec::with_capacity(unique.len());
|
|
for fname in unique {
|
|
let p = repo
|
|
.get(&fname)
|
|
.await
|
|
.with_context(|| format!("fetch sharded safetensors {fname}"))?;
|
|
paths.push(p);
|
|
}
|
|
paths
|
|
}
|
|
Err(_) => {
|
|
// Single-file fallback.
|
|
let p = repo
|
|
.get("model.safetensors")
|
|
.await
|
|
.context("fetch model.safetensors (single-file layout)")?;
|
|
vec![p]
|
|
}
|
|
};
|
|
Ok((config_path, tokenizer_path, safetensors_paths))
|
|
}
|
|
|
|
/// Resolve + load a GGUF (pre-quantized) Qwen3. Returns the
|
|
/// tokenizer.json path so the caller can construct the Tokenizer
|
|
/// uniformly across source formats.
|
|
async fn load_arch_gguf(
|
|
&self,
|
|
spec: &ModelSpec,
|
|
device: &Device,
|
|
) -> Result<(PathBuf, ModelArch)> {
|
|
let (gguf_path, tokenizer_path) = self.resolve_files(spec).await?;
|
|
let device_for_load = device.clone();
|
|
let gguf_path_for_load = gguf_path.clone();
|
|
let model_id_for_log = spec.model_id.clone();
|
|
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
|
|
tracing::info!(model = %model_id_for_log, path = ?gguf_path_for_load, "loading GGUF");
|
|
let mut file = std::fs::File::open(&gguf_path_for_load).context("open GGUF file")?;
|
|
let content = gguf_file::Content::read(&mut file)
|
|
.map_err(|e| anyhow::anyhow!("parse GGUF: {e}"))?;
|
|
|
|
let architecture = content
|
|
.metadata
|
|
.get("general.architecture")
|
|
.and_then(|v| v.to_string().ok().cloned())
|
|
.unwrap_or_default();
|
|
tracing::info!(architecture = %architecture, "GGUF architecture");
|
|
|
|
// The `general.architecture` GGUF metadata key follows
|
|
// llama.cpp conventions (lowercase, no underscores in some
|
|
// cases) — `qwen3moe`, not `qwen3_moe`.
|
|
match architecture.as_str() {
|
|
"qwen3" => {
|
|
let weights =
|
|
QuantizedQwen3Weights::from_gguf(content, &mut file, &device_for_load)
|
|
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
|
|
Ok(ModelArch::Qwen3Quantized(weights))
|
|
}
|
|
"qwen3moe" => {
|
|
// GGUFQWenMoE takes an explicit compute dtype
|
|
// alongside the device — F16 matches the GGUF
|
|
// weights' typical accumulation precision and
|
|
// gives the best tokens/sec on consumer cards.
|
|
let weights =
|
|
GGUFQWenMoE::from_gguf(content, &mut file, &device_for_load, DType::F16)
|
|
.map_err(|e| anyhow::anyhow!("from_gguf qwen3_moe: {e}"))?;
|
|
Ok(ModelArch::Qwen3MoeQuantized(weights))
|
|
}
|
|
"llama" => {
|
|
let weights =
|
|
QuantizedLlamaWeights::from_gguf(content, &mut file, &device_for_load)
|
|
.map_err(|e| anyhow::anyhow!("from_gguf llama: {e}"))?;
|
|
Ok(ModelArch::LlamaQuantized(weights))
|
|
}
|
|
other => anyhow::bail!(
|
|
"unsupported GGUF architecture '{other}'; quantized path supports \
|
|
qwen3, qwen3moe, llama"
|
|
),
|
|
}
|
|
})
|
|
.await
|
|
.context("blocking GGUF load task panicked")??;
|
|
Ok((tokenizer_path, arch))
|
|
}
|
|
|
|
/// Resolve + load a dense Qwen3 from safetensors. Uses
|
|
/// `candle-transformers::models::qwen3::ModelForCausalLM` and
|
|
/// builds a VarBuilder over the mmap'd safetensors files. dtype
|
|
/// is bf16 by default to match the HF distribution dtype for
|
|
/// recent Qwen3 family models; fall back to f16 if the device
|
|
/// doesn't support bf16.
|
|
async fn load_arch_dense(
|
|
&self,
|
|
spec: &ModelSpec,
|
|
device: &Device,
|
|
) -> Result<(PathBuf, ModelArch)> {
|
|
let (config_path, tokenizer_path, safetensors_paths) =
|
|
self.resolve_dense_files(spec).await?;
|
|
let device_for_load = device.clone();
|
|
let model_id_for_log = spec.model_id.clone();
|
|
|
|
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
|
|
let cfg_text = std::fs::read_to_string(&config_path).context("read config.json")?;
|
|
check_dense_config_supported(&cfg_text, &model_id_for_log)?;
|
|
// Peek at model_type to choose the family before the
|
|
// typed deserialize — each family has its own Config.
|
|
let model_type = serde_json::from_str::<serde_json::Value>(&cfg_text)
|
|
.ok()
|
|
.as_ref()
|
|
.and_then(|v| v.get("model_type"))
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("")
|
|
.to_string();
|
|
tracing::info!(
|
|
model = %model_id_for_log,
|
|
model_type = %model_type,
|
|
shards = safetensors_paths.len(),
|
|
"loading dense model from safetensors"
|
|
);
|
|
|
|
// bf16 is the canonical distribution dtype for Qwen3 /
|
|
// Llama 3 / Qwen3 MoE. CUDA on Ada+ has hardware bf16;
|
|
// Ampere has it too. CPU emulates.
|
|
let dtype = DType::BF16;
|
|
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
|
|
// mutation by another process while we hold the mapping is
|
|
// UB. We trust the HF cache is immutable-by-design.
|
|
let vb = unsafe {
|
|
VarBuilder::from_mmaped_safetensors(&safetensors_paths, dtype, &device_for_load)
|
|
.context("build VarBuilder over safetensors")?
|
|
};
|
|
|
|
match model_type.as_str() {
|
|
"qwen3" => {
|
|
let cfg: qwen3_dense::Config =
|
|
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
|
|
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
|
|
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
|
|
Ok(ModelArch::Qwen3Dense(model))
|
|
}
|
|
"qwen3_moe" => {
|
|
let cfg: qwen3_moe_dense::Config =
|
|
serde_json::from_str(&cfg_text).context("parse Qwen3 MoE config.json")?;
|
|
let model = qwen3_moe_dense::ModelForCausalLM::new(&cfg, vb)
|
|
.map_err(|e| anyhow::anyhow!("build Qwen3 MoE dense model: {e}"))?;
|
|
Ok(ModelArch::Qwen3MoeDense(model))
|
|
}
|
|
"llama" => {
|
|
let cfg: llama_dense::LlamaConfig =
|
|
serde_json::from_str(&cfg_text).context("parse Llama config.json")?;
|
|
// Llama has multiple sub-variants (Llama 1 has no
|
|
// GQA; Llama 3 does). `LlamaConfig::into_config`
|
|
// resolves the right shape; the `use_flash_attn`
|
|
// arg defaults to false — the flash kernel is a
|
|
// separate feature flag and uses extra VRAM.
|
|
let config = cfg.into_config(false);
|
|
let cache = llama_dense::Cache::new(true, dtype, &config, &device_for_load)
|
|
.context("build Llama Cache")?;
|
|
let model = llama_dense::Llama::load(vb, &config)
|
|
.map_err(|e| anyhow::anyhow!("build Llama dense model: {e}"))?;
|
|
Ok(ModelArch::LlamaDense(Box::new(LlamaDense {
|
|
model,
|
|
cache,
|
|
config,
|
|
dtype,
|
|
device: device_for_load,
|
|
})))
|
|
}
|
|
other => {
|
|
// Defensive: `check_dense_config_supported` already
|
|
// gated on the supported set, so this branch is
|
|
// unreachable unless that list and the match here
|
|
// drift apart.
|
|
anyhow::bail!(
|
|
"unrouted supported model_type '{other}' — \
|
|
DENSE_SUPPORTED_MODEL_TYPES and load_arch_dense \
|
|
must stay in sync"
|
|
)
|
|
}
|
|
}
|
|
})
|
|
.await
|
|
.context("blocking dense load task panicked")??;
|
|
Ok((tokenizer_path, arch))
|
|
}
|
|
|
|
/// Resolve a model spec to local GGUF and tokenizer file paths via
|
|
/// hf-hub. Downloads on first use; subsequent calls are cached.
|
|
async fn resolve_files(&self, spec: &ModelSpec) -> Result<(PathBuf, PathBuf)> {
|
|
let api = self.hf_api()?;
|
|
let repo = api.model(spec.model_id.clone());
|
|
|
|
let info = repo
|
|
.info()
|
|
.await
|
|
.with_context(|| format!("fetch HF repo info for {}", spec.model_id))?;
|
|
|
|
let quant = spec.quant.as_deref().unwrap_or("");
|
|
let quant_lc = quant.to_lowercase();
|
|
let gguf_filename = info
|
|
.siblings
|
|
.iter()
|
|
.map(|s| s.rfilename.as_str())
|
|
.filter(|name| name.to_lowercase().ends_with(".gguf"))
|
|
.find(|name| quant_lc.is_empty() || name.to_lowercase().contains(&quant_lc))
|
|
.ok_or_else(|| {
|
|
anyhow::anyhow!(
|
|
"no GGUF file matching quant {:?} in repo {}",
|
|
spec.quant,
|
|
spec.model_id
|
|
)
|
|
})?
|
|
.to_string();
|
|
|
|
tracing::info!(
|
|
model = %spec.model_id,
|
|
file = %gguf_filename,
|
|
"resolving GGUF (may be cached)"
|
|
);
|
|
let gguf_path = repo
|
|
.get(&gguf_filename)
|
|
.await
|
|
.with_context(|| format!("fetch GGUF {gguf_filename}"))?;
|
|
|
|
// GGUF-only HF repos (unsloth/Qwen3-*-GGUF, Qwen/Qwen3-*-GGUF,
|
|
// etc.) ship the .gguf file but not tokenizer.json — the
|
|
// tokenizer.json lives in the base non-GGUF repo. Derive the
|
|
// base repo id by stripping a `-GGUF` / `-gguf` suffix; if
|
|
// there's no such suffix the same repo is used (works for
|
|
// non-GGUF model_ids).
|
|
let tokenizer_repo_id = spec
|
|
.model_id
|
|
.strip_suffix("-GGUF")
|
|
.or_else(|| spec.model_id.strip_suffix("-gguf"))
|
|
.unwrap_or(spec.model_id.as_str())
|
|
.to_string();
|
|
let tokenizer_repo = if tokenizer_repo_id == spec.model_id {
|
|
repo
|
|
} else {
|
|
tracing::debug!(
|
|
from = %spec.model_id,
|
|
to = %tokenizer_repo_id,
|
|
"tokenizer.json sourced from base repo (GGUF suffix stripped)"
|
|
);
|
|
api.model(tokenizer_repo_id.clone())
|
|
};
|
|
let tokenizer_path = tokenizer_repo
|
|
.get("tokenizer.json")
|
|
.await
|
|
.with_context(|| format!("fetch tokenizer.json from {tokenizer_repo_id}"))?;
|
|
Ok((gguf_path, tokenizer_path))
|
|
}
|
|
|
|
/// Run a non-streaming chat completion against a loaded model.
|
|
///
|
|
/// Returns a typed `InferenceError` when the model isn't loaded so the
|
|
/// handler can map to an appropriate HTTP status without string-matching.
|
|
pub async fn chat_completion(
|
|
&self,
|
|
request: ChatCompletionRequest,
|
|
) -> Result<ChatCompletionResponse, InferenceError> {
|
|
let handle = {
|
|
let models = self.models.read().await;
|
|
models.get(&request.model).cloned()
|
|
};
|
|
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
|
// The match is technically infallible without `cuda` (only Single
|
|
// exists), but the cfg-gated Tp arm makes this the right shape
|
|
// under both feature flags.
|
|
#[allow(clippy::infallible_destructuring_match)]
|
|
let loaded = match handle {
|
|
LoadedHandle::Single(m) => m,
|
|
#[cfg(feature = "cuda")]
|
|
LoadedHandle::Tp(m) => {
|
|
return self.chat_completion_tp(m, request).await;
|
|
}
|
|
};
|
|
|
|
let prompt = format_qwen3_prompt(&request.messages);
|
|
|
|
let encoding = loaded
|
|
.tokenizer
|
|
.encode(prompt.as_str(), true)
|
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
|
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
|
let prompt_len = prompt_tokens.len();
|
|
|
|
let temperature = request.temperature.unwrap_or(0.7);
|
|
let top_p = request.top_p;
|
|
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
|
let seed = unix_subsec_nanos();
|
|
|
|
let eos_id = loaded
|
|
.tokenizer
|
|
.token_to_id("<|im_end|>")
|
|
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
|
|
|
|
let arch_arc = Arc::clone(&loaded.arch);
|
|
let device = loaded.device.clone();
|
|
let model_id = request.model.clone();
|
|
|
|
let (generated_ids, finish_reason) =
|
|
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
|
let mut guard = arch_arc.blocking_lock();
|
|
run_inference(
|
|
&mut guard,
|
|
&device,
|
|
&prompt_tokens,
|
|
max_new,
|
|
temperature,
|
|
top_p,
|
|
seed,
|
|
eos_id,
|
|
)
|
|
})
|
|
.await
|
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("inference task panicked: {e}")))?
|
|
.map_err(InferenceError::Other)?;
|
|
|
|
let completion_text = loaded
|
|
.tokenizer
|
|
.decode(&generated_ids, true)
|
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
|
|
|
|
let usage = Usage {
|
|
prompt_tokens: prompt_len as u64,
|
|
completion_tokens: generated_ids.len() as u64,
|
|
total_tokens: (prompt_len + generated_ids.len()) as u64,
|
|
};
|
|
|
|
Ok(ChatCompletionResponse {
|
|
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
|
|
object: "chat.completion".into(),
|
|
created: unix_now_secs(),
|
|
model: model_id,
|
|
choices: vec![ChatCompletionChoice {
|
|
index: 0,
|
|
message: ChatMessage {
|
|
role: "assistant".into(),
|
|
content: MessageContent::Text(completion_text),
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
},
|
|
finish_reason: Some(finish_reason),
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
}],
|
|
usage: Some(usage),
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
})
|
|
}
|
|
|
|
/// Run a streaming chat completion against a loaded model.
|
|
///
|
|
/// Returns an `mpsc::Receiver` that yields `ChatCompletionChunk`s in
|
|
/// OpenAI SSE format. The first chunk carries the assistant role;
|
|
/// subsequent chunks carry incremental `content` deltas; the final
|
|
/// chunk carries `finish_reason`. The handler is responsible for
|
|
/// wrapping these into an SSE response and appending the `[DONE]`
|
|
/// terminator.
|
|
///
|
|
/// Token-by-token decoding tracks the cumulative decoded prefix so
|
|
/// BPE byte-fallback boundaries don't split a UTF-8 char across
|
|
/// chunks.
|
|
pub async fn chat_completion_stream(
|
|
&self,
|
|
request: ChatCompletionRequest,
|
|
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
|
let handle = {
|
|
let models = self.models.read().await;
|
|
models.get(&request.model).cloned()
|
|
};
|
|
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
|
// The match is technically infallible without `cuda` (only Single
|
|
// exists), but the cfg-gated Tp arm makes this the right shape
|
|
// under both feature flags.
|
|
#[allow(clippy::infallible_destructuring_match)]
|
|
let loaded = match handle {
|
|
LoadedHandle::Single(m) => m,
|
|
#[cfg(feature = "cuda")]
|
|
LoadedHandle::Tp(m) => {
|
|
return self.chat_completion_tp_stream(m, request).await;
|
|
}
|
|
};
|
|
|
|
let prompt = format_qwen3_prompt(&request.messages);
|
|
let encoding = loaded
|
|
.tokenizer
|
|
.encode(prompt.as_str(), true)
|
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
|
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
|
|
|
let temperature = request.temperature.unwrap_or(0.7);
|
|
let top_p = request.top_p;
|
|
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
|
let seed = unix_subsec_nanos();
|
|
|
|
let eos_id = loaded
|
|
.tokenizer
|
|
.token_to_id("<|im_end|>")
|
|
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
|
|
|
|
let arch_arc = Arc::clone(&loaded.arch);
|
|
let device = loaded.device.clone();
|
|
let tokenizer = loaded.tokenizer.clone();
|
|
let model_id = request.model.clone();
|
|
let id = format!("chatcmpl-{:x}", unix_subsec_nanos());
|
|
let created = unix_now_secs();
|
|
|
|
// Bounded channel so the producer (blocking inference) is back-
|
|
// pressured by the consumer (SSE writer). 32 is generous —
|
|
// tokens arrive one at a time and the SSE writer is async.
|
|
let (tx, rx) = mpsc::channel::<ChatCompletionChunk>(32);
|
|
|
|
// Lead chunk: announce the assistant role per OpenAI streaming
|
|
// conventions. Tools that auto-detect a streaming reply expect
|
|
// this before any content delta.
|
|
let role_chunk = ChatCompletionChunk {
|
|
id: id.clone(),
|
|
object: "chat.completion.chunk".into(),
|
|
created,
|
|
model: model_id.clone(),
|
|
choices: vec![ChunkChoice {
|
|
index: 0,
|
|
delta: json!({"role": "assistant"}),
|
|
finish_reason: None,
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
}],
|
|
usage: None,
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
};
|
|
// If sending the role chunk fails the receiver is already gone;
|
|
// bail before kicking off the heavy blocking work.
|
|
tx.send(role_chunk)
|
|
.await
|
|
.map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?;
|
|
|
|
tokio::task::spawn_blocking(move || {
|
|
let mut guard = arch_arc.blocking_lock();
|
|
if let Err(e) = run_inference_streaming(
|
|
&mut guard,
|
|
&device,
|
|
&tokenizer,
|
|
&prompt_tokens,
|
|
max_new,
|
|
temperature,
|
|
top_p,
|
|
seed,
|
|
eos_id,
|
|
&id,
|
|
created,
|
|
&model_id,
|
|
&tx,
|
|
) {
|
|
tracing::warn!(model = %model_id, error = %e, "streaming inference failed");
|
|
}
|
|
});
|
|
|
|
Ok(rx)
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Harness for CandleHarness {
|
|
fn name(&self) -> &str {
|
|
"candle"
|
|
}
|
|
|
|
async fn health(&self) -> HarnessHealth {
|
|
HarnessHealth {
|
|
name: "candle".into(),
|
|
running: true,
|
|
uptime_secs: None,
|
|
}
|
|
}
|
|
|
|
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
|
let models = self.models.read().await;
|
|
Ok(models
|
|
.values()
|
|
.map(|h| ModelInfo {
|
|
id: h.model_id().into(),
|
|
harness: "candle".into(),
|
|
status: "loaded".into(),
|
|
devices: h.devices(),
|
|
vram_used_mb: None,
|
|
})
|
|
.collect())
|
|
}
|
|
|
|
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
|
|
if spec.harness != "candle" {
|
|
anyhow::bail!("expected harness=candle, got harness={}", spec.harness);
|
|
}
|
|
|
|
{
|
|
let models = self.models.read().await;
|
|
if models.contains_key(&spec.model_id) {
|
|
anyhow::bail!("model '{}' already loaded", spec.model_id);
|
|
}
|
|
}
|
|
|
|
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
|
if tp_size > 1 {
|
|
#[cfg(feature = "cuda")]
|
|
{
|
|
return self.load_tp(spec, tp_size).await;
|
|
}
|
|
#[cfg(not(feature = "cuda"))]
|
|
{
|
|
anyhow::bail!(
|
|
"tensor_parallel={tp_size} requested for '{}': this neuron \
|
|
binary was built without --features cuda; TP requires CUDA + NCCL",
|
|
spec.model_id
|
|
);
|
|
}
|
|
}
|
|
|
|
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
|
|
let device = Self::pick_device(&devices)?;
|
|
|
|
// Dispatch by source format: GGUF (pre-quantized, single-GPU
|
|
// only path) vs safetensors dense (bf16/fp16; the path that
|
|
// grows TP support). `spec.quant` is the signal — Some means
|
|
// the operator picked a quantized GGUF; None means dense.
|
|
let (tokenizer_path, arch) = if spec.quant.is_some() {
|
|
self.load_arch_gguf(spec, &device).await?
|
|
} else {
|
|
self.load_arch_dense(spec, &device).await?
|
|
};
|
|
|
|
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
|
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
|
|
|
let loaded = Arc::new(LoadedModel {
|
|
model_id: spec.model_id.clone(),
|
|
arch: Arc::new(Mutex::new(arch)),
|
|
tokenizer,
|
|
device,
|
|
quant: spec.quant.clone(),
|
|
devices,
|
|
});
|
|
|
|
let mut models = self.models.write().await;
|
|
models.insert(spec.model_id.clone(), LoadedHandle::Single(loaded));
|
|
tracing::info!(model = %spec.model_id, "model loaded");
|
|
Ok(())
|
|
}
|
|
|
|
async fn unload_model(&self, model_id: &str) -> Result<()> {
|
|
let removed = {
|
|
let mut models = self.models.write().await;
|
|
models.remove(model_id)
|
|
};
|
|
let Some(handle) = removed else {
|
|
anyhow::bail!("model '{model_id}' not loaded");
|
|
};
|
|
// Single-GPU drops are immediate — the LoadedModel goes out of
|
|
// scope with the Arc and candle frees VRAM. TP unloads also
|
|
// need to tell every worker to drop its shard before the pool
|
|
// itself is dropped (otherwise the workers keep their shards
|
|
// around until Shutdown, which is wasteful and would surface
|
|
// as VRAM not freed promptly).
|
|
match handle {
|
|
LoadedHandle::Single(_) => {}
|
|
#[cfg(feature = "cuda")]
|
|
LoadedHandle::Tp(tp) => {
|
|
// Try to recover the inner TpLoadedModel so we can move
|
|
// the pool and shut it down. If anyone else still holds
|
|
// a clone of the Arc (shouldn't happen — the only owners
|
|
// are the registry and any in-flight chat_completion),
|
|
// bail with a clear marker rather than silently leaking.
|
|
let tp = match Arc::try_unwrap(tp) {
|
|
Ok(t) => t,
|
|
Err(arc) => {
|
|
// Reinsert so we don't leave the registry in an
|
|
// inconsistent state.
|
|
let mut models = self.models.write().await;
|
|
models.insert(model_id.into(), LoadedHandle::Tp(arc));
|
|
anyhow::bail!("cannot unload '{model_id}': inference still in flight");
|
|
}
|
|
};
|
|
let mut pool = tp.pool.into_inner();
|
|
if let Err(e) = pool.unload_model(model_id).await {
|
|
tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed");
|
|
}
|
|
if let Err(e) = pool.shutdown().await {
|
|
tracing::warn!(model = %model_id, error = %e, "TP pool shutdown failed");
|
|
}
|
|
}
|
|
}
|
|
tracing::info!(model = %model_id, "model unloaded");
|
|
Ok(())
|
|
}
|
|
|
|
async fn inference_endpoint(&self, model_id: &str) -> Option<String> {
|
|
let models = self.models.read().await;
|
|
models.contains_key(model_id).then(|| self.bind_url.clone())
|
|
}
|
|
}
|
|
|
|
impl CandleHarness {
|
|
/// Tensor-parallel load. Resolves dense safetensors via hf-hub the
|
|
/// same way the single-GPU dense path does, spins up a TP worker
|
|
/// pool sized to `tp_size`, runs the NCCL handshake, then has
|
|
/// every rank load its shard of the model.
|
|
///
|
|
/// `spec.devices` carries the per-rank CUDA device indices (one
|
|
/// entry per rank, in rank order); defaults to `0..tp_size`.
|
|
#[cfg(feature = "cuda")]
|
|
async fn load_tp(&self, spec: &ModelSpec, tp_size: u32) -> Result<()> {
|
|
use std::sync::Arc as StdArc;
|
|
use tokio::sync::Mutex as TMutex;
|
|
|
|
// Default per-rank device assignment: 0, 1, ..., tp_size - 1.
|
|
let devices = spec
|
|
.devices
|
|
.clone()
|
|
.unwrap_or_else(|| (0..tp_size).collect());
|
|
if devices.len() as u32 != tp_size {
|
|
anyhow::bail!(
|
|
"tensor_parallel={tp_size} requires {tp_size} entries in devices, got {}",
|
|
devices.len()
|
|
);
|
|
}
|
|
if spec.quant.is_some() {
|
|
anyhow::bail!(
|
|
"tensor_parallel={tp_size} with quant={:?}: GGUF quantized models \
|
|
are not supported in the TP path; use a dense safetensors source",
|
|
spec.quant
|
|
);
|
|
}
|
|
|
|
// 1. Resolve config + tokenizer + safetensors via hf-hub.
|
|
let (config_path, tokenizer_path, safetensors_paths) =
|
|
self.resolve_dense_files(spec).await?;
|
|
let config_json = std::fs::read_to_string(&config_path).context("read config.json")?;
|
|
// Reject unsupported architectures *before* spawning the worker
|
|
// pool and fanning out NCCL — otherwise we'd burn the pool
|
|
// lifecycle on a load that's guaranteed to fail at deserialise
|
|
// time inside every rank.
|
|
check_dense_config_supported(&config_json, &spec.model_id)?;
|
|
|
|
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
|
|
// 1..tp_size are subprocesses, one per device after the
|
|
// leader's own.
|
|
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
|
|
let mut pool = super::tp::WorkerPool::spawn(&exe, tp_size, &devices).await?;
|
|
|
|
// 3. NCCL handshake across all ranks.
|
|
let leader_device_idx = devices[0];
|
|
pool.init_nccl(leader_device_idx).await?;
|
|
|
|
// 4. Pick the leader's candle Device (same index as init_nccl).
|
|
let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize)
|
|
.context("Device::new_cuda for TP leader")?;
|
|
|
|
// 5. Load this rank's shard on every rank.
|
|
let leader_model = pool
|
|
.load_dense_shard(
|
|
&spec.model_id,
|
|
&config_json,
|
|
&safetensors_paths,
|
|
&leader_device,
|
|
candle_core::DType::BF16,
|
|
)
|
|
.await?;
|
|
|
|
// 6. Tokenizer (same as single-GPU path).
|
|
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
|
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
|
|
|
let tp_loaded = StdArc::new(TpLoadedModel {
|
|
model_id: spec.model_id.clone(),
|
|
tokenizer,
|
|
devices: devices.clone(),
|
|
pool: TMutex::new(pool),
|
|
leader_model,
|
|
});
|
|
|
|
let mut models = self.models.write().await;
|
|
models.insert(spec.model_id.clone(), LoadedHandle::Tp(tp_loaded));
|
|
tracing::info!(
|
|
model = %spec.model_id,
|
|
tp_size,
|
|
?devices,
|
|
"TP model loaded"
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
/// Non-streaming chat completion against a TP model. Pattern mirrors
|
|
/// the single-GPU `run_inference`: tokenize, prefill, sample, decode
|
|
/// loop, detokenize. Each forward step fans out to every rank via
|
|
/// the WorkerPool and uses the leader's last-position logits to
|
|
/// sample.
|
|
#[cfg(feature = "cuda")]
|
|
async fn chat_completion_tp(
|
|
&self,
|
|
tp: Arc<TpLoadedModel>,
|
|
request: ChatCompletionRequest,
|
|
) -> Result<ChatCompletionResponse, InferenceError> {
|
|
let prompt = format_qwen3_prompt(&request.messages);
|
|
let encoding = tp
|
|
.tokenizer
|
|
.encode(prompt.as_str(), true)
|
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
|
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
|
let prompt_len = prompt_tokens.len();
|
|
|
|
let temperature = request.temperature.unwrap_or(0.7);
|
|
let top_p = request.top_p;
|
|
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
|
let seed = unix_subsec_nanos();
|
|
|
|
let eos_id = tp
|
|
.tokenizer
|
|
.token_to_id("<|im_end|>")
|
|
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
|
|
|
|
let model_id = request.model.clone();
|
|
|
|
// Acquire the pool lock for the duration of the request. The
|
|
// leader_model's own Mutex is acquired step-by-step inside
|
|
// pool.generate_step (so spawn_blocking can grab it without
|
|
// holding the pool lock across the blocking_lock call).
|
|
let mut pool = tp.pool.lock().await;
|
|
let leader_arc = tp.leader_model.clone();
|
|
|
|
// Reset every rank's KV cache so this request doesn't attend
|
|
// over the previous request's tokens.
|
|
pool.clear_kv_cache(&model_id, leader_arc.clone())
|
|
.await
|
|
.map_err(InferenceError::Other)?;
|
|
|
|
let mut logits_processor = {
|
|
let sampling = if temperature <= 0.0 {
|
|
Sampling::ArgMax
|
|
} else {
|
|
match top_p {
|
|
Some(p) => Sampling::TopP { p, temperature },
|
|
None => Sampling::All { temperature },
|
|
}
|
|
};
|
|
LogitsProcessor::from_sampling(seed, sampling)
|
|
};
|
|
|
|
let mut generated: Vec<u32> = Vec::new();
|
|
let mut finish_reason = "length".to_string();
|
|
|
|
// Prefill: every rank embeds the whole prompt, offset = 0.
|
|
let logits = pool
|
|
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
|
.await
|
|
.map_err(InferenceError::Other)?;
|
|
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
|
.map_err(InferenceError::Other)?;
|
|
|
|
if Some(next_token) == eos_id {
|
|
finish_reason = "stop".into();
|
|
} else {
|
|
generated.push(next_token);
|
|
for index in 0..max_new.saturating_sub(1) {
|
|
let logits = pool
|
|
.generate_step(
|
|
&model_id,
|
|
leader_arc.clone(),
|
|
vec![next_token],
|
|
prompt_len + index,
|
|
)
|
|
.await
|
|
.map_err(InferenceError::Other)?;
|
|
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
|
.map_err(InferenceError::Other)?;
|
|
if Some(next_token) == eos_id {
|
|
finish_reason = "stop".into();
|
|
break;
|
|
}
|
|
generated.push(next_token);
|
|
}
|
|
}
|
|
drop(pool);
|
|
|
|
let completion_text = tp
|
|
.tokenizer
|
|
.decode(&generated, true)
|
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
|
|
|
|
let usage = Usage {
|
|
prompt_tokens: prompt_len as u64,
|
|
completion_tokens: generated.len() as u64,
|
|
total_tokens: (prompt_len + generated.len()) as u64,
|
|
};
|
|
|
|
Ok(ChatCompletionResponse {
|
|
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
|
|
object: "chat.completion".into(),
|
|
created: unix_now_secs(),
|
|
model: model_id,
|
|
choices: vec![ChatCompletionChoice {
|
|
index: 0,
|
|
message: ChatMessage {
|
|
role: "assistant".into(),
|
|
content: MessageContent::Text(completion_text),
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
},
|
|
finish_reason: Some(finish_reason),
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
}],
|
|
usage: Some(usage),
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
})
|
|
}
|
|
|
|
/// Streaming counterpart to `chat_completion_tp`. Same per-step
|
|
/// orchestration (clear cache, prefill, sample, decode loop) but
|
|
/// emits one `ChatCompletionChunk` per token over an mpsc channel
|
|
/// so the handler can write an SSE stream.
|
|
///
|
|
/// Unlike the single-GPU streaming path (which runs the candle
|
|
/// forward inside `spawn_blocking` and uses `blocking_send`), the
|
|
/// TP loop is itself async — every `pool.generate_step` awaits the
|
|
/// leader's spawn_blocking forward plus every worker's recv_only.
|
|
/// So we `tokio::spawn` the orchestration task and use plain
|
|
/// `Sender::send`.
|
|
#[cfg(feature = "cuda")]
|
|
async fn chat_completion_tp_stream(
|
|
&self,
|
|
tp: Arc<TpLoadedModel>,
|
|
request: ChatCompletionRequest,
|
|
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
|
let prompt = format_qwen3_prompt(&request.messages);
|
|
let encoding = tp
|
|
.tokenizer
|
|
.encode(prompt.as_str(), true)
|
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
|
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
|
let prompt_len = prompt_tokens.len();
|
|
|
|
let temperature = request.temperature.unwrap_or(0.7);
|
|
let top_p = request.top_p;
|
|
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
|
let seed = unix_subsec_nanos();
|
|
|
|
let eos_id = tp
|
|
.tokenizer
|
|
.token_to_id("<|im_end|>")
|
|
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
|
|
|
|
let model_id = request.model.clone();
|
|
let id = format!("chatcmpl-{:x}", unix_subsec_nanos());
|
|
let created = unix_now_secs();
|
|
let tokenizer = tp.tokenizer.clone();
|
|
|
|
// Bounded channel — back-pressures the producer when the SSE
|
|
// writer is slow.
|
|
let (tx, rx) = mpsc::channel::<ChatCompletionChunk>(32);
|
|
|
|
// Role chunk first, before kicking off the heavy work — if the
|
|
// receiver is gone by now there's no point starting inference.
|
|
let role_chunk = ChatCompletionChunk {
|
|
id: id.clone(),
|
|
object: "chat.completion.chunk".into(),
|
|
created,
|
|
model: model_id.clone(),
|
|
choices: vec![ChunkChoice {
|
|
index: 0,
|
|
delta: json!({"role": "assistant"}),
|
|
finish_reason: None,
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
}],
|
|
usage: None,
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
};
|
|
tx.send(role_chunk)
|
|
.await
|
|
.map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?;
|
|
|
|
// The orchestration task. Holds the pool lock for the lifetime
|
|
// of this inference; concurrent requests against the same TP
|
|
// model serialise behind it.
|
|
let tp_for_task = Arc::clone(&tp);
|
|
tokio::spawn(async move {
|
|
let mut pool = tp_for_task.pool.lock().await;
|
|
let leader_arc = tp_for_task.leader_model.clone();
|
|
|
|
if let Err(e) = pool.clear_kv_cache(&model_id, leader_arc.clone()).await {
|
|
tracing::warn!(model = %model_id, error = %e, "TP stream: clear_kv_cache failed");
|
|
return;
|
|
}
|
|
|
|
let mut logits_processor = {
|
|
let sampling = if temperature <= 0.0 {
|
|
Sampling::ArgMax
|
|
} else {
|
|
match top_p {
|
|
Some(p) => Sampling::TopP { p, temperature },
|
|
None => Sampling::All { temperature },
|
|
}
|
|
};
|
|
LogitsProcessor::from_sampling(seed, sampling)
|
|
};
|
|
|
|
let mut all_tokens: Vec<u32> = Vec::new();
|
|
let mut decoded_prefix = String::new();
|
|
let mut finish_reason = "length".to_string();
|
|
|
|
// Prefill — every rank embeds the prompt, offset = 0.
|
|
let logits = match pool
|
|
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
|
.await
|
|
{
|
|
Ok(l) => l,
|
|
Err(e) => {
|
|
tracing::warn!(model = %model_id, error = %e, "TP stream: prefill failed");
|
|
return;
|
|
}
|
|
};
|
|
let mut next_token = match sample_with_penalty(
|
|
&logits,
|
|
&all_tokens,
|
|
&mut logits_processor,
|
|
) {
|
|
Ok(t) => t,
|
|
Err(e) => {
|
|
tracing::warn!(model = %model_id, error = %e, "TP stream: prefill sample failed");
|
|
return;
|
|
}
|
|
};
|
|
|
|
if Some(next_token) == eos_id {
|
|
finish_reason = "stop".into();
|
|
} else {
|
|
all_tokens.push(next_token);
|
|
if !emit_chunk(
|
|
&all_tokens,
|
|
&mut decoded_prefix,
|
|
&tokenizer,
|
|
&tx,
|
|
&id,
|
|
created,
|
|
&model_id,
|
|
)
|
|
.await
|
|
{
|
|
return;
|
|
}
|
|
|
|
for index in 0..max_new.saturating_sub(1) {
|
|
let logits = match pool
|
|
.generate_step(
|
|
&model_id,
|
|
leader_arc.clone(),
|
|
vec![next_token],
|
|
prompt_len + index,
|
|
)
|
|
.await
|
|
{
|
|
Ok(l) => l,
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
model = %model_id,
|
|
error = %e,
|
|
"TP stream: decode step failed"
|
|
);
|
|
return;
|
|
}
|
|
};
|
|
next_token =
|
|
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
|
Ok(t) => t,
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
model = %model_id,
|
|
error = %e,
|
|
"TP stream: decode sample failed"
|
|
);
|
|
return;
|
|
}
|
|
};
|
|
if Some(next_token) == eos_id {
|
|
finish_reason = "stop".into();
|
|
break;
|
|
}
|
|
all_tokens.push(next_token);
|
|
if !emit_chunk(
|
|
&all_tokens,
|
|
&mut decoded_prefix,
|
|
&tokenizer,
|
|
&tx,
|
|
&id,
|
|
created,
|
|
&model_id,
|
|
)
|
|
.await
|
|
{
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Final chunk carrying finish_reason.
|
|
let final_chunk = ChatCompletionChunk {
|
|
id: id.clone(),
|
|
object: "chat.completion.chunk".into(),
|
|
created,
|
|
model: model_id.clone(),
|
|
choices: vec![ChunkChoice {
|
|
index: 0,
|
|
delta: serde_json::Value::Object(Default::default()),
|
|
finish_reason: Some(finish_reason),
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
}],
|
|
usage: None,
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
};
|
|
let _ = tx.send(final_chunk).await;
|
|
});
|
|
|
|
Ok(rx)
|
|
}
|
|
}
|
|
|
|
/// Decode the cumulative token list, emit the delta (substring appended
|
|
/// since the last chunk) as a `chat.completion.chunk`. Returns `false`
|
|
/// if the receiver has hung up — the caller should bail.
|
|
#[cfg(feature = "cuda")]
|
|
async fn emit_chunk(
|
|
all_tokens: &[u32],
|
|
decoded_prefix: &mut String,
|
|
tokenizer: &Tokenizer,
|
|
tx: &mpsc::Sender<ChatCompletionChunk>,
|
|
id: &str,
|
|
created: u64,
|
|
model_id: &str,
|
|
) -> bool {
|
|
let full = match tokenizer.decode(all_tokens, true) {
|
|
Ok(s) => s,
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "TP stream: decode failed");
|
|
return false;
|
|
}
|
|
};
|
|
if full.len() > decoded_prefix.len() {
|
|
let delta = full[decoded_prefix.len()..].to_string();
|
|
*decoded_prefix = full;
|
|
let chunk = ChatCompletionChunk {
|
|
id: id.into(),
|
|
object: "chat.completion.chunk".into(),
|
|
created,
|
|
model: model_id.into(),
|
|
choices: vec![ChunkChoice {
|
|
index: 0,
|
|
delta: json!({ "content": delta }),
|
|
finish_reason: None,
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
}],
|
|
usage: None,
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
};
|
|
if tx.send(chunk).await.is_err() {
|
|
return false;
|
|
}
|
|
}
|
|
true
|
|
}
|
|
|
|
/// Errors returned by `CandleHarness::chat_completion`. The
|
|
/// `ModelNotLoaded` variant lets the HTTP handler map cleanly to 404
|
|
/// without string-matching on anyhow messages.
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum InferenceError {
|
|
#[error("model '{0}' not loaded on this neuron")]
|
|
ModelNotLoaded(String),
|
|
#[error(transparent)]
|
|
Other(#[from] anyhow::Error),
|
|
}
|
|
|
|
/// Apply the Qwen3 chat template:
|
|
///
|
|
/// ```text
|
|
/// <|im_start|>{role}\n{content}<|im_end|>\n
|
|
/// ...
|
|
/// <|im_start|>assistant\n
|
|
/// ```
|
|
///
|
|
/// The trailing `<|im_start|>assistant\n` cues the model to begin a turn.
|
|
/// Non-text content parts (vision blocks) are joined as text only; full
|
|
/// multimodal handling is out of scope for Stage 3.
|
|
fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
|
|
let mut prompt = String::new();
|
|
for msg in messages {
|
|
let content = match &msg.content {
|
|
MessageContent::Text(s) => s.clone(),
|
|
MessageContent::Parts(parts) => parts
|
|
.iter()
|
|
.filter_map(|p| p.get("text").and_then(|v| v.as_str()))
|
|
.collect::<Vec<_>>()
|
|
.join(""),
|
|
};
|
|
prompt.push_str("<|im_start|>");
|
|
prompt.push_str(&msg.role);
|
|
prompt.push('\n');
|
|
prompt.push_str(&content);
|
|
prompt.push_str("<|im_end|>\n");
|
|
}
|
|
prompt.push_str("<|im_start|>assistant\n");
|
|
prompt
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
fn run_inference(
|
|
arch: &mut ModelArch,
|
|
device: &Device,
|
|
prompt_tokens: &[u32],
|
|
max_new: usize,
|
|
temperature: f64,
|
|
top_p: Option<f64>,
|
|
seed: u64,
|
|
eos_id: Option<u32>,
|
|
) -> Result<(Vec<u32>, String)> {
|
|
let mut logits_processor = {
|
|
let sampling = if temperature <= 0.0 {
|
|
Sampling::ArgMax
|
|
} else {
|
|
match top_p {
|
|
Some(p) => Sampling::TopP { p, temperature },
|
|
None => Sampling::All { temperature },
|
|
}
|
|
};
|
|
LogitsProcessor::from_sampling(seed, sampling)
|
|
};
|
|
|
|
let mut generated: Vec<u32> = Vec::new();
|
|
|
|
arch.clear_kv_cache()?;
|
|
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
|
let logits = arch.forward(&input, 0)?;
|
|
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
|
|
|
|
if Some(next_token) == eos_id {
|
|
return Ok((generated, "stop".into()));
|
|
}
|
|
generated.push(next_token);
|
|
|
|
for index in 0..max_new.saturating_sub(1) {
|
|
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
|
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
|
|
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
|
|
if Some(next_token) == eos_id {
|
|
return Ok((generated, "stop".into()));
|
|
}
|
|
generated.push(next_token);
|
|
}
|
|
|
|
Ok((generated, "length".into()))
|
|
}
|
|
|
|
/// Streaming counterpart to `run_inference`. Emits chunks via `tx` as
|
|
/// tokens are generated and exits on EOS, max_new, or receiver drop.
|
|
///
|
|
/// Detokenization tracks the cumulative decoded prefix so each chunk's
|
|
/// `content` delta is the substring appended since the last chunk —
|
|
/// safe across BPE byte-fallback boundaries.
|
|
#[allow(clippy::too_many_arguments)]
|
|
fn run_inference_streaming(
|
|
arch: &mut ModelArch,
|
|
device: &Device,
|
|
tokenizer: &Tokenizer,
|
|
prompt_tokens: &[u32],
|
|
max_new: usize,
|
|
temperature: f64,
|
|
top_p: Option<f64>,
|
|
seed: u64,
|
|
eos_id: Option<u32>,
|
|
id: &str,
|
|
created: u64,
|
|
model_id: &str,
|
|
tx: &mpsc::Sender<ChatCompletionChunk>,
|
|
) -> Result<()> {
|
|
let mut logits_processor = {
|
|
let sampling = if temperature <= 0.0 {
|
|
Sampling::ArgMax
|
|
} else {
|
|
match top_p {
|
|
Some(p) => Sampling::TopP { p, temperature },
|
|
None => Sampling::All { temperature },
|
|
}
|
|
};
|
|
LogitsProcessor::from_sampling(seed, sampling)
|
|
};
|
|
|
|
let mut all_tokens: Vec<u32> = Vec::new();
|
|
let mut decoded_prefix = String::new();
|
|
let mut finish_reason = "length".to_string();
|
|
|
|
arch.clear_kv_cache()?;
|
|
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
|
let logits = arch.forward(&input, 0)?;
|
|
let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
|
|
|
|
let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result<bool> {
|
|
let full = tokenizer
|
|
.decode(all_tokens, true)
|
|
.map_err(|e| anyhow::anyhow!("decode: {e}"))?;
|
|
if full.len() > decoded_prefix.len() {
|
|
let delta = full[decoded_prefix.len()..].to_string();
|
|
*decoded_prefix = full;
|
|
let chunk = ChatCompletionChunk {
|
|
id: id.into(),
|
|
object: "chat.completion.chunk".into(),
|
|
created,
|
|
model: model_id.into(),
|
|
choices: vec![ChunkChoice {
|
|
index: 0,
|
|
delta: json!({ "content": delta }),
|
|
finish_reason: None,
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
}],
|
|
usage: None,
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
};
|
|
// blocking_send returns Err if the consumer hung up — signal
|
|
// the caller to stop generating.
|
|
if tx.blocking_send(chunk).is_err() {
|
|
return Ok(false);
|
|
}
|
|
}
|
|
Ok(true)
|
|
};
|
|
|
|
if Some(next_token) == eos_id {
|
|
finish_reason = "stop".into();
|
|
} else {
|
|
all_tokens.push(next_token);
|
|
if !emit_token(&all_tokens, &mut decoded_prefix)? {
|
|
return Ok(());
|
|
}
|
|
|
|
for index in 0..max_new.saturating_sub(1) {
|
|
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
|
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
|
|
next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
|
|
if Some(next_token) == eos_id {
|
|
finish_reason = "stop".into();
|
|
break;
|
|
}
|
|
all_tokens.push(next_token);
|
|
if !emit_token(&all_tokens, &mut decoded_prefix)? {
|
|
return Ok(());
|
|
}
|
|
}
|
|
}
|
|
|
|
let final_chunk = ChatCompletionChunk {
|
|
id: id.into(),
|
|
object: "chat.completion.chunk".into(),
|
|
created,
|
|
model: model_id.into(),
|
|
choices: vec![ChunkChoice {
|
|
index: 0,
|
|
delta: serde_json::Value::Object(Default::default()),
|
|
finish_reason: Some(finish_reason),
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
}],
|
|
usage: None,
|
|
extra: serde_json::Value::Object(Default::default()),
|
|
};
|
|
let _ = tx.blocking_send(final_chunk);
|
|
Ok(())
|
|
}
|
|
|
|
fn unix_now_secs() -> u64 {
|
|
SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.map(|d| d.as_secs())
|
|
.unwrap_or(0)
|
|
}
|
|
|
|
fn unix_subsec_nanos() -> u64 {
|
|
SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.map(|d| d.as_nanos() as u64)
|
|
.unwrap_or(0)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn check_dense_config_accepts_qwen3() {
|
|
let cfg = r#"{
|
|
"model_type": "qwen3",
|
|
"vocab_size": 151936,
|
|
"architectures": ["Qwen3ForCausalLM"]
|
|
}"#;
|
|
check_dense_config_supported(cfg, "Qwen/Qwen3-1.7B").expect("qwen3 should pass");
|
|
}
|
|
|
|
#[test]
|
|
fn check_dense_config_rejects_qwen3_5_with_clear_message() {
|
|
let cfg = r#"{
|
|
"model_type": "qwen3_5",
|
|
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
|
"image_token_id": 248056,
|
|
"text_config": {"hidden_size": 5120}
|
|
}"#;
|
|
let err = check_dense_config_supported(cfg, "Qwen/Qwen3.6-27B")
|
|
.expect_err("qwen3_5 should be rejected");
|
|
let msg = format!("{err}");
|
|
assert!(
|
|
msg.contains("unsupported model_type 'qwen3_5'"),
|
|
"message should name the rejected type: {msg}"
|
|
);
|
|
assert!(
|
|
msg.contains("Qwen/Qwen3.6-27B"),
|
|
"message should echo the model id: {msg}"
|
|
);
|
|
assert!(
|
|
msg.contains("qwen3"),
|
|
"message should list the supported set: {msg}"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn check_dense_config_rejects_missing_model_type() {
|
|
let cfg = r#"{ "vocab_size": 1234 }"#;
|
|
let err = check_dense_config_supported(cfg, "anon/no-type")
|
|
.expect_err("missing model_type should be rejected");
|
|
assert!(
|
|
format!("{err}").contains("missing `model_type`"),
|
|
"message should call out the missing field"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn check_dense_config_rejects_invalid_json() {
|
|
let err = check_dense_config_supported("not json", "anon/bad-json")
|
|
.expect_err("malformed JSON should be rejected");
|
|
assert!(
|
|
format!("{err:#}").contains("config.json"),
|
|
"message should mention config.json"
|
|
);
|
|
}
|
|
}
|