Files
cortex/crates/neuron/src/harness/candle.rs
rob thijssen 4aa71902d0
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m7s
CI / Test (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m47s
build-prerelease / Build neuron-ampere (push) Successful in 5m17s
build-prerelease / Build neuron-ada (push) Successful in 5m14s
build-prerelease / Build cortex binary (push) Successful in 18m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m7s
feat(stage-8e-2): plumb quant config from ModelSpec to TP load path
- LoadDenseShard RPC gains an optional `quant` string field.
- WorkerPool::load_dense_shard takes a `quant: Option<String>`,
  passes it via the RPC to workers and via parse_quant_string to
  the leader's local load.
- The Qwen3-Next TP load chain (ForCausalLM → Model → DecoderLayer
  → Attention / GatedDeltaNet / MLP) takes `quant: Option<GgmlDType>`
  end-to-end, calling Column/RowParallelLinear::load_with_quant.
- The fused in_proj_qkv inside TpQwen3_5GatedDeltaNet is now a
  MaybeQuantLinear so it also picks up quantization.
- parse_quant_string accepts q4_0/q4_1/q5_0/q5_1/q8_0/q8_1, q2k..q8k
  (with or without underscore), and f16/bf16/f32. Empty / None means
  no quantization.

Callers from candle.rs forward spec.quant through pool.load_dense_shard.
This means a `quant = "q5k"` in models.toml now flows end-to-end to a
QTensor-backed QMatMul for every per-rank linear in the Qwen3-Next
TP path. Leaves lm_head and the small replicated bias/log tensors in
their loaded dtype (Stage 8e-3).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 18:03:36 +03:00

1929 lines
75 KiB
Rust

//! Candle harness — in-process inference using huggingface/candle.
//!
//! This is the sole `Harness` implementation. Inference runs inside
//! the neuron process; there is no external subprocess.
//!
//! - Stage 2 wired GGUF (Qwen3 only) load/unload via `quantized_qwen3`.
//! - Stage 3 (this) adds `chat_completion` — a non-streaming OpenAI
//! compatible chat completion routed to the loaded model's forward
//! pass on a per-model serialised generation loop.
use anyhow::{Context, Result};
use async_trait::async_trait;
use candle_core::quantized::gguf_file;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama as llama_dense;
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaWeights;
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE;
use candle_transformers::models::qwen3 as qwen3_dense;
use candle_transformers::models::qwen3_moe as qwen3_moe_dense;
use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec};
use cortex_core::openai::{
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse,
ChatMessage, ChunkChoice, MessageContent, Usage,
};
use serde_json::json;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokenizers::Tokenizer;
use tokio::sync::{Mutex, RwLock, mpsc};
/// In-process candle harness. Owns the loaded model registry.
pub struct CandleHarness {
models: Arc<RwLock<HashMap<String, LoadedHandle>>>,
hf_cache: Option<PathBuf>,
bind_url: String,
}
/// One entry in the harness's loaded-model registry. Single-GPU loads
/// land in `Single`; loads with `tensor_parallel > 1` land in `Tp`.
/// The two variants share the same `model_id` key in the map, so
/// `list_models`, `unload_model`, and `inference_endpoint` can walk
/// them uniformly without branching the storage layout.
///
/// `Clone` is cheap: both variants hold `Arc<_>` and cloning just bumps
/// the refcount.
#[derive(Clone)]
pub enum LoadedHandle {
Single(Arc<LoadedModel>),
#[cfg(feature = "cuda")]
Tp(Arc<TpLoadedModel>),
}
impl LoadedHandle {
pub fn model_id(&self) -> &str {
match self {
LoadedHandle::Single(m) => &m.model_id,
#[cfg(feature = "cuda")]
LoadedHandle::Tp(m) => &m.model_id,
}
}
pub fn devices(&self) -> Vec<u32> {
match self {
LoadedHandle::Single(m) => m.devices.clone(),
#[cfg(feature = "cuda")]
LoadedHandle::Tp(m) => m.devices.clone(),
}
}
}
/// A loaded model with its tokenizer, device placement, and architecture-
/// specific weights. The `arch` is `Arc<Mutex<>>` so the lock guard can be
/// moved into `spawn_blocking` for synchronous candle forward passes.
pub struct LoadedModel {
pub model_id: String,
pub arch: Arc<Mutex<ModelArch>>,
pub tokenizer: Tokenizer,
pub device: Device,
pub quant: Option<String>,
pub devices: Vec<u32>,
}
/// Tensor-parallel loaded model. Holds the leader's rank-0 shard
/// (which the inference loop drives via spawn_blocking) and the
/// `WorkerPool` (which drives every non-zero rank over the RPC
/// channel). Both are behind tokio Mutexes so concurrent inference
/// requests against the same model are serialised; concurrent loads
/// for *different* models would each have their own pool.
#[cfg(feature = "cuda")]
pub struct TpLoadedModel {
pub model_id: String,
pub tokenizer: Tokenizer,
pub devices: Vec<u32>,
/// One end-to-end gate: the pool's RPC stream isn't safe to use
/// concurrently and the leader shard's KV cache mutates with every
/// step. The same Mutex covers both for the simplest correctness
/// story.
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
pub leader_model: Arc<tokio::sync::Mutex<super::tp::TpLeaderModel>>,
}
/// Architecture-specific weights. Each variant covers one (family,
/// source-format) pair; the dense variants take the safetensors path
/// and the `Quantized*` variants take the GGUF path.
///
/// TP currently only works through `Qwen3Dense` (see `tp_qwen3.rs`);
/// every other variant is single-GPU. Quantized variants can't shard
/// across GPUs at all — slicing GGUF super-blocks is intractable —
/// and the new dense families (Llama, Qwen3 MoE) lack their own
/// TP-aware modules yet.
pub enum ModelArch {
// Qwen3 family
Qwen3Quantized(QuantizedQwen3Weights),
Qwen3Dense(qwen3_dense::ModelForCausalLM),
Qwen3MoeQuantized(GGUFQWenMoE),
Qwen3MoeDense(qwen3_moe_dense::ModelForCausalLM),
// Llama family (covers Llama 1/2/3/3.1/3.3). Boxed because the
// wrapper carries an inline Cache + Config — without indirection
// the enum's `LlamaDense` variant is several hundred bytes larger
// than the others (clippy::large_enum_variant).
LlamaQuantized(QuantizedLlamaWeights),
LlamaDense(Box<LlamaDense>),
// Qwen3-Next family (model_type "qwen3_5") — Qwen3.6's
// architecture. Stage 8c scaffolding only: dispatch + config parse
// are real; forward bails "not implemented yet". See
// `arch/qwen3_5.rs` for the open architecture work.
Qwen3_5Dense(super::arch::qwen3_5::Qwen3_5ForCausalLM),
}
impl ModelArch {
/// One forward step on this arch with the rank-1 vocab logits
/// extracted. Hides per-family shape differences (some return
/// `[B, V]`, others `[B, 1, V]`) — every caller gets a `[V]`
/// tensor ready for sampling.
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
let raw = match self {
ModelArch::Qwen3Quantized(m) => m.forward(input, offset)?,
ModelArch::Qwen3Dense(m) => m.forward(input, offset)?,
ModelArch::Qwen3MoeQuantized(m) => m.forward(input, offset)?,
ModelArch::Qwen3MoeDense(m) => m.forward(input, offset)?,
ModelArch::LlamaQuantized(m) => m.forward(input, offset)?,
ModelArch::LlamaDense(m) => m.forward(input, offset)?,
ModelArch::Qwen3_5Dense(m) => m.forward(input, offset)?,
};
squeeze_to_vocab(&raw)
}
/// Reset the KV cache before each new request so we don't attend
/// over a previous request's tokens. Some architectures have an
/// in-place reset; Llama needs a Cache rebuild (held inline in
/// the wrapper).
pub fn clear_kv_cache(&mut self) -> Result<()> {
match self {
ModelArch::Qwen3Quantized(_) => Ok(()), /* keeps cache by design;
* forward() handles offset */
ModelArch::Qwen3Dense(m) => {
m.clear_kv_cache();
Ok(())
}
ModelArch::Qwen3MoeQuantized(_) => Ok(()),
ModelArch::Qwen3MoeDense(m) => {
m.clear_kv_cache();
Ok(())
}
ModelArch::LlamaQuantized(_) => Ok(()),
ModelArch::LlamaDense(m) => m.clear_kv_cache(),
ModelArch::Qwen3_5Dense(m) => {
m.clear_kv_cache();
Ok(())
}
}
}
}
/// Squeeze any leading singleton dims off the logits tensor so the
/// caller gets a rank-1 `[vocab_size]` slice ready for sampling. Bails
/// on a non-singleton leading dim (would mean a batched forward, which
/// no caller emits today).
fn squeeze_to_vocab(t: &Tensor) -> Result<Tensor> {
let mut t = t.clone();
while t.dims().len() > 1 {
if t.dims()[0] != 1 {
anyhow::bail!(
"logits expected to start with a singleton dim, got shape {:?}",
t.dims()
);
}
t = t.squeeze(0)?;
}
Ok(t)
}
/// Llama dense wrapper. Bundles candle's `Llama` model with its
/// externally-managed `Cache` plus enough config to rebuild the
/// cache on `clear_kv_cache` (Llama's Cache doesn't expose a reset).
pub struct LlamaDense {
model: llama_dense::Llama,
cache: llama_dense::Cache,
config: llama_dense::Config,
dtype: DType,
device: Device,
}
impl LlamaDense {
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
Ok(self.model.forward(input, offset, &mut self.cache)?)
}
pub fn clear_kv_cache(&mut self) -> Result<()> {
self.cache = llama_dense::Cache::new(true, self.dtype, &self.config, &self.device)
.context("rebuild Llama Cache for new request")?;
Ok(())
}
}
/// Repetition penalty applied to recently-generated tokens before
/// sampling. 1.0 disables it; >1.0 makes recently-emitted tokens less
/// likely. mistral.rs and llama.cpp default to 1.1, which is enough to
/// stop small quantized models from degenerating into "Wait, no, no..."
/// loops without distorting normal output.
const REPEAT_PENALTY: f32 = 1.1;
/// Number of recently-generated tokens to feed into the repetition
/// penalty. Matches the candle quantized-qwen3 example default.
const REPEAT_LAST_N: usize = 64;
/// Architectures the dense safetensors path can construct. Keep
/// alphabetical; one entry per supported `config.json#/model_type`
/// value. New entries land alongside a new `ModelArch` variant + a
/// dispatch branch in `load_arch_dense` (plus, for TP, a parallel
/// pattern in `tp_qwen3.rs`).
const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["llama", "qwen3", "qwen3_5", "qwen3_moe"];
/// Pre-flight check the operator's `config.json` against the set of
/// architectures the dense path actually knows how to build. Surfaces
/// architecture mismatches as a single clean error before the serde
/// deserializer trips on missing fields — the latter happens because
/// every architecture has different hyperparameter names, so when the
/// JSON is e.g. Qwen3.6 wrapped under `text_config: {...}`, candle's
/// `qwen3::Config` finds none of its expected top-level fields and
/// fails with a cryptic `missing field 'vocab_size' at line N col 1`.
///
/// The result message names the model_type we saw, the supported set,
/// and points at the files an operator (or future contributor) needs
/// to touch to grow the supported set.
fn check_dense_config_supported(config_json: &str, model_id: &str) -> Result<()> {
let v: serde_json::Value = serde_json::from_str(config_json)
.with_context(|| format!("parse config.json for '{model_id}' as JSON"))?;
let model_type = v.get("model_type").and_then(|x| x.as_str()).unwrap_or("");
if model_type.is_empty() {
anyhow::bail!(
"config.json for '{model_id}' is missing `model_type`; the dense \
path needs it to gate architecture support (supported: {:?})",
DENSE_SUPPORTED_MODEL_TYPES
);
}
if DENSE_SUPPORTED_MODEL_TYPES.contains(&model_type) {
return Ok(());
}
// Bonus context: the model usually also lists architectures, which
// is what `transformers` keys on. Including it makes the error
// self-contained.
let architectures = v
.get("architectures")
.and_then(|x| x.as_array())
.map(|a| {
a.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect::<Vec<_>>()
})
.unwrap_or_default();
anyhow::bail!(
"unsupported model_type '{model_type}' for '{model_id}' \
(architectures={architectures:?}); the dense path supports {:?}. \
Add a `ModelArch` variant + load/forward branches in \
crates/neuron/src/harness/candle.rs (and the TP analogue in \
tp_qwen3.rs) to extend coverage.",
DENSE_SUPPORTED_MODEL_TYPES
);
}
/// Architectures the TP path can actually load and run. A subset of
/// `DENSE_SUPPORTED_MODEL_TYPES` — the single-GPU path supports more
/// families than the TP path because each TP-aware module is a real
/// chunk of work (`tp_qwen3.rs` is the only one shipped today).
#[cfg(feature = "cuda")]
const TP_SUPPORTED_MODEL_TYPES: &[&str] = &["qwen3", "qwen3_5"];
/// TP-side counterpart to `check_dense_config_supported`. Gates the
/// `load_tp` path on a narrower architecture set: even though the
/// single-GPU dense path knows how to build a Llama model, the worker
/// pool's `load_dense_shard` reconstructs the config as Qwen3 — there
/// is no `tp_llama.rs` yet. Surfacing this as a config-time error
/// (before we spawn workers and burn NCCL handshake cost) is much
/// kinder than the inevitable per-rank deserialise failure.
#[cfg(feature = "cuda")]
fn check_tp_arch_supported(config_json: &str, model_id: &str) -> Result<()> {
let v: serde_json::Value = serde_json::from_str(config_json)
.with_context(|| format!("parse config.json for '{model_id}' as JSON"))?;
let model_type = v.get("model_type").and_then(|x| x.as_str()).unwrap_or("");
if TP_SUPPORTED_MODEL_TYPES.contains(&model_type) {
return Ok(());
}
anyhow::bail!(
"tensor_parallel requested for '{model_id}' (model_type='{model_type}') but \
the TP path supports only {TP_SUPPORTED_MODEL_TYPES:?}. Adding a new \
TP-aware architecture needs a `harness/tp/tp_<family>.rs` module mirroring \
`tp_qwen3.rs` (sharded linears, AllReduce, per-rank head counts) and a \
dispatch in `WorkerPool::load_dense_shard`. For models that fit on one \
GPU, drop `tensor_parallel` to use the single-GPU dense path."
)
}
/// Resolve the effective HuggingFace cache directory for the candle
/// harness. Precedence (first hit wins):
///
/// 1. Explicit `hf_cache` from `[harness.candle]` in `neuron.toml`.
/// Operator's wishes always win.
/// 2. `HF_HUB_CACHE` env var. The Python `huggingface_hub` library
/// points at the cache root directly with this var; the Rust
/// `hf-hub` crate doesn't read it natively, so we bridge here.
/// Honouring it lets a neuron host share a cache directory with
/// Python tooling and other harnesses without per-tool config.
/// 3. `HF_HOME` env var. Canonical HuggingFace base directory; the
/// cache lives at `$HF_HOME/hub`. Hf-hub respects this on its own,
/// but we resolve it here too so the resulting path shows up in
/// logs alongside the explicit/HF_HUB_CACHE cases.
/// 4. `None`. Falls through to `hf-hub`'s default
/// (`~/.cache/huggingface/hub`).
fn resolve_hf_cache(explicit: Option<PathBuf>) -> Option<PathBuf> {
if let Some(p) = explicit {
return Some(p);
}
if let Ok(v) = std::env::var("HF_HUB_CACHE")
&& !v.is_empty()
{
return Some(PathBuf::from(v));
}
if let Ok(v) = std::env::var("HF_HOME")
&& !v.is_empty()
{
return Some(PathBuf::from(v).join("hub"));
}
None
}
/// Apply the repetition penalty (if any) to the prediction logits and
/// then sample. Centralises the prefill / generation-loop call sites
/// so they share identical sampling behaviour.
fn sample_with_penalty(
logits: &Tensor,
history: &[u32],
logits_processor: &mut LogitsProcessor,
) -> Result<u32> {
let penalised = if (REPEAT_PENALTY - 1.0).abs() < f32::EPSILON || history.is_empty() {
logits.clone()
} else {
let start = history.len().saturating_sub(REPEAT_LAST_N);
candle_transformers::utils::apply_repeat_penalty(logits, REPEAT_PENALTY, &history[start..])?
};
Ok(logits_processor.sample(&penalised)?)
}
impl CandleHarness {
pub fn new(bind_url: String, hf_cache: Option<PathBuf>) -> Self {
let hf_cache = resolve_hf_cache(hf_cache);
if let Some(p) = &hf_cache {
tracing::info!(path = %p.display(), "candle harness using HuggingFace cache");
}
Self {
models: Arc::new(RwLock::new(HashMap::new())),
hf_cache,
bind_url,
}
}
/// Pick a candle `Device` for the requested indices. Without the
/// `cuda` feature, or if CUDA initialisation fails, falls back to CPU.
fn pick_device(devices: &[u32]) -> Result<Device> {
let _idx = devices.first().copied().unwrap_or(0) as usize;
#[cfg(feature = "cuda")]
{
match Device::new_cuda(_idx) {
Ok(d) => return Ok(d),
Err(e) => tracing::warn!(
device = _idx,
error = %e,
"CUDA device unavailable, falling back to CPU"
),
}
}
Ok(Device::Cpu)
}
/// Build an hf-hub API client pre-configured with the harness's
/// `hf_cache` (when one is set).
fn hf_api(&self) -> Result<hf_hub::api::tokio::Api> {
let mut builder = hf_hub::api::tokio::ApiBuilder::new();
if let Some(cache) = &self.hf_cache {
builder = builder.with_cache_dir(cache.clone());
}
builder.build().context("build hf-hub API")
}
/// Resolve a dense (bf16/fp16 safetensors) model to its local file
/// paths.
///
/// Handles both sharded repos (`model.safetensors.index.json` plus
/// several `model-*.safetensors`) and the single-file layout
/// (`model.safetensors`). Returns the safetensors paths in
/// arbitrary order — `VarBuilder` unifies them into one tensor view.
async fn resolve_dense_files(
&self,
spec: &ModelSpec,
) -> Result<(PathBuf, PathBuf, Vec<PathBuf>)> {
let api = self.hf_api()?;
let repo = api.model(spec.model_id.clone());
let config_path = repo
.get("config.json")
.await
.with_context(|| format!("fetch config.json from {}", spec.model_id))?;
let tokenizer_path = repo
.get("tokenizer.json")
.await
.with_context(|| format!("fetch tokenizer.json from {}", spec.model_id))?;
// Prefer the sharded layout (most HF dense models > 5B ship it).
let safetensors_paths = match repo.get("model.safetensors.index.json").await {
Ok(index_path) => {
let index_text = std::fs::read_to_string(&index_path)
.context("read model.safetensors.index.json")?;
let index: serde_json::Value = serde_json::from_str(&index_text)
.context("parse model.safetensors.index.json")?;
let weight_map = index
.get("weight_map")
.and_then(|v| v.as_object())
.ok_or_else(|| {
anyhow::anyhow!("safetensors index missing weight_map object")
})?;
let unique: std::collections::BTreeSet<String> = weight_map
.values()
.filter_map(|v| v.as_str().map(String::from))
.collect();
let mut paths = Vec::with_capacity(unique.len());
for fname in unique {
let p = repo
.get(&fname)
.await
.with_context(|| format!("fetch sharded safetensors {fname}"))?;
paths.push(p);
}
paths
}
Err(_) => {
// Single-file fallback.
let p = repo
.get("model.safetensors")
.await
.context("fetch model.safetensors (single-file layout)")?;
vec![p]
}
};
Ok((config_path, tokenizer_path, safetensors_paths))
}
/// Resolve + load a GGUF (pre-quantized) Qwen3. Returns the
/// tokenizer.json path so the caller can construct the Tokenizer
/// uniformly across source formats.
async fn load_arch_gguf(
&self,
spec: &ModelSpec,
device: &Device,
) -> Result<(PathBuf, ModelArch)> {
let (gguf_path, tokenizer_path) = self.resolve_files(spec).await?;
let device_for_load = device.clone();
let gguf_path_for_load = gguf_path.clone();
let model_id_for_log = spec.model_id.clone();
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
tracing::info!(model = %model_id_for_log, path = ?gguf_path_for_load, "loading GGUF");
let mut file = std::fs::File::open(&gguf_path_for_load).context("open GGUF file")?;
let content = gguf_file::Content::read(&mut file)
.map_err(|e| anyhow::anyhow!("parse GGUF: {e}"))?;
let architecture = content
.metadata
.get("general.architecture")
.and_then(|v| v.to_string().ok().cloned())
.unwrap_or_default();
tracing::info!(architecture = %architecture, "GGUF architecture");
// The `general.architecture` GGUF metadata key follows
// llama.cpp conventions (lowercase, no underscores in some
// cases) — `qwen3moe`, not `qwen3_moe`.
match architecture.as_str() {
"qwen3" => {
let weights =
QuantizedQwen3Weights::from_gguf(content, &mut file, &device_for_load)
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
Ok(ModelArch::Qwen3Quantized(weights))
}
"qwen3moe" => {
// GGUFQWenMoE takes an explicit compute dtype
// alongside the device — F16 matches the GGUF
// weights' typical accumulation precision and
// gives the best tokens/sec on consumer cards.
let weights =
GGUFQWenMoE::from_gguf(content, &mut file, &device_for_load, DType::F16)
.map_err(|e| anyhow::anyhow!("from_gguf qwen3_moe: {e}"))?;
Ok(ModelArch::Qwen3MoeQuantized(weights))
}
"llama" => {
let weights =
QuantizedLlamaWeights::from_gguf(content, &mut file, &device_for_load)
.map_err(|e| anyhow::anyhow!("from_gguf llama: {e}"))?;
Ok(ModelArch::LlamaQuantized(weights))
}
other => anyhow::bail!(
"unsupported GGUF architecture '{other}'; quantized path supports \
qwen3, qwen3moe, llama"
),
}
})
.await
.context("blocking GGUF load task panicked")??;
Ok((tokenizer_path, arch))
}
/// Resolve + load a dense Qwen3 from safetensors. Uses
/// `candle-transformers::models::qwen3::ModelForCausalLM` and
/// builds a VarBuilder over the mmap'd safetensors files. dtype
/// is bf16 by default to match the HF distribution dtype for
/// recent Qwen3 family models; fall back to f16 if the device
/// doesn't support bf16.
async fn load_arch_dense(
&self,
spec: &ModelSpec,
device: &Device,
) -> Result<(PathBuf, ModelArch)> {
let (config_path, tokenizer_path, safetensors_paths) =
self.resolve_dense_files(spec).await?;
let device_for_load = device.clone();
let model_id_for_log = spec.model_id.clone();
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
let cfg_text = std::fs::read_to_string(&config_path).context("read config.json")?;
check_dense_config_supported(&cfg_text, &model_id_for_log)?;
// Peek at model_type to choose the family before the
// typed deserialize — each family has its own Config.
let model_type = serde_json::from_str::<serde_json::Value>(&cfg_text)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
tracing::info!(
model = %model_id_for_log,
model_type = %model_type,
shards = safetensors_paths.len(),
"loading dense model from safetensors"
);
// bf16 is the canonical distribution dtype for Qwen3 /
// Llama 3 / Qwen3 MoE. CUDA on Ada+ has hardware bf16;
// Ampere has it too. CPU emulates.
let dtype = DType::BF16;
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
// mutation by another process while we hold the mapping is
// UB. We trust the HF cache is immutable-by-design.
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&safetensors_paths, dtype, &device_for_load)
.context("build VarBuilder over safetensors")?
};
match model_type.as_str() {
"qwen3" => {
let cfg: qwen3_dense::Config =
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
Ok(ModelArch::Qwen3Dense(model))
}
"qwen3_moe" => {
let cfg: qwen3_moe_dense::Config =
serde_json::from_str(&cfg_text).context("parse Qwen3 MoE config.json")?;
let model = qwen3_moe_dense::ModelForCausalLM::new(&cfg, vb)
.map_err(|e| anyhow::anyhow!("build Qwen3 MoE dense model: {e}"))?;
Ok(ModelArch::Qwen3MoeDense(model))
}
"llama" => {
let cfg: llama_dense::LlamaConfig =
serde_json::from_str(&cfg_text).context("parse Llama config.json")?;
// Llama has multiple sub-variants (Llama 1 has no
// GQA; Llama 3 does). `LlamaConfig::into_config`
// resolves the right shape; the `use_flash_attn`
// arg defaults to false — the flash kernel is a
// separate feature flag and uses extra VRAM.
let config = cfg.into_config(false);
let cache = llama_dense::Cache::new(true, dtype, &config, &device_for_load)
.context("build Llama Cache")?;
let model = llama_dense::Llama::load(vb, &config)
.map_err(|e| anyhow::anyhow!("build Llama dense model: {e}"))?;
Ok(ModelArch::LlamaDense(Box::new(LlamaDense {
model,
cache,
config,
dtype,
device: device_for_load,
})))
}
"qwen3_5" => {
// Qwen3-Next needs a ShardedVarBuilder because its
// load functions use the sharded backend (so they
// can be reused unchanged by the future TP variant).
// With world_size=1 the backend falls through to
// the unsharded path, so there is no per-load cost.
let cfg: super::arch::qwen3_5::Config = serde_json::from_str(&cfg_text)
.context("parse Qwen3-Next (qwen3_5) config.json")?;
let sharded_vb = unsafe {
candle_nn::var_builder::ShardedSafeTensors::var_builder(
&safetensors_paths,
dtype,
&device_for_load,
)
.context("build ShardedVarBuilder for Qwen3-Next")?
};
let model = super::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, sharded_vb)
.context("build Qwen3-Next dense model")?;
Ok(ModelArch::Qwen3_5Dense(model))
}
other => {
// Defensive: `check_dense_config_supported` already
// gated on the supported set, so this branch is
// unreachable unless that list and the match here
// drift apart.
anyhow::bail!(
"unrouted supported model_type '{other}' — \
DENSE_SUPPORTED_MODEL_TYPES and load_arch_dense \
must stay in sync"
)
}
}
})
.await
.context("blocking dense load task panicked")??;
Ok((tokenizer_path, arch))
}
/// Resolve a model spec to local GGUF and tokenizer file paths via
/// hf-hub. Downloads on first use; subsequent calls are cached.
async fn resolve_files(&self, spec: &ModelSpec) -> Result<(PathBuf, PathBuf)> {
let api = self.hf_api()?;
let repo = api.model(spec.model_id.clone());
let info = repo
.info()
.await
.with_context(|| format!("fetch HF repo info for {}", spec.model_id))?;
let quant = spec.quant.as_deref().unwrap_or("");
let quant_lc = quant.to_lowercase();
let gguf_filename = info
.siblings
.iter()
.map(|s| s.rfilename.as_str())
.filter(|name| name.to_lowercase().ends_with(".gguf"))
.find(|name| quant_lc.is_empty() || name.to_lowercase().contains(&quant_lc))
.ok_or_else(|| {
anyhow::anyhow!(
"no GGUF file matching quant {:?} in repo {}",
spec.quant,
spec.model_id
)
})?
.to_string();
tracing::info!(
model = %spec.model_id,
file = %gguf_filename,
"resolving GGUF (may be cached)"
);
let gguf_path = repo
.get(&gguf_filename)
.await
.with_context(|| format!("fetch GGUF {gguf_filename}"))?;
// GGUF-only HF repos (unsloth/Qwen3-*-GGUF, Qwen/Qwen3-*-GGUF,
// etc.) ship the .gguf file but not tokenizer.json — the
// tokenizer.json lives in the base non-GGUF repo. Derive the
// base repo id by stripping a `-GGUF` / `-gguf` suffix; if
// there's no such suffix the same repo is used (works for
// non-GGUF model_ids).
let tokenizer_repo_id = spec
.model_id
.strip_suffix("-GGUF")
.or_else(|| spec.model_id.strip_suffix("-gguf"))
.unwrap_or(spec.model_id.as_str())
.to_string();
let tokenizer_repo = if tokenizer_repo_id == spec.model_id {
repo
} else {
tracing::debug!(
from = %spec.model_id,
to = %tokenizer_repo_id,
"tokenizer.json sourced from base repo (GGUF suffix stripped)"
);
api.model(tokenizer_repo_id.clone())
};
let tokenizer_path = tokenizer_repo
.get("tokenizer.json")
.await
.with_context(|| format!("fetch tokenizer.json from {tokenizer_repo_id}"))?;
Ok((gguf_path, tokenizer_path))
}
/// Run a non-streaming chat completion against a loaded model.
///
/// Returns a typed `InferenceError` when the model isn't loaded so the
/// handler can map to an appropriate HTTP status without string-matching.
pub async fn chat_completion(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, InferenceError> {
let handle = {
let models = self.models.read().await;
models.get(&request.model).cloned()
};
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
// The match is technically infallible without `cuda` (only Single
// exists), but the cfg-gated Tp arm makes this the right shape
// under both feature flags.
#[allow(clippy::infallible_destructuring_match)]
let loaded = match handle {
LoadedHandle::Single(m) => m,
#[cfg(feature = "cuda")]
LoadedHandle::Tp(m) => {
return self.chat_completion_tp(m, request).await;
}
};
let prompt = format_qwen3_prompt(&request.messages);
let encoding = loaded
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let prompt_len = prompt_tokens.len();
let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(512) as usize;
let seed = unix_subsec_nanos();
let eos_id = loaded
.tokenizer
.token_to_id("<|im_end|>")
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
let arch_arc = Arc::clone(&loaded.arch);
let device = loaded.device.clone();
let model_id = request.model.clone();
let (generated_ids, finish_reason) =
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
let mut guard = arch_arc.blocking_lock();
run_inference(
&mut guard,
&device,
&prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
)
})
.await
.map_err(|e| InferenceError::Other(anyhow::anyhow!("inference task panicked: {e}")))?
.map_err(InferenceError::Other)?;
let completion_text = loaded
.tokenizer
.decode(&generated_ids, true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
let usage = Usage {
prompt_tokens: prompt_len as u64,
completion_tokens: generated_ids.len() as u64,
total_tokens: (prompt_len + generated_ids.len()) as u64,
};
Ok(ChatCompletionResponse {
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
object: "chat.completion".into(),
created: unix_now_secs(),
model: model_id,
choices: vec![ChatCompletionChoice {
index: 0,
message: ChatMessage {
role: "assistant".into(),
content: MessageContent::Text(completion_text),
extra: serde_json::Value::Object(Default::default()),
},
finish_reason: Some(finish_reason),
extra: serde_json::Value::Object(Default::default()),
}],
usage: Some(usage),
extra: serde_json::Value::Object(Default::default()),
})
}
/// Run a streaming chat completion against a loaded model.
///
/// Returns an `mpsc::Receiver` that yields `ChatCompletionChunk`s in
/// OpenAI SSE format. The first chunk carries the assistant role;
/// subsequent chunks carry incremental `content` deltas; the final
/// chunk carries `finish_reason`. The handler is responsible for
/// wrapping these into an SSE response and appending the `[DONE]`
/// terminator.
///
/// Token-by-token decoding tracks the cumulative decoded prefix so
/// BPE byte-fallback boundaries don't split a UTF-8 char across
/// chunks.
pub async fn chat_completion_stream(
&self,
request: ChatCompletionRequest,
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
let handle = {
let models = self.models.read().await;
models.get(&request.model).cloned()
};
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
// The match is technically infallible without `cuda` (only Single
// exists), but the cfg-gated Tp arm makes this the right shape
// under both feature flags.
#[allow(clippy::infallible_destructuring_match)]
let loaded = match handle {
LoadedHandle::Single(m) => m,
#[cfg(feature = "cuda")]
LoadedHandle::Tp(m) => {
return self.chat_completion_tp_stream(m, request).await;
}
};
let prompt = format_qwen3_prompt(&request.messages);
let encoding = loaded
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(512) as usize;
let seed = unix_subsec_nanos();
let eos_id = loaded
.tokenizer
.token_to_id("<|im_end|>")
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
let arch_arc = Arc::clone(&loaded.arch);
let device = loaded.device.clone();
let tokenizer = loaded.tokenizer.clone();
let model_id = request.model.clone();
let id = format!("chatcmpl-{:x}", unix_subsec_nanos());
let created = unix_now_secs();
// Bounded channel so the producer (blocking inference) is back-
// pressured by the consumer (SSE writer). 32 is generous —
// tokens arrive one at a time and the SSE writer is async.
let (tx, rx) = mpsc::channel::<ChatCompletionChunk>(32);
// Lead chunk: announce the assistant role per OpenAI streaming
// conventions. Tools that auto-detect a streaming reply expect
// this before any content delta.
let role_chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".into(),
created,
model: model_id.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: json!({"role": "assistant"}),
finish_reason: None,
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
// If sending the role chunk fails the receiver is already gone;
// bail before kicking off the heavy blocking work.
tx.send(role_chunk)
.await
.map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?;
tokio::task::spawn_blocking(move || {
let mut guard = arch_arc.blocking_lock();
if let Err(e) = run_inference_streaming(
&mut guard,
&device,
&tokenizer,
&prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
&id,
created,
&model_id,
&tx,
) {
tracing::warn!(model = %model_id, error = %e, "streaming inference failed");
}
});
Ok(rx)
}
}
#[async_trait]
impl Harness for CandleHarness {
fn name(&self) -> &str {
"candle"
}
async fn health(&self) -> HarnessHealth {
HarnessHealth {
name: "candle".into(),
running: true,
uptime_secs: None,
}
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
let models = self.models.read().await;
Ok(models
.values()
.map(|h| ModelInfo {
id: h.model_id().into(),
harness: "candle".into(),
status: "loaded".into(),
devices: h.devices(),
vram_used_mb: None,
})
.collect())
}
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
if spec.harness != "candle" {
anyhow::bail!("expected harness=candle, got harness={}", spec.harness);
}
{
let models = self.models.read().await;
if models.contains_key(&spec.model_id) {
anyhow::bail!("model '{}' already loaded", spec.model_id);
}
}
let tp_size = spec.tensor_parallel.unwrap_or(1);
if tp_size > 1 {
#[cfg(feature = "cuda")]
{
return self.load_tp(spec, tp_size).await;
}
#[cfg(not(feature = "cuda"))]
{
anyhow::bail!(
"tensor_parallel={tp_size} requested for '{}': this neuron \
binary was built without --features cuda; TP requires CUDA + NCCL",
spec.model_id
);
}
}
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
let device = Self::pick_device(&devices)?;
// Dispatch by source format: GGUF (pre-quantized, single-GPU
// only path) vs safetensors dense (bf16/fp16; the path that
// grows TP support). `spec.quant` is the signal — Some means
// the operator picked a quantized GGUF; None means dense.
let (tokenizer_path, arch) = if spec.quant.is_some() {
self.load_arch_gguf(spec, &device).await?
} else {
self.load_arch_dense(spec, &device).await?
};
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
let loaded = Arc::new(LoadedModel {
model_id: spec.model_id.clone(),
arch: Arc::new(Mutex::new(arch)),
tokenizer,
device,
quant: spec.quant.clone(),
devices,
});
let mut models = self.models.write().await;
models.insert(spec.model_id.clone(), LoadedHandle::Single(loaded));
tracing::info!(model = %spec.model_id, "model loaded");
Ok(())
}
async fn unload_model(&self, model_id: &str) -> Result<()> {
let removed = {
let mut models = self.models.write().await;
models.remove(model_id)
};
let Some(handle) = removed else {
anyhow::bail!("model '{model_id}' not loaded");
};
// Single-GPU drops are immediate — the LoadedModel goes out of
// scope with the Arc and candle frees VRAM. TP unloads also
// need to tell every worker to drop its shard before the pool
// itself is dropped (otherwise the workers keep their shards
// around until Shutdown, which is wasteful and would surface
// as VRAM not freed promptly).
match handle {
LoadedHandle::Single(_) => {}
#[cfg(feature = "cuda")]
LoadedHandle::Tp(tp) => {
// Try to recover the inner TpLoadedModel so we can move
// the pool and shut it down. If anyone else still holds
// a clone of the Arc (shouldn't happen — the only owners
// are the registry and any in-flight chat_completion),
// bail with a clear marker rather than silently leaking.
let tp = match Arc::try_unwrap(tp) {
Ok(t) => t,
Err(arc) => {
// Reinsert so we don't leave the registry in an
// inconsistent state.
let mut models = self.models.write().await;
models.insert(model_id.into(), LoadedHandle::Tp(arc));
anyhow::bail!("cannot unload '{model_id}': inference still in flight");
}
};
let mut pool = tp.pool.into_inner();
if let Err(e) = pool.unload_model(model_id).await {
tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed");
}
if let Err(e) = pool.shutdown().await {
tracing::warn!(model = %model_id, error = %e, "TP pool shutdown failed");
}
}
}
tracing::info!(model = %model_id, "model unloaded");
Ok(())
}
async fn inference_endpoint(&self, model_id: &str) -> Option<String> {
let models = self.models.read().await;
models.contains_key(model_id).then(|| self.bind_url.clone())
}
}
impl CandleHarness {
/// Tensor-parallel load. Resolves dense safetensors via hf-hub the
/// same way the single-GPU dense path does, spins up a TP worker
/// pool sized to `tp_size`, runs the NCCL handshake, then has
/// every rank load its shard of the model.
///
/// `spec.devices` carries the per-rank CUDA device indices (one
/// entry per rank, in rank order); defaults to `0..tp_size`.
#[cfg(feature = "cuda")]
async fn load_tp(&self, spec: &ModelSpec, tp_size: u32) -> Result<()> {
use std::sync::Arc as StdArc;
use tokio::sync::Mutex as TMutex;
// Default per-rank device assignment: 0, 1, ..., tp_size - 1.
let devices = spec
.devices
.clone()
.unwrap_or_else(|| (0..tp_size).collect());
if devices.len() as u32 != tp_size {
anyhow::bail!(
"tensor_parallel={tp_size} requires {tp_size} entries in devices, got {}",
devices.len()
);
}
if spec.quant.is_some() {
anyhow::bail!(
"tensor_parallel={tp_size} with quant={:?}: GGUF quantized models \
are not supported in the TP path; use a dense safetensors source",
spec.quant
);
}
// 1. Resolve config + tokenizer + safetensors via hf-hub.
let (config_path, tokenizer_path, safetensors_paths) =
self.resolve_dense_files(spec).await?;
let config_json = std::fs::read_to_string(&config_path).context("read config.json")?;
// Reject unsupported architectures *before* spawning the worker
// pool and fanning out NCCL — otherwise we'd burn the pool
// lifecycle on a load that's guaranteed to fail at deserialise
// time inside every rank.
check_dense_config_supported(&config_json, &spec.model_id)?;
// The TP path knows how to ship and reconstruct a Qwen3 dense
// shard (`tp_qwen3.rs`). Other architectures may pass the
// single-GPU `check_dense_config_supported` check above but
// have no TP-aware module — bail with a clear marker pointing
// at the file the implementer needs to add. This keeps an
// operator who sets `tensor_parallel=2` on a Llama model from
// silently routing through `pool.load_dense_shard` (which
// assumes Qwen3 config shape on the worker side) and producing
// a confusing config-parse failure inside every rank.
check_tp_arch_supported(&config_json, &spec.model_id)?;
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
// 1..tp_size are subprocesses, one per device after the
// leader's own.
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
let mut pool = super::tp::WorkerPool::spawn(&exe, tp_size, &devices).await?;
// 3. NCCL handshake across all ranks.
let leader_device_idx = devices[0];
pool.init_nccl(leader_device_idx).await?;
// 4. Pick the leader's candle Device (same index as init_nccl).
let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize)
.context("Device::new_cuda for TP leader")?;
// 5. Load this rank's shard on every rank.
let leader_model = pool
.load_dense_shard(
&spec.model_id,
&config_json,
&safetensors_paths,
&leader_device,
candle_core::DType::BF16,
spec.quant.clone(),
)
.await?;
// 6. Tokenizer (same as single-GPU path).
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
let tp_loaded = StdArc::new(TpLoadedModel {
model_id: spec.model_id.clone(),
tokenizer,
devices: devices.clone(),
pool: TMutex::new(pool),
leader_model,
});
let mut models = self.models.write().await;
models.insert(spec.model_id.clone(), LoadedHandle::Tp(tp_loaded));
tracing::info!(
model = %spec.model_id,
tp_size,
?devices,
"TP model loaded"
);
Ok(())
}
/// Non-streaming chat completion against a TP model.
///
/// The actual work runs inside a `tokio::spawn`'d task so the HTTP
/// client disconnecting (curl timeout, browser nav-away, etc.)
/// can't cancel the future mid-`pool.generate_step` and leave the
/// worker subprocesses mid-RPC. If the spawned task is dropped,
/// it still runs to completion and finishes draining the pool —
/// the next inference request finds a clean pool. The HTTP layer
/// just gives up on the response.
///
/// Every step also emits `info`/`debug` tracing so journalctl
/// shows where time went without needing to surface internals in
/// the HTTP error response.
#[cfg(feature = "cuda")]
async fn chat_completion_tp(
&self,
tp: Arc<TpLoadedModel>,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, InferenceError> {
let handle = tokio::spawn(chat_completion_tp_inner(tp, request));
match handle.await {
Ok(result) => result,
Err(join_err) => Err(InferenceError::Other(anyhow::anyhow!(
"TP inference task panicked or was cancelled: {join_err}"
))),
}
}
/// Streaming counterpart to `chat_completion_tp`. Same per-step
/// orchestration (clear cache, prefill, sample, decode loop) but
/// emits one `ChatCompletionChunk` per token over an mpsc channel
/// so the handler can write an SSE stream.
///
/// Unlike the single-GPU streaming path (which runs the candle
/// forward inside `spawn_blocking` and uses `blocking_send`), the
/// TP loop is itself async — every `pool.generate_step` awaits the
/// leader's spawn_blocking forward plus every worker's recv_only.
/// So we `tokio::spawn` the orchestration task and use plain
/// `Sender::send`.
#[cfg(feature = "cuda")]
async fn chat_completion_tp_stream(
&self,
tp: Arc<TpLoadedModel>,
request: ChatCompletionRequest,
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
let prompt = format_qwen3_prompt(&request.messages);
let encoding = tp
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let prompt_len = prompt_tokens.len();
let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(512) as usize;
let seed = unix_subsec_nanos();
let eos_id = tp
.tokenizer
.token_to_id("<|im_end|>")
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
let model_id = request.model.clone();
let id = format!("chatcmpl-{:x}", unix_subsec_nanos());
let created = unix_now_secs();
let tokenizer = tp.tokenizer.clone();
// Bounded channel — back-pressures the producer when the SSE
// writer is slow.
let (tx, rx) = mpsc::channel::<ChatCompletionChunk>(32);
// Role chunk first, before kicking off the heavy work — if the
// receiver is gone by now there's no point starting inference.
let role_chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".into(),
created,
model: model_id.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: json!({"role": "assistant"}),
finish_reason: None,
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
tx.send(role_chunk)
.await
.map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?;
// The orchestration task. Holds the pool lock for the lifetime
// of this inference; concurrent requests against the same TP
// model serialise behind it.
let tp_for_task = Arc::clone(&tp);
tokio::spawn(async move {
let mut pool = tp_for_task.pool.lock().await;
let leader_arc = tp_for_task.leader_model.clone();
if let Err(e) = pool.clear_kv_cache(&model_id, leader_arc.clone()).await {
tracing::warn!(model = %model_id, error = %e, "TP stream: clear_kv_cache failed");
return;
}
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut all_tokens: Vec<u32> = Vec::new();
let mut decoded_prefix = String::new();
let mut finish_reason = "length".to_string();
// Prefill — every rank embeds the prompt, offset = 0.
let logits = match pool
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
.await
{
Ok(l) => l,
Err(e) => {
tracing::warn!(model = %model_id, error = %e, "TP stream: prefill failed");
return;
}
};
let mut next_token = match sample_with_penalty(
&logits,
&all_tokens,
&mut logits_processor,
) {
Ok(t) => t,
Err(e) => {
tracing::warn!(model = %model_id, error = %e, "TP stream: prefill sample failed");
return;
}
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
} else {
all_tokens.push(next_token);
if !emit_chunk(
&all_tokens,
&mut decoded_prefix,
&tokenizer,
&tx,
&id,
created,
&model_id,
)
.await
{
return;
}
for index in 0..max_new.saturating_sub(1) {
let logits = match pool
.generate_step(
&model_id,
leader_arc.clone(),
vec![next_token],
prompt_len + index,
)
.await
{
Ok(l) => l,
Err(e) => {
tracing::warn!(
model = %model_id,
error = %e,
"TP stream: decode step failed"
);
return;
}
};
next_token =
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
tracing::warn!(
model = %model_id,
error = %e,
"TP stream: decode sample failed"
);
return;
}
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
}
all_tokens.push(next_token);
if !emit_chunk(
&all_tokens,
&mut decoded_prefix,
&tokenizer,
&tx,
&id,
created,
&model_id,
)
.await
{
return;
}
}
}
// Final chunk carrying finish_reason.
let final_chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".into(),
created,
model: model_id.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: serde_json::Value::Object(Default::default()),
finish_reason: Some(finish_reason),
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
let _ = tx.send(final_chunk).await;
});
Ok(rx)
}
}
/// Body of the TP non-streaming chat completion, hoisted out of
/// `CandleHarness::chat_completion_tp` so it can run inside
/// `tokio::spawn` (which requires a `'static` future) and survive
/// HTTP-layer cancellation.
///
/// Tracing strategy: `info` for request entry/exit so journalctl
/// always shows when an inference started and finished; `debug` for
/// per-step timing so an operator running with `RUST_LOG=debug` sees
/// where the request actually spends its time without needing to
/// instrument the model code.
#[cfg(feature = "cuda")]
async fn chat_completion_tp_inner(
tp: Arc<TpLoadedModel>,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, InferenceError> {
let req_start = std::time::Instant::now();
let model_id = request.model.clone();
let prompt = format_qwen3_prompt(&request.messages);
let encoding = tp
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let prompt_len = prompt_tokens.len();
let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(512) as usize;
let seed = unix_subsec_nanos();
let eos_id = tp
.tokenizer
.token_to_id("<|im_end|>")
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
tracing::info!(
model = %model_id,
prompt_len,
max_new,
temperature,
?top_p,
?eos_id,
"TP chat_completion: starting"
);
// Acquire the pool lock for the duration of the request. The
// leader_model's own Mutex is acquired step-by-step inside
// pool.generate_step (so spawn_blocking can grab it without
// holding the pool lock across the blocking_lock call).
let lock_start = std::time::Instant::now();
let mut pool = tp.pool.lock().await;
tracing::debug!(
model = %model_id,
elapsed_ms = lock_start.elapsed().as_millis(),
"TP chat_completion: pool lock acquired"
);
let leader_arc = tp.leader_model.clone();
// Reset every rank's KV cache so this request doesn't attend
// over the previous request's tokens.
let clear_start = std::time::Instant::now();
pool.clear_kv_cache(&model_id, leader_arc.clone())
.await
.map_err(InferenceError::Other)?;
tracing::debug!(
model = %model_id,
elapsed_ms = clear_start.elapsed().as_millis(),
"TP chat_completion: kv cache cleared"
);
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut generated: Vec<u32> = Vec::new();
let mut finish_reason = "length".to_string();
// Prefill: every rank embeds the whole prompt, offset = 0.
let prefill_start = std::time::Instant::now();
let logits = pool
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
.await
.map_err(InferenceError::Other)?;
tracing::info!(
model = %model_id,
prompt_len,
elapsed_ms = prefill_start.elapsed().as_millis(),
"TP chat_completion: prefill complete"
);
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
.map_err(InferenceError::Other)?;
if Some(next_token) == eos_id {
finish_reason = "stop".into();
} else {
generated.push(next_token);
let decode_start = std::time::Instant::now();
for index in 0..max_new.saturating_sub(1) {
let step_start = std::time::Instant::now();
let logits = pool
.generate_step(
&model_id,
leader_arc.clone(),
vec![next_token],
prompt_len + index,
)
.await
.map_err(InferenceError::Other)?;
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
.map_err(InferenceError::Other)?;
tracing::trace!(
model = %model_id,
step = index,
next_token,
step_ms = step_start.elapsed().as_millis(),
"TP chat_completion: decode step"
);
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
}
generated.push(next_token);
}
tracing::info!(
model = %model_id,
generated = generated.len(),
elapsed_ms = decode_start.elapsed().as_millis(),
"TP chat_completion: decode complete"
);
}
drop(pool);
let completion_text = tp
.tokenizer
.decode(&generated, true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
let usage = Usage {
prompt_tokens: prompt_len as u64,
completion_tokens: generated.len() as u64,
total_tokens: (prompt_len + generated.len()) as u64,
};
tracing::info!(
model = %model_id,
prompt_tokens = prompt_len,
completion_tokens = generated.len(),
finish_reason = %finish_reason,
total_ms = req_start.elapsed().as_millis(),
"TP chat_completion: done"
);
Ok(ChatCompletionResponse {
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
object: "chat.completion".into(),
created: unix_now_secs(),
model: model_id,
choices: vec![ChatCompletionChoice {
index: 0,
message: ChatMessage {
role: "assistant".into(),
content: MessageContent::Text(completion_text),
extra: serde_json::Value::Object(Default::default()),
},
finish_reason: Some(finish_reason),
extra: serde_json::Value::Object(Default::default()),
}],
usage: Some(usage),
extra: serde_json::Value::Object(Default::default()),
})
}
/// Decode the cumulative token list, emit the delta (substring appended
/// since the last chunk) as a `chat.completion.chunk`. Returns `false`
/// if the receiver has hung up — the caller should bail.
#[cfg(feature = "cuda")]
async fn emit_chunk(
all_tokens: &[u32],
decoded_prefix: &mut String,
tokenizer: &Tokenizer,
tx: &mpsc::Sender<ChatCompletionChunk>,
id: &str,
created: u64,
model_id: &str,
) -> bool {
let full = match tokenizer.decode(all_tokens, true) {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "TP stream: decode failed");
return false;
}
};
if full.len() > decoded_prefix.len() {
let delta = full[decoded_prefix.len()..].to_string();
*decoded_prefix = full;
let chunk = ChatCompletionChunk {
id: id.into(),
object: "chat.completion.chunk".into(),
created,
model: model_id.into(),
choices: vec![ChunkChoice {
index: 0,
delta: json!({ "content": delta }),
finish_reason: None,
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
if tx.send(chunk).await.is_err() {
return false;
}
}
true
}
/// Errors returned by `CandleHarness::chat_completion`. The
/// `ModelNotLoaded` variant lets the HTTP handler map cleanly to 404
/// without string-matching on anyhow messages.
#[derive(Debug, thiserror::Error)]
pub enum InferenceError {
#[error("model '{0}' not loaded on this neuron")]
ModelNotLoaded(String),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
/// Apply the Qwen3 chat template:
///
/// ```text
/// <|im_start|>{role}\n{content}<|im_end|>\n
/// ...
/// <|im_start|>assistant\n
/// ```
///
/// The trailing `<|im_start|>assistant\n` cues the model to begin a turn.
/// Non-text content parts (vision blocks) are joined as text only; full
/// multimodal handling is out of scope for Stage 3.
fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
let mut prompt = String::new();
for msg in messages {
let content = match &msg.content {
MessageContent::Text(s) => s.clone(),
MessageContent::Parts(parts) => parts
.iter()
.filter_map(|p| p.get("text").and_then(|v| v.as_str()))
.collect::<Vec<_>>()
.join(""),
};
prompt.push_str("<|im_start|>");
prompt.push_str(&msg.role);
prompt.push('\n');
prompt.push_str(&content);
prompt.push_str("<|im_end|>\n");
}
prompt.push_str("<|im_start|>assistant\n");
prompt
}
#[allow(clippy::too_many_arguments)]
fn run_inference(
arch: &mut ModelArch,
device: &Device,
prompt_tokens: &[u32],
max_new: usize,
temperature: f64,
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
) -> Result<(Vec<u32>, String)> {
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut generated: Vec<u32> = Vec::new();
arch.clear_kv_cache()?;
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
let logits = arch.forward(&input, 0)?;
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
for index in 0..max_new.saturating_sub(1) {
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
}
Ok((generated, "length".into()))
}
/// Streaming counterpart to `run_inference`. Emits chunks via `tx` as
/// tokens are generated and exits on EOS, max_new, or receiver drop.
///
/// Detokenization tracks the cumulative decoded prefix so each chunk's
/// `content` delta is the substring appended since the last chunk —
/// safe across BPE byte-fallback boundaries.
#[allow(clippy::too_many_arguments)]
fn run_inference_streaming(
arch: &mut ModelArch,
device: &Device,
tokenizer: &Tokenizer,
prompt_tokens: &[u32],
max_new: usize,
temperature: f64,
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
id: &str,
created: u64,
model_id: &str,
tx: &mpsc::Sender<ChatCompletionChunk>,
) -> Result<()> {
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut all_tokens: Vec<u32> = Vec::new();
let mut decoded_prefix = String::new();
let mut finish_reason = "length".to_string();
arch.clear_kv_cache()?;
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
let logits = arch.forward(&input, 0)?;
let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result<bool> {
let full = tokenizer
.decode(all_tokens, true)
.map_err(|e| anyhow::anyhow!("decode: {e}"))?;
if full.len() > decoded_prefix.len() {
let delta = full[decoded_prefix.len()..].to_string();
*decoded_prefix = full;
let chunk = ChatCompletionChunk {
id: id.into(),
object: "chat.completion.chunk".into(),
created,
model: model_id.into(),
choices: vec![ChunkChoice {
index: 0,
delta: json!({ "content": delta }),
finish_reason: None,
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
// blocking_send returns Err if the consumer hung up — signal
// the caller to stop generating.
if tx.blocking_send(chunk).is_err() {
return Ok(false);
}
}
Ok(true)
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
} else {
all_tokens.push(next_token);
if !emit_token(&all_tokens, &mut decoded_prefix)? {
return Ok(());
}
for index in 0..max_new.saturating_sub(1) {
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
}
all_tokens.push(next_token);
if !emit_token(&all_tokens, &mut decoded_prefix)? {
return Ok(());
}
}
}
let final_chunk = ChatCompletionChunk {
id: id.into(),
object: "chat.completion.chunk".into(),
created,
model: model_id.into(),
choices: vec![ChunkChoice {
index: 0,
delta: serde_json::Value::Object(Default::default()),
finish_reason: Some(finish_reason),
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
let _ = tx.blocking_send(final_chunk);
Ok(())
}
fn unix_now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
fn unix_subsec_nanos() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn check_dense_config_accepts_qwen3() {
let cfg = r#"{
"model_type": "qwen3",
"vocab_size": 151936,
"architectures": ["Qwen3ForCausalLM"]
}"#;
check_dense_config_supported(cfg, "Qwen/Qwen3-1.7B").expect("qwen3 should pass");
}
#[test]
fn check_dense_config_rejects_unsupported_arch_with_clear_message() {
// Use a deliberately-fake model_type so this test stays
// meaningful as the supported set grows. (qwen3_5 was the
// motivating real example but now lives in the supported set
// as a Stage 8c scaffold.)
let cfg = r#"{
"model_type": "fictional_arch_99",
"architectures": ["FictionalArch99ForCausalLM"]
}"#;
let err = check_dense_config_supported(cfg, "Fake/Model-99")
.expect_err("fictional_arch_99 should be rejected");
let msg = format!("{err}");
assert!(
msg.contains("unsupported model_type 'fictional_arch_99'"),
"message should name the rejected type: {msg}"
);
assert!(
msg.contains("Fake/Model-99"),
"message should echo the model id: {msg}"
);
assert!(
msg.contains("qwen3"),
"message should list the supported set: {msg}"
);
}
#[test]
fn check_dense_config_accepts_qwen3_5() {
// Sanity: Stage 8c scaffold means qwen3_5 deserialises into the
// supported set. Forward still bails (covered by tests on the
// architecture module itself), but the dispatch gate must let
// it through.
let cfg = r#"{
"model_type": "qwen3_5",
"architectures": ["Qwen3_5ForConditionalGeneration"],
"text_config": {"hidden_size": 5120}
}"#;
check_dense_config_supported(cfg, "Qwen/Qwen3.6-27B")
.expect("qwen3_5 should be in the supported set as of Stage 8c scaffold");
}
#[test]
fn check_dense_config_rejects_missing_model_type() {
let cfg = r#"{ "vocab_size": 1234 }"#;
let err = check_dense_config_supported(cfg, "anon/no-type")
.expect_err("missing model_type should be rejected");
assert!(
format!("{err}").contains("missing `model_type`"),
"message should call out the missing field"
);
}
#[test]
fn check_dense_config_rejects_invalid_json() {
let err = check_dense_config_supported("not json", "anon/bad-json")
.expect_err("malformed JSON should be rejected");
assert!(
format!("{err:#}").contains("config.json"),
"message should mention config.json"
);
}
}