Files
cortex/crates/neuron/src/harness/candle.rs
rob thijssen 1e13889392
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 34s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m15s
CI / Test (push) Successful in 5m9s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 5m1s
build-prerelease / Package cortex RPM (push) Successful in 1m20s
build-prerelease / Build neuron-blackwell (push) Successful in 11m7s
build-prerelease / Build neuron-ampere (push) Successful in 12m16s
build-prerelease / Build neuron-ada (push) Successful in 12m30s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m47s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m3s
feat(neuron): chunked prefill + VRAM/prompt-length pre-flight checks
Prevents the OOM-during-prefill → poisoned-context → 5-minute-reload
cycle observed on beast under agent-zero workloads. Three changes,
all keyed off env-driven knobs so an operator can tune without a
rebuild:

1. Chunked prefill (NEURON_PREFILL_CHUNK_TOKENS, default 512). The
   initial forward is split into N-token windows, each with a
   monotonically growing offset. KV cache accumulates across chunks
   exactly as it would under one big prefill; only the final chunk's
   logits are kept for sampling. Activation memory now scales with
   chunk size instead of prompt length, so a 13 k-token prompt stops
   holding tens of GB of intermediate activations live at once.

   Wired into all six prefill call sites:
   - run_inference / run_inference_streaming (CPU path)
   - run_inference_via_worker / stream_inference_via_worker (CUDA
     single-GPU through device worker)
   - chat_completion_tp_inner / chat_completion_tp_stream (TP via
     WorkerPool)

   Three helpers — chunked_prefill_local, chunked_prefill_via_worker,
   chunked_prefill_tp — own the loop shape so the chunking semantics
   stay identical across paths. Per-chunk debug log shows progress.

2. Max prompt length (NEURON_MAX_PROMPT_TOKENS, default 16384).
   Requests above the cap return a structured 400 with
   `code: prompt_too_long` rather than going through the prefill and
   discovering the limit by OOMing partway through. New
   InferenceError::PromptTooLong variant.

3. Minimum free VRAM gate (NEURON_MIN_FREE_VRAM_MB, default 1500).
   If `vram_free_mb` is below the threshold at request start (e.g.
   another concurrent request is mid-prefill), reject with a clean
   503 + `code: insufficient_vram` rather than starting work that
   will OOM. New InferenceError::InsufficientVram variant. CPU loads
   (vram=0 sentinel) skip this check.

All three gates fire BEFORE any device work, so a rejected request
costs ~one tokenisation pass and never touches the worker thread —
poison cascades from rejected work are now impossible.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-27 13:46:54 +03:00

3207 lines
128 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Candle harness — in-process inference using huggingface/candle.
//!
//! This is the sole `Harness` implementation. Inference runs inside
//! the neuron process; there is no external subprocess.
//!
//! - Stage 2 wired GGUF (Qwen3 only) load/unload via `quantized_qwen3`.
//! - Stage 3 (this) adds `chat_completion` — a non-streaming OpenAI
//! compatible chat completion routed to the loaded model's forward
//! pass on a per-model serialised generation loop.
use anyhow::{Context, Result};
use async_trait::async_trait;
use candle_core::quantized::gguf_file;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama as llama_dense;
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaWeights;
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE;
use candle_transformers::models::qwen3 as qwen3_dense;
use candle_transformers::models::qwen3_moe as qwen3_moe_dense;
use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec};
use cortex_core::openai::{
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse,
ChatMessage, ChunkChoice, MessageContent, Usage,
};
use serde_json::json;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
#[cfg(feature = "cuda")]
use std::time::Duration;
use std::time::{SystemTime, UNIX_EPOCH};
use tokenizers::Tokenizer;
use tokio::sync::{Mutex, RwLock, mpsc};
use tracing::Instrument;
/// In-process candle harness. Owns the loaded model registry.
pub struct CandleHarness {
models: Arc<RwLock<HashMap<String, LoadedHandle>>>,
hf_cache: Option<PathBuf>,
bind_url: String,
/// One worker thread per CUDA device index that owns its
/// `CudaContext` for the daemon's lifetime. Populated lazily by
/// `ensure_device_worker()` when a model is loaded onto a CUDA
/// device. CPU `Device::Cpu` loads don't get an entry; they have
/// no context to own. Unused on the no-cuda build (the harness
/// can still load on CPU for tests, just without worker threads).
#[allow(dead_code)]
device_workers: Arc<RwLock<HashMap<u32, Arc<super::device_worker::DeviceWorkerHandle>>>>,
}
/// One entry in the harness's loaded-model registry. Single-GPU loads
/// land in `Single`; loads with `tensor_parallel > 1` land in `Tp`.
/// The two variants share the same `model_id` key in the map, so
/// `list_models`, `unload_model`, and `inference_endpoint` can walk
/// them uniformly without branching the storage layout.
///
/// `Clone` is cheap: both variants hold `Arc<_>` and cloning just bumps
/// the refcount.
#[derive(Clone)]
pub enum LoadedHandle {
Single(Arc<LoadedModel>),
#[cfg(feature = "cuda")]
Tp(Arc<TpLoadedModel>),
}
impl LoadedHandle {
pub fn model_id(&self) -> &str {
match self {
LoadedHandle::Single(m) => &m.model_id,
#[cfg(feature = "cuda")]
LoadedHandle::Tp(m) => &m.model_id,
}
}
pub fn devices(&self) -> Vec<u32> {
match self {
LoadedHandle::Single(m) => m.devices.clone(),
#[cfg(feature = "cuda")]
LoadedHandle::Tp(m) => m.devices.clone(),
}
}
/// True if an earlier inference left the device context in an
/// unrecoverable state. Surfaced in `/models` so cortex (and an
/// operator running `curl beast:13131/models`) can see at a glance
/// that the model needs unload+reload.
pub fn is_poisoned(&self) -> bool {
match self {
LoadedHandle::Single(m) => m.poisoned.load(Ordering::Acquire),
#[cfg(feature = "cuda")]
LoadedHandle::Tp(m) => m.poisoned.load(Ordering::Acquire),
}
}
}
/// A loaded model with its tokenizer, device placement, and architecture-
/// specific weights. The `arch` is `Arc<Mutex<>>` so the lock guard can be
/// moved into `spawn_blocking` for synchronous candle forward passes.
pub struct LoadedModel {
pub model_id: String,
/// Local (async-side) handle to the model architecture. `Some`
/// only when the model loaded onto the CPU device (no CUDA
/// available); the inference path then takes this mutex via
/// `spawn_blocking` and runs candle ops on the CPU backend.
/// `None` when the model loaded onto a CUDA device — in that case
/// the architecture lives in the worker thread's slab and is
/// addressed via [`Self::arch_handle`].
pub arch: Option<Arc<Mutex<ModelArch>>>,
pub tokenizer: Tokenizer,
pub device: Device,
pub quant: Option<String>,
pub devices: Vec<u32>,
/// Set to `true` after any forward / kv-cache call fails. A CUDA
/// driver error (OOM, illegal address) leaves the device's context
/// in an unrecoverable state — subsequent kernels can hang, return
/// garbage, or hit another illegal address. The harness refuses
/// further inference against a poisoned model and reports a clear
/// error so an operator knows to unload+reload to recover. See
/// the 2026-05-26 beast incident where a 14k-token prefill OOM
/// silently turned every subsequent request into a stuck wait.
pub poisoned: AtomicBool,
/// Handle to the per-device CUDA worker thread for this model's
/// device. `None` for CPU loads (no context to own). VRAM queries
/// and — for CUDA loads — forward / kv-cache / drop ops route
/// through this handle so the device's CUDA context stays bound
/// to one OS thread for the daemon's lifetime.
pub worker: Option<Arc<super::device_worker::DeviceWorkerHandle>>,
/// Index into the worker's `ModelArch` slab. `Some` iff the model
/// loaded onto a CUDA device and was successfully transferred to
/// the worker; in that case [`Self::arch`] is `None`. The two
/// fields are mutually exclusive.
pub arch_handle: Option<super::device_worker::ArchHandle>,
}
impl LoadedModel {
/// Free / total VRAM on this model's device in MiB. Routes the
/// query through the device worker thread (where the CUDA context
/// is already bound) rather than rebinding on whatever tokio
/// thread the caller happens to be on. Returns `(0, 0)` on CPU
/// loads, or if the worker is gone / poisoned / the cudarc call
/// itself failed — same sentinel the previous `device_vram_mb`
/// helper returned, so log field values stay comparable.
pub async fn query_vram(&self) -> (u64, u64) {
match &self.worker {
Some(w) => w.query_vram().await.unwrap_or((0, 0)),
None => (0, 0),
}
}
}
/// Tensor-parallel loaded model. Holds the leader's rank-0 shard
/// (which the inference loop drives via spawn_blocking) and the
/// `WorkerPool` (which drives every non-zero rank over the RPC
/// channel). Both are behind tokio Mutexes so concurrent inference
/// requests against the same model are serialised; concurrent loads
/// for *different* models would each have their own pool.
#[cfg(feature = "cuda")]
pub struct TpLoadedModel {
pub model_id: String,
pub tokenizer: Tokenizer,
pub devices: Vec<u32>,
/// One end-to-end gate: the pool's RPC stream to the subprocess
/// workers isn't safe to use concurrently. After Phase 3 the
/// leader's `TpLeaderModel` lives in the worker thread's slab,
/// so this Mutex no longer covers the leader's KV cache; it just
/// serialises subprocess RPC traffic on the pool's
/// `Vec<Worker>` channels.
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
/// Handle into the leader device worker's TP slab. The boxed
/// `TpLeaderModel` (with its embedded `Arc<Comm>` clones and
/// per-rank CUDA tensors) lives on the worker thread; we hold an
/// opaque index. Forward / clear_kv / unload all route through
/// `Job::Tp*` against this handle.
pub leader_handle: super::device_worker::TpHandle,
/// Candle device for rank 0. Mirrors what
/// `TpLeaderModel::device()` would return, kept on the struct so
/// the request path can name the device without an RPC.
pub leader_device: Device,
/// Same poisoning gate as [`LoadedModel::poisoned`]. A TP forward
/// failure (CUDA OOM on any rank, NCCL desync, illegal address) is
/// terminal: the leader's and workers' CUDA contexts cannot be
/// reliably reset without restarting the worker subprocesses.
pub poisoned: AtomicBool,
/// Worker thread for the leader's CUDA device. Owns the leader's
/// `CudaContext`, `NcclState`, and the boxed `TpLeaderModel`
/// referenced by `leader_handle`.
pub worker: Arc<super::device_worker::DeviceWorkerHandle>,
}
#[cfg(feature = "cuda")]
impl TpLoadedModel {
/// Free / total VRAM on the leader's device in MiB. See
/// [`LoadedModel::query_vram`] for rationale and sentinel
/// semantics — same pattern, TP just always has a worker because
/// the harness rejects TP without CUDA at load time.
pub async fn query_vram(&self) -> (u64, u64) {
self.worker.query_vram().await.unwrap_or((0, 0))
}
}
/// Architecture-specific weights. Each variant covers one (family,
/// source-format) pair; the dense variants take the safetensors path
/// and the `Quantized*` variants take the GGUF path.
///
/// TP currently only works through `Qwen3Dense` (see `tp_qwen3.rs`);
/// every other variant is single-GPU. Quantized variants can't shard
/// across GPUs at all — slicing GGUF super-blocks is intractable —
/// and the new dense families (Llama, Qwen3 MoE) lack their own
/// TP-aware modules yet.
pub enum ModelArch {
// Qwen3 family
Qwen3Quantized(QuantizedQwen3Weights),
Qwen3Dense(qwen3_dense::ModelForCausalLM),
Qwen3MoeQuantized(GGUFQWenMoE),
Qwen3MoeDense(qwen3_moe_dense::ModelForCausalLM),
// Llama family (covers Llama 1/2/3/3.1/3.3). Boxed because the
// wrapper carries an inline Cache + Config — without indirection
// the enum's `LlamaDense` variant is several hundred bytes larger
// than the others (clippy::large_enum_variant).
LlamaQuantized(QuantizedLlamaWeights),
LlamaDense(Box<LlamaDense>),
// Qwen3-Next family (model_type "qwen3_5") — Qwen3.6's
// architecture. Stage 8c scaffolding only: dispatch + config parse
// are real; forward bails "not implemented yet". See
// `arch/qwen3_5.rs` for the open architecture work.
Qwen3_5Dense(super::arch::qwen3_5::Qwen3_5ForCausalLM),
}
impl ModelArch {
/// One forward step on this arch with the rank-1 vocab logits
/// extracted. Hides per-family shape differences (some return
/// `[B, V]`, others `[B, 1, V]`) — every caller gets a `[V]`
/// tensor ready for sampling.
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
let raw = match self {
ModelArch::Qwen3Quantized(m) => m.forward(input, offset)?,
ModelArch::Qwen3Dense(m) => m.forward(input, offset)?,
ModelArch::Qwen3MoeQuantized(m) => m.forward(input, offset)?,
ModelArch::Qwen3MoeDense(m) => m.forward(input, offset)?,
ModelArch::LlamaQuantized(m) => m.forward(input, offset)?,
ModelArch::LlamaDense(m) => m.forward(input, offset)?,
ModelArch::Qwen3_5Dense(m) => m.forward(input, offset)?,
};
squeeze_to_vocab(&raw)
}
/// Reset the KV cache before each new request so we don't attend
/// over a previous request's tokens. Some architectures have an
/// in-place reset; Llama needs a Cache rebuild (held inline in
/// the wrapper).
pub fn clear_kv_cache(&mut self) -> Result<()> {
match self {
ModelArch::Qwen3Quantized(_) => Ok(()), /* keeps cache by design;
* forward() handles offset */
ModelArch::Qwen3Dense(m) => {
m.clear_kv_cache();
Ok(())
}
ModelArch::Qwen3MoeQuantized(_) => Ok(()),
ModelArch::Qwen3MoeDense(m) => {
m.clear_kv_cache();
Ok(())
}
ModelArch::LlamaQuantized(_) => Ok(()),
ModelArch::LlamaDense(m) => m.clear_kv_cache(),
ModelArch::Qwen3_5Dense(m) => {
m.clear_kv_cache();
Ok(())
}
}
}
}
/// Squeeze any leading singleton dims off the logits tensor so the
/// caller gets a rank-1 `[vocab_size]` slice ready for sampling. Bails
/// on a non-singleton leading dim (would mean a batched forward, which
/// no caller emits today).
fn squeeze_to_vocab(t: &Tensor) -> Result<Tensor> {
let mut t = t.clone();
while t.dims().len() > 1 {
if t.dims()[0] != 1 {
anyhow::bail!(
"logits expected to start with a singleton dim, got shape {:?}",
t.dims()
);
}
t = t.squeeze(0)?;
}
Ok(t)
}
/// Llama dense wrapper. Bundles candle's `Llama` model with its
/// externally-managed `Cache` plus enough config to rebuild the
/// cache on `clear_kv_cache` (Llama's Cache doesn't expose a reset).
pub struct LlamaDense {
model: llama_dense::Llama,
cache: llama_dense::Cache,
config: llama_dense::Config,
dtype: DType,
device: Device,
}
impl LlamaDense {
/// Constructor used by the dispatch-side loader. Keeps the field
/// names private while letting the worker thread build a
/// `LlamaDense` from already-loaded weights without going through
/// async candle code.
pub(crate) fn from_parts(
model: llama_dense::Llama,
cache: llama_dense::Cache,
config: llama_dense::Config,
dtype: DType,
device: Device,
) -> Self {
Self {
model,
cache,
config,
dtype,
device,
}
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
Ok(self.model.forward(input, offset, &mut self.cache)?)
}
pub fn clear_kv_cache(&mut self) -> Result<()> {
self.cache = llama_dense::Cache::new(true, self.dtype, &self.config, &self.device)
.context("rebuild Llama Cache for new request")?;
Ok(())
}
}
/// Repetition penalty applied to recently-generated tokens before
/// sampling. 1.0 disables it; >1.0 makes recently-emitted tokens less
/// likely. mistral.rs and llama.cpp default to 1.1, which is enough to
/// stop small quantized models from degenerating into "Wait, no, no..."
/// loops without distorting normal output.
const REPEAT_PENALTY: f32 = 1.1;
/// Number of recently-generated tokens to feed into the repetition
/// penalty. Matches the candle quantized-qwen3 example default.
const REPEAT_LAST_N: usize = 64;
/// Architectures the dense safetensors path can construct. Keep
/// alphabetical; one entry per supported `config.json#/model_type`
/// value. New entries land alongside a new `ModelArch` variant + a
/// dispatch branch in `load_arch_dense` (plus, for TP, a parallel
/// pattern in `tp_qwen3.rs`).
const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["llama", "qwen3", "qwen3_5", "qwen3_moe"];
/// Pre-flight check the operator's `config.json` against the set of
/// architectures the dense path actually knows how to build. Surfaces
/// architecture mismatches as a single clean error before the serde
/// deserializer trips on missing fields — the latter happens because
/// every architecture has different hyperparameter names, so when the
/// JSON is e.g. Qwen3.6 wrapped under `text_config: {...}`, candle's
/// `qwen3::Config` finds none of its expected top-level fields and
/// fails with a cryptic `missing field 'vocab_size' at line N col 1`.
///
/// The result message names the model_type we saw, the supported set,
/// and points at the files an operator (or future contributor) needs
/// to touch to grow the supported set.
pub(crate) fn check_dense_config_supported(config_json: &str, model_id: &str) -> Result<()> {
let v: serde_json::Value = serde_json::from_str(config_json)
.with_context(|| format!("parse config.json for '{model_id}' as JSON"))?;
let model_type = v.get("model_type").and_then(|x| x.as_str()).unwrap_or("");
if model_type.is_empty() {
anyhow::bail!(
"config.json for '{model_id}' is missing `model_type`; the dense \
path needs it to gate architecture support (supported: {:?})",
DENSE_SUPPORTED_MODEL_TYPES
);
}
if DENSE_SUPPORTED_MODEL_TYPES.contains(&model_type) {
return Ok(());
}
// Bonus context: the model usually also lists architectures, which
// is what `transformers` keys on. Including it makes the error
// self-contained.
let architectures = v
.get("architectures")
.and_then(|x| x.as_array())
.map(|a| {
a.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect::<Vec<_>>()
})
.unwrap_or_default();
anyhow::bail!(
"unsupported model_type '{model_type}' for '{model_id}' \
(architectures={architectures:?}); the dense path supports {:?}. \
Add a `ModelArch` variant + load/forward branches in \
crates/neuron/src/harness/candle.rs (and the TP analogue in \
tp_qwen3.rs) to extend coverage.",
DENSE_SUPPORTED_MODEL_TYPES
);
}
/// Architectures the TP path can actually load and run. A subset of
/// `DENSE_SUPPORTED_MODEL_TYPES` — the single-GPU path supports more
/// families than the TP path because each TP-aware module is a real
/// chunk of work (`tp_qwen3.rs` is the only one shipped today).
#[cfg(feature = "cuda")]
const TP_SUPPORTED_MODEL_TYPES: &[&str] = &["qwen3", "qwen3_5"];
/// TP-side counterpart to `check_dense_config_supported`. Gates the
/// `load_tp` path on a narrower architecture set: even though the
/// single-GPU dense path knows how to build a Llama model, the worker
/// pool's `load_dense_shard` reconstructs the config as Qwen3 — there
/// is no `tp_llama.rs` yet. Surfacing this as a config-time error
/// (before we spawn workers and burn NCCL handshake cost) is much
/// kinder than the inevitable per-rank deserialise failure.
#[cfg(feature = "cuda")]
fn check_tp_arch_supported(config_json: &str, model_id: &str) -> Result<()> {
let v: serde_json::Value = serde_json::from_str(config_json)
.with_context(|| format!("parse config.json for '{model_id}' as JSON"))?;
let model_type = v.get("model_type").and_then(|x| x.as_str()).unwrap_or("");
if TP_SUPPORTED_MODEL_TYPES.contains(&model_type) {
return Ok(());
}
anyhow::bail!(
"tensor_parallel requested for '{model_id}' (model_type='{model_type}') but \
the TP path supports only {TP_SUPPORTED_MODEL_TYPES:?}. Adding a new \
TP-aware architecture needs a `harness/tp/tp_<family>.rs` module mirroring \
`tp_qwen3.rs` (sharded linears, AllReduce, per-rank head counts) and a \
dispatch in `WorkerPool::load_dense_shard`. For models that fit on one \
GPU, drop `tensor_parallel` to use the single-GPU dense path."
)
}
/// Resolve the effective HuggingFace cache directory for the candle
/// harness. Precedence (first hit wins):
///
/// 1. Explicit `hf_cache` from `[harness.candle]` in `neuron.toml`.
/// Operator's wishes always win.
/// 2. `HF_HUB_CACHE` env var. The Python `huggingface_hub` library
/// points at the cache root directly with this var; the Rust
/// `hf-hub` crate doesn't read it natively, so we bridge here.
/// Honouring it lets a neuron host share a cache directory with
/// Python tooling and other harnesses without per-tool config.
/// 3. `HF_HOME` env var. Canonical HuggingFace base directory; the
/// cache lives at `$HF_HOME/hub`. Hf-hub respects this on its own,
/// but we resolve it here too so the resulting path shows up in
/// logs alongside the explicit/HF_HUB_CACHE cases.
/// 4. `None`. Falls through to `hf-hub`'s default
/// (`~/.cache/huggingface/hub`).
fn resolve_hf_cache(explicit: Option<PathBuf>) -> Option<PathBuf> {
if let Some(p) = explicit {
return Some(p);
}
if let Ok(v) = std::env::var("HF_HUB_CACHE")
&& !v.is_empty()
{
return Some(PathBuf::from(v));
}
if let Ok(v) = std::env::var("HF_HOME")
&& !v.is_empty()
{
return Some(PathBuf::from(v).join("hub"));
}
None
}
/// Summary stats over a 1-D logits tensor, used for the failure log
/// when sampling rejects the distribution. Gathers nan/inf/negative
/// counts and finite min/max/mean — enough to distinguish a NaN
/// cascade (all-NaN, typical of softmax overflow propagating) from
/// an Inf at a single position (numerical edge case) from negative
/// weights (different bug entirely).
///
/// Computed only on the failure path, so the to_vec1 copy cost is
/// paid at most once per poisoned model.
#[derive(Debug)]
#[allow(dead_code)]
struct LogitsHealth {
len: usize,
nan: usize,
pos_inf: usize,
neg_inf: usize,
neg: usize,
finite_min: Option<f32>,
finite_max: Option<f32>,
finite_mean: Option<f32>,
}
#[allow(dead_code)]
fn logits_health(t: &Tensor) -> LogitsHealth {
let values: Vec<f32> = match t
.to_dtype(candle_core::DType::F32)
.and_then(|t| t.flatten_all())
.and_then(|t| t.to_vec1::<f32>())
{
Ok(v) => v,
Err(_) => {
return LogitsHealth {
len: 0,
nan: 0,
pos_inf: 0,
neg_inf: 0,
neg: 0,
finite_min: None,
finite_max: None,
finite_mean: None,
};
}
};
logits_health_slice(&values)
}
/// Same diagnostic as [`logits_health`] but operates directly on a
/// `[f32]` slice. Used by the worker-routed inference paths where the
/// device → host copy has already happened on the worker thread and
/// the async caller has the values in hand. Avoids the round-trip of
/// rebuilding a Tensor just to call to_vec1 again.
#[allow(dead_code)]
fn logits_health_slice(values: &[f32]) -> LogitsHealth {
let mut nan = 0usize;
let mut pos_inf = 0usize;
let mut neg_inf = 0usize;
let mut neg = 0usize;
let mut finite_min = f32::INFINITY;
let mut finite_max = f32::NEG_INFINITY;
let mut finite_sum = 0.0_f64;
let mut finite_count = 0usize;
for &v in values {
if v.is_nan() {
nan += 1;
} else if v == f32::INFINITY {
pos_inf += 1;
} else if v == f32::NEG_INFINITY {
neg_inf += 1;
} else {
if v < 0.0 {
neg += 1;
}
if v < finite_min {
finite_min = v;
}
if v > finite_max {
finite_max = v;
}
finite_sum += v as f64;
finite_count += 1;
}
}
let finite_mean = if finite_count > 0 {
Some((finite_sum / finite_count as f64) as f32)
} else {
None
};
LogitsHealth {
len: values.len(),
nan,
pos_inf,
neg_inf,
neg,
finite_min: (finite_count > 0).then_some(finite_min),
finite_max: (finite_count > 0).then_some(finite_max),
finite_mean,
}
}
/// Build the InferenceError reported to a client when their request
/// hits a model that's been marked poisoned by an earlier driver
/// failure. The message names the model and the recovery procedure so
/// the operator doesn't have to chase the original failure to know
/// what to do.
fn poisoned_error(model_id: &str) -> InferenceError {
InferenceError::Other(anyhow::anyhow!(
"model '{model_id}' is in a poisoned state \
(an earlier inference hit a CUDA driver error and the device \
context cannot be safely reused); unload and reload the model \
to recover"
))
}
/// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if
/// the query fails or the device is the CPU fallback so logging never
/// crashes the request path. Mirrors the existing helper in
/// `tp_qwen3_5.rs`; kept separate to avoid coupling the inference path
/// to the TP-specific module.
#[cfg(feature = "cuda")]
fn device_vram_mb(device: &Device) -> (u64, u64) {
use candle_core::cuda::cudarc::driver::result;
use candle_core::cuda_backend::WrapErr;
let Device::Cuda(dev) = device else {
return (0, 0);
};
let Ok(()) = dev.cuda_stream().context().bind_to_thread().w() else {
return (0, 0);
};
match result::mem_get_info() {
Ok((free, total)) => (
(free / (1024 * 1024)) as u64,
(total / (1024 * 1024)) as u64,
),
Err(_) => (0, 0),
}
}
#[cfg(not(feature = "cuda"))]
#[allow(dead_code)]
fn device_vram_mb(_device: &Device) -> (u64, u64) {
(0, 0)
}
/// A short hex tag used to group every log line emitted on behalf of
/// one chat-completion request. Six hex digits is unique enough across
/// a 4-hour journal window (24 bits ≈ 16M values, while a busy neuron
/// sees ~10³ requests/hour) and fits cleanly inside `req_id=…` in the
/// fmt subscriber's span-prefix output.
fn new_req_id() -> String {
format!("{:06x}", unix_subsec_nanos() & 0xFFFFFF)
}
/// Read a positive `usize` from `name` in the process env, falling back
/// to `default` if unset or unparseable. Used for runtime tuning knobs
/// that we want operators to be able to adjust without a recompile.
fn env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.filter(|v: &usize| *v > 0)
.unwrap_or(default)
}
/// Same as [`env_usize`] but for `u64`.
fn env_u64(name: &str, default: u64) -> u64 {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
/// Prefill chunk size in tokens. The initial forward over a long prompt
/// is split into windows of this many tokens, each with a monotonically
/// growing offset, so activation memory is bounded by chunk × layers ×
/// hidden instead of prompt × layers × hidden. The default (512) keeps
/// activation peaks under ~1 GiB on a 27B Qwen-class model while
/// keeping the per-step overhead negligible vs. one big prefill.
fn prefill_chunk_tokens() -> usize {
env_usize("NEURON_PREFILL_CHUNK_TOKENS", 512)
}
/// Maximum allowed prompt length, in tokens. Requests above this are
/// rejected with [`InferenceError::PromptTooLong`] before any device
/// work — this is the explicit upper bound on context size, separate
/// from the model's `max_position_embeddings` (which can be much
/// larger than what fits in VRAM in practice).
fn max_prompt_tokens() -> usize {
env_usize("NEURON_MAX_PROMPT_TOKENS", 16384)
}
/// Minimum free VRAM (MiB) required to even attempt a prefill. Requests
/// below this are rejected with [`InferenceError::InsufficientVram`]
/// before any device work. Acts as a backstop when concurrent requests
/// have eaten the headroom; intentionally conservative — a request
/// that gets past this can still OOM, but the rejection is a clean 503
/// rather than a poisoned context.
fn min_free_vram_mb() -> u64 {
env_u64("NEURON_MIN_FREE_VRAM_MB", 1500)
}
/// Pre-flight check: reject the request if the prompt exceeds the
/// configured max, or if there isn't enough free VRAM to safely start a
/// prefill. Called from every chat_completion entry point right after
/// the VRAM query. A `prompt_len == 0` is accepted (some clients send
/// empty inputs to probe the endpoint); the prefill loop handles it.
fn validate_request(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> {
let max = max_prompt_tokens();
if prompt_len > max {
return Err(InferenceError::PromptTooLong { prompt_len, max });
}
// VRAM check is skipped on CPU loads (vram_free_mb == 0 sentinel)
// because the (0, 0) reply from `query_vram` is also what a missing
// worker returns. The CPU path has no per-GPU memory limit anyway —
// host RAM is bounded by the OOM killer, not this check.
let min = min_free_vram_mb();
if vram_free_mb != 0 && vram_free_mb < min {
return Err(InferenceError::InsufficientVram {
free_mb: vram_free_mb,
required_mb: min,
});
}
Ok(())
}
/// Threshold above which `pool.lock().await` blocking is interesting
/// enough to warn about. Healthy concurrent requests serialise behind
/// the pool in single-digit ms — anything past 2 seconds is either a
/// huge in-flight prompt or, more often, a stuck request holding the
/// lock against a poisoned CUDA context. See the 2026-05-26 4-hour
/// silence on beast where dozens of requests piled up invisibly here.
#[cfg(feature = "cuda")]
const POOL_LOCK_WARN_THRESHOLD: Duration = Duration::from_secs(2);
/// Acquire the TP pool lock, emitting a warn-level breadcrumb if the
/// wait exceeds [`POOL_LOCK_WARN_THRESHOLD`]. Wrapped in a helper so
/// the warn happens at the call site — the request whose lock-wait is
/// slow is the one that knows its prompt_len and other context.
#[cfg(feature = "cuda")]
async fn acquire_pool_lock<'a>(
pool: &'a tokio::sync::Mutex<super::tp::WorkerPool>,
model_id: &str,
) -> tokio::sync::MutexGuard<'a, super::tp::WorkerPool> {
let start = std::time::Instant::now();
// Tick once at the threshold so a stuck request shows up in
// journalctl even while it's still waiting. Without this the wait
// looks like silence in the log right up until the lock is freed.
tokio::pin! {
let lock = pool.lock();
}
loop {
tokio::select! {
guard = &mut lock => {
let elapsed = start.elapsed();
if elapsed >= POOL_LOCK_WARN_THRESHOLD {
tracing::warn!(
model = %model_id,
waited_ms = elapsed.as_millis(),
"TP chat_completion: pool lock acquired after long wait"
);
}
return guard;
}
_ = tokio::time::sleep(POOL_LOCK_WARN_THRESHOLD) => {
tracing::warn!(
model = %model_id,
waited_ms = start.elapsed().as_millis(),
"TP chat_completion: still waiting on pool lock"
);
}
}
}
}
/// Apply the repetition penalty (if any) to the prediction logits and
/// then sample. Centralises the prefill / generation-loop call sites
/// so they share identical sampling behaviour.
fn sample_with_penalty(
logits: &Tensor,
history: &[u32],
logits_processor: &mut LogitsProcessor,
) -> Result<u32> {
let penalised = if (REPEAT_PENALTY - 1.0).abs() < f32::EPSILON || history.is_empty() {
logits.clone()
} else {
let start = history.len().saturating_sub(REPEAT_LAST_N);
candle_transformers::utils::apply_repeat_penalty(logits, REPEAT_PENALTY, &history[start..])?
};
Ok(logits_processor.sample(&penalised)?)
}
/// Chunked prefill against an in-process [`ModelArch`]. Splits
/// `prompt_tokens` into [`prefill_chunk_tokens()`]-sized windows, runs
/// each through `arch.forward(chunk, offset)` with a monotonically
/// growing offset, and returns the last chunk's logits ready for
/// sampling. Bounds activation memory to O(chunk × layers × hidden)
/// instead of O(prompt × layers × hidden); the KV cache grows
/// monotonically so the model sees the full prompt at the final chunk.
fn chunked_prefill_local(
arch: &mut ModelArch,
device: &Device,
prompt_tokens: &[u32],
) -> Result<Tensor> {
let prompt_len = prompt_tokens.len();
if prompt_len == 0 {
anyhow::bail!("chunked_prefill_local: empty prompt");
}
let chunk_size = prefill_chunk_tokens();
let mut offset = 0;
let mut last_logits: Option<Tensor> = None;
while offset < prompt_len {
let end = (offset + chunk_size).min(prompt_len);
let chunk = &prompt_tokens[offset..end];
let input = Tensor::new(chunk, device)?.unsqueeze(0)?;
let logits = arch.forward(&input, offset)?;
if end == prompt_len {
last_logits = Some(logits);
}
offset = end;
}
last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_local: no chunks produced"))
}
/// Chunked prefill via the per-device worker. Same shape as
/// [`chunked_prefill_local`] but the forward runs on the worker thread
/// and replies with a CPU-side `Vec<f32>` of logits at the final
/// chunk's last position. Tensors never escape the worker.
#[cfg(feature = "cuda")]
async fn chunked_prefill_via_worker(
worker: &super::device_worker::DeviceWorkerHandle,
handle: super::device_worker::ArchHandle,
prompt_tokens: &[u32],
) -> Result<Vec<f32>> {
let prompt_len = prompt_tokens.len();
if prompt_len == 0 {
anyhow::bail!("chunked_prefill_via_worker: empty prompt");
}
let chunk_size = prefill_chunk_tokens();
let mut offset = 0;
let mut last_logits: Option<Vec<f32>> = None;
let total_chunks = prompt_len.div_ceil(chunk_size);
let mut chunk_idx = 0_usize;
while offset < prompt_len {
let end = (offset + chunk_size).min(prompt_len);
let chunk = prompt_tokens[offset..end].to_vec();
let chunk_len = chunk.len();
let step_start = std::time::Instant::now();
let logits = worker
.forward_logits(handle, chunk, offset)
.await
.map_err(|e| anyhow::anyhow!("prefill chunk {chunk_idx}/{total_chunks}: {e}"))?;
tracing::debug!(
chunk_idx,
total_chunks,
chunk_len,
offset,
elapsed_ms = step_start.elapsed().as_millis(),
"chunked prefill (worker): chunk done"
);
if end == prompt_len {
last_logits = Some(logits);
}
offset = end;
chunk_idx += 1;
}
last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_via_worker: no chunks produced"))
}
/// Chunked prefill via the TP `WorkerPool`. Same shape as
/// [`chunked_prefill_via_worker`] but the forward fans out to every
/// rank via `pool.generate_step`. Returns the leader's CPU-side
/// `Vec<f32>` of logits at the final chunk's last position.
#[cfg(feature = "cuda")]
async fn chunked_prefill_tp(
pool: &mut super::tp::WorkerPool,
model_id: &str,
leader_handle: super::device_worker::TpHandle,
prompt_tokens: &[u32],
) -> Result<Vec<f32>> {
let prompt_len = prompt_tokens.len();
if prompt_len == 0 {
anyhow::bail!("chunked_prefill_tp: empty prompt");
}
let chunk_size = prefill_chunk_tokens();
let mut offset = 0;
let mut last_logits: Option<Vec<f32>> = None;
let total_chunks = prompt_len.div_ceil(chunk_size);
let mut chunk_idx = 0_usize;
while offset < prompt_len {
let end = (offset + chunk_size).min(prompt_len);
let chunk = prompt_tokens[offset..end].to_vec();
let chunk_len = chunk.len();
let step_start = std::time::Instant::now();
let logits = pool
.generate_step(model_id, leader_handle, chunk, offset)
.await
.map_err(|e| anyhow::anyhow!("TP prefill chunk {chunk_idx}/{total_chunks}: {e}"))?;
tracing::debug!(
chunk_idx,
total_chunks,
chunk_len,
offset,
elapsed_ms = step_start.elapsed().as_millis(),
"chunked prefill (TP): chunk done"
);
if end == prompt_len {
last_logits = Some(logits);
}
offset = end;
chunk_idx += 1;
}
last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_tp: no chunks produced"))
}
impl CandleHarness {
pub fn new(bind_url: String, hf_cache: Option<PathBuf>) -> Self {
let hf_cache = resolve_hf_cache(hf_cache);
if let Some(p) = &hf_cache {
tracing::info!(path = %p.display(), "candle harness using HuggingFace cache");
}
Self {
models: Arc::new(RwLock::new(HashMap::new())),
hf_cache,
bind_url,
device_workers: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Pick a candle `Device` for the requested indices. Without the
/// `cuda` feature, or if CUDA initialisation fails, falls back to CPU.
fn pick_device(devices: &[u32]) -> Result<Device> {
let _idx = devices.first().copied().unwrap_or(0) as usize;
#[cfg(feature = "cuda")]
{
match Device::new_cuda(_idx) {
Ok(d) => return Ok(d),
Err(e) => tracing::warn!(
device = _idx,
error = %e,
"CUDA device unavailable, falling back to CPU"
),
}
}
Ok(Device::Cpu)
}
/// Return the worker handle for `device_index`, spawning it on
/// first request. The handle is cached on `self` so subsequent
/// loads against the same device share the same thread. Used to
/// populate `LoadedModel::worker` and `TpLoadedModel::worker` at
/// load time; in later refactor phases the worker also owns the
/// `ModelArch` and `TpLeaderModel` slabs.
#[allow(dead_code)]
async fn ensure_device_worker(
&self,
device_index: u32,
) -> Result<Arc<super::device_worker::DeviceWorkerHandle>> {
{
let workers = self.device_workers.read().await;
if let Some(w) = workers.get(&device_index) {
return Ok(Arc::clone(w));
}
}
// Write-lock acquired separately so the read path stays cheap.
// The `get` is repeated under the write lock to handle the
// race where two loads against a fresh device land here at
// once — the second caller sees the first's insertion and
// skips the second spawn.
let mut workers = self.device_workers.write().await;
if let Some(w) = workers.get(&device_index) {
return Ok(Arc::clone(w));
}
let handle = super::device_worker::DeviceWorkerHandle::spawn(device_index)
.with_context(|| format!("spawn device worker for cuda:{device_index}"))?;
workers.insert(device_index, Arc::clone(&handle));
tracing::info!(device_index, "spawned device worker");
Ok(handle)
}
/// Build an hf-hub API client pre-configured with the harness's
/// `hf_cache` (when one is set).
fn hf_api(&self) -> Result<hf_hub::api::tokio::Api> {
let mut builder = hf_hub::api::tokio::ApiBuilder::new();
if let Some(cache) = &self.hf_cache {
builder = builder.with_cache_dir(cache.clone());
}
builder.build().context("build hf-hub API")
}
/// Resolve a dense (bf16/fp16 safetensors) model to its local file
/// paths.
///
/// Handles both sharded repos (`model.safetensors.index.json` plus
/// several `model-*.safetensors`) and the single-file layout
/// (`model.safetensors`). Returns the safetensors paths in
/// arbitrary order — `VarBuilder` unifies them into one tensor view.
async fn resolve_dense_files(
&self,
spec: &ModelSpec,
) -> Result<(PathBuf, PathBuf, Vec<PathBuf>)> {
let api = self.hf_api()?;
let repo = api.model(spec.model_id.clone());
let config_path = repo
.get("config.json")
.await
.with_context(|| format!("fetch config.json from {}", spec.model_id))?;
let tokenizer_path = repo
.get("tokenizer.json")
.await
.with_context(|| format!("fetch tokenizer.json from {}", spec.model_id))?;
// Prefer the sharded layout (most HF dense models > 5B ship it).
let safetensors_paths = match repo.get("model.safetensors.index.json").await {
Ok(index_path) => {
let index_text = std::fs::read_to_string(&index_path)
.context("read model.safetensors.index.json")?;
let index: serde_json::Value = serde_json::from_str(&index_text)
.context("parse model.safetensors.index.json")?;
let weight_map = index
.get("weight_map")
.and_then(|v| v.as_object())
.ok_or_else(|| {
anyhow::anyhow!("safetensors index missing weight_map object")
})?;
let unique: std::collections::BTreeSet<String> = weight_map
.values()
.filter_map(|v| v.as_str().map(String::from))
.collect();
let mut paths = Vec::with_capacity(unique.len());
for fname in unique {
let p = repo
.get(&fname)
.await
.with_context(|| format!("fetch sharded safetensors {fname}"))?;
paths.push(p);
}
paths
}
Err(_) => {
// Single-file fallback.
let p = repo
.get("model.safetensors")
.await
.context("fetch model.safetensors (single-file layout)")?;
vec![p]
}
};
Ok((config_path, tokenizer_path, safetensors_paths))
}
/// Resolve + load a GGUF (pre-quantized) Qwen3. Returns the
/// tokenizer.json path so the caller can construct the Tokenizer
/// uniformly across source formats.
async fn load_arch_gguf(
&self,
spec: &ModelSpec,
device: &Device,
) -> Result<(PathBuf, ModelArch)> {
let (gguf_path, tokenizer_path) = self.resolve_files(spec).await?;
let device_for_load = device.clone();
let gguf_path_for_load = gguf_path.clone();
let model_id_for_log = spec.model_id.clone();
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
tracing::info!(model = %model_id_for_log, path = ?gguf_path_for_load, "loading GGUF");
let mut file = std::fs::File::open(&gguf_path_for_load).context("open GGUF file")?;
let content = gguf_file::Content::read(&mut file)
.map_err(|e| anyhow::anyhow!("parse GGUF: {e}"))?;
let architecture = content
.metadata
.get("general.architecture")
.and_then(|v| v.to_string().ok().cloned())
.unwrap_or_default();
tracing::info!(architecture = %architecture, "GGUF architecture");
// The `general.architecture` GGUF metadata key follows
// llama.cpp conventions (lowercase, no underscores in some
// cases) — `qwen3moe`, not `qwen3_moe`.
match architecture.as_str() {
"qwen3" => {
let weights =
QuantizedQwen3Weights::from_gguf(content, &mut file, &device_for_load)
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
Ok(ModelArch::Qwen3Quantized(weights))
}
"qwen3moe" => {
// GGUFQWenMoE takes an explicit compute dtype
// alongside the device — F16 matches the GGUF
// weights' typical accumulation precision and
// gives the best tokens/sec on consumer cards.
let weights =
GGUFQWenMoE::from_gguf(content, &mut file, &device_for_load, DType::F16)
.map_err(|e| anyhow::anyhow!("from_gguf qwen3_moe: {e}"))?;
Ok(ModelArch::Qwen3MoeQuantized(weights))
}
"llama" => {
let weights =
QuantizedLlamaWeights::from_gguf(content, &mut file, &device_for_load)
.map_err(|e| anyhow::anyhow!("from_gguf llama: {e}"))?;
Ok(ModelArch::LlamaQuantized(weights))
}
other => anyhow::bail!(
"unsupported GGUF architecture '{other}'; quantized path supports \
qwen3, qwen3moe, llama"
),
}
})
.await
.context("blocking GGUF load task panicked")??;
Ok((tokenizer_path, arch))
}
/// Resolve + load a dense Qwen3 from safetensors. Uses
/// `candle-transformers::models::qwen3::ModelForCausalLM` and
/// builds a VarBuilder over the mmap'd safetensors files. dtype
/// is bf16 by default to match the HF distribution dtype for
/// recent Qwen3 family models; fall back to f16 if the device
/// doesn't support bf16.
async fn load_arch_dense(
&self,
spec: &ModelSpec,
device: &Device,
) -> Result<(PathBuf, ModelArch)> {
let (config_path, tokenizer_path, safetensors_paths) =
self.resolve_dense_files(spec).await?;
let device_for_load = device.clone();
let model_id_for_log = spec.model_id.clone();
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
let cfg_text = std::fs::read_to_string(&config_path).context("read config.json")?;
check_dense_config_supported(&cfg_text, &model_id_for_log)?;
// Peek at model_type to choose the family before the
// typed deserialize — each family has its own Config.
let model_type = serde_json::from_str::<serde_json::Value>(&cfg_text)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
tracing::info!(
model = %model_id_for_log,
model_type = %model_type,
shards = safetensors_paths.len(),
"loading dense model from safetensors"
);
// bf16 is the canonical distribution dtype for Qwen3 /
// Llama 3 / Qwen3 MoE. CUDA on Ada+ has hardware bf16;
// Ampere has it too. CPU emulates.
let dtype = DType::BF16;
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
// mutation by another process while we hold the mapping is
// UB. We trust the HF cache is immutable-by-design.
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&safetensors_paths, dtype, &device_for_load)
.context("build VarBuilder over safetensors")?
};
match model_type.as_str() {
"qwen3" => {
let cfg: qwen3_dense::Config =
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
Ok(ModelArch::Qwen3Dense(model))
}
"qwen3_moe" => {
let cfg: qwen3_moe_dense::Config =
serde_json::from_str(&cfg_text).context("parse Qwen3 MoE config.json")?;
let model = qwen3_moe_dense::ModelForCausalLM::new(&cfg, vb)
.map_err(|e| anyhow::anyhow!("build Qwen3 MoE dense model: {e}"))?;
Ok(ModelArch::Qwen3MoeDense(model))
}
"llama" => {
let cfg: llama_dense::LlamaConfig =
serde_json::from_str(&cfg_text).context("parse Llama config.json")?;
// Llama has multiple sub-variants (Llama 1 has no
// GQA; Llama 3 does). `LlamaConfig::into_config`
// resolves the right shape; the `use_flash_attn`
// arg defaults to false — the flash kernel is a
// separate feature flag and uses extra VRAM.
let config = cfg.into_config(false);
let cache = llama_dense::Cache::new(true, dtype, &config, &device_for_load)
.context("build Llama Cache")?;
let model = llama_dense::Llama::load(vb, &config)
.map_err(|e| anyhow::anyhow!("build Llama dense model: {e}"))?;
Ok(ModelArch::LlamaDense(Box::new(LlamaDense {
model,
cache,
config,
dtype,
device: device_for_load,
})))
}
"qwen3_5" => {
// Qwen3-Next needs a ShardedVarBuilder because its
// load functions use the sharded backend (so they
// can be reused unchanged by the future TP variant).
// With world_size=1 the backend falls through to
// the unsharded path, so there is no per-load cost.
let cfg: super::arch::qwen3_5::Config = serde_json::from_str(&cfg_text)
.context("parse Qwen3-Next (qwen3_5) config.json")?;
let sharded_vb = unsafe {
candle_nn::var_builder::ShardedSafeTensors::var_builder(
&safetensors_paths,
dtype,
&device_for_load,
)
.context("build ShardedVarBuilder for Qwen3-Next")?
};
let model = super::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, sharded_vb)
.context("build Qwen3-Next dense model")?;
Ok(ModelArch::Qwen3_5Dense(model))
}
other => {
// Defensive: `check_dense_config_supported` already
// gated on the supported set, so this branch is
// unreachable unless that list and the match here
// drift apart.
anyhow::bail!(
"unrouted supported model_type '{other}' — \
DENSE_SUPPORTED_MODEL_TYPES and load_arch_dense \
must stay in sync"
)
}
}
})
.await
.context("blocking dense load task panicked")??;
Ok((tokenizer_path, arch))
}
/// Resolve a model spec to local GGUF and tokenizer file paths via
/// hf-hub. Downloads on first use; subsequent calls are cached.
async fn resolve_files(&self, spec: &ModelSpec) -> Result<(PathBuf, PathBuf)> {
let api = self.hf_api()?;
let repo = api.model(spec.model_id.clone());
let info = repo
.info()
.await
.with_context(|| format!("fetch HF repo info for {}", spec.model_id))?;
let quant = spec.quant.as_deref().unwrap_or("");
let quant_lc = quant.to_lowercase();
let gguf_filename = info
.siblings
.iter()
.map(|s| s.rfilename.as_str())
.filter(|name| name.to_lowercase().ends_with(".gguf"))
.find(|name| quant_lc.is_empty() || name.to_lowercase().contains(&quant_lc))
.ok_or_else(|| {
anyhow::anyhow!(
"no GGUF file matching quant {:?} in repo {}",
spec.quant,
spec.model_id
)
})?
.to_string();
tracing::info!(
model = %spec.model_id,
file = %gguf_filename,
"resolving GGUF (may be cached)"
);
let gguf_path = repo
.get(&gguf_filename)
.await
.with_context(|| format!("fetch GGUF {gguf_filename}"))?;
// GGUF-only HF repos (unsloth/Qwen3-*-GGUF, Qwen/Qwen3-*-GGUF,
// etc.) ship the .gguf file but not tokenizer.json — the
// tokenizer.json lives in the base non-GGUF repo. Derive the
// base repo id by stripping a `-GGUF` / `-gguf` suffix; if
// there's no such suffix the same repo is used (works for
// non-GGUF model_ids).
let tokenizer_repo_id = spec
.model_id
.strip_suffix("-GGUF")
.or_else(|| spec.model_id.strip_suffix("-gguf"))
.unwrap_or(spec.model_id.as_str())
.to_string();
let tokenizer_repo = if tokenizer_repo_id == spec.model_id {
repo
} else {
tracing::debug!(
from = %spec.model_id,
to = %tokenizer_repo_id,
"tokenizer.json sourced from base repo (GGUF suffix stripped)"
);
api.model(tokenizer_repo_id.clone())
};
let tokenizer_path = tokenizer_repo
.get("tokenizer.json")
.await
.with_context(|| format!("fetch tokenizer.json from {tokenizer_repo_id}"))?;
Ok((gguf_path, tokenizer_path))
}
/// Run a non-streaming chat completion against a loaded model.
///
/// Returns a typed `InferenceError` when the model isn't loaded so the
/// handler can map to an appropriate HTTP status without string-matching.
pub async fn chat_completion(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, InferenceError> {
let handle = {
let models = self.models.read().await;
models.get(&request.model).cloned()
};
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
// The match is technically infallible without `cuda` (only Single
// exists), but the cfg-gated Tp arm makes this the right shape
// under both feature flags.
#[allow(clippy::infallible_destructuring_match)]
let loaded = match handle {
LoadedHandle::Single(m) => m,
#[cfg(feature = "cuda")]
LoadedHandle::Tp(m) => {
return self.chat_completion_tp(m, request).await;
}
};
// Span every line of this request with a short req_id +
// model so `grep req_id=…` over the journal can reconstruct
// one request even when dozens overlap. Add a terminal log
// line on both success and failure — the single-GPU path
// used to log nothing on either side, so a failing request
// looked exactly like an idle neuron.
let req_id = new_req_id();
let model_id = request.model.clone();
let span = tracing::info_span!("chat", req_id = %req_id, model = %model_id);
let req_start = std::time::Instant::now();
// Refuse the request up front if a prior inference poisoned
// the device context — otherwise we hand the doomed forward
// off to spawn_blocking and stall waiting for CUDA to fail.
if loaded.poisoned.load(Ordering::Acquire) {
let _g = span.enter();
tracing::warn!("chat_completion: refusing request, model poisoned");
return Err(poisoned_error(&model_id));
}
let result = async {
let prompt = format_qwen3_prompt(&request.messages);
let encoding = loaded
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let prompt_len = prompt_tokens.len();
let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(512) as usize;
let seed = unix_subsec_nanos();
let eos_id = loaded
.tokenizer
.token_to_id("<|im_end|>")
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
let (vram_free_mb, vram_total_mb) = loaded.query_vram().await;
tracing::info!(
prompt_len,
max_new,
temperature,
?top_p,
?eos_id,
vram_free_mb,
vram_total_mb,
"chat_completion: starting"
);
validate_request(prompt_len, vram_free_mb)?;
// Routing: CUDA loads go through the per-device worker
// thread (introduced in Phase 1; forward/clear added in
// Phase 2). CPU loads keep the existing spawn_blocking
// path because there's no context to own and the channel
// round-trip would only add latency. The two arms produce
// the same `(Vec<u32>, String)` shape so the rest of the
// path is shared.
let (generated_ids, finish_reason) = if let (Some(worker), Some(handle)) =
(loaded.worker.as_ref(), loaded.arch_handle)
{
// Worker path (CUDA).
#[cfg(feature = "cuda")]
{
match run_inference_via_worker(
worker,
handle,
&prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
)
.await
{
Ok(v) => v,
Err(e) => {
loaded.poisoned.store(true, Ordering::Release);
return Err(InferenceError::Other(e));
}
}
}
#[cfg(not(feature = "cuda"))]
{
// Can't happen: `loaded.worker` is only Some on
// CUDA builds. The dead branch keeps the no-cuda
// build well-typed.
let _ = (worker, handle);
unreachable!("worker handle present without cuda feature");
}
} else if let Some(arch_arc) = loaded.arch.clone() {
// CPU path: existing spawn_blocking on the local
// Arc<Mutex<ModelArch>>.
let device = loaded.device.clone();
let inference_result =
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
let mut guard = arch_arc.blocking_lock();
run_inference(
&mut guard,
&device,
&prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
)
})
.await;
// Any failure inside the spawn_blocking touched CUDA via
// candle's forward / cache code, so we treat it as a
// device-poisoning event. The terminal log at the bottom
// of the wrapper reports the error; this flag stops the
// NEXT request from going down the same path.
match inference_result {
Ok(Ok(v)) => v,
Ok(Err(e)) => {
loaded.poisoned.store(true, Ordering::Release);
return Err(InferenceError::Other(e));
}
Err(e) => {
loaded.poisoned.store(true, Ordering::Release);
return Err(InferenceError::Other(anyhow::anyhow!(
"inference task panicked: {e}"
)));
}
}
} else {
// LoadedModel invariant: exactly one of `worker` /
// `arch` is Some. Reaching here is a construction bug.
return Err(InferenceError::Other(anyhow::anyhow!(
"LoadedModel has neither worker handle nor local arch — load-path bug"
)));
};
let completion_text = loaded
.tokenizer
.decode(&generated_ids, true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
let usage = Usage {
prompt_tokens: prompt_len as u64,
completion_tokens: generated_ids.len() as u64,
total_tokens: (prompt_len + generated_ids.len()) as u64,
};
tracing::info!(
prompt_tokens = prompt_len,
completion_tokens = generated_ids.len(),
finish_reason = %finish_reason,
total_ms = req_start.elapsed().as_millis(),
"chat_completion: done"
);
Ok::<_, InferenceError>(ChatCompletionResponse {
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
object: "chat.completion".into(),
created: unix_now_secs(),
model: request.model.clone(),
choices: vec![ChatCompletionChoice {
index: 0,
message: ChatMessage {
role: "assistant".into(),
content: MessageContent::Text(completion_text),
extra: serde_json::Value::Object(Default::default()),
},
finish_reason: Some(finish_reason),
extra: serde_json::Value::Object(Default::default()),
}],
usage: Some(usage),
extra: serde_json::Value::Object(Default::default()),
})
}
.instrument(span.clone())
.await;
if let Err(ref e) = result {
let _g = span.enter();
tracing::error!(
error = %format!("{e:#}"),
total_ms = req_start.elapsed().as_millis(),
"chat_completion: failed"
);
}
result
}
/// Run a streaming chat completion against a loaded model.
///
/// Returns an `mpsc::Receiver` that yields `ChatCompletionChunk`s in
/// OpenAI SSE format. The first chunk carries the assistant role;
/// subsequent chunks carry incremental `content` deltas; the final
/// chunk carries `finish_reason`. The handler is responsible for
/// wrapping these into an SSE response and appending the `[DONE]`
/// terminator.
///
/// Token-by-token decoding tracks the cumulative decoded prefix so
/// BPE byte-fallback boundaries don't split a UTF-8 char across
/// chunks.
pub async fn chat_completion_stream(
&self,
request: ChatCompletionRequest,
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
let handle = {
let models = self.models.read().await;
models.get(&request.model).cloned()
};
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
// The match is technically infallible without `cuda` (only Single
// exists), but the cfg-gated Tp arm makes this the right shape
// under both feature flags.
#[allow(clippy::infallible_destructuring_match)]
let loaded = match handle {
LoadedHandle::Single(m) => m,
#[cfg(feature = "cuda")]
LoadedHandle::Tp(m) => {
return self.chat_completion_tp_stream(m, request).await;
}
};
let prompt = format_qwen3_prompt(&request.messages);
let encoding = loaded
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(512) as usize;
let seed = unix_subsec_nanos();
let eos_id = loaded
.tokenizer
.token_to_id("<|im_end|>")
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
let device = loaded.device.clone();
let tokenizer = loaded.tokenizer.clone();
let model_id = request.model.clone();
let id = format!("chatcmpl-{:x}", unix_subsec_nanos());
let created = unix_now_secs();
// Bounded channel so the producer (blocking inference) is back-
// pressured by the consumer (SSE writer). 32 is generous —
// tokens arrive one at a time and the SSE writer is async.
let (tx, rx) = mpsc::channel::<ChatCompletionChunk>(32);
// Lead chunk: announce the assistant role per OpenAI streaming
// conventions. Tools that auto-detect a streaming reply expect
// this before any content delta.
let role_chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".into(),
created,
model: model_id.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: json!({"role": "assistant"}),
finish_reason: None,
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
// Refuse if the model is already poisoned. No point opening
// an SSE stream just to send the role chunk and then bail.
if loaded.poisoned.load(Ordering::Acquire) {
return Err(poisoned_error(&model_id));
}
// If sending the role chunk fails the receiver is already gone;
// bail before kicking off the heavy blocking work.
tx.send(role_chunk)
.await
.map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?;
// Span context — spawn_blocking detaches from the async
// executor so we capture the span explicitly and re-enter it
// inside the closure to keep the req_id on every emitted line.
let req_id = new_req_id();
let span = tracing::info_span!("chat_stream", req_id = %req_id, model = %model_id);
let prompt_len = prompt_tokens.len();
let req_start = std::time::Instant::now();
// Cloned `Arc<LoadedModel>` so the spawned task can mark the
// model poisoned if its forward fails.
let loaded_for_task = Arc::clone(&loaded);
let span_for_starting = span.clone();
let span_for_task = span.clone();
// Query VRAM before entering the span so we don't await inside
// an entered guard (Span::enter creates a synchronous guard
// that can't span await points). The span gets entered in a
// separate scope below purely for the log emission.
let (vram_free_mb, vram_total_mb) = loaded.query_vram().await;
{
let _g = span_for_starting.enter();
tracing::info!(
prompt_len,
max_new,
temperature,
?top_p,
?eos_id,
vram_free_mb,
vram_total_mb,
"chat_completion (stream): starting"
);
}
validate_request(prompt_len, vram_free_mb)?;
// Routing parallel to the non-streaming chat_completion: CUDA
// goes through the worker (async task), CPU keeps the
// spawn_blocking + Arc<Mutex<ModelArch>> path.
if let (Some(worker), Some(handle)) = (loaded.worker.clone(), loaded.arch_handle) {
#[cfg(feature = "cuda")]
{
let prompt_tokens = prompt_tokens.clone();
tokio::spawn(
async move {
match stream_inference_via_worker(
worker,
handle,
tokenizer,
prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
id,
created,
model_id,
tx,
)
.await
{
Ok(_finish_reason) => tracing::info!(
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): done"
),
Err(e) => {
loaded_for_task.poisoned.store(true, Ordering::Release);
tracing::error!(
error = %format!("{e:#}"),
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): failed, model marked poisoned"
);
}
}
}
.instrument(span_for_task),
);
}
#[cfg(not(feature = "cuda"))]
{
let _ = (worker, handle, span_for_task);
unreachable!("worker handle present without cuda feature");
}
} else if let Some(arch_arc) = loaded.arch.clone() {
tokio::task::spawn_blocking(move || {
let _g = span_for_task.enter();
let mut guard = arch_arc.blocking_lock();
match run_inference_streaming(
&mut guard,
&device,
&tokenizer,
&prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
&id,
created,
&model_id,
&tx,
) {
Ok(()) => tracing::info!(
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): done"
),
Err(e) => {
loaded_for_task.poisoned.store(true, Ordering::Release);
tracing::error!(
error = %format!("{e:#}"),
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): failed, model marked poisoned"
);
}
}
});
} else {
return Err(InferenceError::Other(anyhow::anyhow!(
"LoadedModel has neither worker handle nor local arch — load-path bug"
)));
}
Ok(rx)
}
}
#[async_trait]
impl Harness for CandleHarness {
fn name(&self) -> &str {
"candle"
}
async fn health(&self) -> HarnessHealth {
HarnessHealth {
name: "candle".into(),
running: true,
uptime_secs: None,
}
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
let models = self.models.read().await;
Ok(models
.values()
.map(|h| ModelInfo {
id: h.model_id().into(),
harness: "candle".into(),
status: if h.is_poisoned() {
"poisoned".into()
} else {
"loaded".into()
},
devices: h.devices(),
vram_used_mb: None,
})
.collect())
}
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
if spec.harness != "candle" {
anyhow::bail!("expected harness=candle, got harness={}", spec.harness);
}
{
let models = self.models.read().await;
if models.contains_key(&spec.model_id) {
anyhow::bail!("model '{}' already loaded", spec.model_id);
}
}
let tp_size = spec.tensor_parallel.unwrap_or(1);
if tp_size > 1 {
#[cfg(feature = "cuda")]
{
return self.load_tp(spec, tp_size).await;
}
#[cfg(not(feature = "cuda"))]
{
anyhow::bail!(
"tensor_parallel={tp_size} requested for '{}': this neuron \
binary was built without --features cuda; TP requires CUDA + NCCL",
spec.model_id
);
}
}
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
let device = Self::pick_device(&devices)?;
// Phase 4: load directly on the worker thread for CUDA;
// legacy spawn_blocking + Arc<Mutex<>> only for CPU. Resolve
// hf-hub paths up front (always async), then either dispatch
// a load Job (CUDA) or call the legacy local loader (CPU).
let worker: Option<Arc<super::device_worker::DeviceWorkerHandle>> = match &device {
#[cfg(feature = "cuda")]
Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?),
_ => None,
};
let (tokenizer_path, arch_local, arch_handle) = if let Some(w) = &worker {
// CUDA path: resolve, then load in the worker.
if spec.quant.is_some() {
let (gguf_path, tokenizer_path) = self.resolve_files(spec).await?;
let handle = w
.load_gguf(gguf_path, spec.model_id.clone())
.await
.map_err(|e| anyhow::anyhow!("worker load_gguf: {e}"))?;
(tokenizer_path, None, Some(handle))
} else {
let (config_path, tokenizer_path, safetensors_paths) =
self.resolve_dense_files(spec).await?;
let handle = w
.load_dense(config_path, safetensors_paths, spec.model_id.clone())
.await
.map_err(|e| anyhow::anyhow!("worker load_dense: {e}"))?;
(tokenizer_path, None, Some(handle))
}
} else {
// CPU path: legacy spawn_blocking + Arc<Mutex<ModelArch>>.
let (tokenizer_path, arch) = if spec.quant.is_some() {
self.load_arch_gguf(spec, &device).await?
} else {
self.load_arch_dense(spec, &device).await?
};
(tokenizer_path, Some(Arc::new(Mutex::new(arch))), None)
};
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
let loaded = Arc::new(LoadedModel {
model_id: spec.model_id.clone(),
arch: arch_local,
tokenizer,
device,
quant: spec.quant.clone(),
devices,
poisoned: AtomicBool::new(false),
worker,
arch_handle,
});
let mut models = self.models.write().await;
models.insert(spec.model_id.clone(), LoadedHandle::Single(loaded));
tracing::info!(model = %spec.model_id, "model loaded");
Ok(())
}
async fn unload_model(&self, model_id: &str) -> Result<()> {
let removed = {
let mut models = self.models.write().await;
models.remove(model_id)
};
let Some(handle) = removed else {
anyhow::bail!("model '{model_id}' not loaded");
};
// Single-GPU drops are immediate — the LoadedModel goes out of
// scope with the Arc and candle frees VRAM. CUDA loads also
// ship a `Job::DropArch` to the device worker so the boxed
// `ModelArch` releases its CUDA allocations on the right
// thread (with the bound context); without that, the Drop
// would run on whatever tokio thread happens to be holding
// the last `Arc<LoadedModel>` clone when this fn returns.
// TP unloads further coordinate the subprocess pool below.
match handle {
LoadedHandle::Single(single) => {
if let (Some(worker), Some(arch_handle)) =
(single.worker.as_ref(), single.arch_handle)
&& let Err(e) = worker.drop_arch(arch_handle).await
{
tracing::warn!(
model = %model_id,
error = %e,
"single-GPU unload: DropArch RPC failed (model state may leak in worker slab)"
);
}
}
#[cfg(feature = "cuda")]
LoadedHandle::Tp(tp) => {
// Try to recover the inner TpLoadedModel so we can move
// the pool and shut it down. If anyone else still holds
// a clone of the Arc (shouldn't happen — the only owners
// are the registry and any in-flight chat_completion),
// bail with a clear marker rather than silently leaking.
let tp = match Arc::try_unwrap(tp) {
Ok(t) => t,
Err(arc) => {
// Reinsert so we don't leave the registry in an
// inconsistent state.
let mut models = self.models.write().await;
models.insert(model_id.into(), LoadedHandle::Tp(arc));
anyhow::bail!("cannot unload '{model_id}': inference still in flight");
}
};
// Drop the leader's TpLeaderModel on the device worker
// thread (CUDA tensors and Arc<Comm> clones release on
// the same OS thread that allocated them).
if let Err(e) = tp.worker.drop_tp(tp.leader_handle).await {
tracing::warn!(
model = %model_id,
error = %e,
"TP unload: DropTp RPC failed (leader model may leak in worker slab)"
);
}
let mut pool = tp.pool.into_inner();
if let Err(e) = pool.unload_model(model_id).await {
tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed");
}
if let Err(e) = pool.shutdown().await {
tracing::warn!(model = %model_id, error = %e, "TP pool shutdown failed");
}
}
}
tracing::info!(model = %model_id, "model unloaded");
Ok(())
}
async fn inference_endpoint(&self, model_id: &str) -> Option<String> {
let models = self.models.read().await;
models.contains_key(model_id).then(|| self.bind_url.clone())
}
}
impl CandleHarness {
/// Tensor-parallel load. Resolves dense safetensors via hf-hub the
/// same way the single-GPU dense path does, spins up a TP worker
/// pool sized to `tp_size`, runs the NCCL handshake, then has
/// every rank load its shard of the model.
///
/// `spec.devices` carries the per-rank CUDA device indices (one
/// entry per rank, in rank order); defaults to `0..tp_size`.
#[cfg(feature = "cuda")]
async fn load_tp(&self, spec: &ModelSpec, tp_size: u32) -> Result<()> {
use std::sync::Arc as StdArc;
use tokio::sync::Mutex as TMutex;
// Default per-rank device assignment: 0, 1, ..., tp_size - 1.
let devices = spec
.devices
.clone()
.unwrap_or_else(|| (0..tp_size).collect());
if devices.len() as u32 != tp_size {
anyhow::bail!(
"tensor_parallel={tp_size} requires {tp_size} entries in devices, got {}",
devices.len()
);
}
// `quant` on the TP path now means in-situ quantization (ISQ):
// load safetensors, quantize the per-rank shard to the named
// GgmlDType at load time. The worker's parse_quant_string
// accepts the same names (q5k, q8_0, etc.) as the single-GPU
// path. GGUF-source-file models still aren't TP-loadable, but
// resolve_dense_files only looks for safetensors so that path
// errors out cleanly later if no safetensors are present.
// 1. Resolve config + tokenizer + safetensors via hf-hub.
let (config_path, tokenizer_path, safetensors_paths) =
self.resolve_dense_files(spec).await?;
let config_json = std::fs::read_to_string(&config_path).context("read config.json")?;
// Reject unsupported architectures *before* spawning the worker
// pool and fanning out NCCL — otherwise we'd burn the pool
// lifecycle on a load that's guaranteed to fail at deserialise
// time inside every rank.
check_dense_config_supported(&config_json, &spec.model_id)?;
// The TP path knows how to ship and reconstruct a Qwen3 dense
// shard (`tp_qwen3.rs`). Other architectures may pass the
// single-GPU `check_dense_config_supported` check above but
// have no TP-aware module — bail with a clear marker pointing
// at the file the implementer needs to add. This keeps an
// operator who sets `tensor_parallel=2` on a Llama model from
// silently routing through `pool.load_dense_shard` (which
// assumes Qwen3 config shape on the worker side) and producing
// a confusing config-parse failure inside every rank.
check_tp_arch_supported(&config_json, &spec.model_id)?;
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
// 1..tp_size are subprocesses, one per device after the
// leader's own. The leader's device worker thread is
// spawned (or reused) here and passed into the pool so
// `init_nccl`, the load, every TP forward, and KV-cache
// clears all dispatch from the same OS thread.
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
let leader_worker = self.ensure_device_worker(devices[0]).await?;
let mut pool =
super::tp::WorkerPool::spawn(&exe, tp_size, &devices, leader_worker.clone()).await?;
// 3. NCCL handshake across all ranks.
let leader_device_idx = devices[0];
pool.init_nccl(leader_device_idx).await?;
// 4. Pick the leader's candle Device (same index as init_nccl).
let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize)
.context("Device::new_cuda for TP leader")?;
// 5. Load this rank's shard on every rank. After Phase 3
// `load_dense_shard` transfers the freshly-built
// `TpLeaderModel` into the device worker's TP slab and
// returns the resulting handle.
let leader_handle = pool
.load_dense_shard(
&spec.model_id,
&config_json,
&safetensors_paths,
&leader_device,
candle_core::DType::BF16,
spec.quant.clone(),
)
.await?;
// 6. Tokenizer (same as single-GPU path).
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
let tp_loaded = StdArc::new(TpLoadedModel {
model_id: spec.model_id.clone(),
tokenizer,
devices: devices.clone(),
pool: TMutex::new(pool),
leader_handle,
leader_device: leader_device.clone(),
poisoned: AtomicBool::new(false),
// Same `leader_worker` we passed into the pool above —
// single `Arc` shared between WorkerPool and
// TpLoadedModel so they reference the same thread.
worker: leader_worker,
});
let mut models = self.models.write().await;
models.insert(spec.model_id.clone(), LoadedHandle::Tp(tp_loaded));
tracing::info!(
model = %spec.model_id,
tp_size,
?devices,
"TP model loaded"
);
Ok(())
}
/// Non-streaming chat completion against a TP model.
///
/// The actual work runs inside a `tokio::spawn`'d task so the HTTP
/// client disconnecting (curl timeout, browser nav-away, etc.)
/// can't cancel the future mid-`pool.generate_step` and leave the
/// worker subprocesses mid-RPC. If the spawned task is dropped,
/// it still runs to completion and finishes draining the pool —
/// the next inference request finds a clean pool. The HTTP layer
/// just gives up on the response.
///
/// Every step also emits `info`/`debug` tracing so journalctl
/// shows where time went without needing to surface internals in
/// the HTTP error response.
#[cfg(feature = "cuda")]
async fn chat_completion_tp(
&self,
tp: Arc<TpLoadedModel>,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, InferenceError> {
// Tag every line of this request with a short req_id so a
// grep over journalctl reconstructs one request even when
// dozens are queued and interleaved. The span prefix is added
// by the fmt subscriber to every event emitted within the
// instrumented future, including events from `WorkerPool::*`
// since those run on the leader's task.
let req_id = new_req_id();
let model_id = request.model.clone();
let span = tracing::info_span!("tp_chat", req_id = %req_id, model = %model_id);
let req_start = std::time::Instant::now();
if tp.poisoned.load(Ordering::Acquire) {
let _g = span.enter();
tracing::warn!("TP chat_completion: refusing request, model poisoned");
return Err(poisoned_error(&model_id));
}
let tp_for_marker = Arc::clone(&tp);
let handle = tokio::spawn(chat_completion_tp_inner(tp, request).instrument(span.clone()));
let result = match handle.await {
Ok(r) => r,
Err(join_err) => Err(InferenceError::Other(anyhow::anyhow!(
"TP inference task panicked or was cancelled: {join_err}"
))),
};
if let Err(ref e) = result {
// Mark poisoned: a failure inside the spawned task either
// hit a CUDA/NCCL driver error directly or surfaced as a
// task panic. Both cases leave the worker subprocesses in
// an unknown state — refuse subsequent requests until an
// operator unload+reloads. This is the gate that turned
// the 2026-05-26 silent-hang into a clean 5xx.
tp_for_marker.poisoned.store(true, Ordering::Release);
let _g = span.enter();
tracing::error!(
error = %format!("{e:#}"),
total_ms = req_start.elapsed().as_millis(),
"TP chat_completion: failed, model marked poisoned"
);
}
result
}
/// Streaming counterpart to `chat_completion_tp`. Same per-step
/// orchestration (clear cache, prefill, sample, decode loop) but
/// emits one `ChatCompletionChunk` per token over an mpsc channel
/// so the handler can write an SSE stream.
///
/// Unlike the single-GPU streaming path (which runs the candle
/// forward inside `spawn_blocking` and uses `blocking_send`), the
/// TP loop is itself async — every `pool.generate_step` awaits the
/// leader's spawn_blocking forward plus every worker's recv_only.
/// So we `tokio::spawn` the orchestration task and use plain
/// `Sender::send`.
#[cfg(feature = "cuda")]
async fn chat_completion_tp_stream(
&self,
tp: Arc<TpLoadedModel>,
request: ChatCompletionRequest,
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
if tp.poisoned.load(Ordering::Acquire) {
return Err(poisoned_error(&request.model));
}
let prompt = format_qwen3_prompt(&request.messages);
let encoding = tp
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let prompt_len = prompt_tokens.len();
let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(512) as usize;
let seed = unix_subsec_nanos();
let eos_id = tp
.tokenizer
.token_to_id("<|im_end|>")
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
let model_id = request.model.clone();
let id = format!("chatcmpl-{:x}", unix_subsec_nanos());
let created = unix_now_secs();
let tokenizer = tp.tokenizer.clone();
// Bounded channel — back-pressures the producer when the SSE
// writer is slow.
let (tx, rx) = mpsc::channel::<ChatCompletionChunk>(32);
// Role chunk first, before kicking off the heavy work — if the
// receiver is gone by now there's no point starting inference.
let role_chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".into(),
created,
model: model_id.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: json!({"role": "assistant"}),
finish_reason: None,
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
tx.send(role_chunk)
.await
.map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?;
// The orchestration task. Holds the pool lock for the lifetime
// of this inference; concurrent requests against the same TP
// model serialise behind it.
//
// Tagged with the same req_id span as the non-streaming path
// so the journal can be reconstructed regardless of which API
// surface the client hit.
let req_id = new_req_id();
let span = tracing::info_span!(
"tp_chat_stream",
req_id = %req_id,
model = %model_id
);
let req_start = std::time::Instant::now();
let (vram_free_mb, vram_total_mb) = tp.query_vram().await;
tracing::info!(
parent: &span,
prompt_len,
max_new,
temperature,
?top_p,
?eos_id,
vram_free_mb,
vram_total_mb,
"TP chat_completion (stream): starting"
);
validate_request(prompt_len, vram_free_mb)?;
let tp_for_task = Arc::clone(&tp);
tokio::spawn(
async move {
let mut failure: Option<String> = None;
let mut pool = acquire_pool_lock(&tp_for_task.pool, &model_id).await;
let leader_handle = tp_for_task.leader_handle;
let mut all_tokens: Vec<u32> = Vec::new();
let mut decoded_prefix = String::new();
let mut finish_reason = "length".to_string();
'work: {
if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await {
failure = Some(format!("clear_kv_cache: {e:#}"));
break 'work;
}
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
// Chunked prefill — see `chunked_prefill_tp`. Each
// chunk fans out to every rank with a growing
// offset; only the final chunk's logits are kept
// for the first sample.
let logits_vec = match chunked_prefill_tp(
&mut pool,
&model_id,
leader_handle,
&prompt_tokens,
)
.await
{
Ok(l) => l,
Err(e) => {
failure = Some(format!("prefill: {e:#}"));
break 'work;
}
};
let (post_prefill_vram_free_mb, _) = tp_for_task.query_vram().await;
tracing::info!(
model = %model_id,
prompt_len,
vram_free_mb = post_prefill_vram_free_mb,
"TP chat_completion (stream): prefill complete"
);
let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) {
Ok(t) => t,
Err(e) => {
failure = Some(format!("prefill build cpu logits: {e:#}"));
break 'work;
}
};
let mut next_token =
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
model = %model_id,
?health,
"TP chat_completion (stream): prefill sample failed; logits unhealthy"
);
failure = Some(format!("prefill sample: {e:#}"));
break 'work;
}
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
} else {
all_tokens.push(next_token);
if !emit_chunk(
&all_tokens,
&mut decoded_prefix,
&tokenizer,
&tx,
&id,
created,
&model_id,
)
.await
{
// Client gone — treat as normal stream end,
// not a failure. No log spam.
break 'work;
}
for index in 0..max_new.saturating_sub(1) {
let logits_vec = match pool
.generate_step(
&model_id,
leader_handle,
vec![next_token],
prompt_len + index,
)
.await
{
Ok(l) => l,
Err(e) => {
failure = Some(format!("decode step {index}: {e:#}"));
break 'work;
}
};
let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) {
Ok(t) => t,
Err(e) => {
failure =
Some(format!("decode build cpu logits {index}: {e:#}"));
break 'work;
}
};
next_token = match sample_with_penalty(
&logits,
&all_tokens,
&mut logits_processor,
) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
model = %model_id,
step = index,
?health,
"TP chat_completion (stream): decode sample failed; logits unhealthy"
);
failure = Some(format!("decode sample {index}: {e:#}"));
break 'work;
}
};
// Always await the query (even when the
// trace! is filtered out by RUST_LOG): the
// channel hop is ~tens of µs, comparable to
// the previous in-line bind+query cost, and
// making the call conditional adds complexity
// for negligible win. Revisit if it shows up
// in a hot-path profile.
let step_vram_free_mb = tp_for_task.query_vram().await.0;
tracing::trace!(
model = %model_id,
step = index,
next_token,
vram_free_mb = step_vram_free_mb,
"TP chat_completion (stream): decode step"
);
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
}
all_tokens.push(next_token);
if !emit_chunk(
&all_tokens,
&mut decoded_prefix,
&tokenizer,
&tx,
&id,
created,
&model_id,
)
.await
{
break 'work;
}
}
}
}
// One terminal line per request, success or failure. The
// success branch was previously implicit (the SSE final
// chunk went out and the spawned task just ended); now
// there's always a log line for the operator.
if let Some(err) = &failure {
tp_for_task.poisoned.store(true, Ordering::Release);
tracing::error!(
error = %err,
completion_tokens = all_tokens.len(),
total_ms = req_start.elapsed().as_millis(),
"TP chat_completion (stream): failed, model marked poisoned"
);
} else {
tracing::info!(
prompt_tokens = prompt_len,
completion_tokens = all_tokens.len(),
finish_reason = %finish_reason,
total_ms = req_start.elapsed().as_millis(),
"TP chat_completion (stream): done"
);
}
// Final chunk carrying finish_reason — only on the success
// path. On failure we drop the channel so the client sees
// the SSE stream end abruptly (matches pre-change behaviour
// when the failed-path early-returned without final chunk).
if failure.is_none() {
let final_chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".into(),
created,
model: model_id.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: serde_json::Value::Object(Default::default()),
finish_reason: Some(finish_reason),
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
let _ = tx.send(final_chunk).await;
}
}
.instrument(span),
);
Ok(rx)
}
}
/// Body of the TP non-streaming chat completion, hoisted out of
/// `CandleHarness::chat_completion_tp` so it can run inside
/// `tokio::spawn` (which requires a `'static` future) and survive
/// HTTP-layer cancellation.
///
/// Tracing strategy: `info` for request entry/exit so journalctl
/// always shows when an inference started and finished; `debug` for
/// per-step timing so an operator running with `RUST_LOG=debug` sees
/// where the request actually spends its time without needing to
/// instrument the model code.
#[cfg(feature = "cuda")]
async fn chat_completion_tp_inner(
tp: Arc<TpLoadedModel>,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, InferenceError> {
let req_start = std::time::Instant::now();
let model_id = request.model.clone();
let prompt = format_qwen3_prompt(&request.messages);
let encoding = tp
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let prompt_len = prompt_tokens.len();
let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(512) as usize;
let seed = unix_subsec_nanos();
let eos_id = tp
.tokenizer
.token_to_id("<|im_end|>")
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
let (vram_free_mb, vram_total_mb) = tp.query_vram().await;
tracing::info!(
model = %model_id,
prompt_len,
max_new,
temperature,
?top_p,
?eos_id,
vram_free_mb,
vram_total_mb,
"TP chat_completion: starting"
);
validate_request(prompt_len, vram_free_mb)?;
// Acquire the pool lock for the duration of the request. After
// Phase 3 the leader's TpLeaderModel lives in the device worker
// thread, so the pool lock now serialises only subprocess RPC
// traffic — but holding it for the whole request still keeps
// concurrent chat_completions against the same TP model from
// interleaving prefill/decode jobs.
let mut pool = acquire_pool_lock(&tp.pool, &model_id).await;
let leader_handle = tp.leader_handle;
// Reset every rank's KV cache so this request doesn't attend
// over the previous request's tokens.
let clear_start = std::time::Instant::now();
pool.clear_kv_cache(&model_id, leader_handle)
.await
.map_err(InferenceError::Other)?;
tracing::debug!(
model = %model_id,
elapsed_ms = clear_start.elapsed().as_millis(),
"TP chat_completion: kv cache cleared"
);
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut generated: Vec<u32> = Vec::new();
let mut finish_reason = "length".to_string();
// Prefill: chunk the prompt through `chunked_prefill_tp` so
// activation memory is bounded by chunk size rather than the full
// prompt length. Every rank still sees the prompt in order, just
// spread across multiple `generate_step` calls with monotonically
// growing offsets.
let prefill_start = std::time::Instant::now();
let logits_vec = chunked_prefill_tp(&mut pool, &model_id, leader_handle, &prompt_tokens)
.await
.map_err(InferenceError::Other)?;
let (post_prefill_vram_free_mb, _) = tp.query_vram().await;
tracing::info!(
model = %model_id,
prompt_len,
elapsed_ms = prefill_start.elapsed().as_millis(),
vram_free_mb = post_prefill_vram_free_mb,
"TP chat_completion: prefill complete"
);
// Wrap the CPU-side logits in a CPU candle Tensor for sampling.
// No device touch on the async caller's thread — sampling reads
// from CPU memory only.
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("build cpu logits: {e}")))?;
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
// Logits health snapshot — the surrounding wrapper logs
// "failed, model marked poisoned" with the error chain;
// this WARN sits just above that and carries the actual
// numerical state so an operator can tell at a glance
// whether it was a NaN cascade, an Inf, or something else.
let health = logits_health_slice(&logits_vec);
tracing::warn!(
model = %model_id,
?health,
"TP chat_completion: prefill sample failed; logits unhealthy"
);
return Err(InferenceError::Other(e));
}
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
} else {
generated.push(next_token);
let decode_start = std::time::Instant::now();
for index in 0..max_new.saturating_sub(1) {
let step_start = std::time::Instant::now();
let logits_vec = pool
.generate_step(
&model_id,
leader_handle,
vec![next_token],
prompt_len + index,
)
.await
.map_err(InferenceError::Other)?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu).map_err(|e| {
InferenceError::Other(anyhow::anyhow!("build cpu logits step {index}: {e}"))
})?;
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
model = %model_id,
step = index,
?health,
"TP chat_completion: decode sample failed; logits unhealthy"
);
return Err(InferenceError::Other(e));
}
};
let step_vram_free_mb = tp.query_vram().await.0;
tracing::trace!(
model = %model_id,
step = index,
next_token,
step_ms = step_start.elapsed().as_millis(),
vram_free_mb = step_vram_free_mb,
"TP chat_completion: decode step"
);
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
}
generated.push(next_token);
}
tracing::info!(
model = %model_id,
generated = generated.len(),
elapsed_ms = decode_start.elapsed().as_millis(),
"TP chat_completion: decode complete"
);
}
drop(pool);
let completion_text = tp
.tokenizer
.decode(&generated, true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
let usage = Usage {
prompt_tokens: prompt_len as u64,
completion_tokens: generated.len() as u64,
total_tokens: (prompt_len + generated.len()) as u64,
};
tracing::info!(
model = %model_id,
prompt_tokens = prompt_len,
completion_tokens = generated.len(),
finish_reason = %finish_reason,
total_ms = req_start.elapsed().as_millis(),
"TP chat_completion: done"
);
Ok(ChatCompletionResponse {
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
object: "chat.completion".into(),
created: unix_now_secs(),
model: model_id,
choices: vec![ChatCompletionChoice {
index: 0,
message: ChatMessage {
role: "assistant".into(),
content: MessageContent::Text(completion_text),
extra: serde_json::Value::Object(Default::default()),
},
finish_reason: Some(finish_reason),
extra: serde_json::Value::Object(Default::default()),
}],
usage: Some(usage),
extra: serde_json::Value::Object(Default::default()),
})
}
/// Decode the cumulative token list, emit the delta (substring appended
/// since the last chunk) as a `chat.completion.chunk`. Returns `false`
/// if the receiver has hung up — the caller should bail.
#[cfg(feature = "cuda")]
async fn emit_chunk(
all_tokens: &[u32],
decoded_prefix: &mut String,
tokenizer: &Tokenizer,
tx: &mpsc::Sender<ChatCompletionChunk>,
id: &str,
created: u64,
model_id: &str,
) -> bool {
let full = match tokenizer.decode(all_tokens, true) {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "TP stream: decode failed");
return false;
}
};
if full.len() > decoded_prefix.len() {
let delta = full[decoded_prefix.len()..].to_string();
*decoded_prefix = full;
let chunk = ChatCompletionChunk {
id: id.into(),
object: "chat.completion.chunk".into(),
created,
model: model_id.into(),
choices: vec![ChunkChoice {
index: 0,
delta: json!({ "content": delta }),
finish_reason: None,
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
if tx.send(chunk).await.is_err() {
return false;
}
}
true
}
/// Errors returned by `CandleHarness::chat_completion`. The
/// `ModelNotLoaded`, `PromptTooLong`, and `InsufficientVram` variants
/// let the HTTP handler map cleanly to 404 / 400 / 503 without
/// string-matching on anyhow messages.
#[derive(Debug, thiserror::Error)]
pub enum InferenceError {
#[error("model '{0}' not loaded on this neuron")]
ModelNotLoaded(String),
#[error("prompt has {prompt_len} tokens but max is {max}")]
PromptTooLong { prompt_len: usize, max: usize },
#[error(
"insufficient free VRAM for prefill: {free_mb} MiB free, need at least {required_mb} MiB"
)]
InsufficientVram { free_mb: u64, required_mb: u64 },
#[error(transparent)]
Other(#[from] anyhow::Error),
}
/// Apply the Qwen3 chat template:
///
/// ```text
/// <|im_start|>{role}\n{content}<|im_end|>\n
/// ...
/// <|im_start|>assistant\n
/// ```
///
/// The trailing `<|im_start|>assistant\n` cues the model to begin a turn.
/// Non-text content parts (vision blocks) are joined as text only; full
/// multimodal handling is out of scope for Stage 3.
fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
let mut prompt = String::new();
for msg in messages {
let content = match &msg.content {
MessageContent::Text(s) => s.clone(),
MessageContent::Parts(parts) => parts
.iter()
.filter_map(|p| p.get("text").and_then(|v| v.as_str()))
.collect::<Vec<_>>()
.join(""),
};
prompt.push_str("<|im_start|>");
prompt.push_str(&msg.role);
prompt.push('\n');
prompt.push_str(&content);
prompt.push_str("<|im_end|>\n");
}
prompt.push_str("<|im_start|>assistant\n");
prompt
}
#[allow(clippy::too_many_arguments)]
/// Run the full single-GPU inference loop via the device worker.
///
/// Mirrors `run_inference`'s logic but routes each forward step
/// through `worker.forward_logits()` (returns CPU-side `Vec<f32>`)
/// and runs `apply_repeat_penalty` + sampling on a CPU candle tensor.
/// The device-resident logits tensor never escapes the worker thread.
///
/// Used by the CUDA path of `chat_completion`. The CPU path keeps
/// `run_inference` (spawn_blocking against `Arc<Mutex<ModelArch>>`)
/// because there's no CUDA context to own and the worker indirection
/// would only add channel overhead with no diagnostic benefit.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
async fn run_inference_via_worker(
worker: &super::device_worker::DeviceWorkerHandle,
handle: super::device_worker::ArchHandle,
prompt_tokens: &[u32],
max_new: usize,
temperature: f64,
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
) -> Result<(Vec<u32>, String)> {
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut generated: Vec<u32> = Vec::new();
let prompt_len = prompt_tokens.len();
worker
.clear_kv_cache(handle)
.await
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
// Prefill the prompt in `prefill_chunk_tokens()`-sized chunks so
// activation memory is bounded per step rather than scaling with
// prompt length. The KV cache accumulates across chunks; we keep
// only the final chunk's logits for sampling the first generated
// token.
let logits_vec = chunked_prefill_via_worker(worker, handle, prompt_tokens).await?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
?health,
"chat_completion (worker): prefill sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
for index in 0..max_new.saturating_sub(1) {
let logits_vec = worker
.forward_logits(handle, vec![next_token], prompt_len + index)
.await
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
step = index,
?health,
"chat_completion (worker): decode sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
}
Ok((generated, "length".into()))
}
/// Streaming counterpart of [`run_inference_via_worker`]. Emits one
/// `ChatCompletionChunk` per generated token via `tx`; routes every
/// forward step through `worker.forward_logits()`. Same per-step
/// CPU-side sampling discipline — no device tensor escapes the
/// worker thread.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
async fn stream_inference_via_worker(
worker: Arc<super::device_worker::DeviceWorkerHandle>,
handle: super::device_worker::ArchHandle,
tokenizer: Tokenizer,
prompt_tokens: Vec<u32>,
max_new: usize,
temperature: f64,
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
id: String,
created: u64,
model_id: String,
tx: mpsc::Sender<ChatCompletionChunk>,
) -> Result<String> {
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut all_tokens: Vec<u32> = Vec::new();
let mut decoded_prefix = String::new();
let prompt_len = prompt_tokens.len();
let mut finish_reason = "length".to_string();
worker
.clear_kv_cache(handle)
.await
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
// Chunked prefill (see `chunked_prefill_via_worker`). The owning
// `prompt_tokens: Vec<u32>` is borrowed for the loop's duration;
// we still need `prompt_len` (already extracted above) for the
// decode-step offset arithmetic.
let logits_vec = chunked_prefill_via_worker(&*worker, handle, &prompt_tokens).await?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
let mut next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
?health,
"chat_completion (stream/worker): prefill sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
} else {
all_tokens.push(next_token);
if !emit_chunk(
&all_tokens,
&mut decoded_prefix,
&tokenizer,
&tx,
&id,
created,
&model_id,
)
.await
{
return Ok(finish_reason); // Client gone — clean stream end.
}
for index in 0..max_new.saturating_sub(1) {
let logits_vec = worker
.forward_logits(handle, vec![next_token], prompt_len + index)
.await
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
step = index,
?health,
"chat_completion (stream/worker): decode sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
}
all_tokens.push(next_token);
if !emit_chunk(
&all_tokens,
&mut decoded_prefix,
&tokenizer,
&tx,
&id,
created,
&model_id,
)
.await
{
return Ok(finish_reason);
}
}
}
// Final chunk carrying finish_reason. Matches the run_inference_streaming
// shape so the SSE consumer sees an identical termination sequence.
let final_chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".into(),
created,
model: model_id.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: serde_json::Value::Object(Default::default()),
finish_reason: Some(finish_reason.clone()),
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
let _ = tx.send(final_chunk).await;
Ok(finish_reason)
}
#[allow(clippy::too_many_arguments)]
fn run_inference(
arch: &mut ModelArch,
device: &Device,
prompt_tokens: &[u32],
max_new: usize,
temperature: f64,
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
) -> Result<(Vec<u32>, String)> {
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut generated: Vec<u32> = Vec::new();
arch.clear_kv_cache()?;
let logits = chunked_prefill_local(arch, device, prompt_tokens)?;
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
for index in 0..max_new.saturating_sub(1) {
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
}
Ok((generated, "length".into()))
}
/// Streaming counterpart to `run_inference`. Emits chunks via `tx` as
/// tokens are generated and exits on EOS, max_new, or receiver drop.
///
/// Detokenization tracks the cumulative decoded prefix so each chunk's
/// `content` delta is the substring appended since the last chunk —
/// safe across BPE byte-fallback boundaries.
#[allow(clippy::too_many_arguments)]
fn run_inference_streaming(
arch: &mut ModelArch,
device: &Device,
tokenizer: &Tokenizer,
prompt_tokens: &[u32],
max_new: usize,
temperature: f64,
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
id: &str,
created: u64,
model_id: &str,
tx: &mpsc::Sender<ChatCompletionChunk>,
) -> Result<()> {
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut all_tokens: Vec<u32> = Vec::new();
let mut decoded_prefix = String::new();
let mut finish_reason = "length".to_string();
arch.clear_kv_cache()?;
let logits = chunked_prefill_local(arch, device, prompt_tokens)?;
let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result<bool> {
let full = tokenizer
.decode(all_tokens, true)
.map_err(|e| anyhow::anyhow!("decode: {e}"))?;
if full.len() > decoded_prefix.len() {
let delta = full[decoded_prefix.len()..].to_string();
*decoded_prefix = full;
let chunk = ChatCompletionChunk {
id: id.into(),
object: "chat.completion.chunk".into(),
created,
model: model_id.into(),
choices: vec![ChunkChoice {
index: 0,
delta: json!({ "content": delta }),
finish_reason: None,
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
// blocking_send returns Err if the consumer hung up — signal
// the caller to stop generating.
if tx.blocking_send(chunk).is_err() {
return Ok(false);
}
}
Ok(true)
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
} else {
all_tokens.push(next_token);
if !emit_token(&all_tokens, &mut decoded_prefix)? {
return Ok(());
}
for index in 0..max_new.saturating_sub(1) {
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
}
all_tokens.push(next_token);
if !emit_token(&all_tokens, &mut decoded_prefix)? {
return Ok(());
}
}
}
let final_chunk = ChatCompletionChunk {
id: id.into(),
object: "chat.completion.chunk".into(),
created,
model: model_id.into(),
choices: vec![ChunkChoice {
index: 0,
delta: serde_json::Value::Object(Default::default()),
finish_reason: Some(finish_reason),
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
let _ = tx.blocking_send(final_chunk);
Ok(())
}
fn unix_now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
fn unix_subsec_nanos() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn check_dense_config_accepts_qwen3() {
let cfg = r#"{
"model_type": "qwen3",
"vocab_size": 151936,
"architectures": ["Qwen3ForCausalLM"]
}"#;
check_dense_config_supported(cfg, "Qwen/Qwen3-1.7B").expect("qwen3 should pass");
}
#[test]
fn check_dense_config_rejects_unsupported_arch_with_clear_message() {
// Use a deliberately-fake model_type so this test stays
// meaningful as the supported set grows. (qwen3_5 was the
// motivating real example but now lives in the supported set
// as a Stage 8c scaffold.)
let cfg = r#"{
"model_type": "fictional_arch_99",
"architectures": ["FictionalArch99ForCausalLM"]
}"#;
let err = check_dense_config_supported(cfg, "Fake/Model-99")
.expect_err("fictional_arch_99 should be rejected");
let msg = format!("{err}");
assert!(
msg.contains("unsupported model_type 'fictional_arch_99'"),
"message should name the rejected type: {msg}"
);
assert!(
msg.contains("Fake/Model-99"),
"message should echo the model id: {msg}"
);
assert!(
msg.contains("qwen3"),
"message should list the supported set: {msg}"
);
}
#[test]
fn check_dense_config_accepts_qwen3_5() {
// Sanity: Stage 8c scaffold means qwen3_5 deserialises into the
// supported set. Forward still bails (covered by tests on the
// architecture module itself), but the dispatch gate must let
// it through.
let cfg = r#"{
"model_type": "qwen3_5",
"architectures": ["Qwen3_5ForConditionalGeneration"],
"text_config": {"hidden_size": 5120}
}"#;
check_dense_config_supported(cfg, "Qwen/Qwen3.6-27B")
.expect("qwen3_5 should be in the supported set as of Stage 8c scaffold");
}
#[test]
fn check_dense_config_rejects_missing_model_type() {
let cfg = r#"{ "vocab_size": 1234 }"#;
let err = check_dense_config_supported(cfg, "anon/no-type")
.expect_err("missing model_type should be rejected");
assert!(
format!("{err}").contains("missing `model_type`"),
"message should call out the missing field"
);
}
#[test]
fn check_dense_config_rejects_invalid_json() {
let err = check_dense_config_supported("not json", "anon/bad-json")
.expect_err("malformed JSON should be rejected");
assert!(
format!("{err:#}").contains("config.json"),
"message should mention config.json"
);
}
}