Some checks failed
CI / CUDA type-check (push) Failing after 19s
build-prerelease / Resolve version stamps (push) Successful in 43s
CI / Format (push) Successful in 50s
CI / Clippy (push) Failing after 57s
build-prerelease / Build neuron-ada (push) Failing after 48s
build-prerelease / Build cortex binary (push) Successful in 5m5s
build-prerelease / Build neuron-blackwell (push) Successful in 6m38s
build-prerelease / Package cortex RPM (push) Successful in 1m27s
build-prerelease / Build neuron-ampere (push) Successful in 7m27s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
CI / Test (push) Successful in 10m27s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
One-shot, env-gated fault injector for beast verification: when NEURON_DEBUG_POISON names a model, the first request for it triggers the auto-recovery path as if a device fault had occurred — exercising unload→reload→healthy without corrupting the GPU. Latched so it fires exactly once (no recovery loop). No-op unless the env var is set; wired into both the single-GPU and TP chat poison gates. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
5288 lines
224 KiB
Rust
5288 lines
224 KiB
Rust
//! Candle harness — in-process inference using huggingface/candle.
|
||
//!
|
||
//! This is the sole `Harness` implementation. Inference runs inside
|
||
//! the neuron process; there is no external subprocess.
|
||
//!
|
||
//! - Stage 2 wired GGUF (Qwen3 only) load/unload via `quantized_qwen3`.
|
||
//! - Stage 3 (this) adds `chat_completion` — a non-streaming OpenAI
|
||
//! compatible chat completion routed to the loaded model's forward
|
||
//! pass on a per-model serialised generation loop.
|
||
|
||
use anyhow::{Context, Result};
|
||
use async_trait::async_trait;
|
||
use candle_core::quantized::gguf_file;
|
||
use candle_core::{DType, Device, Tensor};
|
||
use candle_nn::VarBuilder;
|
||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||
use candle_transformers::models::llama as llama_dense;
|
||
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaWeights;
|
||
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
|
||
use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE;
|
||
use candle_transformers::models::qwen3 as qwen3_dense;
|
||
use candle_transformers::models::qwen3_moe as qwen3_moe_dense;
|
||
use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec};
|
||
use cortex_core::openai::{
|
||
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse,
|
||
ChatMessage, MessageContent, Usage,
|
||
};
|
||
|
||
use crate::wire::{
|
||
FinishReason, InferenceEvent, ReasoningTokenPair, ToolCallTokenPair,
|
||
detect_reasoning_token_pair, detect_tool_call_token_pair, openai_chat as wire_chat,
|
||
};
|
||
use std::collections::HashMap;
|
||
use std::path::PathBuf;
|
||
use std::sync::Arc;
|
||
use std::sync::atomic::{AtomicBool, Ordering};
|
||
#[cfg(feature = "cuda")]
|
||
use std::time::Duration;
|
||
use std::time::{SystemTime, UNIX_EPOCH};
|
||
use tokenizers::Tokenizer;
|
||
use tokio::sync::{Mutex, RwLock, mpsc};
|
||
use tracing::Instrument;
|
||
|
||
/// In-process candle harness. Owns the loaded model registry.
|
||
pub struct CandleHarness {
|
||
models: Arc<RwLock<HashMap<String, LoadedHandle>>>,
|
||
/// Post-resolution source map: scheme → endpoint/token/cache. Built
|
||
/// in `new()` from the operator's `CandleHarnessConfig`; auth tokens
|
||
/// are read from their configured env vars at startup so secrets
|
||
/// don't leak through the config file.
|
||
sources: HashMap<String, ResolvedSource>,
|
||
/// Scheme to substitute for bare `org/name` model ids.
|
||
default_source: String,
|
||
bind_url: String,
|
||
/// One worker thread per CUDA device index that owns its
|
||
/// `CudaContext` for the daemon's lifetime. Populated lazily by
|
||
/// `ensure_device_worker()` when a model is loaded onto a CUDA
|
||
/// device. CPU `Device::Cpu` loads don't get an entry; they have
|
||
/// no context to own. Unused on the no-cuda build (the harness
|
||
/// can still load on CPU for tests, just without worker threads).
|
||
#[allow(dead_code)]
|
||
device_workers: Arc<RwLock<HashMap<u32, Arc<super::device_worker::DeviceWorkerHandle>>>>,
|
||
/// Auto-recovery (#17): model ids whose poisoned context is being
|
||
/// rebuilt via unload+reload. Insert is the single-flight gate (one
|
||
/// recovery per model in flight); membership also lets the request
|
||
/// path answer "recovering, retry shortly" during the reload gap
|
||
/// rather than a bare "not loaded".
|
||
recovering: Arc<RwLock<std::collections::HashSet<String>>>,
|
||
/// Sender to the background recovery task. The request path enqueues
|
||
/// a poisoned model id here; the task (holding a `Weak<Self>`) runs
|
||
/// the unload→reload→health-gate. Unbounded + tiny (model ids), and
|
||
/// the `recovering` set dedupes, so it can't back up.
|
||
recovery_tx: tokio::sync::mpsc::UnboundedSender<String>,
|
||
}
|
||
|
||
/// One entry in the harness's loaded-model registry. Single-GPU loads
|
||
/// land in `Single`; loads with `tensor_parallel > 1` land in `Tp`.
|
||
/// The two variants share the same `model_id` key in the map, so
|
||
/// `list_models`, `unload_model`, and `inference_endpoint` can walk
|
||
/// them uniformly without branching the storage layout.
|
||
///
|
||
/// `Clone` is cheap: both variants hold `Arc<_>` and cloning just bumps
|
||
/// the refcount.
|
||
#[derive(Clone)]
|
||
pub enum LoadedHandle {
|
||
Single(Arc<LoadedModel>),
|
||
#[cfg(feature = "cuda")]
|
||
Tp(Arc<TpLoadedModel>),
|
||
}
|
||
|
||
impl LoadedHandle {
|
||
pub fn model_id(&self) -> &str {
|
||
match self {
|
||
LoadedHandle::Single(m) => &m.model_id,
|
||
#[cfg(feature = "cuda")]
|
||
LoadedHandle::Tp(m) => &m.model_id,
|
||
}
|
||
}
|
||
|
||
/// The spec this model was loaded from (for auto-recovery #17).
|
||
pub fn spec(&self) -> &ModelSpec {
|
||
match self {
|
||
LoadedHandle::Single(m) => &m.spec,
|
||
#[cfg(feature = "cuda")]
|
||
LoadedHandle::Tp(m) => &m.spec,
|
||
}
|
||
}
|
||
|
||
pub fn devices(&self) -> Vec<u32> {
|
||
match self {
|
||
LoadedHandle::Single(m) => m.devices.clone(),
|
||
#[cfg(feature = "cuda")]
|
||
LoadedHandle::Tp(m) => m.devices.clone(),
|
||
}
|
||
}
|
||
|
||
/// True if an earlier inference left the device context in an
|
||
/// unrecoverable state. Surfaced in `/models` so cortex (and an
|
||
/// operator running `curl beast:13131/models`) can see at a glance
|
||
/// that the model needs unload+reload.
|
||
pub fn is_poisoned(&self) -> bool {
|
||
match self {
|
||
LoadedHandle::Single(m) => m.poisoned.load(Ordering::Acquire),
|
||
#[cfg(feature = "cuda")]
|
||
LoadedHandle::Tp(m) => m.poisoned.load(Ordering::Acquire),
|
||
}
|
||
}
|
||
|
||
/// Modalities the loaded model supports. Stage B7 (single-GPU) +
|
||
/// TP-vision (#12) — both single-GPU and TP loads advertise
|
||
/// `"vision"` when a replicated vision tower materialised.
|
||
pub fn capabilities(&self) -> Vec<String> {
|
||
let mut caps = vec!["text".to_string()];
|
||
let has_vision = match self {
|
||
LoadedHandle::Single(m) => m.has_vision,
|
||
#[cfg(feature = "cuda")]
|
||
LoadedHandle::Tp(m) => m.has_vision,
|
||
};
|
||
if has_vision {
|
||
caps.push("vision".to_string());
|
||
}
|
||
caps
|
||
}
|
||
}
|
||
|
||
/// A loaded model with its tokenizer, device placement, and architecture-
|
||
/// specific weights. The `arch` is `Arc<Mutex<>>` so the lock guard can be
|
||
/// moved into `spawn_blocking` for synchronous candle forward passes.
|
||
pub struct LoadedModel {
|
||
pub model_id: String,
|
||
/// Local (async-side) handle to the model architecture. `Some`
|
||
/// only when the model loaded onto the CPU device (no CUDA
|
||
/// available); the inference path then takes this mutex via
|
||
/// `spawn_blocking` and runs candle ops on the CPU backend.
|
||
/// `None` when the model loaded onto a CUDA device — in that case
|
||
/// the architecture lives in the worker thread's slab and is
|
||
/// addressed via [`Self::arch_handle`].
|
||
pub arch: Option<Arc<Mutex<ModelArch>>>,
|
||
pub tokenizer: Tokenizer,
|
||
pub device: Device,
|
||
pub quant: Option<String>,
|
||
pub devices: Vec<u32>,
|
||
/// Set to `true` after any forward / kv-cache call fails. A CUDA
|
||
/// driver error (OOM, illegal address) leaves the device's context
|
||
/// in an unrecoverable state — subsequent kernels can hang, return
|
||
/// garbage, or hit another illegal address. The harness refuses
|
||
/// further inference against a poisoned model and reports a clear
|
||
/// error so an operator knows to unload+reload to recover. See
|
||
/// the 2026-05-26 beast incident where a 14k-token prefill OOM
|
||
/// silently turned every subsequent request into a stuck wait.
|
||
pub poisoned: AtomicBool,
|
||
/// Handle to the per-device CUDA worker thread for this model's
|
||
/// device. `None` for CPU loads (no context to own). VRAM queries
|
||
/// and — for CUDA loads — forward / kv-cache / drop ops route
|
||
/// through this handle so the device's CUDA context stays bound
|
||
/// to one OS thread for the daemon's lifetime.
|
||
pub worker: Option<Arc<super::device_worker::DeviceWorkerHandle>>,
|
||
/// Index into the worker's `ModelArch` slab. `Some` iff the model
|
||
/// loaded onto a CUDA device and was successfully transferred to
|
||
/// the worker; in that case [`Self::arch`] is `None`. The two
|
||
/// fields are mutually exclusive.
|
||
pub arch_handle: Option<super::device_worker::ArchHandle>,
|
||
/// Serialises chat-completion requests against this model. Held
|
||
/// from the start of `clear_kv_cache` through the last decode
|
||
/// step, so concurrent requests can't interleave their KV-cache
|
||
/// mutations. Without this, two requests' chunked-prefill
|
||
/// `clear → forward(chunk0) → forward(chunk1) → ...` sequences
|
||
/// could end up sharing a cache between them — the device worker
|
||
/// channel serialises individual jobs, but not the sequence
|
||
/// boundary. Observed on benjy 2026-05-27 18:41 when agent-zero's
|
||
/// memorize extensions fired in parallel and produced a
|
||
/// shape-mismatch failure mid-prefill. Mirrors TpLoadedModel.pool
|
||
/// for the TP path (which already had this invariant by accident
|
||
/// because the pool lock covered the same window).
|
||
pub inference_lock: tokio::sync::Mutex<()>,
|
||
/// Open/close token IDs for the reasoning marker this model
|
||
/// emits, populated once at load time by probing the tokenizer's
|
||
/// added-tokens table. `None` for non-reasoning models or
|
||
/// reasoning models whose markers aren't single tokens. When
|
||
/// `Some`, the streaming inference loop splits output into
|
||
/// [`InferenceEvent::TextDelta`] and
|
||
/// [`InferenceEvent::ReasoningDelta`] at the token boundary;
|
||
/// when `None` everything is `TextDelta`.
|
||
pub reasoning_tokens: Option<ReasoningTokenPair>,
|
||
/// Open/close token IDs for the model's tool-call marker
|
||
/// pair (`<tool_call>` / `</tool_call>` on Qwen3-Coder / Hermes
|
||
/// / DeepSeek / gpt-oss). `None` for models that don't emit
|
||
/// structured tool calls in this convention; output passes
|
||
/// through as plain text in that case and the consumer parses
|
||
/// the markers itself if it knows how.
|
||
pub tool_call_tokens: Option<ToolCallTokenPair>,
|
||
/// Raw Jinja `chat_template` string loaded from this model's
|
||
/// `tokenizer_config.json` at load time. `None` when the file
|
||
/// is absent / unparseable / lacks the field. When `Some`,
|
||
/// the prompt-build path renders it through `minijinja` with
|
||
/// `chat_template_kwargs` from the request body; when `None`,
|
||
/// the hardcoded Qwen3 ChatML fallback (`format_qwen3_prompt`)
|
||
/// is used. The `NEURON_USE_CHAT_TEMPLATE=false` env var
|
||
/// forces the fallback path even when `Some`.
|
||
pub chat_template: Option<String>,
|
||
/// Vision capability flag derived at load time. `true` iff the
|
||
/// loaded `ModelArch` exposes a vision tower (Stage A4 wires this
|
||
/// from `Qwen3_5ForCausalLM::has_vision`). Used by the chat
|
||
/// completion handler to reject image content on non-vision
|
||
/// models with a structured 400 (Stage B6) and by `/v1/models`
|
||
/// to advertise `capabilities: ["text", "vision"]` (Stage B7).
|
||
pub has_vision: bool,
|
||
/// `<|image_pad|>` token id from `config.json::image_token_id`.
|
||
/// The Stage B prompt-builder uses this to compute expansion
|
||
/// targets and the worker forward uses it to locate splice
|
||
/// positions in the LM input embeddings.
|
||
pub image_token_id: Option<u32>,
|
||
/// `patch_size × spatial_merge_size` — divides a resized pixel
|
||
/// dimension into LM-grid units. Per-image LM token count is
|
||
/// `(h/factor) × (w/factor)` (#14 dynamic resolution). `None` for
|
||
/// text-only models. Set at load time.
|
||
pub image_grid_factor: Option<usize>,
|
||
/// The spec this model was loaded from — retained so auto-recovery
|
||
/// (#17) can `unload_model` + `load_model(spec)` a poisoned model
|
||
/// without an operator reconstructing it.
|
||
pub spec: ModelSpec,
|
||
}
|
||
|
||
impl LoadedModel {
|
||
/// Free / total VRAM on this model's device in MiB. Routes the
|
||
/// query through the device worker thread (where the CUDA context
|
||
/// is already bound) rather than rebinding on whatever tokio
|
||
/// thread the caller happens to be on. Returns `(0, 0)` on CPU
|
||
/// loads, or if the worker is gone / poisoned / the cudarc call
|
||
/// itself failed — same sentinel the previous `device_vram_mb`
|
||
/// helper returned, so log field values stay comparable.
|
||
pub async fn query_vram(&self) -> (u64, u64) {
|
||
match &self.worker {
|
||
Some(w) => w.query_vram().await.unwrap_or((0, 0)),
|
||
None => (0, 0),
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Tensor-parallel loaded model. Holds the leader's rank-0 shard
|
||
/// (which the inference loop drives via spawn_blocking) and the
|
||
/// `WorkerPool` (which drives every non-zero rank over the RPC
|
||
/// channel). Both are behind tokio Mutexes so concurrent inference
|
||
/// requests against the same model are serialised; concurrent loads
|
||
/// for *different* models would each have their own pool.
|
||
#[cfg(feature = "cuda")]
|
||
pub struct TpLoadedModel {
|
||
pub model_id: String,
|
||
pub tokenizer: Tokenizer,
|
||
pub devices: Vec<u32>,
|
||
/// One end-to-end gate: the pool's RPC stream to the subprocess
|
||
/// workers isn't safe to use concurrently. After Phase 3 the
|
||
/// leader's `TpLeaderModel` lives in the worker thread's slab,
|
||
/// so this Mutex no longer covers the leader's KV cache; it just
|
||
/// serialises subprocess RPC traffic on the pool's
|
||
/// `Vec<Worker>` channels.
|
||
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
|
||
/// Handle into the leader device worker's TP slab. The boxed
|
||
/// `TpLeaderModel` (with its embedded `Arc<Comm>` clones and
|
||
/// per-rank CUDA tensors) lives on the worker thread; we hold an
|
||
/// opaque index. Forward / clear_kv / unload all route through
|
||
/// `Job::Tp*` against this handle.
|
||
pub leader_handle: super::device_worker::TpHandle,
|
||
/// Candle device for rank 0. Mirrors what
|
||
/// `TpLeaderModel::device()` would return, kept on the struct so
|
||
/// the request path can name the device without an RPC.
|
||
pub leader_device: Device,
|
||
/// Same poisoning gate as [`LoadedModel::poisoned`]. A TP forward
|
||
/// failure (CUDA OOM on any rank, NCCL desync, illegal address) is
|
||
/// terminal: the leader's and workers' CUDA contexts cannot be
|
||
/// reliably reset without restarting the worker subprocesses.
|
||
pub poisoned: AtomicBool,
|
||
/// Worker thread for the leader's CUDA device. Owns the leader's
|
||
/// `CudaContext`, `NcclState`, and the boxed `TpLeaderModel`
|
||
/// referenced by `leader_handle`.
|
||
pub worker: Arc<super::device_worker::DeviceWorkerHandle>,
|
||
/// Same shape as [`LoadedModel::reasoning_tokens`] — open/close
|
||
/// reasoning marker token IDs probed from the tokenizer at
|
||
/// load time. `None` when the model declares no reasoning
|
||
/// markers.
|
||
pub reasoning_tokens: Option<ReasoningTokenPair>,
|
||
/// Same shape as [`LoadedModel::tool_call_tokens`].
|
||
pub tool_call_tokens: Option<ToolCallTokenPair>,
|
||
/// Same shape as [`LoadedModel::chat_template`].
|
||
pub chat_template: Option<String>,
|
||
/// Vision capability flag (TP-vision). `true` iff every rank
|
||
/// materialised a replicated vision tower. Mirrors
|
||
/// [`LoadedModel::has_vision`]; drives capability advertising and
|
||
/// the TP vision dispatch.
|
||
pub has_vision: bool,
|
||
/// `<|image_pad|>` token id — same as [`LoadedModel::image_token_id`].
|
||
pub image_token_id: Option<u32>,
|
||
/// Pixel→LM-grid divisor — same as
|
||
/// [`LoadedModel::image_grid_factor`].
|
||
pub image_grid_factor: Option<usize>,
|
||
/// Loading spec, retained for auto-recovery (#17) — see
|
||
/// [`LoadedModel::spec`].
|
||
pub spec: ModelSpec,
|
||
}
|
||
|
||
#[cfg(feature = "cuda")]
|
||
impl TpLoadedModel {
|
||
/// Free / total VRAM on the leader's device in MiB. See
|
||
/// [`LoadedModel::query_vram`] for rationale and sentinel
|
||
/// semantics — same pattern, TP just always has a worker because
|
||
/// the harness rejects TP without CUDA at load time.
|
||
pub async fn query_vram(&self) -> (u64, u64) {
|
||
self.worker.query_vram().await.unwrap_or((0, 0))
|
||
}
|
||
}
|
||
|
||
/// Architecture-specific weights. Each variant covers one (family,
|
||
/// source-format) pair; the dense variants take the safetensors path
|
||
/// and the `Quantized*` variants take the GGUF path.
|
||
///
|
||
/// TP currently only works through `Qwen3Dense` (see `tp_qwen3.rs`);
|
||
/// every other variant is single-GPU. Quantized variants can't shard
|
||
/// across GPUs at all — slicing GGUF super-blocks is intractable —
|
||
/// and the new dense families (Llama, Qwen3 MoE) lack their own
|
||
/// TP-aware modules yet.
|
||
pub enum ModelArch {
|
||
// Qwen3 family
|
||
Qwen3Quantized(QuantizedQwen3Weights),
|
||
Qwen3Dense(qwen3_dense::ModelForCausalLM),
|
||
Qwen3MoeQuantized(GGUFQWenMoE),
|
||
Qwen3MoeDense(qwen3_moe_dense::ModelForCausalLM),
|
||
|
||
// Llama family (covers Llama 1/2/3/3.1/3.3). Boxed because the
|
||
// wrapper carries an inline Cache + Config — without indirection
|
||
// the enum's `LlamaDense` variant is several hundred bytes larger
|
||
// than the others (clippy::large_enum_variant).
|
||
LlamaQuantized(QuantizedLlamaWeights),
|
||
LlamaDense(Box<LlamaDense>),
|
||
|
||
// Qwen3-Next family (model_type "qwen3_5") — Qwen3.6's
|
||
// architecture. Stage 8c scaffolding only: dispatch + config parse
|
||
// are real; forward bails "not implemented yet". See
|
||
// `arch/qwen3_5.rs` for the open architecture work.
|
||
Qwen3_5Dense(super::arch::qwen3_5::Qwen3_5ForCausalLM),
|
||
}
|
||
|
||
impl ModelArch {
|
||
/// One forward step on this arch with the rank-1 vocab logits
|
||
/// extracted. Hides per-family shape differences (some return
|
||
/// `[B, V]`, others `[B, 1, V]`) — every caller gets a `[V]`
|
||
/// tensor ready for sampling.
|
||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||
let raw = match self {
|
||
ModelArch::Qwen3Quantized(m) => m.forward(input, offset)?,
|
||
ModelArch::Qwen3Dense(m) => m.forward(input, offset)?,
|
||
ModelArch::Qwen3MoeQuantized(m) => m.forward(input, offset)?,
|
||
ModelArch::Qwen3MoeDense(m) => m.forward(input, offset)?,
|
||
ModelArch::LlamaQuantized(m) => m.forward(input, offset)?,
|
||
ModelArch::LlamaDense(m) => m.forward(input, offset)?,
|
||
ModelArch::Qwen3_5Dense(m) => m.forward(input, offset)?,
|
||
};
|
||
squeeze_to_vocab(&raw)
|
||
}
|
||
|
||
/// Reset the KV cache before each new request so we don't attend
|
||
/// over a previous request's tokens. Some architectures have an
|
||
/// in-place reset; Llama needs a Cache rebuild (held inline in
|
||
/// the wrapper).
|
||
pub fn clear_kv_cache(&mut self) -> Result<()> {
|
||
match self {
|
||
ModelArch::Qwen3Quantized(_) => Ok(()), /* keeps cache by design;
|
||
* forward() handles offset */
|
||
ModelArch::Qwen3Dense(m) => {
|
||
m.clear_kv_cache();
|
||
Ok(())
|
||
}
|
||
ModelArch::Qwen3MoeQuantized(_) => Ok(()),
|
||
ModelArch::Qwen3MoeDense(m) => {
|
||
m.clear_kv_cache();
|
||
Ok(())
|
||
}
|
||
ModelArch::LlamaQuantized(_) => Ok(()),
|
||
ModelArch::LlamaDense(m) => m.clear_kv_cache(),
|
||
ModelArch::Qwen3_5Dense(m) => {
|
||
m.clear_kv_cache();
|
||
Ok(())
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Forward step that splices vision-tower output at
|
||
/// `<|image_pad|>` token positions. Stage B2.
|
||
///
|
||
/// Only `Qwen3_5Dense` supports this — other architectures error
|
||
/// because they don't have a vision tower. The HTTP layer is
|
||
/// expected to have rejected image content for non-vision models
|
||
/// already (Stage B6); this is a defence-in-depth error path.
|
||
///
|
||
/// Returns rank-1 `[vocab_size]` logits, same shape contract as
|
||
/// `forward`.
|
||
pub fn forward_with_vision(
|
||
&mut self,
|
||
input: &Tensor,
|
||
offset: usize,
|
||
image_embeds: &Tensor,
|
||
image_token_id: u32,
|
||
grids: &[(usize, usize)],
|
||
) -> Result<Tensor> {
|
||
let raw = match self {
|
||
ModelArch::Qwen3_5Dense(m) => {
|
||
m.forward_with_vision(input, offset, image_embeds, image_token_id, grids)?
|
||
}
|
||
other => anyhow::bail!(
|
||
"forward_with_vision: architecture {} has no vision tower",
|
||
std::any::type_name_of_val(other)
|
||
),
|
||
};
|
||
squeeze_to_vocab(&raw)
|
||
}
|
||
|
||
/// `patch_size × spatial_merge_size` for the loaded vision tower —
|
||
/// divides a resized pixel dim into LM-grid units (an image of
|
||
/// resized `(h, w)` yields the LM grid `(h/factor, w/factor)`).
|
||
/// `None` for architectures/checkpoints without a vision tower.
|
||
pub fn vision_grid_factor(&self) -> Option<usize> {
|
||
match self {
|
||
ModelArch::Qwen3_5Dense(m) => m.vision().map(|v| {
|
||
let c = v.config();
|
||
c.patch_size * c.spatial_merge_size
|
||
}),
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
/// Encode a preprocessed image into LM-side token embeddings via
|
||
/// the loaded vision tower. Stage A5.
|
||
///
|
||
/// `image`: device-resident `(C, H, W)` f32 tensor — caller has
|
||
/// already preprocessed via `harness::preprocess::preprocess` and
|
||
/// uploaded to the worker's device. Returns
|
||
/// `(N_lm_tokens, hidden_size)`.
|
||
///
|
||
/// Errors when the loaded architecture has no vision tower
|
||
/// (text-only checkpoint, or architecture that doesn't support
|
||
/// vision at all). The HTTP layer maps this to a 400 with
|
||
/// `vision_unsupported` so clients see a clean rejection rather
|
||
/// than a confident text-only hallucination.
|
||
pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
|
||
match self {
|
||
ModelArch::Qwen3_5Dense(m) => m
|
||
.vision()
|
||
.ok_or_else(|| {
|
||
anyhow::anyhow!(
|
||
"encode_image: this Qwen3.6 checkpoint was loaded without a vision \
|
||
tower (config.json::vision_config absent or weights missing)"
|
||
)
|
||
})?
|
||
.forward(image),
|
||
other => anyhow::bail!(
|
||
"encode_image: architecture {} has no vision tower",
|
||
std::any::type_name_of_val(other)
|
||
),
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Squeeze any leading singleton dims off the logits tensor so the
|
||
/// caller gets a rank-1 `[vocab_size]` slice ready for sampling. Bails
|
||
/// on a non-singleton leading dim (would mean a batched forward, which
|
||
/// no caller emits today).
|
||
fn squeeze_to_vocab(t: &Tensor) -> Result<Tensor> {
|
||
let mut t = t.clone();
|
||
while t.dims().len() > 1 {
|
||
if t.dims()[0] != 1 {
|
||
anyhow::bail!(
|
||
"logits expected to start with a singleton dim, got shape {:?}",
|
||
t.dims()
|
||
);
|
||
}
|
||
t = t.squeeze(0)?;
|
||
}
|
||
Ok(t)
|
||
}
|
||
|
||
/// Llama dense wrapper. Bundles candle's `Llama` model with its
|
||
/// externally-managed `Cache` plus enough config to rebuild the
|
||
/// cache on `clear_kv_cache` (Llama's Cache doesn't expose a reset).
|
||
pub struct LlamaDense {
|
||
model: llama_dense::Llama,
|
||
cache: llama_dense::Cache,
|
||
config: llama_dense::Config,
|
||
dtype: DType,
|
||
device: Device,
|
||
}
|
||
|
||
impl LlamaDense {
|
||
/// Constructor used by the dispatch-side loader. Keeps the field
|
||
/// names private while letting the worker thread build a
|
||
/// `LlamaDense` from already-loaded weights without going through
|
||
/// async candle code.
|
||
pub(crate) fn from_parts(
|
||
model: llama_dense::Llama,
|
||
cache: llama_dense::Cache,
|
||
config: llama_dense::Config,
|
||
dtype: DType,
|
||
device: Device,
|
||
) -> Self {
|
||
Self {
|
||
model,
|
||
cache,
|
||
config,
|
||
dtype,
|
||
device,
|
||
}
|
||
}
|
||
|
||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||
Ok(self.model.forward(input, offset, &mut self.cache)?)
|
||
}
|
||
|
||
pub fn clear_kv_cache(&mut self) -> Result<()> {
|
||
self.cache = llama_dense::Cache::new(true, self.dtype, &self.config, &self.device)
|
||
.context("rebuild Llama Cache for new request")?;
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
/// Repetition penalty applied to recently-generated tokens before
|
||
/// sampling. 1.0 disables it; >1.0 makes recently-emitted tokens less
|
||
/// likely. mistral.rs and llama.cpp default to 1.1, which is enough to
|
||
/// stop small quantized models from degenerating into "Wait, no, no..."
|
||
/// loops without distorting normal output.
|
||
const REPEAT_PENALTY: f32 = 1.1;
|
||
|
||
/// Number of recently-generated tokens to feed into the repetition
|
||
/// penalty. Matches the candle quantized-qwen3 example default.
|
||
const REPEAT_LAST_N: usize = 64;
|
||
|
||
/// Architectures the dense safetensors path can construct. Keep
|
||
/// alphabetical; one entry per supported `config.json#/model_type`
|
||
/// value. New entries land alongside a new `ModelArch` variant + a
|
||
/// dispatch branch in `load_arch_dense` (plus, for TP, a parallel
|
||
/// pattern in `tp_qwen3.rs`).
|
||
const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["llama", "qwen3", "qwen3_5", "qwen3_moe"];
|
||
|
||
/// Pre-flight check the operator's `config.json` against the set of
|
||
/// architectures the dense path actually knows how to build. Surfaces
|
||
/// architecture mismatches as a single clean error before the serde
|
||
/// deserializer trips on missing fields — the latter happens because
|
||
/// every architecture has different hyperparameter names, so when the
|
||
/// JSON is e.g. Qwen3.6 wrapped under `text_config: {...}`, candle's
|
||
/// `qwen3::Config` finds none of its expected top-level fields and
|
||
/// fails with a cryptic `missing field 'vocab_size' at line N col 1`.
|
||
///
|
||
/// The result message names the model_type we saw, the supported set,
|
||
/// and points at the files an operator (or future contributor) needs
|
||
/// to touch to grow the supported set.
|
||
pub(crate) fn check_dense_config_supported(config_json: &str, model_id: &str) -> Result<()> {
|
||
let v: serde_json::Value = serde_json::from_str(config_json)
|
||
.with_context(|| format!("parse config.json for '{model_id}' as JSON"))?;
|
||
let model_type = v.get("model_type").and_then(|x| x.as_str()).unwrap_or("");
|
||
if model_type.is_empty() {
|
||
anyhow::bail!(
|
||
"config.json for '{model_id}' is missing `model_type`; the dense \
|
||
path needs it to gate architecture support (supported: {:?})",
|
||
DENSE_SUPPORTED_MODEL_TYPES
|
||
);
|
||
}
|
||
if DENSE_SUPPORTED_MODEL_TYPES.contains(&model_type) {
|
||
return Ok(());
|
||
}
|
||
// Bonus context: the model usually also lists architectures, which
|
||
// is what `transformers` keys on. Including it makes the error
|
||
// self-contained.
|
||
let architectures = v
|
||
.get("architectures")
|
||
.and_then(|x| x.as_array())
|
||
.map(|a| {
|
||
a.iter()
|
||
.filter_map(|v| v.as_str().map(String::from))
|
||
.collect::<Vec<_>>()
|
||
})
|
||
.unwrap_or_default();
|
||
anyhow::bail!(
|
||
"unsupported model_type '{model_type}' for '{model_id}' \
|
||
(architectures={architectures:?}); the dense path supports {:?}. \
|
||
Add a `ModelArch` variant + load/forward branches in \
|
||
crates/neuron/src/harness/candle.rs (and the TP analogue in \
|
||
tp_qwen3.rs) to extend coverage.",
|
||
DENSE_SUPPORTED_MODEL_TYPES
|
||
);
|
||
}
|
||
|
||
/// Architectures the TP path can actually load and run. A subset of
|
||
/// `DENSE_SUPPORTED_MODEL_TYPES` — the single-GPU path supports more
|
||
/// families than the TP path because each TP-aware module is a real
|
||
/// chunk of work (`tp_qwen3.rs` is the only one shipped today).
|
||
#[cfg(feature = "cuda")]
|
||
const TP_SUPPORTED_MODEL_TYPES: &[&str] = &["qwen3", "qwen3_5"];
|
||
|
||
/// TP-side counterpart to `check_dense_config_supported`. Gates the
|
||
/// `load_tp` path on a narrower architecture set: even though the
|
||
/// single-GPU dense path knows how to build a Llama model, the worker
|
||
/// pool's `load_dense_shard` reconstructs the config as Qwen3 — there
|
||
/// is no `tp_llama.rs` yet. Surfacing this as a config-time error
|
||
/// (before we spawn workers and burn NCCL handshake cost) is much
|
||
/// kinder than the inevitable per-rank deserialise failure.
|
||
#[cfg(feature = "cuda")]
|
||
fn check_tp_arch_supported(config_json: &str, model_id: &str) -> Result<()> {
|
||
let v: serde_json::Value = serde_json::from_str(config_json)
|
||
.with_context(|| format!("parse config.json for '{model_id}' as JSON"))?;
|
||
let model_type = v.get("model_type").and_then(|x| x.as_str()).unwrap_or("");
|
||
if TP_SUPPORTED_MODEL_TYPES.contains(&model_type) {
|
||
return Ok(());
|
||
}
|
||
anyhow::bail!(
|
||
"tensor_parallel requested for '{model_id}' (model_type='{model_type}') but \
|
||
the TP path supports only {TP_SUPPORTED_MODEL_TYPES:?}. Adding a new \
|
||
TP-aware architecture needs a `harness/tp/tp_<family>.rs` module mirroring \
|
||
`tp_qwen3.rs` (sharded linears, AllReduce, per-rank head counts) and a \
|
||
dispatch in `WorkerPool::load_dense_shard`. For models that fit on one \
|
||
GPU, drop `tensor_parallel` to use the single-GPU dense path."
|
||
)
|
||
}
|
||
|
||
/// Resolve the effective HuggingFace cache directory for the candle
|
||
/// harness. Precedence (first hit wins):
|
||
///
|
||
/// 1. Explicit `hf_cache` from `[harness.candle]` in `neuron.toml`.
|
||
/// Operator's wishes always win.
|
||
/// 2. `HF_HUB_CACHE` env var. The Python `huggingface_hub` library
|
||
/// points at the cache root directly with this var; the Rust
|
||
/// `hf-hub` crate doesn't read it natively, so we bridge here.
|
||
/// Honouring it lets a neuron host share a cache directory with
|
||
/// Python tooling and other harnesses without per-tool config.
|
||
/// 3. `HF_HOME` env var. Canonical HuggingFace base directory; the
|
||
/// cache lives at `$HF_HOME/hub`. Hf-hub respects this on its own,
|
||
/// but we resolve it here too so the resulting path shows up in
|
||
/// logs alongside the explicit/HF_HUB_CACHE cases.
|
||
/// 4. `None`. Falls through to `hf-hub`'s default
|
||
/// (`~/.cache/huggingface/hub`).
|
||
fn resolve_hf_cache(explicit: Option<PathBuf>) -> Option<PathBuf> {
|
||
if let Some(p) = explicit {
|
||
return Some(p);
|
||
}
|
||
if let Ok(v) = std::env::var("HF_HUB_CACHE")
|
||
&& !v.is_empty()
|
||
{
|
||
return Some(PathBuf::from(v));
|
||
}
|
||
if let Ok(v) = std::env::var("HF_HOME")
|
||
&& !v.is_empty()
|
||
{
|
||
return Some(PathBuf::from(v).join("hub"));
|
||
}
|
||
None
|
||
}
|
||
|
||
/// Summary stats over a 1-D logits tensor, used for the failure log
|
||
/// when sampling rejects the distribution. Gathers nan/inf/negative
|
||
/// counts and finite min/max/mean — enough to distinguish a NaN
|
||
/// cascade (all-NaN, typical of softmax overflow propagating) from
|
||
/// an Inf at a single position (numerical edge case) from negative
|
||
/// weights (different bug entirely).
|
||
///
|
||
/// Computed only on the failure path, so the to_vec1 copy cost is
|
||
/// paid at most once per poisoned model.
|
||
#[derive(Debug)]
|
||
#[allow(dead_code)]
|
||
struct LogitsHealth {
|
||
len: usize,
|
||
nan: usize,
|
||
pos_inf: usize,
|
||
neg_inf: usize,
|
||
neg: usize,
|
||
finite_min: Option<f32>,
|
||
finite_max: Option<f32>,
|
||
finite_mean: Option<f32>,
|
||
}
|
||
|
||
#[allow(dead_code)]
|
||
fn logits_health(t: &Tensor) -> LogitsHealth {
|
||
let values: Vec<f32> = match t
|
||
.to_dtype(candle_core::DType::F32)
|
||
.and_then(|t| t.flatten_all())
|
||
.and_then(|t| t.to_vec1::<f32>())
|
||
{
|
||
Ok(v) => v,
|
||
Err(_) => {
|
||
return LogitsHealth {
|
||
len: 0,
|
||
nan: 0,
|
||
pos_inf: 0,
|
||
neg_inf: 0,
|
||
neg: 0,
|
||
finite_min: None,
|
||
finite_max: None,
|
||
finite_mean: None,
|
||
};
|
||
}
|
||
};
|
||
logits_health_slice(&values)
|
||
}
|
||
|
||
/// Same diagnostic as [`logits_health`] but operates directly on a
|
||
/// `[f32]` slice. Used by the worker-routed inference paths where the
|
||
/// device → host copy has already happened on the worker thread and
|
||
/// the async caller has the values in hand. Avoids the round-trip of
|
||
/// rebuilding a Tensor just to call to_vec1 again.
|
||
#[allow(dead_code)]
|
||
fn logits_health_slice(values: &[f32]) -> LogitsHealth {
|
||
let mut nan = 0usize;
|
||
let mut pos_inf = 0usize;
|
||
let mut neg_inf = 0usize;
|
||
let mut neg = 0usize;
|
||
let mut finite_min = f32::INFINITY;
|
||
let mut finite_max = f32::NEG_INFINITY;
|
||
let mut finite_sum = 0.0_f64;
|
||
let mut finite_count = 0usize;
|
||
for &v in values {
|
||
if v.is_nan() {
|
||
nan += 1;
|
||
} else if v == f32::INFINITY {
|
||
pos_inf += 1;
|
||
} else if v == f32::NEG_INFINITY {
|
||
neg_inf += 1;
|
||
} else {
|
||
if v < 0.0 {
|
||
neg += 1;
|
||
}
|
||
if v < finite_min {
|
||
finite_min = v;
|
||
}
|
||
if v > finite_max {
|
||
finite_max = v;
|
||
}
|
||
finite_sum += v as f64;
|
||
finite_count += 1;
|
||
}
|
||
}
|
||
let finite_mean = if finite_count > 0 {
|
||
Some((finite_sum / finite_count as f64) as f32)
|
||
} else {
|
||
None
|
||
};
|
||
LogitsHealth {
|
||
len: values.len(),
|
||
nan,
|
||
pos_inf,
|
||
neg_inf,
|
||
neg,
|
||
finite_min: (finite_count > 0).then_some(finite_min),
|
||
finite_max: (finite_count > 0).then_some(finite_max),
|
||
finite_mean,
|
||
}
|
||
}
|
||
|
||
/// Classify an inference-failure error string: should we mark the
|
||
/// model poisoned, or is this a logic / numerical / tokenizer failure
|
||
/// that leaves the device context healthy? Default is "yes, poison" —
|
||
/// the cost of failing to poison a genuinely-corrupt context (next
|
||
/// request hangs or returns garbage) outweighs the cost of
|
||
/// over-poisoning (operator unload+reloads). The opt-out list covers
|
||
/// errors we know don't touch device state.
|
||
///
|
||
/// Pass the `format!("{err:#}")` rendering of an anyhow::Error (or the
|
||
/// already-stringified error in paths that stringify failures, like
|
||
/// the TP streaming task). Matching against the full chain lets the
|
||
/// classification survive `.context("…")` and `format!("…: {e}")`
|
||
/// wrappers in the call sites.
|
||
fn is_device_fault(chain_text: &str) -> bool {
|
||
let chain = chain_text.to_lowercase();
|
||
// Non-device patterns: shape errors are pre-kernel and don't touch
|
||
// GPU state; NaN-logits failures happen on the CPU side after the
|
||
// forward; tokenize/detokenize is pure CPU; missing-handle lookups
|
||
// are pre-dispatch. Everything else we treat conservatively as a
|
||
// potential device fault.
|
||
let non_device_markers = [
|
||
"shape mismatch",
|
||
"broadcast",
|
||
"cannot broadcast",
|
||
"logits unhealthy",
|
||
"tokenize",
|
||
"detokenize",
|
||
"decode_stream",
|
||
"no model for handle",
|
||
"no tp model for handle",
|
||
"empty prompt",
|
||
];
|
||
!non_device_markers.iter().any(|m| chain.contains(m))
|
||
}
|
||
|
||
/// Build the InferenceError reported to a client when their request
|
||
/// hits a model that's been marked poisoned by an earlier driver
|
||
/// failure. The message names the model and the recovery procedure so
|
||
/// the operator doesn't have to chase the original failure to know
|
||
/// what to do.
|
||
fn poisoned_error(model_id: &str) -> InferenceError {
|
||
InferenceError::Other(anyhow::anyhow!(
|
||
"model '{model_id}' is in a poisoned state \
|
||
(an earlier inference hit a CUDA driver error and the device \
|
||
context cannot be safely reused); unload and reload the model \
|
||
to recover"
|
||
))
|
||
}
|
||
|
||
/// Reported while auto-recovery (#17) is rebuilding a poisoned model's
|
||
/// context. Unlike [`poisoned_error`] this is a *transient* state — the
|
||
/// model is being reloaded automatically; the client should retry.
|
||
fn recovering_error(model_id: &str) -> InferenceError {
|
||
InferenceError::Other(anyhow::anyhow!(
|
||
"model '{model_id}' is recovering (its device context was poisoned \
|
||
by an earlier failure and is being automatically rebuilt); retry \
|
||
shortly"
|
||
))
|
||
}
|
||
|
||
/// Verification hook for #17 auto-recovery. When `NEURON_DEBUG_POISON`
|
||
/// names a model, the **first** request for it (process-wide) returns
|
||
/// true, so the request path can trigger recovery as if a device fault
|
||
/// had occurred — exercising the unload→reload→healthy cycle without
|
||
/// corrupting the GPU. One-shot (a `swap` latch) so it can't loop the
|
||
/// model through endless recoveries. No-op unless the env var is set.
|
||
fn debug_poison_armed(model_id: &str) -> bool {
|
||
static FIRED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
||
let armed = std::env::var("NEURON_DEBUG_POISON").ok().as_deref() == Some(model_id);
|
||
armed && !FIRED.swap(true, Ordering::Relaxed)
|
||
}
|
||
|
||
/// Background auto-recovery task (#17). Drains poisoned model ids and
|
||
/// rebuilds each via [`CandleHarness::recover_one`]. Holds a `Weak` so a
|
||
/// shutting-down harness lets the task exit; processes one id at a time,
|
||
/// which (with the `recovering` set deduping enqueues) keeps recovery
|
||
/// single-flight per model.
|
||
async fn recovery_loop(
|
||
weak: std::sync::Weak<CandleHarness>,
|
||
mut rx: tokio::sync::mpsc::UnboundedReceiver<String>,
|
||
) {
|
||
while let Some(model_id) = rx.recv().await {
|
||
let Some(this) = weak.upgrade() else {
|
||
break;
|
||
};
|
||
this.recover_one(&model_id).await;
|
||
}
|
||
}
|
||
|
||
/// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if
|
||
/// the query fails or the device is the CPU fallback so logging never
|
||
/// crashes the request path. Mirrors the existing helper in
|
||
/// `tp_qwen3_5.rs`; kept separate to avoid coupling the inference path
|
||
/// to the TP-specific module.
|
||
#[cfg(feature = "cuda")]
|
||
fn device_vram_mb(device: &Device) -> (u64, u64) {
|
||
use candle_core::cuda::cudarc::driver::result;
|
||
use candle_core::cuda_backend::WrapErr;
|
||
let Device::Cuda(dev) = device else {
|
||
return (0, 0);
|
||
};
|
||
let Ok(()) = dev.cuda_stream().context().bind_to_thread().w() else {
|
||
return (0, 0);
|
||
};
|
||
match result::mem_get_info() {
|
||
Ok((free, total)) => (
|
||
(free / (1024 * 1024)) as u64,
|
||
(total / (1024 * 1024)) as u64,
|
||
),
|
||
Err(_) => (0, 0),
|
||
}
|
||
}
|
||
|
||
#[cfg(not(feature = "cuda"))]
|
||
#[allow(dead_code)]
|
||
fn device_vram_mb(_device: &Device) -> (u64, u64) {
|
||
(0, 0)
|
||
}
|
||
|
||
/// A short hex tag used to group every log line emitted on behalf of
|
||
/// one chat-completion request. Six hex digits is unique enough across
|
||
/// a 4-hour journal window (24 bits ≈ 16M values, while a busy neuron
|
||
/// sees ~10³ requests/hour) and fits cleanly inside `req_id=…` in the
|
||
/// fmt subscriber's span-prefix output.
|
||
fn new_req_id() -> String {
|
||
format!("{:06x}", unix_subsec_nanos() & 0xFFFFFF)
|
||
}
|
||
|
||
/// Read a positive `usize` from `name` in the process env, falling back
|
||
/// to `default` if unset or unparseable. Used for runtime tuning knobs
|
||
/// that we want operators to be able to adjust without a recompile.
|
||
fn env_usize(name: &str, default: usize) -> usize {
|
||
std::env::var(name)
|
||
.ok()
|
||
.and_then(|s| s.parse().ok())
|
||
.filter(|v: &usize| *v > 0)
|
||
.unwrap_or(default)
|
||
}
|
||
|
||
/// Same as [`env_usize`] but for `u64`.
|
||
fn env_u64(name: &str, default: u64) -> u64 {
|
||
std::env::var(name)
|
||
.ok()
|
||
.and_then(|s| s.parse().ok())
|
||
.unwrap_or(default)
|
||
}
|
||
|
||
/// Prefill chunk size in tokens. The initial forward over a long prompt
|
||
/// is split into windows of this many tokens, each with a monotonically
|
||
/// growing offset, so activation memory is bounded by chunk × layers ×
|
||
/// hidden instead of prompt × layers × hidden. The default (512) keeps
|
||
/// activation peaks under ~1 GiB on a 27B Qwen-class model while
|
||
/// keeping the per-step overhead negligible vs. one big prefill.
|
||
fn prefill_chunk_tokens() -> usize {
|
||
env_usize("NEURON_PREFILL_CHUNK_TOKENS", 512)
|
||
}
|
||
|
||
/// Maximum allowed prompt length, in tokens. Requests above this are
|
||
/// rejected with [`InferenceError::PromptTooLong`] before any device
|
||
/// work — this is the explicit upper bound on context size, separate
|
||
/// from the model's `max_position_embeddings` (which can be much
|
||
/// larger than what fits in VRAM in practice).
|
||
fn max_prompt_tokens() -> usize {
|
||
env_usize("NEURON_MAX_PROMPT_TOKENS", 16384)
|
||
}
|
||
|
||
/// Minimum free VRAM (MiB) required to even attempt a prefill. Requests
|
||
/// below this are rejected with [`InferenceError::InsufficientVram`]
|
||
/// before any device work. Acts as a backstop when concurrent requests
|
||
/// have eaten the headroom; intentionally conservative — a request
|
||
/// that gets past this can still OOM, but the rejection is a clean 503
|
||
/// rather than a poisoned context.
|
||
fn min_free_vram_mb() -> u64 {
|
||
env_u64("NEURON_MIN_FREE_VRAM_MB", 1500)
|
||
}
|
||
|
||
/// Pre-flight check: reject the request if the prompt exceeds the
|
||
/// configured max, or if there isn't enough free VRAM to safely start a
|
||
/// prefill. Called from every chat_completion entry point right after
|
||
/// the VRAM query. A `prompt_len == 0` is accepted (some clients send
|
||
/// empty inputs to probe the endpoint); the prefill loop handles it.
|
||
/// Rough MiB of VRAM a vision prefill needs per 1000 prompt tokens
|
||
/// (accumulating KV cache + per-chunk activation headroom). Tunable;
|
||
/// the default is deliberately permissive so the guard rejects only
|
||
/// clearly-too-large requests, not ones the chunked prefill handles.
|
||
fn vision_prefill_mb_per_1k_tokens() -> u64 {
|
||
env_u64("NEURON_VISION_PREFILL_MB_PER_1K_TOKENS", 500)
|
||
}
|
||
|
||
/// Fixed VRAM overhead (MiB) a vision prefill reserves on top of the
|
||
/// per-token estimate — image encode buffers + one chunk's activations.
|
||
fn vision_prefill_base_mb() -> u64 {
|
||
env_u64("NEURON_VISION_PREFILL_BASE_MB", 2000)
|
||
}
|
||
|
||
/// Pre-flight check specific to vision prefills. Even with the chunked
|
||
/// prefill bounding per-step activation, the accumulating KV cache for
|
||
/// a long prompt can exhaust VRAM mid-forward — and on the TP path a
|
||
/// mid-forward OOM strands the NCCL collective (one rank dies, the other
|
||
/// hangs on the all-reduce, holding the pool lock). Reject up front with
|
||
/// a clean `InsufficientVram` when the estimated footprint exceeds free
|
||
/// VRAM, so a doomed request fails fast instead of hanging the daemon.
|
||
///
|
||
/// Heuristic and tunable (`NEURON_VISION_PREFILL_*`); the default errs
|
||
/// permissive. Skipped on the CPU sentinel (`vram_free_mb == 0`).
|
||
fn validate_vision_prefill(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> {
|
||
if vram_free_mb == 0 {
|
||
return Ok(());
|
||
}
|
||
let required_mb = vision_prefill_base_mb()
|
||
+ (prompt_len as u64).saturating_mul(vision_prefill_mb_per_1k_tokens()) / 1000;
|
||
if required_mb > vram_free_mb {
|
||
return Err(InferenceError::InsufficientVram {
|
||
free_mb: vram_free_mb,
|
||
required_mb,
|
||
});
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
fn validate_request(prompt_len: usize, vram_free_mb: u64) -> Result<(), InferenceError> {
|
||
let max = max_prompt_tokens();
|
||
if prompt_len > max {
|
||
return Err(InferenceError::PromptTooLong { prompt_len, max });
|
||
}
|
||
// VRAM check is skipped on CPU loads (vram_free_mb == 0 sentinel)
|
||
// because the (0, 0) reply from `query_vram` is also what a missing
|
||
// worker returns. The CPU path has no per-GPU memory limit anyway —
|
||
// host RAM is bounded by the OOM killer, not this check.
|
||
let min = min_free_vram_mb();
|
||
if vram_free_mb != 0 && vram_free_mb < min {
|
||
return Err(InferenceError::InsufficientVram {
|
||
free_mb: vram_free_mb,
|
||
required_mb: min,
|
||
});
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
/// Threshold above which `pool.lock().await` blocking is interesting
|
||
/// enough to warn about. Healthy concurrent requests serialise behind
|
||
/// the pool in single-digit ms — anything past 2 seconds is either a
|
||
/// huge in-flight prompt or, more often, a stuck request holding the
|
||
/// lock against a poisoned CUDA context. See the 2026-05-26 4-hour
|
||
/// silence on beast where dozens of requests piled up invisibly here.
|
||
#[cfg(feature = "cuda")]
|
||
const POOL_LOCK_WARN_THRESHOLD: Duration = Duration::from_secs(2);
|
||
|
||
/// Acquire the TP pool lock, emitting a warn-level breadcrumb if the
|
||
/// wait exceeds [`POOL_LOCK_WARN_THRESHOLD`]. Wrapped in a helper so
|
||
/// the warn happens at the call site — the request whose lock-wait is
|
||
/// slow is the one that knows its prompt_len and other context.
|
||
#[cfg(feature = "cuda")]
|
||
async fn acquire_pool_lock<'a>(
|
||
pool: &'a tokio::sync::Mutex<super::tp::WorkerPool>,
|
||
model_id: &str,
|
||
) -> tokio::sync::MutexGuard<'a, super::tp::WorkerPool> {
|
||
let start = std::time::Instant::now();
|
||
// Tick once at the threshold so a stuck request shows up in
|
||
// journalctl even while it's still waiting. Without this the wait
|
||
// looks like silence in the log right up until the lock is freed.
|
||
tokio::pin! {
|
||
let lock = pool.lock();
|
||
}
|
||
loop {
|
||
tokio::select! {
|
||
guard = &mut lock => {
|
||
let elapsed = start.elapsed();
|
||
if elapsed >= POOL_LOCK_WARN_THRESHOLD {
|
||
tracing::warn!(
|
||
model = %model_id,
|
||
waited_ms = elapsed.as_millis(),
|
||
"TP chat_completion: pool lock acquired after long wait"
|
||
);
|
||
}
|
||
return guard;
|
||
}
|
||
_ = tokio::time::sleep(POOL_LOCK_WARN_THRESHOLD) => {
|
||
tracing::warn!(
|
||
model = %model_id,
|
||
waited_ms = start.elapsed().as_millis(),
|
||
"TP chat_completion: still waiting on pool lock"
|
||
);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Apply the repetition penalty (if any) to the prediction logits and
|
||
/// then sample. Centralises the prefill / generation-loop call sites
|
||
/// so they share identical sampling behaviour.
|
||
fn sample_with_penalty(
|
||
logits: &Tensor,
|
||
history: &[u32],
|
||
logits_processor: &mut LogitsProcessor,
|
||
) -> Result<u32> {
|
||
let penalised = if (REPEAT_PENALTY - 1.0).abs() < f32::EPSILON || history.is_empty() {
|
||
logits.clone()
|
||
} else {
|
||
let start = history.len().saturating_sub(REPEAT_LAST_N);
|
||
candle_transformers::utils::apply_repeat_penalty(logits, REPEAT_PENALTY, &history[start..])?
|
||
};
|
||
Ok(logits_processor.sample(&penalised)?)
|
||
}
|
||
|
||
/// Chunked prefill against an in-process [`ModelArch`]. Splits
|
||
/// `prompt_tokens` into [`prefill_chunk_tokens()`]-sized windows, runs
|
||
/// each through `arch.forward(chunk, offset)` with a monotonically
|
||
/// growing offset, and returns the last chunk's logits ready for
|
||
/// sampling. Bounds activation memory to O(chunk × layers × hidden)
|
||
/// instead of O(prompt × layers × hidden); the KV cache grows
|
||
/// monotonically so the model sees the full prompt at the final chunk.
|
||
fn chunked_prefill_local(
|
||
arch: &mut ModelArch,
|
||
device: &Device,
|
||
prompt_tokens: &[u32],
|
||
) -> Result<Tensor> {
|
||
let prompt_len = prompt_tokens.len();
|
||
if prompt_len == 0 {
|
||
anyhow::bail!("chunked_prefill_local: empty prompt");
|
||
}
|
||
let chunk_size = prefill_chunk_tokens();
|
||
let mut offset = 0;
|
||
let mut last_logits: Option<Tensor> = None;
|
||
while offset < prompt_len {
|
||
let end = (offset + chunk_size).min(prompt_len);
|
||
let chunk = &prompt_tokens[offset..end];
|
||
let input = Tensor::new(chunk, device)?.unsqueeze(0)?;
|
||
let logits = arch.forward(&input, offset)?;
|
||
if end == prompt_len {
|
||
last_logits = Some(logits);
|
||
}
|
||
offset = end;
|
||
}
|
||
last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_local: no chunks produced"))
|
||
}
|
||
|
||
/// Chunked prefill via the per-device worker. Same shape as
|
||
/// [`chunked_prefill_local`] but the forward runs on the worker thread
|
||
/// and replies with a CPU-side `Vec<f32>` of logits at the final
|
||
/// chunk's last position. Tensors never escape the worker.
|
||
#[cfg(feature = "cuda")]
|
||
async fn chunked_prefill_via_worker(
|
||
worker: &super::device_worker::DeviceWorkerHandle,
|
||
handle: super::device_worker::ArchHandle,
|
||
prompt_tokens: &[u32],
|
||
) -> Result<Vec<f32>> {
|
||
let prompt_len = prompt_tokens.len();
|
||
if prompt_len == 0 {
|
||
anyhow::bail!("chunked_prefill_via_worker: empty prompt");
|
||
}
|
||
let chunk_size = prefill_chunk_tokens();
|
||
let mut offset = 0;
|
||
let mut last_logits: Option<Vec<f32>> = None;
|
||
let total_chunks = prompt_len.div_ceil(chunk_size);
|
||
let mut chunk_idx = 0_usize;
|
||
while offset < prompt_len {
|
||
let end = (offset + chunk_size).min(prompt_len);
|
||
let chunk = prompt_tokens[offset..end].to_vec();
|
||
let chunk_len = chunk.len();
|
||
let step_start = std::time::Instant::now();
|
||
let logits = worker
|
||
.forward_logits(handle, chunk, offset)
|
||
.await
|
||
.map_err(|e| anyhow::anyhow!("prefill chunk {chunk_idx}/{total_chunks}: {e}"))?;
|
||
tracing::debug!(
|
||
chunk_idx,
|
||
total_chunks,
|
||
chunk_len,
|
||
offset,
|
||
elapsed_ms = step_start.elapsed().as_millis(),
|
||
"chunked prefill (worker): chunk done"
|
||
);
|
||
if end == prompt_len {
|
||
last_logits = Some(logits);
|
||
}
|
||
offset = end;
|
||
chunk_idx += 1;
|
||
}
|
||
last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_via_worker: no chunks produced"))
|
||
}
|
||
|
||
/// Chunked prefill via the TP `WorkerPool`. Same shape as
|
||
/// [`chunked_prefill_via_worker`] but the forward fans out to every
|
||
/// rank via `pool.generate_step`. Returns the leader's CPU-side
|
||
/// `Vec<f32>` of logits at the final chunk's last position.
|
||
#[cfg(feature = "cuda")]
|
||
async fn chunked_prefill_tp(
|
||
pool: &mut super::tp::WorkerPool,
|
||
model_id: &str,
|
||
leader_handle: super::device_worker::TpHandle,
|
||
prompt_tokens: &[u32],
|
||
) -> Result<Vec<f32>> {
|
||
let prompt_len = prompt_tokens.len();
|
||
if prompt_len == 0 {
|
||
anyhow::bail!("chunked_prefill_tp: empty prompt");
|
||
}
|
||
let chunk_size = prefill_chunk_tokens();
|
||
let mut offset = 0;
|
||
let mut last_logits: Option<Vec<f32>> = None;
|
||
let total_chunks = prompt_len.div_ceil(chunk_size);
|
||
let mut chunk_idx = 0_usize;
|
||
while offset < prompt_len {
|
||
let end = (offset + chunk_size).min(prompt_len);
|
||
let chunk = prompt_tokens[offset..end].to_vec();
|
||
let chunk_len = chunk.len();
|
||
let step_start = std::time::Instant::now();
|
||
let logits = pool
|
||
.generate_step(model_id, leader_handle, chunk, offset)
|
||
.await
|
||
.map_err(|e| anyhow::anyhow!("TP prefill chunk {chunk_idx}/{total_chunks}: {e}"))?;
|
||
tracing::debug!(
|
||
chunk_idx,
|
||
total_chunks,
|
||
chunk_len,
|
||
offset,
|
||
elapsed_ms = step_start.elapsed().as_millis(),
|
||
"chunked prefill (TP): chunk done"
|
||
);
|
||
if end == prompt_len {
|
||
last_logits = Some(logits);
|
||
}
|
||
offset = end;
|
||
chunk_idx += 1;
|
||
}
|
||
last_logits.ok_or_else(|| anyhow::anyhow!("chunked_prefill_tp: no chunks produced"))
|
||
}
|
||
|
||
/// Per-scheme source after env-var resolution. The auth token is the
|
||
/// already-read env-var value (or None for anonymous access), and the
|
||
/// cache dir is the post-`resolve_hf_cache` path for the huggingface
|
||
/// scheme and the operator's literal value for everything else.
|
||
#[derive(Debug, Clone)]
|
||
struct ResolvedSource {
|
||
endpoint: String,
|
||
auth_token: Option<String>,
|
||
cache_dir: Option<PathBuf>,
|
||
}
|
||
|
||
impl CandleHarness {
|
||
/// Construct a new harness for `bind_url` using `config`. Resolves
|
||
/// every configured source's auth env var and cache dir up front so
|
||
/// the hot load path (`hf_api_for`) is a pure HashMap lookup.
|
||
pub fn new(bind_url: String, config: &crate::config::CandleHarnessConfig) -> Arc<Self> {
|
||
let raw_sources = config.effective_sources();
|
||
let default_source = config.effective_default_source().to_string();
|
||
let mut sources = HashMap::with_capacity(raw_sources.len());
|
||
for (scheme, src) in raw_sources.into_iter() {
|
||
// Only the huggingface source gets the legacy
|
||
// HF_HUB_CACHE/HF_HOME env-var fallback chain — other
|
||
// schemes resolve to whatever the operator typed.
|
||
let cache_dir = if scheme == crate::config::DEFAULT_SOURCE_SCHEME {
|
||
resolve_hf_cache(src.cache_dir.clone())
|
||
} else {
|
||
src.cache_dir.clone()
|
||
};
|
||
let auth_token = src
|
||
.auth_env
|
||
.as_deref()
|
||
.and_then(|var| std::env::var(var).ok())
|
||
.filter(|v| !v.is_empty());
|
||
if let Some(p) = &cache_dir {
|
||
tracing::info!(
|
||
scheme = %scheme,
|
||
endpoint = %src.endpoint,
|
||
cache = %p.display(),
|
||
auth = auth_token.is_some(),
|
||
"candle harness source resolved"
|
||
);
|
||
} else {
|
||
tracing::info!(
|
||
scheme = %scheme,
|
||
endpoint = %src.endpoint,
|
||
auth = auth_token.is_some(),
|
||
"candle harness source resolved (no cache dir; using hf-hub default)"
|
||
);
|
||
}
|
||
sources.insert(
|
||
scheme,
|
||
ResolvedSource {
|
||
endpoint: src.endpoint,
|
||
auth_token,
|
||
cache_dir,
|
||
},
|
||
);
|
||
}
|
||
if !sources.contains_key(&default_source) {
|
||
tracing::warn!(
|
||
default_source,
|
||
"configured default_source has no matching [harness.candle.sources.*] entry; \
|
||
bare model ids will fail to resolve until this is fixed"
|
||
);
|
||
}
|
||
let (recovery_tx, recovery_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||
let this = Arc::new(Self {
|
||
models: Arc::new(RwLock::new(HashMap::new())),
|
||
sources,
|
||
default_source,
|
||
bind_url,
|
||
device_workers: Arc::new(RwLock::new(HashMap::new())),
|
||
recovering: Arc::new(RwLock::new(std::collections::HashSet::new())),
|
||
recovery_tx,
|
||
});
|
||
// Background auto-recovery task (#17). Holds a `Weak` so it can't
|
||
// keep the harness alive. Spawned only when a tokio runtime is
|
||
// present — sync unit tests that build a harness without one
|
||
// simply skip it (they don't exercise recovery).
|
||
if tokio::runtime::Handle::try_current().is_ok() {
|
||
let weak = Arc::downgrade(&this);
|
||
tokio::spawn(recovery_loop(weak, recovery_rx));
|
||
}
|
||
this
|
||
}
|
||
|
||
/// Scheme to substitute for bare `org/name` model ids. Mirrors the
|
||
/// effective default from the operator's config, exposed for the
|
||
/// load path's `ModelSourceId::with_default_scheme`.
|
||
pub(crate) fn default_source_scheme(&self) -> &str {
|
||
&self.default_source
|
||
}
|
||
|
||
/// Pick a candle `Device` for the requested indices. Without the
|
||
/// `cuda` feature, or if CUDA initialisation fails, falls back to CPU.
|
||
fn pick_device(devices: &[u32]) -> Result<Device> {
|
||
let _idx = devices.first().copied().unwrap_or(0) as usize;
|
||
#[cfg(feature = "cuda")]
|
||
{
|
||
match Device::new_cuda(_idx) {
|
||
Ok(d) => return Ok(d),
|
||
Err(e) => tracing::warn!(
|
||
device = _idx,
|
||
error = %e,
|
||
"CUDA device unavailable, falling back to CPU"
|
||
),
|
||
}
|
||
}
|
||
Ok(Device::Cpu)
|
||
}
|
||
|
||
/// Return the worker handle for `device_index`, spawning it on
|
||
/// first request. The handle is cached on `self` so subsequent
|
||
/// loads against the same device share the same thread. Used to
|
||
/// populate `LoadedModel::worker` and `TpLoadedModel::worker` at
|
||
/// load time; in later refactor phases the worker also owns the
|
||
/// `ModelArch` and `TpLeaderModel` slabs.
|
||
#[allow(dead_code)]
|
||
async fn ensure_device_worker(
|
||
&self,
|
||
device_index: u32,
|
||
) -> Result<Arc<super::device_worker::DeviceWorkerHandle>> {
|
||
{
|
||
let workers = self.device_workers.read().await;
|
||
if let Some(w) = workers.get(&device_index) {
|
||
return Ok(Arc::clone(w));
|
||
}
|
||
}
|
||
// Write-lock acquired separately so the read path stays cheap.
|
||
// The `get` is repeated under the write lock to handle the
|
||
// race where two loads against a fresh device land here at
|
||
// once — the second caller sees the first's insertion and
|
||
// skips the second spawn.
|
||
let mut workers = self.device_workers.write().await;
|
||
if let Some(w) = workers.get(&device_index) {
|
||
return Ok(Arc::clone(w));
|
||
}
|
||
let handle = super::device_worker::DeviceWorkerHandle::spawn(device_index)
|
||
.with_context(|| format!("spawn device worker for cuda:{device_index}"))?;
|
||
workers.insert(device_index, Arc::clone(&handle));
|
||
tracing::info!(device_index, "spawned device worker");
|
||
Ok(handle)
|
||
}
|
||
|
||
/// Build an hf-hub API client for the given scheme. The scheme
|
||
/// must be present in the operator's configured `sources` table
|
||
/// (the synth `huggingface` entry counts). Each source carries its
|
||
/// own endpoint, optional bearer token, and cache directory, so
|
||
/// the same `org/name` served by two registries cannot collide on
|
||
/// disk.
|
||
pub(crate) fn hf_api_for(&self, scheme: &str) -> Result<hf_hub::api::tokio::Api> {
|
||
let src = self.sources.get(scheme).ok_or_else(|| {
|
||
let mut configured: Vec<&str> = self.sources.keys().map(String::as_str).collect();
|
||
configured.sort();
|
||
anyhow::anyhow!(
|
||
"no source configured for scheme '{scheme}'; \
|
||
configured: {configured:?}. Add a \
|
||
[harness.candle.sources.{scheme}] block to neuron.toml \
|
||
with endpoint = '...'."
|
||
)
|
||
})?;
|
||
let mut builder = hf_hub::api::tokio::ApiBuilder::new().with_endpoint(src.endpoint.clone());
|
||
if let Some(cache) = &src.cache_dir {
|
||
builder = builder.with_cache_dir(cache.clone());
|
||
}
|
||
if let Some(token) = &src.auth_token {
|
||
builder = builder.with_token(Some(token.clone()));
|
||
}
|
||
builder
|
||
.build()
|
||
.with_context(|| format!("build hf-hub API for scheme '{scheme}'"))
|
||
}
|
||
|
||
/// Resolve a dense (bf16/fp16 safetensors) model to its local file
|
||
/// paths.
|
||
///
|
||
/// Handles both sharded repos (`model.safetensors.index.json` plus
|
||
/// several `model-*.safetensors`) and the single-file layout
|
||
/// (`model.safetensors`). Returns the safetensors paths in
|
||
/// arbitrary order — `VarBuilder` unifies them into one tensor view.
|
||
async fn resolve_dense_files(
|
||
&self,
|
||
spec: &ModelSpec,
|
||
source_id: &cortex_core::source::ModelSourceId,
|
||
) -> Result<(PathBuf, PathBuf, Vec<PathBuf>)> {
|
||
let api = self.hf_api_for(&source_id.scheme)?;
|
||
let repo = api.model(source_id.repo_path());
|
||
let display_id = source_id.to_string();
|
||
let _ = spec; // reserved for future use (quant-aware filtering)
|
||
|
||
let config_path = repo
|
||
.get("config.json")
|
||
.await
|
||
.with_context(|| format!("fetch config.json from {display_id}"))?;
|
||
let tokenizer_path = repo
|
||
.get("tokenizer.json")
|
||
.await
|
||
.with_context(|| format!("fetch tokenizer.json from {display_id}"))?;
|
||
|
||
// Prefer the sharded layout (most HF dense models > 5B ship it).
|
||
let safetensors_paths = match repo.get("model.safetensors.index.json").await {
|
||
Ok(index_path) => {
|
||
let index_text = std::fs::read_to_string(&index_path)
|
||
.context("read model.safetensors.index.json")?;
|
||
let index: serde_json::Value = serde_json::from_str(&index_text)
|
||
.context("parse model.safetensors.index.json")?;
|
||
let weight_map = index
|
||
.get("weight_map")
|
||
.and_then(|v| v.as_object())
|
||
.ok_or_else(|| {
|
||
anyhow::anyhow!("safetensors index missing weight_map object")
|
||
})?;
|
||
let unique: std::collections::BTreeSet<String> = weight_map
|
||
.values()
|
||
.filter_map(|v| v.as_str().map(String::from))
|
||
.collect();
|
||
let mut paths = Vec::with_capacity(unique.len());
|
||
for fname in unique {
|
||
let p = repo
|
||
.get(&fname)
|
||
.await
|
||
.with_context(|| format!("fetch sharded safetensors {fname}"))?;
|
||
paths.push(p);
|
||
}
|
||
paths
|
||
}
|
||
Err(_) => {
|
||
// Single-file fallback.
|
||
let p = repo
|
||
.get("model.safetensors")
|
||
.await
|
||
.context("fetch model.safetensors (single-file layout)")?;
|
||
vec![p]
|
||
}
|
||
};
|
||
Ok((config_path, tokenizer_path, safetensors_paths))
|
||
}
|
||
|
||
/// Resolve + load a GGUF (pre-quantized) Qwen3. Returns the
|
||
/// tokenizer.json path so the caller can construct the Tokenizer
|
||
/// uniformly across source formats.
|
||
async fn load_arch_gguf(
|
||
&self,
|
||
spec: &ModelSpec,
|
||
source_id: &cortex_core::source::ModelSourceId,
|
||
device: &Device,
|
||
) -> Result<(PathBuf, ModelArch)> {
|
||
let (gguf_path, tokenizer_path) = self.resolve_files(spec, source_id).await?;
|
||
let device_for_load = device.clone();
|
||
let gguf_path_for_load = gguf_path.clone();
|
||
let model_id_for_log = spec.model_id.clone();
|
||
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
|
||
tracing::info!(model = %model_id_for_log, path = ?gguf_path_for_load, "loading GGUF");
|
||
let mut file = std::fs::File::open(&gguf_path_for_load).context("open GGUF file")?;
|
||
let content = gguf_file::Content::read(&mut file)
|
||
.map_err(|e| anyhow::anyhow!("parse GGUF: {e}"))?;
|
||
|
||
let architecture = content
|
||
.metadata
|
||
.get("general.architecture")
|
||
.and_then(|v| v.to_string().ok().cloned())
|
||
.unwrap_or_default();
|
||
tracing::info!(architecture = %architecture, "GGUF architecture");
|
||
|
||
// The `general.architecture` GGUF metadata key follows
|
||
// llama.cpp conventions (lowercase, no underscores in some
|
||
// cases) — `qwen3moe`, not `qwen3_moe`.
|
||
match architecture.as_str() {
|
||
"qwen3" => {
|
||
let weights =
|
||
QuantizedQwen3Weights::from_gguf(content, &mut file, &device_for_load)
|
||
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
|
||
Ok(ModelArch::Qwen3Quantized(weights))
|
||
}
|
||
"qwen3moe" => {
|
||
// GGUFQWenMoE takes an explicit compute dtype
|
||
// alongside the device — F16 matches the GGUF
|
||
// weights' typical accumulation precision and
|
||
// gives the best tokens/sec on consumer cards.
|
||
let weights =
|
||
GGUFQWenMoE::from_gguf(content, &mut file, &device_for_load, DType::F16)
|
||
.map_err(|e| anyhow::anyhow!("from_gguf qwen3_moe: {e}"))?;
|
||
Ok(ModelArch::Qwen3MoeQuantized(weights))
|
||
}
|
||
"llama" => {
|
||
let weights =
|
||
QuantizedLlamaWeights::from_gguf(content, &mut file, &device_for_load)
|
||
.map_err(|e| anyhow::anyhow!("from_gguf llama: {e}"))?;
|
||
Ok(ModelArch::LlamaQuantized(weights))
|
||
}
|
||
other => anyhow::bail!(
|
||
"unsupported GGUF architecture '{other}'; quantized path supports \
|
||
qwen3, qwen3moe, llama"
|
||
),
|
||
}
|
||
})
|
||
.await
|
||
.context("blocking GGUF load task panicked")??;
|
||
Ok((tokenizer_path, arch))
|
||
}
|
||
|
||
/// Resolve + load a dense Qwen3 from safetensors. Uses
|
||
/// `candle-transformers::models::qwen3::ModelForCausalLM` and
|
||
/// builds a VarBuilder over the mmap'd safetensors files. dtype
|
||
/// is bf16 by default to match the HF distribution dtype for
|
||
/// recent Qwen3 family models; fall back to f16 if the device
|
||
/// doesn't support bf16.
|
||
async fn load_arch_dense(
|
||
&self,
|
||
spec: &ModelSpec,
|
||
source_id: &cortex_core::source::ModelSourceId,
|
||
device: &Device,
|
||
) -> Result<(PathBuf, ModelArch)> {
|
||
let (config_path, tokenizer_path, safetensors_paths) =
|
||
self.resolve_dense_files(spec, source_id).await?;
|
||
let device_for_load = device.clone();
|
||
let model_id_for_log = spec.model_id.clone();
|
||
|
||
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
|
||
let cfg_text = std::fs::read_to_string(&config_path).context("read config.json")?;
|
||
check_dense_config_supported(&cfg_text, &model_id_for_log)?;
|
||
// Peek at model_type to choose the family before the
|
||
// typed deserialize — each family has its own Config.
|
||
let model_type = serde_json::from_str::<serde_json::Value>(&cfg_text)
|
||
.ok()
|
||
.as_ref()
|
||
.and_then(|v| v.get("model_type"))
|
||
.and_then(|v| v.as_str())
|
||
.unwrap_or("")
|
||
.to_string();
|
||
tracing::info!(
|
||
model = %model_id_for_log,
|
||
model_type = %model_type,
|
||
shards = safetensors_paths.len(),
|
||
"loading dense model from safetensors"
|
||
);
|
||
|
||
// bf16 is the canonical distribution dtype for Qwen3 /
|
||
// Llama 3 / Qwen3 MoE. CUDA on Ada+ has hardware bf16;
|
||
// Ampere has it too. CPU emulates.
|
||
let dtype = DType::BF16;
|
||
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
|
||
// mutation by another process while we hold the mapping is
|
||
// UB. We trust the HF cache is immutable-by-design.
|
||
let vb = unsafe {
|
||
VarBuilder::from_mmaped_safetensors(&safetensors_paths, dtype, &device_for_load)
|
||
.context("build VarBuilder over safetensors")?
|
||
};
|
||
|
||
match model_type.as_str() {
|
||
"qwen3" => {
|
||
let cfg: qwen3_dense::Config =
|
||
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
|
||
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
|
||
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
|
||
Ok(ModelArch::Qwen3Dense(model))
|
||
}
|
||
"qwen3_moe" => {
|
||
let cfg: qwen3_moe_dense::Config =
|
||
serde_json::from_str(&cfg_text).context("parse Qwen3 MoE config.json")?;
|
||
let model = qwen3_moe_dense::ModelForCausalLM::new(&cfg, vb)
|
||
.map_err(|e| anyhow::anyhow!("build Qwen3 MoE dense model: {e}"))?;
|
||
Ok(ModelArch::Qwen3MoeDense(model))
|
||
}
|
||
"llama" => {
|
||
let cfg: llama_dense::LlamaConfig =
|
||
serde_json::from_str(&cfg_text).context("parse Llama config.json")?;
|
||
// Llama has multiple sub-variants (Llama 1 has no
|
||
// GQA; Llama 3 does). `LlamaConfig::into_config`
|
||
// resolves the right shape; the `use_flash_attn`
|
||
// arg defaults to false — the flash kernel is a
|
||
// separate feature flag and uses extra VRAM.
|
||
let config = cfg.into_config(false);
|
||
let cache = llama_dense::Cache::new(true, dtype, &config, &device_for_load)
|
||
.context("build Llama Cache")?;
|
||
let model = llama_dense::Llama::load(vb, &config)
|
||
.map_err(|e| anyhow::anyhow!("build Llama dense model: {e}"))?;
|
||
Ok(ModelArch::LlamaDense(Box::new(LlamaDense {
|
||
model,
|
||
cache,
|
||
config,
|
||
dtype,
|
||
device: device_for_load,
|
||
})))
|
||
}
|
||
"qwen3_5" => {
|
||
// Qwen3-Next needs a ShardedVarBuilder because its
|
||
// load functions use the sharded backend (so they
|
||
// can be reused unchanged by the future TP variant).
|
||
// With world_size=1 the backend falls through to
|
||
// the unsharded path, so there is no per-load cost.
|
||
let cfg: super::arch::qwen3_5::Config = serde_json::from_str(&cfg_text)
|
||
.context("parse Qwen3-Next (qwen3_5) config.json")?;
|
||
let sharded_vb = unsafe {
|
||
candle_nn::var_builder::ShardedSafeTensors::var_builder(
|
||
&safetensors_paths,
|
||
dtype,
|
||
&device_for_load,
|
||
)
|
||
.context("build ShardedVarBuilder for Qwen3-Next")?
|
||
};
|
||
let model = super::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, sharded_vb)
|
||
.context("build Qwen3-Next dense model")?;
|
||
Ok(ModelArch::Qwen3_5Dense(model))
|
||
}
|
||
other => {
|
||
// Defensive: `check_dense_config_supported` already
|
||
// gated on the supported set, so this branch is
|
||
// unreachable unless that list and the match here
|
||
// drift apart.
|
||
anyhow::bail!(
|
||
"unrouted supported model_type '{other}' — \
|
||
DENSE_SUPPORTED_MODEL_TYPES and load_arch_dense \
|
||
must stay in sync"
|
||
)
|
||
}
|
||
}
|
||
})
|
||
.await
|
||
.context("blocking dense load task panicked")??;
|
||
Ok((tokenizer_path, arch))
|
||
}
|
||
|
||
/// Resolve a model spec to local GGUF and tokenizer file paths via
|
||
/// hf-hub. Downloads on first use; subsequent calls are cached.
|
||
async fn resolve_files(
|
||
&self,
|
||
spec: &ModelSpec,
|
||
source_id: &cortex_core::source::ModelSourceId,
|
||
) -> Result<(PathBuf, PathBuf)> {
|
||
let api = self.hf_api_for(&source_id.scheme)?;
|
||
let repo_path = source_id.repo_path();
|
||
let repo = api.model(repo_path.clone());
|
||
let display_id = source_id.to_string();
|
||
|
||
let info = repo
|
||
.info()
|
||
.await
|
||
.with_context(|| format!("fetch HF repo info for {display_id}"))?;
|
||
|
||
let quant = spec.quant.as_deref().unwrap_or("");
|
||
let quant_lc = quant.to_lowercase();
|
||
let gguf_filename = info
|
||
.siblings
|
||
.iter()
|
||
.map(|s| s.rfilename.as_str())
|
||
.filter(|name| name.to_lowercase().ends_with(".gguf"))
|
||
.find(|name| quant_lc.is_empty() || name.to_lowercase().contains(&quant_lc))
|
||
.ok_or_else(|| {
|
||
anyhow::anyhow!(
|
||
"no GGUF file matching quant {:?} in repo {display_id}",
|
||
spec.quant,
|
||
)
|
||
})?
|
||
.to_string();
|
||
|
||
tracing::info!(
|
||
model = %display_id,
|
||
file = %gguf_filename,
|
||
"resolving GGUF (may be cached)"
|
||
);
|
||
let gguf_path = repo
|
||
.get(&gguf_filename)
|
||
.await
|
||
.with_context(|| format!("fetch GGUF {gguf_filename}"))?;
|
||
|
||
// GGUF-only HF repos (unsloth/Qwen3-*-GGUF, Qwen/Qwen3-*-GGUF,
|
||
// etc.) ship the .gguf file but not tokenizer.json — the
|
||
// tokenizer.json lives in the base non-GGUF repo. Derive the
|
||
// base repo id by stripping a `-GGUF` / `-gguf` suffix; if
|
||
// there's no such suffix the same repo is used (works for
|
||
// non-GGUF model_ids). Stripping happens on the repo_path
|
||
// (scheme already accounted for) so this composes cleanly with
|
||
// helexa-scheme GGUF repos too.
|
||
let tokenizer_repo_path = repo_path
|
||
.strip_suffix("-GGUF")
|
||
.or_else(|| repo_path.strip_suffix("-gguf"))
|
||
.unwrap_or(&repo_path)
|
||
.to_string();
|
||
let tokenizer_repo = if tokenizer_repo_path == repo_path {
|
||
repo
|
||
} else {
|
||
tracing::debug!(
|
||
from = %repo_path,
|
||
to = %tokenizer_repo_path,
|
||
"tokenizer.json sourced from base repo (GGUF suffix stripped)"
|
||
);
|
||
api.model(tokenizer_repo_path.clone())
|
||
};
|
||
let tokenizer_path = tokenizer_repo
|
||
.get("tokenizer.json")
|
||
.await
|
||
.with_context(|| format!("fetch tokenizer.json from {tokenizer_repo_path}"))?;
|
||
Ok((gguf_path, tokenizer_path))
|
||
}
|
||
|
||
/// Run a non-streaming chat completion against a loaded model.
|
||
///
|
||
/// Returns a typed `InferenceError` when the model isn't loaded so the
|
||
/// handler can map to an appropriate HTTP status without string-matching.
|
||
pub async fn chat_completion(
|
||
&self,
|
||
request: ChatCompletionRequest,
|
||
) -> Result<ChatCompletionResponse, InferenceError> {
|
||
let handle = {
|
||
let models = self.models.read().await;
|
||
models.get(&request.model).cloned()
|
||
};
|
||
let handle = match handle {
|
||
Some(h) => h,
|
||
// Absent from the registry: distinguish a genuinely unloaded
|
||
// model from one whose slot is briefly gone mid auto-recovery
|
||
// (#17), so the client gets a transient "retry shortly" instead
|
||
// of a misleading "not loaded".
|
||
None if self.is_recovering(&request.model).await => {
|
||
return Err(recovering_error(&request.model));
|
||
}
|
||
None => return Err(InferenceError::ModelNotLoaded(request.model.clone())),
|
||
};
|
||
// The match is technically infallible without `cuda` (only Single
|
||
// exists), but the cfg-gated Tp arm makes this the right shape
|
||
// under both feature flags.
|
||
#[allow(clippy::infallible_destructuring_match)]
|
||
let loaded = match handle {
|
||
LoadedHandle::Single(m) => m,
|
||
#[cfg(feature = "cuda")]
|
||
LoadedHandle::Tp(m) => {
|
||
return self.chat_completion_tp(m, request).await;
|
||
}
|
||
};
|
||
|
||
// Span every line of this request with a short req_id +
|
||
// model so `grep req_id=…` over the journal can reconstruct
|
||
// one request even when dozens overlap. Add a terminal log
|
||
// line on both success and failure — the single-GPU path
|
||
// used to log nothing on either side, so a failing request
|
||
// looked exactly like an idle neuron.
|
||
let req_id = new_req_id();
|
||
let model_id = request.model.clone();
|
||
let span = tracing::info_span!("chat", req_id = %req_id, model = %model_id);
|
||
let req_start = std::time::Instant::now();
|
||
|
||
// Refuse the request up front if a prior inference poisoned
|
||
// the device context — otherwise we hand the doomed forward
|
||
// off to spawn_blocking and stall waiting for CUDA to fail.
|
||
if loaded.poisoned.load(Ordering::Acquire) {
|
||
let _g = span.enter();
|
||
tracing::warn!("chat_completion: refusing request, model poisoned");
|
||
return Err(self.trigger_recovery(&model_id).await);
|
||
}
|
||
if debug_poison_armed(&model_id) {
|
||
let _g = span.enter();
|
||
tracing::warn!("NEURON_DEBUG_POISON: forcing auto-recovery (#17 verification)");
|
||
return Err(self.trigger_recovery(&model_id).await);
|
||
}
|
||
|
||
// Serialise concurrent requests against this model. Holds for
|
||
// the duration of clear_kv_cache → prefill → decode so two
|
||
// requests' chunked-prefill sequences can't interleave on the
|
||
// shared KV cache (see `LoadedModel.inference_lock` for the
|
||
// observed failure mode).
|
||
let _inference_guard = loaded.inference_lock.lock().await;
|
||
|
||
let result = async {
|
||
let prompt = build_prompt_for_request(loaded.chat_template.as_deref(), &request);
|
||
|
||
let encoding = loaded
|
||
.tokenizer
|
||
.encode(prompt.as_str(), true)
|
||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||
let mut prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||
|
||
// Stage B: when the request carries images, preprocess
|
||
// them, expand each `<|image_pad|>` sentinel to N copies
|
||
// matching the per-image patch count, and route to the
|
||
// vision-aware worker path. Non-image requests skip all
|
||
// of this and follow the existing text-only flow.
|
||
let vision_route = if request_has_images(&request) {
|
||
// Stage B6: surface a structured `vision_unsupported`
|
||
// rejection when the request asks for vision against a
|
||
// text-only model. Cheap and stops the issue-#3 silent-
|
||
// drop pattern.
|
||
if !loaded.has_vision {
|
||
return Err(InferenceError::VisionUnsupported {
|
||
model_id: request.model.clone(),
|
||
});
|
||
}
|
||
let image_token_id = loaded
|
||
.image_token_id
|
||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||
model_id: request.model.clone(),
|
||
})?;
|
||
let factor = loaded.image_grid_factor.ok_or_else(|| {
|
||
InferenceError::VisionUnsupported {
|
||
model_id: request.model.clone(),
|
||
}
|
||
})?;
|
||
let profile = super::preprocess::PreprocessProfile::qwen3_6();
|
||
let images = extract_images_from_request(&request, &profile).map_err(|e| {
|
||
InferenceError::Other(anyhow::anyhow!("extract_images: {e}"))
|
||
})?;
|
||
if images.is_empty() {
|
||
// request_has_images said true but extract returned
|
||
// empty — defensive bail rather than silently dropping.
|
||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||
"request has image content but extractor produced zero images"
|
||
)));
|
||
}
|
||
// Per-image LM token count from each image's resized grid
|
||
// (#14 dynamic resolution; was a constant 196).
|
||
let per_image_counts: Vec<usize> = images
|
||
.iter()
|
||
.map(|im| (im.h / factor) * (im.w / factor))
|
||
.collect();
|
||
prompt_tokens =
|
||
expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
|
||
.map_err(InferenceError::Other)?;
|
||
Some((images, image_token_id))
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let prompt_len = prompt_tokens.len();
|
||
let temperature = request.temperature.unwrap_or(0.7);
|
||
let top_p = request.top_p;
|
||
let max_new = request.max_tokens.unwrap_or(8192) as usize;
|
||
let seed = unix_subsec_nanos();
|
||
|
||
let eos_id = loaded
|
||
.tokenizer
|
||
.token_to_id("<|im_end|>")
|
||
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
|
||
|
||
let (vram_free_mb, vram_total_mb) = loaded.query_vram().await;
|
||
tracing::info!(
|
||
prompt_len,
|
||
max_new,
|
||
temperature,
|
||
?top_p,
|
||
?eos_id,
|
||
vram_free_mb,
|
||
vram_total_mb,
|
||
vision = vision_route.is_some(),
|
||
"chat_completion: starting"
|
||
);
|
||
|
||
validate_request(prompt_len, vram_free_mb)?;
|
||
if vision_route.is_some() {
|
||
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
||
}
|
||
if vision_route.is_some() {
|
||
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
||
}
|
||
|
||
// Routing: CUDA loads go through the per-device worker
|
||
// thread (introduced in Phase 1; forward/clear added in
|
||
// Phase 2). CPU loads keep the existing spawn_blocking
|
||
// path because there's no context to own and the channel
|
||
// round-trip would only add latency. The two arms produce
|
||
// the same `(Vec<u32>, String)` shape so the rest of the
|
||
// path is shared.
|
||
let (generated_ids, finish_reason) = if let (Some(worker), Some(handle)) =
|
||
(loaded.worker.as_ref(), loaded.arch_handle)
|
||
{
|
||
// Worker path (CUDA).
|
||
#[cfg(feature = "cuda")]
|
||
{
|
||
let result = match &vision_route {
|
||
Some((images, image_token_id)) => {
|
||
run_inference_with_images_via_worker(
|
||
worker,
|
||
handle,
|
||
&prompt_tokens,
|
||
images.clone(),
|
||
*image_token_id,
|
||
max_new,
|
||
temperature,
|
||
top_p,
|
||
seed,
|
||
eos_id,
|
||
)
|
||
.await
|
||
}
|
||
None => {
|
||
run_inference_via_worker(
|
||
worker,
|
||
handle,
|
||
&prompt_tokens,
|
||
max_new,
|
||
temperature,
|
||
top_p,
|
||
seed,
|
||
eos_id,
|
||
)
|
||
.await
|
||
}
|
||
};
|
||
match result {
|
||
Ok(v) => v,
|
||
Err(e) => {
|
||
let chain = format!("{e:#}");
|
||
if is_device_fault(&chain) {
|
||
loaded.poisoned.store(true, Ordering::Release);
|
||
tracing::warn!(
|
||
error = %chain,
|
||
"chat_completion: failed with device fault, model marked poisoned"
|
||
);
|
||
} else {
|
||
tracing::warn!(
|
||
error = %chain,
|
||
"chat_completion: failed (non-device fault); model NOT marked poisoned"
|
||
);
|
||
}
|
||
return Err(InferenceError::Other(e));
|
||
}
|
||
}
|
||
}
|
||
#[cfg(not(feature = "cuda"))]
|
||
{
|
||
// Can't happen: `loaded.worker` is only Some on
|
||
// CUDA builds. The dead branch keeps the no-cuda
|
||
// build well-typed.
|
||
let _ = (worker, handle);
|
||
unreachable!("worker handle present without cuda feature");
|
||
}
|
||
} else if let Some(arch_arc) = loaded.arch.clone() {
|
||
// CPU path: existing spawn_blocking on the local
|
||
// Arc<Mutex<ModelArch>>.
|
||
let device = loaded.device.clone();
|
||
let inference_result =
|
||
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
||
let mut guard = arch_arc.blocking_lock();
|
||
run_inference(
|
||
&mut guard,
|
||
&device,
|
||
&prompt_tokens,
|
||
max_new,
|
||
temperature,
|
||
top_p,
|
||
seed,
|
||
eos_id,
|
||
)
|
||
})
|
||
.await;
|
||
|
||
// Distinguish "inference returned Err" (almost always a
|
||
// candle/CUDA failure that propagated through `?`, e.g.
|
||
// an OOM or driver error — the context is unreliable,
|
||
// poison the model) from "spawn_blocking task panicked
|
||
// or was cancelled" (a Rust-level panic in the closure,
|
||
// not a device fault; failing the one request without
|
||
// tearing down the model for everyone else is correct).
|
||
match inference_result {
|
||
Ok(Ok(v)) => v,
|
||
Ok(Err(e)) => {
|
||
let chain = format!("{e:#}");
|
||
if is_device_fault(&chain) {
|
||
loaded.poisoned.store(true, Ordering::Release);
|
||
tracing::warn!(
|
||
error = %chain,
|
||
"chat_completion: failed with device fault, model marked poisoned"
|
||
);
|
||
} else {
|
||
tracing::warn!(
|
||
error = %chain,
|
||
"chat_completion: failed (non-device fault); model NOT marked poisoned"
|
||
);
|
||
}
|
||
return Err(InferenceError::Other(e));
|
||
}
|
||
Err(join_err) => {
|
||
let cause = if join_err.is_panic() {
|
||
"panicked"
|
||
} else if join_err.is_cancelled() {
|
||
"was cancelled"
|
||
} else {
|
||
"ended abnormally"
|
||
};
|
||
tracing::error!(
|
||
cause,
|
||
error = %join_err,
|
||
"chat_completion: inference task {cause}; model NOT marked poisoned"
|
||
);
|
||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||
"inference task {cause}: {join_err}"
|
||
)));
|
||
}
|
||
}
|
||
} else {
|
||
// LoadedModel invariant: exactly one of `worker` /
|
||
// `arch` is Some. Reaching here is a construction bug.
|
||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||
"LoadedModel has neither worker handle nor local arch — load-path bug"
|
||
)));
|
||
};
|
||
|
||
let completion_text = loaded
|
||
.tokenizer
|
||
.decode(&generated_ids, true)
|
||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
|
||
|
||
let usage = Usage {
|
||
prompt_tokens: prompt_len as u64,
|
||
completion_tokens: generated_ids.len() as u64,
|
||
total_tokens: (prompt_len + generated_ids.len()) as u64,
|
||
};
|
||
|
||
tracing::info!(
|
||
prompt_tokens = prompt_len,
|
||
completion_tokens = generated_ids.len(),
|
||
finish_reason = %finish_reason,
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"chat_completion: done"
|
||
);
|
||
|
||
Ok::<_, InferenceError>(ChatCompletionResponse {
|
||
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
|
||
object: "chat.completion".into(),
|
||
created: unix_now_secs(),
|
||
model: request.model.clone(),
|
||
choices: vec![ChatCompletionChoice {
|
||
index: 0,
|
||
message: ChatMessage {
|
||
role: "assistant".into(),
|
||
content: MessageContent::Text(completion_text),
|
||
extra: serde_json::Value::Object(Default::default()),
|
||
},
|
||
finish_reason: Some(finish_reason),
|
||
extra: serde_json::Value::Object(Default::default()),
|
||
}],
|
||
usage: Some(usage),
|
||
extra: serde_json::Value::Object(Default::default()),
|
||
})
|
||
}
|
||
.instrument(span.clone())
|
||
.await;
|
||
|
||
if let Err(ref e) = result {
|
||
let _g = span.enter();
|
||
tracing::error!(
|
||
error = %format!("{e:#}"),
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"chat_completion: failed"
|
||
);
|
||
}
|
||
result
|
||
}
|
||
|
||
/// Run a streaming chat completion against a loaded model.
|
||
///
|
||
/// Returns an `mpsc::Receiver` that yields `ChatCompletionChunk`s in
|
||
/// OpenAI SSE format. The first chunk carries the assistant role;
|
||
/// subsequent chunks carry incremental `content` deltas; the final
|
||
/// chunk carries `finish_reason`. The handler is responsible for
|
||
/// wrapping these into an SSE response and appending the `[DONE]`
|
||
/// terminator.
|
||
///
|
||
/// Token-by-token decoding tracks the cumulative decoded prefix so
|
||
/// BPE byte-fallback boundaries don't split a UTF-8 char across
|
||
/// chunks.
|
||
pub async fn chat_completion_stream(
|
||
&self,
|
||
request: ChatCompletionRequest,
|
||
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
||
self.chat_completion_stream_with(request, wire_chat::ChatProjectionConfig::default())
|
||
.await
|
||
}
|
||
|
||
/// Same as [`Self::chat_completion_stream`] but lets the caller
|
||
/// pick the projection config — currently used by the HTTP
|
||
/// handler to thread `x-include-thinking` from the request
|
||
/// headers into [`wire_chat::ChatProjectionConfig::include_thinking`].
|
||
pub async fn chat_completion_stream_with(
|
||
&self,
|
||
request: ChatCompletionRequest,
|
||
mut config: wire_chat::ChatProjectionConfig,
|
||
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
||
let stream = self.inference_stream(request).await?;
|
||
// Fill in the model's reasoning markers if the caller
|
||
// didn't pre-populate them — they're a property of the
|
||
// loaded model (which the HTTP handler doesn't reach into
|
||
// directly), not of the request.
|
||
if config.reasoning_markers.is_none() {
|
||
config.reasoning_markers = stream.reasoning_markers.clone();
|
||
}
|
||
Ok(wire_chat::project_chat_stream_with(
|
||
stream.events,
|
||
stream.id,
|
||
stream.created,
|
||
stream.model_id,
|
||
config,
|
||
))
|
||
}
|
||
|
||
/// Streaming OpenAI Responses API entry point. Same harness
|
||
/// output as [`Self::chat_completion_stream`], projected into
|
||
/// the named-event SSE frames the Responses API client wants.
|
||
/// `response_id` and `message_item_id` are stamped into every
|
||
/// frame so the consumer can correlate.
|
||
pub async fn responses_stream(
|
||
&self,
|
||
request: ChatCompletionRequest,
|
||
response_id: String,
|
||
message_item_id: String,
|
||
) -> Result<mpsc::Receiver<crate::wire::openai_responses::ResponseStreamFrame>, InferenceError>
|
||
{
|
||
let stream = self.inference_stream(request).await?;
|
||
let meta = crate::wire::openai_responses::ResponseMeta {
|
||
response_id,
|
||
created_at: stream.created,
|
||
model_id: stream.model_id,
|
||
message_item_id,
|
||
};
|
||
Ok(crate::wire::openai_responses::project_responses_stream(
|
||
stream.events,
|
||
meta,
|
||
))
|
||
}
|
||
|
||
/// Format-agnostic streaming inference. Returns the raw
|
||
/// [`InferenceEvent`] receiver plus the per-request metadata
|
||
/// wire projectors stamp onto their frames. Lets every wire
|
||
/// format land on the same harness output without duplicating
|
||
/// setup / dispatch / spawn logic.
|
||
async fn inference_stream(
|
||
&self,
|
||
request: ChatCompletionRequest,
|
||
) -> Result<InferenceStream, InferenceError> {
|
||
let handle = {
|
||
let models = self.models.read().await;
|
||
models.get(&request.model).cloned()
|
||
};
|
||
let handle = match handle {
|
||
Some(h) => h,
|
||
// Absent from the registry: distinguish a genuinely unloaded
|
||
// model from one whose slot is briefly gone mid auto-recovery
|
||
// (#17), so the client gets a transient "retry shortly" instead
|
||
// of a misleading "not loaded".
|
||
None if self.is_recovering(&request.model).await => {
|
||
return Err(recovering_error(&request.model));
|
||
}
|
||
None => return Err(InferenceError::ModelNotLoaded(request.model.clone())),
|
||
};
|
||
// The match is technically infallible without `cuda` (only Single
|
||
// exists), but the cfg-gated Tp arm makes this the right shape
|
||
// under both feature flags.
|
||
#[allow(clippy::infallible_destructuring_match)]
|
||
let loaded = match handle {
|
||
LoadedHandle::Single(m) => m,
|
||
#[cfg(feature = "cuda")]
|
||
LoadedHandle::Tp(m) => {
|
||
return self.inference_tp_stream(m, request).await;
|
||
}
|
||
};
|
||
|
||
let prompt = build_prompt_for_request(loaded.chat_template.as_deref(), &request);
|
||
let encoding = loaded
|
||
.tokenizer
|
||
.encode(prompt.as_str(), true)
|
||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||
let mut prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||
|
||
// Stage C1: vision routing for the streaming path. Mirrors the
|
||
// non-streaming `chat_completion` block — detect image content,
|
||
// reject it against text-only models, preprocess each image and
|
||
// expand its `<|image_pad|>` sentinel to the per-image patch
|
||
// count, then carry the payload through to a single-shot
|
||
// image-spliced prefill. Non-image requests skip all of this.
|
||
// Returning early here (before the `Start` event below) keeps a
|
||
// rejected vision request from opening a half-formed SSE stream.
|
||
let vision_route: Option<(Vec<super::device_worker::jobs::ImageInput>, u32)> =
|
||
if request_has_images(&request) {
|
||
if !loaded.has_vision {
|
||
return Err(InferenceError::VisionUnsupported {
|
||
model_id: request.model.clone(),
|
||
});
|
||
}
|
||
let image_token_id =
|
||
loaded
|
||
.image_token_id
|
||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||
model_id: request.model.clone(),
|
||
})?;
|
||
let factor =
|
||
loaded
|
||
.image_grid_factor
|
||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||
model_id: request.model.clone(),
|
||
})?;
|
||
let profile = super::preprocess::PreprocessProfile::qwen3_6();
|
||
let images = extract_images_from_request(&request, &profile)
|
||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("extract_images: {e}")))?;
|
||
if images.is_empty() {
|
||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||
"request has image content but extractor produced zero images"
|
||
)));
|
||
}
|
||
// Per-image LM token count from each image's resized grid (#14).
|
||
let per_image_counts: Vec<usize> = images
|
||
.iter()
|
||
.map(|im| (im.h / factor) * (im.w / factor))
|
||
.collect();
|
||
prompt_tokens =
|
||
expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
|
||
.map_err(InferenceError::Other)?;
|
||
Some((images, image_token_id))
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let temperature = request.temperature.unwrap_or(0.7);
|
||
let top_p = request.top_p;
|
||
let max_new = request.max_tokens.unwrap_or(8192) as usize;
|
||
let seed = unix_subsec_nanos();
|
||
|
||
let eos_id = loaded
|
||
.tokenizer
|
||
.token_to_id("<|im_end|>")
|
||
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
|
||
|
||
let device = loaded.device.clone();
|
||
let tokenizer = loaded.tokenizer.clone();
|
||
let model_id = request.model.clone();
|
||
let id = format!("chatcmpl-{:x}", unix_subsec_nanos());
|
||
let created = unix_now_secs();
|
||
|
||
// Bounded channel so the producer (blocking inference) is back-
|
||
// pressured by the consumer (SSE writer, via the wire
|
||
// projector). 32 is generous — tokens arrive one at a time
|
||
// and downstream consumption is async.
|
||
let (tx, event_rx) = mpsc::channel::<InferenceEvent>(32);
|
||
|
||
// Refuse if the model is already poisoned. No point opening
|
||
// an SSE stream just to send the Start event and then bail.
|
||
if loaded.poisoned.load(Ordering::Acquire) {
|
||
return Err(self.trigger_recovery(&model_id).await);
|
||
}
|
||
|
||
// Start event: tells the wire projector to emit its
|
||
// format-specific "the assistant is about to speak" frame
|
||
// (an OpenAI `delta: {role: "assistant"}` chunk here; a
|
||
// `response.created` + `response.output_item.added` pair on
|
||
// the Responses path). If sending fails the receiver is
|
||
// already gone; bail before kicking off the heavy work.
|
||
tx.send(InferenceEvent::Start)
|
||
.await
|
||
.map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?;
|
||
|
||
// Span context — spawn_blocking detaches from the async
|
||
// executor so we capture the span explicitly and re-enter it
|
||
// inside the closure to keep the req_id on every emitted line.
|
||
let req_id = new_req_id();
|
||
let span = tracing::info_span!("chat_stream", req_id = %req_id, model = %model_id);
|
||
let prompt_len = prompt_tokens.len();
|
||
let req_start = std::time::Instant::now();
|
||
// Cloned `Arc<LoadedModel>` so the spawned task can mark the
|
||
// model poisoned if its forward fails.
|
||
let loaded_for_task = Arc::clone(&loaded);
|
||
let span_for_starting = span.clone();
|
||
let span_for_task = span.clone();
|
||
// Query VRAM before entering the span so we don't await inside
|
||
// an entered guard (Span::enter creates a synchronous guard
|
||
// that can't span await points). The span gets entered in a
|
||
// separate scope below purely for the log emission.
|
||
let (vram_free_mb, vram_total_mb) = loaded.query_vram().await;
|
||
{
|
||
let _g = span_for_starting.enter();
|
||
tracing::info!(
|
||
prompt_len,
|
||
max_new,
|
||
temperature,
|
||
?top_p,
|
||
?eos_id,
|
||
vram_free_mb,
|
||
vram_total_mb,
|
||
vision = vision_route.is_some(),
|
||
"chat_completion (stream): starting"
|
||
);
|
||
}
|
||
|
||
validate_request(prompt_len, vram_free_mb)?;
|
||
if vision_route.is_some() {
|
||
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
||
}
|
||
|
||
// Routing parallel to the non-streaming chat_completion: CUDA
|
||
// goes through the worker (async task), CPU keeps the
|
||
// spawn_blocking + Arc<Mutex<ModelArch>> path. Both branches
|
||
// acquire `loaded.inference_lock` from inside the spawned
|
||
// task so concurrent stream requests against the same model
|
||
// serialise at the request boundary (preventing the
|
||
// chunked-prefill KV-cache interleave failure mode). The
|
||
// role chunk was already sent above, so the client sees
|
||
// immediate "stream open" feedback even when this request
|
||
// queues behind another for the lock.
|
||
if let (Some(worker), Some(handle)) = (loaded.worker.clone(), loaded.arch_handle) {
|
||
#[cfg(feature = "cuda")]
|
||
{
|
||
let prompt_tokens = prompt_tokens.clone();
|
||
let reasoning_tokens_inner = loaded.reasoning_tokens.clone();
|
||
let tool_call_tokens_inner = loaded.tool_call_tokens.clone();
|
||
tokio::spawn(
|
||
async move {
|
||
let _inference_guard = loaded_for_task.inference_lock.lock().await;
|
||
match stream_inference_via_worker(
|
||
worker,
|
||
handle,
|
||
tokenizer,
|
||
prompt_tokens,
|
||
vision_route,
|
||
max_new,
|
||
temperature,
|
||
top_p,
|
||
seed,
|
||
eos_id,
|
||
reasoning_tokens_inner,
|
||
tool_call_tokens_inner,
|
||
tx,
|
||
)
|
||
.await
|
||
{
|
||
Ok(_finish_reason) => tracing::info!(
|
||
prompt_tokens = prompt_len,
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"chat_completion (stream): done"
|
||
),
|
||
Err(e) => {
|
||
let chain = format!("{e:#}");
|
||
if is_device_fault(&chain) {
|
||
loaded_for_task.poisoned.store(true, Ordering::Release);
|
||
tracing::error!(
|
||
error = %chain,
|
||
prompt_tokens = prompt_len,
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"chat_completion (stream): failed with device fault, model marked poisoned"
|
||
);
|
||
} else {
|
||
tracing::error!(
|
||
error = %chain,
|
||
prompt_tokens = prompt_len,
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"chat_completion (stream): failed (non-device fault); model NOT marked poisoned"
|
||
);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
.instrument(span_for_task),
|
||
);
|
||
}
|
||
#[cfg(not(feature = "cuda"))]
|
||
{
|
||
let _ = (worker, handle, span_for_task);
|
||
unreachable!("worker handle present without cuda feature");
|
||
}
|
||
} else if let Some(arch_arc) = loaded.arch.clone() {
|
||
let reasoning_tokens_inner = loaded.reasoning_tokens.clone();
|
||
let tool_call_tokens_inner = loaded.tool_call_tokens.clone();
|
||
tokio::task::spawn_blocking(move || {
|
||
let _g = span_for_task.enter();
|
||
// `blocking_lock` is safe here: spawn_blocking runs on
|
||
// a dedicated thread, not on the async runtime, so
|
||
// there's no executor to stall.
|
||
let _inference_guard = loaded_for_task.inference_lock.blocking_lock();
|
||
let mut guard = arch_arc.blocking_lock();
|
||
match run_inference_streaming(
|
||
&mut guard,
|
||
&device,
|
||
&tokenizer,
|
||
&prompt_tokens,
|
||
max_new,
|
||
temperature,
|
||
top_p,
|
||
seed,
|
||
eos_id,
|
||
reasoning_tokens_inner.as_ref(),
|
||
tool_call_tokens_inner.as_ref(),
|
||
&tx,
|
||
) {
|
||
Ok(()) => tracing::info!(
|
||
prompt_tokens = prompt_len,
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"chat_completion (stream): done"
|
||
),
|
||
Err(e) => {
|
||
let chain = format!("{e:#}");
|
||
if is_device_fault(&chain) {
|
||
loaded_for_task.poisoned.store(true, Ordering::Release);
|
||
tracing::error!(
|
||
error = %chain,
|
||
prompt_tokens = prompt_len,
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"chat_completion (stream): failed with device fault, model marked poisoned"
|
||
);
|
||
} else {
|
||
tracing::error!(
|
||
error = %chain,
|
||
prompt_tokens = prompt_len,
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"chat_completion (stream): failed (non-device fault); model NOT marked poisoned"
|
||
);
|
||
}
|
||
}
|
||
}
|
||
});
|
||
} else {
|
||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||
"LoadedModel has neither worker handle nor local arch — load-path bug"
|
||
)));
|
||
}
|
||
|
||
// Hand the raw event channel back to the public entry
|
||
// points (chat_completion_stream / responses_stream); they
|
||
// pick the wire projection.
|
||
let reasoning_markers = loaded.reasoning_tokens.clone();
|
||
Ok(InferenceStream {
|
||
events: event_rx,
|
||
id,
|
||
created,
|
||
model_id,
|
||
reasoning_markers,
|
||
})
|
||
}
|
||
}
|
||
|
||
/// The seam between inference (one shape, always) and wire formats
|
||
/// (many shapes, projector-per-format). Public so the format
|
||
/// projectors live outside the harness and the harness's
|
||
/// streaming-inference internals stay encapsulated.
|
||
pub struct InferenceStream {
|
||
/// Stream of model-output events. Producers (the various
|
||
/// inference loops) emit on this; consumers (wire projectors)
|
||
/// read from it.
|
||
pub events: mpsc::Receiver<InferenceEvent>,
|
||
/// Request id stamped into every wire-format frame
|
||
/// (`chatcmpl-…` for chat completions; the Responses path
|
||
/// makes its own `resp_…` id separately and ignores this one).
|
||
pub id: String,
|
||
/// Unix seconds when inference began. Same field threads into
|
||
/// every wire format's `created` / `created_at` slot.
|
||
pub created: u64,
|
||
/// Local model id (no endpoint prefix). Stamped into every
|
||
/// wire-format frame so consumers can correlate.
|
||
pub model_id: String,
|
||
/// Open/close reasoning marker text (and token ids) for the
|
||
/// loaded model, or `None` for non-reasoning models. Used by
|
||
/// the chat-completions projector when `include_thinking` is
|
||
/// set — the projector re-wraps reasoning content with the
|
||
/// literal markers so client-side parsers (helexa-acp's
|
||
/// `ThinkParser`) see the original on-the-wire shape.
|
||
pub reasoning_markers: Option<ReasoningTokenPair>,
|
||
}
|
||
|
||
/// Auto-recovery (#17) — rebuild a poisoned model's device context
|
||
/// automatically instead of leaving it bricked until a human reloads.
|
||
impl CandleHarness {
|
||
/// True while `model_id` is being auto-recovered (its slot is briefly
|
||
/// absent from the registry during the reload).
|
||
pub async fn is_recovering(&self, model_id: &str) -> bool {
|
||
self.recovering.read().await.contains(model_id)
|
||
}
|
||
|
||
/// Single-flight trigger from the request path: enqueue a rebuild for a
|
||
/// poisoned model (only the first caller per model enqueues) and return
|
||
/// the transient "recovering" error to hand back to the client.
|
||
async fn trigger_recovery(&self, model_id: &str) -> InferenceError {
|
||
let newly = self.recovering.write().await.insert(model_id.to_string());
|
||
if newly {
|
||
tracing::warn!(model = %model_id, "auto-recovery: poisoned, enqueueing rebuild");
|
||
if self.recovery_tx.send(model_id.to_string()).is_err() {
|
||
// Background task gone (harness shutting down). Drop the
|
||
// marker and fall back to the manual-reload message.
|
||
self.recovering.write().await.remove(model_id);
|
||
tracing::error!(model = %model_id, "auto-recovery: task unavailable");
|
||
return poisoned_error(model_id);
|
||
}
|
||
}
|
||
recovering_error(model_id)
|
||
}
|
||
|
||
/// Rebuild a poisoned model: `unload_model` (drops it → cudarc aborts
|
||
/// NCCL + releases the context) then `load_model` from the retained
|
||
/// spec. A successful reload re-runs NCCL init + sanity inside the load
|
||
/// path, so it returns a fresh, healthy model; a failed reload leaves
|
||
/// the model unloaded (recoverable by the next load), never poisoned
|
||
/// forever. Runs on the background task — never inline on the request
|
||
/// path (would deadlock on the `models` write lock).
|
||
async fn recover_one(&self, model_id: &str) {
|
||
let spec = {
|
||
let models = self.models.read().await;
|
||
models.get(model_id).map(|h| h.spec().clone())
|
||
};
|
||
let Some(spec) = spec else {
|
||
self.recovering.write().await.remove(model_id);
|
||
return;
|
||
};
|
||
tracing::warn!(model = %model_id, "auto-recovery: unload+reload starting");
|
||
if let Err(e) = self.unload_model(model_id).await {
|
||
tracing::error!(
|
||
model = %model_id,
|
||
error = %format!("{e:#}"),
|
||
"auto-recovery: unload failed (continuing to reload)"
|
||
);
|
||
}
|
||
match self.load_model(&spec).await {
|
||
Ok(()) => tracing::info!(model = %model_id, "auto-recovery: reloaded; model healthy"),
|
||
Err(e) => tracing::error!(
|
||
model = %model_id,
|
||
error = %format!("{e:#}"),
|
||
"auto-recovery: reload failed; model left unloaded"
|
||
),
|
||
}
|
||
self.recovering.write().await.remove(model_id);
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl Harness for CandleHarness {
|
||
fn name(&self) -> &str {
|
||
"candle"
|
||
}
|
||
|
||
async fn health(&self) -> HarnessHealth {
|
||
HarnessHealth {
|
||
name: "candle".into(),
|
||
running: true,
|
||
uptime_secs: None,
|
||
}
|
||
}
|
||
|
||
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||
let models = self.models.read().await;
|
||
Ok(models
|
||
.values()
|
||
.map(|h| ModelInfo {
|
||
id: h.model_id().into(),
|
||
harness: "candle".into(),
|
||
status: if h.is_poisoned() {
|
||
"poisoned".into()
|
||
} else {
|
||
"loaded".into()
|
||
},
|
||
devices: h.devices(),
|
||
vram_used_mb: None,
|
||
capabilities: h.capabilities(),
|
||
})
|
||
.collect())
|
||
}
|
||
|
||
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
|
||
if spec.harness != "candle" {
|
||
anyhow::bail!("expected harness=candle, got harness={}", spec.harness);
|
||
}
|
||
|
||
{
|
||
let models = self.models.read().await;
|
||
if models.contains_key(&spec.model_id) {
|
||
anyhow::bail!("model '{}' already loaded", spec.model_id);
|
||
}
|
||
}
|
||
|
||
// Parse the model id, substituting the harness's default
|
||
// source for bare `org/name` entries so existing operator
|
||
// configs keep working unchanged. Stored on the request-local
|
||
// path so downstream resolve_* can ask the right registry.
|
||
let source_id = spec
|
||
.model_id
|
||
.parse::<cortex_core::source::ModelSourceId>()
|
||
.with_context(|| format!("parse model id '{}' as scheme:org/name", spec.model_id))?
|
||
.with_default_scheme(self.default_source_scheme());
|
||
|
||
// Preflight: classify the source repo and apply the
|
||
// tp/quant/source feasibility table before any device
|
||
// allocation, NCCL handshake, or weight fetch. Failures bubble
|
||
// up as `super::preflight::PreflightError` wrapped in anyhow;
|
||
// the api.rs handler downcasts to produce a 422 with structured
|
||
// JSON. The plan it returns is not yet threaded through the
|
||
// dispatch — downstream `resolve_files` / `resolve_dense_files`
|
||
// re-run their own substring match — but the structured error
|
||
// surface is the main payoff.
|
||
let api = self.hf_api_for(&source_id.scheme)?;
|
||
super::preflight::preflight(&api, &source_id, spec)
|
||
.await
|
||
.map_err(anyhow::Error::new)?;
|
||
|
||
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
||
if tp_size > 1 {
|
||
#[cfg(feature = "cuda")]
|
||
{
|
||
return self.load_tp(spec, &source_id, tp_size).await;
|
||
}
|
||
#[cfg(not(feature = "cuda"))]
|
||
{
|
||
anyhow::bail!(
|
||
"tensor_parallel={tp_size} requested for '{}': this neuron \
|
||
binary was built without --features cuda; TP requires CUDA + NCCL",
|
||
spec.model_id
|
||
);
|
||
}
|
||
}
|
||
|
||
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
|
||
let device = Self::pick_device(&devices)?;
|
||
|
||
// Phase 4: load directly on the worker thread for CUDA;
|
||
// legacy spawn_blocking + Arc<Mutex<>> only for CPU. Resolve
|
||
// hf-hub paths up front (always async), then either dispatch
|
||
// a load Job (CUDA) or call the legacy local loader (CPU).
|
||
let worker: Option<Arc<super::device_worker::DeviceWorkerHandle>> = match &device {
|
||
#[cfg(feature = "cuda")]
|
||
Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?),
|
||
_ => None,
|
||
};
|
||
|
||
let (tokenizer_path, arch_local, arch_handle, vision_meta) = if let Some(w) = &worker {
|
||
// CUDA path: resolve, then load in the worker.
|
||
if spec.quant.is_some() {
|
||
let (gguf_path, tokenizer_path) = self.resolve_files(spec, &source_id).await?;
|
||
let handle = w
|
||
.load_gguf(gguf_path, spec.model_id.clone())
|
||
.await
|
||
.map_err(|e| anyhow::anyhow!("worker load_gguf: {e}"))?;
|
||
// GGUF Qwen3.6 releases don't ship the vision tower
|
||
// (Qwen-VL weights are in the dense safetensors only),
|
||
// so a GGUF load is text-only by construction.
|
||
(tokenizer_path, None, Some(handle), VisionMeta::default())
|
||
} else {
|
||
let (config_path, tokenizer_path, safetensors_paths) =
|
||
self.resolve_dense_files(spec, &source_id).await?;
|
||
let meta = VisionMeta::from_config_path(&config_path);
|
||
let handle = w
|
||
.load_dense(config_path, safetensors_paths, spec.model_id.clone())
|
||
.await
|
||
.map_err(|e| anyhow::anyhow!("worker load_dense: {e}"))?;
|
||
(tokenizer_path, None, Some(handle), meta)
|
||
}
|
||
} else {
|
||
// CPU path: legacy spawn_blocking + Arc<Mutex<ModelArch>>.
|
||
let (tokenizer_path, arch) = if spec.quant.is_some() {
|
||
self.load_arch_gguf(spec, &source_id, &device).await?
|
||
} else {
|
||
self.load_arch_dense(spec, &source_id, &device).await?
|
||
};
|
||
// CPU Qwen3.6 isn't a supported deployment target — the
|
||
// 27B doesn't fit any reasonable CPU memory budget — so
|
||
// we don't attempt to reach into the arch for vision
|
||
// metadata. Stays text-only.
|
||
(
|
||
tokenizer_path,
|
||
Some(Arc::new(Mutex::new(arch))),
|
||
None,
|
||
VisionMeta::default(),
|
||
)
|
||
};
|
||
|
||
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
||
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
||
|
||
// Probe for reasoning markers in the tokenizer's
|
||
// added-tokens table — `<think>` / `</think>` on Qwen3 +
|
||
// DeepSeek-R1 + gpt-oss, `[THINK]` / `[/THINK]` on
|
||
// Mistral Magistral, etc. `None` for non-reasoning models.
|
||
// The streaming loop uses this to route between TextDelta
|
||
// and ReasoningDelta without any hardcoded model
|
||
// knowledge; wire projectors decide what to do with the
|
||
// split.
|
||
let reasoning_tokens = detect_reasoning_token_pair(|s| tokenizer.token_to_id(s));
|
||
if let Some(ref pair) = reasoning_tokens {
|
||
tracing::info!(
|
||
model = %spec.model_id,
|
||
open = %pair.open_text,
|
||
close = %pair.close_text,
|
||
open_id = pair.open_id,
|
||
close_id = pair.close_id,
|
||
"reasoning markers detected — streaming will route ReasoningDelta separately"
|
||
);
|
||
}
|
||
let tool_call_tokens = detect_tool_call_token_pair(|s| tokenizer.token_to_id(s));
|
||
if let Some(ref pair) = tool_call_tokens {
|
||
tracing::info!(
|
||
model = %spec.model_id,
|
||
open = %pair.open_text,
|
||
close = %pair.close_text,
|
||
open_id = pair.open_id,
|
||
close_id = pair.close_id,
|
||
"tool-call markers detected — streaming will emit structured ToolCall events"
|
||
);
|
||
}
|
||
// Probe `tokenizer_config.json` in the same snapshot dir.
|
||
// When present and non-empty, the inference path renders
|
||
// this Jinja template with the request's
|
||
// `chat_template_kwargs` instead of using the hardcoded
|
||
// ChatML formatter. Best-effort: missing or unparseable
|
||
// configs silently fall through to the legacy path.
|
||
let chat_template = super::chat_template::load_chat_template_alongside(&tokenizer_path);
|
||
if chat_template.is_some() {
|
||
tracing::info!(
|
||
model = %spec.model_id,
|
||
"chat_template loaded from tokenizer_config.json — prompt assembly will use the model's own template"
|
||
);
|
||
}
|
||
|
||
let loaded = Arc::new(LoadedModel {
|
||
model_id: spec.model_id.clone(),
|
||
arch: arch_local,
|
||
tokenizer,
|
||
device,
|
||
quant: spec.quant.clone(),
|
||
devices,
|
||
poisoned: AtomicBool::new(false),
|
||
worker,
|
||
arch_handle,
|
||
inference_lock: tokio::sync::Mutex::new(()),
|
||
reasoning_tokens,
|
||
tool_call_tokens,
|
||
chat_template,
|
||
has_vision: vision_meta.has_vision,
|
||
image_token_id: vision_meta.image_token_id,
|
||
image_grid_factor: vision_meta.image_grid_factor,
|
||
spec: spec.clone(),
|
||
});
|
||
|
||
let mut models = self.models.write().await;
|
||
models.insert(spec.model_id.clone(), LoadedHandle::Single(loaded));
|
||
tracing::info!(model = %spec.model_id, "model loaded");
|
||
Ok(())
|
||
}
|
||
|
||
async fn unload_model(&self, model_id: &str) -> Result<()> {
|
||
let removed = {
|
||
let mut models = self.models.write().await;
|
||
models.remove(model_id)
|
||
};
|
||
let Some(handle) = removed else {
|
||
anyhow::bail!("model '{model_id}' not loaded");
|
||
};
|
||
// Single-GPU drops are immediate — the LoadedModel goes out of
|
||
// scope with the Arc and candle frees VRAM. CUDA loads also
|
||
// ship a `Job::DropArch` to the device worker so the boxed
|
||
// `ModelArch` releases its CUDA allocations on the right
|
||
// thread (with the bound context); without that, the Drop
|
||
// would run on whatever tokio thread happens to be holding
|
||
// the last `Arc<LoadedModel>` clone when this fn returns.
|
||
// TP unloads further coordinate the subprocess pool below.
|
||
match handle {
|
||
LoadedHandle::Single(single) => {
|
||
if let (Some(worker), Some(arch_handle)) =
|
||
(single.worker.as_ref(), single.arch_handle)
|
||
&& let Err(e) = worker.drop_arch(arch_handle).await
|
||
{
|
||
tracing::warn!(
|
||
model = %model_id,
|
||
error = %e,
|
||
"single-GPU unload: DropArch RPC failed (model state may leak in worker slab)"
|
||
);
|
||
}
|
||
}
|
||
#[cfg(feature = "cuda")]
|
||
LoadedHandle::Tp(tp) => {
|
||
// Try to recover the inner TpLoadedModel so we can move
|
||
// the pool and shut it down. If anyone else still holds
|
||
// a clone of the Arc (shouldn't happen — the only owners
|
||
// are the registry and any in-flight chat_completion),
|
||
// bail with a clear marker rather than silently leaking.
|
||
let tp = match Arc::try_unwrap(tp) {
|
||
Ok(t) => t,
|
||
Err(arc) => {
|
||
// Reinsert so we don't leave the registry in an
|
||
// inconsistent state.
|
||
let mut models = self.models.write().await;
|
||
models.insert(model_id.into(), LoadedHandle::Tp(arc));
|
||
anyhow::bail!("cannot unload '{model_id}': inference still in flight");
|
||
}
|
||
};
|
||
// Drop the leader's TpLeaderModel on the device worker
|
||
// thread (CUDA tensors and Arc<Comm> clones release on
|
||
// the same OS thread that allocated them).
|
||
if let Err(e) = tp.worker.drop_tp(tp.leader_handle).await {
|
||
tracing::warn!(
|
||
model = %model_id,
|
||
error = %e,
|
||
"TP unload: DropTp RPC failed (leader model may leak in worker slab)"
|
||
);
|
||
}
|
||
let mut pool = tp.pool.into_inner();
|
||
if let Err(e) = pool.unload_model(model_id).await {
|
||
tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed");
|
||
}
|
||
if let Err(e) = pool.shutdown().await {
|
||
tracing::warn!(model = %model_id, error = %e, "TP pool shutdown failed");
|
||
}
|
||
}
|
||
}
|
||
tracing::info!(model = %model_id, "model unloaded");
|
||
Ok(())
|
||
}
|
||
|
||
async fn inference_endpoint(&self, model_id: &str) -> Option<String> {
|
||
let models = self.models.read().await;
|
||
models.contains_key(model_id).then(|| self.bind_url.clone())
|
||
}
|
||
}
|
||
|
||
impl CandleHarness {
|
||
/// Tensor-parallel load. Resolves dense safetensors via hf-hub the
|
||
/// same way the single-GPU dense path does, spins up a TP worker
|
||
/// pool sized to `tp_size`, runs the NCCL handshake, then has
|
||
/// every rank load its shard of the model.
|
||
///
|
||
/// `spec.devices` carries the per-rank CUDA device indices (one
|
||
/// entry per rank, in rank order); defaults to `0..tp_size`.
|
||
#[cfg(feature = "cuda")]
|
||
async fn load_tp(
|
||
&self,
|
||
spec: &ModelSpec,
|
||
source_id: &cortex_core::source::ModelSourceId,
|
||
tp_size: u32,
|
||
) -> Result<()> {
|
||
use std::sync::Arc as StdArc;
|
||
use tokio::sync::Mutex as TMutex;
|
||
|
||
// Default per-rank device assignment: 0, 1, ..., tp_size - 1.
|
||
let devices = spec
|
||
.devices
|
||
.clone()
|
||
.unwrap_or_else(|| (0..tp_size).collect());
|
||
if devices.len() as u32 != tp_size {
|
||
anyhow::bail!(
|
||
"tensor_parallel={tp_size} requires {tp_size} entries in devices, got {}",
|
||
devices.len()
|
||
);
|
||
}
|
||
// `quant` on the TP path now means in-situ quantization (ISQ):
|
||
// load safetensors, quantize the per-rank shard to the named
|
||
// GgmlDType at load time. The worker's parse_quant_string
|
||
// accepts the same names (q5k, q8_0, etc.) as the single-GPU
|
||
// path. GGUF-source-file models still aren't TP-loadable, but
|
||
// resolve_dense_files only looks for safetensors so that path
|
||
// errors out cleanly later if no safetensors are present.
|
||
|
||
// 1. Resolve config + tokenizer + safetensors via hf-hub.
|
||
let (config_path, tokenizer_path, safetensors_paths) =
|
||
self.resolve_dense_files(spec, source_id).await?;
|
||
let config_json = std::fs::read_to_string(&config_path).context("read config.json")?;
|
||
// Reject unsupported architectures *before* spawning the worker
|
||
// pool and fanning out NCCL — otherwise we'd burn the pool
|
||
// lifecycle on a load that's guaranteed to fail at deserialise
|
||
// time inside every rank.
|
||
check_dense_config_supported(&config_json, &spec.model_id)?;
|
||
// The TP path knows how to ship and reconstruct a Qwen3 dense
|
||
// shard (`tp_qwen3.rs`). Other architectures may pass the
|
||
// single-GPU `check_dense_config_supported` check above but
|
||
// have no TP-aware module — bail with a clear marker pointing
|
||
// at the file the implementer needs to add. This keeps an
|
||
// operator who sets `tensor_parallel=2` on a Llama model from
|
||
// silently routing through `pool.load_dense_shard` (which
|
||
// assumes Qwen3 config shape on the worker side) and producing
|
||
// a confusing config-parse failure inside every rank.
|
||
check_tp_arch_supported(&config_json, &spec.model_id)?;
|
||
|
||
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
|
||
// 1..tp_size are subprocesses, one per device after the
|
||
// leader's own. The leader's device worker thread is
|
||
// spawned (or reused) here and passed into the pool so
|
||
// `init_nccl`, the load, every TP forward, and KV-cache
|
||
// clears all dispatch from the same OS thread.
|
||
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
|
||
let leader_worker = self.ensure_device_worker(devices[0]).await?;
|
||
let mut pool =
|
||
super::tp::WorkerPool::spawn(&exe, tp_size, &devices, leader_worker.clone()).await?;
|
||
|
||
// 3. NCCL handshake across all ranks.
|
||
let leader_device_idx = devices[0];
|
||
pool.init_nccl(leader_device_idx).await?;
|
||
|
||
// 4. Pick the leader's candle Device (same index as init_nccl).
|
||
let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize)
|
||
.context("Device::new_cuda for TP leader")?;
|
||
|
||
// 5. Load this rank's shard on every rank. After Phase 3
|
||
// `load_dense_shard` transfers the freshly-built
|
||
// `TpLeaderModel` into the device worker's TP slab and
|
||
// returns the resulting handle.
|
||
let leader_handle = pool
|
||
.load_dense_shard(
|
||
&spec.model_id,
|
||
&config_json,
|
||
&safetensors_paths,
|
||
&leader_device,
|
||
candle_core::DType::BF16,
|
||
spec.quant.clone(),
|
||
)
|
||
.await?;
|
||
|
||
// 6. Tokenizer (same as single-GPU path).
|
||
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
||
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
||
// Reasoning + tool-call marker probes — identical to the
|
||
// single-GPU path. See LoadedModel's matching fields for
|
||
// the why.
|
||
let reasoning_tokens = detect_reasoning_token_pair(|s| tokenizer.token_to_id(s));
|
||
if let Some(ref pair) = reasoning_tokens {
|
||
tracing::info!(
|
||
model = %spec.model_id,
|
||
open = %pair.open_text,
|
||
close = %pair.close_text,
|
||
"TP load: reasoning markers detected"
|
||
);
|
||
}
|
||
let tool_call_tokens = detect_tool_call_token_pair(|s| tokenizer.token_to_id(s));
|
||
if let Some(ref pair) = tool_call_tokens {
|
||
tracing::info!(
|
||
model = %spec.model_id,
|
||
open = %pair.open_text,
|
||
close = %pair.close_text,
|
||
"TP load: tool-call markers detected"
|
||
);
|
||
}
|
||
let chat_template = super::chat_template::load_chat_template_alongside(&tokenizer_path);
|
||
if chat_template.is_some() {
|
||
tracing::info!(
|
||
model = %spec.model_id,
|
||
"TP load: chat_template loaded from tokenizer_config.json"
|
||
);
|
||
}
|
||
|
||
// Vision metadata from the same config.json the shards loaded
|
||
// from. The TP model builder (Stage 1) materialises a replicated
|
||
// vision tower on every rank when `vision_config` is present, so
|
||
// `has_vision` here is consistent with what each rank loaded.
|
||
let vision_meta = VisionMeta::from_config_path(&config_path);
|
||
if vision_meta.has_vision {
|
||
tracing::info!(
|
||
model = %spec.model_id,
|
||
image_token_id = ?vision_meta.image_token_id,
|
||
image_grid_factor = ?vision_meta.image_grid_factor,
|
||
"TP load: vision tower present, advertising vision capability"
|
||
);
|
||
}
|
||
|
||
let tp_loaded = StdArc::new(TpLoadedModel {
|
||
model_id: spec.model_id.clone(),
|
||
tokenizer,
|
||
devices: devices.clone(),
|
||
pool: TMutex::new(pool),
|
||
leader_handle,
|
||
leader_device: leader_device.clone(),
|
||
poisoned: AtomicBool::new(false),
|
||
// Same `leader_worker` we passed into the pool above —
|
||
// single `Arc` shared between WorkerPool and
|
||
// TpLoadedModel so they reference the same thread.
|
||
worker: leader_worker,
|
||
reasoning_tokens,
|
||
tool_call_tokens,
|
||
chat_template,
|
||
has_vision: vision_meta.has_vision,
|
||
image_token_id: vision_meta.image_token_id,
|
||
image_grid_factor: vision_meta.image_grid_factor,
|
||
spec: spec.clone(),
|
||
});
|
||
|
||
let mut models = self.models.write().await;
|
||
models.insert(spec.model_id.clone(), LoadedHandle::Tp(tp_loaded));
|
||
tracing::info!(
|
||
model = %spec.model_id,
|
||
tp_size,
|
||
?devices,
|
||
"TP model loaded"
|
||
);
|
||
Ok(())
|
||
}
|
||
|
||
/// Non-streaming chat completion against a TP model.
|
||
///
|
||
/// The actual work runs inside a `tokio::spawn`'d task so the HTTP
|
||
/// client disconnecting (curl timeout, browser nav-away, etc.)
|
||
/// can't cancel the future mid-`pool.generate_step` and leave the
|
||
/// worker subprocesses mid-RPC. If the spawned task is dropped,
|
||
/// it still runs to completion and finishes draining the pool —
|
||
/// the next inference request finds a clean pool. The HTTP layer
|
||
/// just gives up on the response.
|
||
///
|
||
/// Every step also emits `info`/`debug` tracing so journalctl
|
||
/// shows where time went without needing to surface internals in
|
||
/// the HTTP error response.
|
||
#[cfg(feature = "cuda")]
|
||
async fn chat_completion_tp(
|
||
&self,
|
||
tp: Arc<TpLoadedModel>,
|
||
request: ChatCompletionRequest,
|
||
) -> Result<ChatCompletionResponse, InferenceError> {
|
||
// Tag every line of this request with a short req_id so a
|
||
// grep over journalctl reconstructs one request even when
|
||
// dozens are queued and interleaved. The span prefix is added
|
||
// by the fmt subscriber to every event emitted within the
|
||
// instrumented future, including events from `WorkerPool::*`
|
||
// since those run on the leader's task.
|
||
let req_id = new_req_id();
|
||
let model_id = request.model.clone();
|
||
let span = tracing::info_span!("tp_chat", req_id = %req_id, model = %model_id);
|
||
let req_start = std::time::Instant::now();
|
||
|
||
if tp.poisoned.load(Ordering::Acquire) {
|
||
let _g = span.enter();
|
||
tracing::warn!("TP chat_completion: refusing request, model poisoned");
|
||
return Err(self.trigger_recovery(&model_id).await);
|
||
}
|
||
if debug_poison_armed(&model_id) {
|
||
let _g = span.enter();
|
||
tracing::warn!("NEURON_DEBUG_POISON: forcing auto-recovery (#17 verification)");
|
||
return Err(self.trigger_recovery(&model_id).await);
|
||
}
|
||
|
||
// Reject image-bearing requests against a TP model with no
|
||
// vision tower, cleanly (`vision_unsupported`) rather than
|
||
// silently dropping the image. Vision-capable TP loads fall
|
||
// through to the image-aware prefill in chat_completion_tp_inner.
|
||
if request_has_images(&request) && !tp.has_vision {
|
||
let _g = span.enter();
|
||
tracing::warn!(
|
||
"TP chat_completion: rejecting image request, model has no vision tower"
|
||
);
|
||
return Err(InferenceError::VisionUnsupported { model_id });
|
||
}
|
||
|
||
let tp_for_marker = Arc::clone(&tp);
|
||
let handle = tokio::spawn(chat_completion_tp_inner(tp, request).instrument(span.clone()));
|
||
match handle.await {
|
||
Ok(Ok(resp)) => Ok(resp),
|
||
Ok(Err(e)) => {
|
||
// The inner task returned Err. Only poison when the
|
||
// failure indicates a CUDA / NCCL driver fault — shape
|
||
// mismatches, NaN logits, tokenizer errors etc. don't
|
||
// touch the device context and shouldn't take the
|
||
// model down for everyone else.
|
||
let chain = format!("{e:#}");
|
||
let _g = span.enter();
|
||
if matches!(&e, InferenceError::Other(inner) if is_device_fault(&format!("{inner:#}")))
|
||
{
|
||
tp_for_marker.poisoned.store(true, Ordering::Release);
|
||
tracing::error!(
|
||
error = %chain,
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"TP chat_completion: failed with device fault, model marked poisoned"
|
||
);
|
||
} else {
|
||
tracing::error!(
|
||
error = %chain,
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"TP chat_completion: failed (non-device fault); model NOT marked poisoned"
|
||
);
|
||
}
|
||
Err(e)
|
||
}
|
||
Err(join_err) => {
|
||
// JoinError: the spawned task panicked or was cancelled.
|
||
// Tokenizer / sampling / serialisation panics don't touch
|
||
// the device, so don't poison the model — failing this
|
||
// one request is enough. (CUDA failures arrive as Err
|
||
// through `?`, not as panics, and are handled above.)
|
||
let cause = if join_err.is_panic() {
|
||
"panicked"
|
||
} else if join_err.is_cancelled() {
|
||
"was cancelled"
|
||
} else {
|
||
"ended abnormally"
|
||
};
|
||
let _g = span.enter();
|
||
tracing::error!(
|
||
cause,
|
||
error = %join_err,
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"TP chat_completion: inference task {cause}; model NOT marked poisoned"
|
||
);
|
||
Err(InferenceError::Other(anyhow::anyhow!(
|
||
"TP inference task {cause}: {join_err}"
|
||
)))
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Streaming counterpart to `chat_completion_tp`. Same per-step
|
||
/// orchestration (clear cache, prefill, sample, decode loop) but
|
||
/// emits one `ChatCompletionChunk` per token over an mpsc channel
|
||
/// so the handler can write an SSE stream.
|
||
///
|
||
/// Unlike the single-GPU streaming path (which runs the candle
|
||
/// forward inside `spawn_blocking` and uses `blocking_send`), the
|
||
/// TP loop is itself async — every `pool.generate_step` awaits the
|
||
/// leader's spawn_blocking forward plus every worker's recv_only.
|
||
/// So we `tokio::spawn` the orchestration task and use plain
|
||
/// `Sender::send`.
|
||
#[cfg(feature = "cuda")]
|
||
async fn inference_tp_stream(
|
||
&self,
|
||
tp: Arc<TpLoadedModel>,
|
||
request: ChatCompletionRequest,
|
||
) -> Result<InferenceStream, InferenceError> {
|
||
if tp.poisoned.load(Ordering::Acquire) {
|
||
return Err(self.trigger_recovery(&request.model).await);
|
||
}
|
||
|
||
// Reject image requests against a non-vision TP model before
|
||
// opening the SSE stream. Vision-capable TP loads fall through
|
||
// to the image-aware prefill in the orchestration task below.
|
||
if request_has_images(&request) && !tp.has_vision {
|
||
tracing::warn!(
|
||
"TP chat_completion (stream): rejecting image request, model has no vision tower"
|
||
);
|
||
return Err(InferenceError::VisionUnsupported {
|
||
model_id: request.model.clone(),
|
||
});
|
||
}
|
||
|
||
let prompt = build_prompt_for_request(tp.chat_template.as_deref(), &request);
|
||
let encoding = tp
|
||
.tokenizer
|
||
.encode(prompt.as_str(), true)
|
||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||
let mut prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||
|
||
// TP-vision (streaming): same detection + pad expansion as the
|
||
// non-streaming path. The resulting `vision_route` moves into
|
||
// the orchestration task, which runs a single-shot image prefill
|
||
// when present. Returning early here keeps a rejected request
|
||
// from opening the SSE stream.
|
||
let vision_route: Option<(Vec<String>, u32)> = if request_has_images(&request) {
|
||
if !tp.has_vision {
|
||
return Err(InferenceError::VisionUnsupported {
|
||
model_id: request.model.clone(),
|
||
});
|
||
}
|
||
let image_token_id =
|
||
tp.image_token_id
|
||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||
model_id: request.model.clone(),
|
||
})?;
|
||
let factor = tp
|
||
.image_grid_factor
|
||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||
model_id: request.model.clone(),
|
||
})?;
|
||
let data_uris = extract_image_data_uris(&request);
|
||
if data_uris.is_empty() {
|
||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||
"request has image content but extractor produced zero data URIs"
|
||
)));
|
||
}
|
||
// Per-image LM token count from each image's resized grid (#14).
|
||
// Decode header + smart_resize only; the workers re-derive the
|
||
// same dims when they preprocess for the replicated tower.
|
||
let profile = super::preprocess::PreprocessProfile::qwen3_6();
|
||
let per_image_counts: Vec<usize> = data_uris
|
||
.iter()
|
||
.enumerate()
|
||
.map(|(i, uri)| {
|
||
let (h, w) =
|
||
super::preprocess::resized_dims_for_uri(uri, &profile).map_err(|e| {
|
||
InferenceError::Other(anyhow::anyhow!("resized_dims image #{i}: {e}"))
|
||
})?;
|
||
Ok::<usize, InferenceError>((h as usize / factor) * (w as usize / factor))
|
||
})
|
||
.collect::<Result<Vec<_>, _>>()?;
|
||
prompt_tokens =
|
||
expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
|
||
.map_err(InferenceError::Other)?;
|
||
Some((data_uris, image_token_id))
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let prompt_len = prompt_tokens.len();
|
||
|
||
let temperature = request.temperature.unwrap_or(0.7);
|
||
let top_p = request.top_p;
|
||
let max_new = request.max_tokens.unwrap_or(8192) as usize;
|
||
let seed = unix_subsec_nanos();
|
||
|
||
let eos_id = tp
|
||
.tokenizer
|
||
.token_to_id("<|im_end|>")
|
||
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
|
||
|
||
let model_id = request.model.clone();
|
||
let id = format!("chatcmpl-{:x}", unix_subsec_nanos());
|
||
let created = unix_now_secs();
|
||
let tokenizer = tp.tokenizer.clone();
|
||
let reasoning_tokens = tp.reasoning_tokens.clone();
|
||
let tool_call_tokens = tp.tool_call_tokens.clone();
|
||
// The spawned orchestration task below consumes both `id`
|
||
// and `model_id` (tracing, pool lookups, NCCL ops use them
|
||
// heavily). The wire projector at the bottom of this fn
|
||
// also needs them to stamp request metadata onto every
|
||
// chunk. Clone here so each side owns its copy.
|
||
let projector_id = id.clone();
|
||
let projector_model_id = model_id.clone();
|
||
|
||
// Bounded channel — back-pressures the producer when
|
||
// downstream consumption (wire projector → SSE writer) is
|
||
// slow.
|
||
let (tx, event_rx) = mpsc::channel::<InferenceEvent>(32);
|
||
|
||
// Start event first, before kicking off the heavy work — if
|
||
// the receiver is gone by now there's no point starting
|
||
// inference. The wire projector materialises this as the
|
||
// OpenAI `delta: {role: "assistant"}` chunk.
|
||
tx.send(InferenceEvent::Start)
|
||
.await
|
||
.map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?;
|
||
|
||
// The orchestration task. Holds the pool lock for the lifetime
|
||
// of this inference; concurrent requests against the same TP
|
||
// model serialise behind it.
|
||
//
|
||
// Tagged with the same req_id span as the non-streaming path
|
||
// so the journal can be reconstructed regardless of which API
|
||
// surface the client hit.
|
||
let req_id = new_req_id();
|
||
let span = tracing::info_span!(
|
||
"tp_chat_stream",
|
||
req_id = %req_id,
|
||
model = %model_id
|
||
);
|
||
let req_start = std::time::Instant::now();
|
||
let (vram_free_mb, vram_total_mb) = tp.query_vram().await;
|
||
tracing::info!(
|
||
parent: &span,
|
||
prompt_len,
|
||
max_new,
|
||
temperature,
|
||
?top_p,
|
||
?eos_id,
|
||
vram_free_mb,
|
||
vram_total_mb,
|
||
"TP chat_completion (stream): starting"
|
||
);
|
||
|
||
validate_request(prompt_len, vram_free_mb)?;
|
||
if vision_route.is_some() {
|
||
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
||
}
|
||
|
||
let tp_for_task = Arc::clone(&tp);
|
||
tokio::spawn(
|
||
async move {
|
||
let mut failure: Option<String> = None;
|
||
let mut pool = acquire_pool_lock(&tp_for_task.pool, &model_id).await;
|
||
let leader_handle = tp_for_task.leader_handle;
|
||
|
||
let mut all_tokens: Vec<u32> = Vec::new();
|
||
// Incremental detokenizer. See the equivalent in
|
||
// `stream_inference_via_worker` for the why: the old
|
||
// "full decode + byte-slice delta" pattern panicked on
|
||
// UTF-8 mid-codepoint boundaries when BPE byte-fallback
|
||
// split a multi-byte char across tokens.
|
||
let mut decode_stream = tokenizer.decode_stream(true);
|
||
let mut finish_reason = FinishReason::Length;
|
||
// Reasoning + tool-call state machines — same as
|
||
// the single-GPU path. The TP path needs its own
|
||
// copies because the spawn closure owns them.
|
||
let mut in_reasoning = false;
|
||
let mut in_tool_call = false;
|
||
let mut tool_call_buf = String::new();
|
||
let mut tool_call_idx: usize = 0;
|
||
|
||
'work: {
|
||
if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await {
|
||
failure = Some(format!("clear_kv_cache: {e:#}"));
|
||
break 'work;
|
||
}
|
||
|
||
let mut logits_processor = {
|
||
let sampling = if temperature <= 0.0 {
|
||
Sampling::ArgMax
|
||
} else {
|
||
match top_p {
|
||
Some(p) => Sampling::TopP { p, temperature },
|
||
None => Sampling::All { temperature },
|
||
}
|
||
};
|
||
LogitsProcessor::from_sampling(seed, sampling)
|
||
};
|
||
|
||
// Chunked prefill — see `chunked_prefill_tp`. Each
|
||
// chunk fans out to every rank with a growing
|
||
// offset; only the final chunk's logits are kept
|
||
// for the first sample.
|
||
// Vision requests do a chunked image prefill (encode
|
||
// once, splice per chunk); text requests chunk it the
|
||
// same way. `vision_route` was moved into this task
|
||
// from the synchronous setup above.
|
||
let prefill_result = match &vision_route {
|
||
Some((data_uris, image_token_id)) => {
|
||
pool.generate_step_with_images(
|
||
&model_id,
|
||
leader_handle,
|
||
prompt_tokens.clone(),
|
||
0,
|
||
*image_token_id,
|
||
data_uris.clone(),
|
||
prefill_chunk_tokens(),
|
||
)
|
||
.await
|
||
}
|
||
None => {
|
||
chunked_prefill_tp(&mut pool, &model_id, leader_handle, &prompt_tokens)
|
||
.await
|
||
}
|
||
};
|
||
let logits_vec = match prefill_result {
|
||
Ok(l) => l,
|
||
Err(e) => {
|
||
failure = Some(format!("prefill: {e:#}"));
|
||
break 'work;
|
||
}
|
||
};
|
||
let (post_prefill_vram_free_mb, _) = tp_for_task.query_vram().await;
|
||
tracing::info!(
|
||
model = %model_id,
|
||
prompt_len,
|
||
vram_free_mb = post_prefill_vram_free_mb,
|
||
"TP chat_completion (stream): prefill complete"
|
||
);
|
||
let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) {
|
||
Ok(t) => t,
|
||
Err(e) => {
|
||
failure = Some(format!("prefill build cpu logits: {e:#}"));
|
||
break 'work;
|
||
}
|
||
};
|
||
let mut next_token =
|
||
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||
Ok(t) => t,
|
||
Err(e) => {
|
||
let health = logits_health_slice(&logits_vec);
|
||
tracing::warn!(
|
||
model = %model_id,
|
||
?health,
|
||
"TP chat_completion (stream): prefill sample failed; logits unhealthy"
|
||
);
|
||
failure = Some(format!("prefill sample: {e:#}"));
|
||
break 'work;
|
||
}
|
||
};
|
||
|
||
if Some(next_token) == eos_id {
|
||
finish_reason = FinishReason::Stop;
|
||
} else {
|
||
all_tokens.push(next_token);
|
||
match handle_tool_call_marker(
|
||
next_token,
|
||
tool_call_tokens.as_ref(),
|
||
&mut in_tool_call,
|
||
&mut tool_call_buf,
|
||
) {
|
||
ToolCallMarker::Enter => {}
|
||
ToolCallMarker::Exit { buffer } => {
|
||
let idx = tool_call_idx;
|
||
tool_call_idx += 1;
|
||
match parse_tool_call_body(&buffer, idx) {
|
||
Some((id, name, arguments)) => {
|
||
if tx
|
||
.send(InferenceEvent::ToolCall {
|
||
index: idx,
|
||
id,
|
||
name,
|
||
arguments,
|
||
})
|
||
.await
|
||
.is_err()
|
||
{
|
||
break 'work;
|
||
}
|
||
}
|
||
None => {
|
||
let open = tool_call_tokens
|
||
.as_ref()
|
||
.map(|p| p.open_text.as_str())
|
||
.unwrap_or("<tool_call>");
|
||
let close = tool_call_tokens
|
||
.as_ref()
|
||
.map(|p| p.close_text.as_str())
|
||
.unwrap_or("</tool_call>");
|
||
let raw = format!("{open}{buffer}{close}");
|
||
if !emit_delta(&raw, &tx, in_reasoning).await {
|
||
break 'work;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
ToolCallMarker::None => {
|
||
if in_tool_call {
|
||
match decode_stream.step(next_token) {
|
||
Ok(Some(s)) => tool_call_buf.push_str(&s),
|
||
Ok(None) => {}
|
||
Err(e) => tracing::warn!(
|
||
model = %model_id,
|
||
error = %e,
|
||
"TP stream: decode_stream step failed (in tool_call)"
|
||
),
|
||
}
|
||
} else if handle_reasoning_marker(
|
||
next_token,
|
||
reasoning_tokens.as_ref(),
|
||
&mut in_reasoning,
|
||
) {
|
||
// marker — nothing to emit
|
||
} else {
|
||
match decode_stream.step(next_token) {
|
||
Ok(Some(delta)) => {
|
||
if !emit_delta(&delta, &tx, in_reasoning).await {
|
||
break 'work;
|
||
}
|
||
}
|
||
Ok(None) => {}
|
||
Err(e) => tracing::warn!(
|
||
model = %model_id,
|
||
error = %e,
|
||
"TP stream: decode_stream step failed"
|
||
),
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
for index in 0..max_new.saturating_sub(1) {
|
||
let logits_vec = match pool
|
||
.generate_step(
|
||
&model_id,
|
||
leader_handle,
|
||
vec![next_token],
|
||
prompt_len + index,
|
||
)
|
||
.await
|
||
{
|
||
Ok(l) => l,
|
||
Err(e) => {
|
||
failure = Some(format!("decode step {index}: {e:#}"));
|
||
break 'work;
|
||
}
|
||
};
|
||
let logits = match Tensor::new(logits_vec.as_slice(), &Device::Cpu) {
|
||
Ok(t) => t,
|
||
Err(e) => {
|
||
failure =
|
||
Some(format!("decode build cpu logits {index}: {e:#}"));
|
||
break 'work;
|
||
}
|
||
};
|
||
next_token = match sample_with_penalty(
|
||
&logits,
|
||
&all_tokens,
|
||
&mut logits_processor,
|
||
) {
|
||
Ok(t) => t,
|
||
Err(e) => {
|
||
let health = logits_health_slice(&logits_vec);
|
||
tracing::warn!(
|
||
model = %model_id,
|
||
step = index,
|
||
?health,
|
||
"TP chat_completion (stream): decode sample failed; logits unhealthy"
|
||
);
|
||
failure = Some(format!("decode sample {index}: {e:#}"));
|
||
break 'work;
|
||
}
|
||
};
|
||
// Always await the query (even when the
|
||
// trace! is filtered out by RUST_LOG): the
|
||
// channel hop is ~tens of µs, comparable to
|
||
// the previous in-line bind+query cost, and
|
||
// making the call conditional adds complexity
|
||
// for negligible win. Revisit if it shows up
|
||
// in a hot-path profile.
|
||
let step_vram_free_mb = tp_for_task.query_vram().await.0;
|
||
tracing::trace!(
|
||
model = %model_id,
|
||
step = index,
|
||
next_token,
|
||
vram_free_mb = step_vram_free_mb,
|
||
"TP chat_completion (stream): decode step"
|
||
);
|
||
if Some(next_token) == eos_id {
|
||
finish_reason = FinishReason::Stop;
|
||
break;
|
||
}
|
||
all_tokens.push(next_token);
|
||
match handle_tool_call_marker(
|
||
next_token,
|
||
tool_call_tokens.as_ref(),
|
||
&mut in_tool_call,
|
||
&mut tool_call_buf,
|
||
) {
|
||
ToolCallMarker::Enter => continue,
|
||
ToolCallMarker::Exit { buffer } => {
|
||
let idx = tool_call_idx;
|
||
tool_call_idx += 1;
|
||
match parse_tool_call_body(&buffer, idx) {
|
||
Some((id, name, arguments)) => {
|
||
if tx
|
||
.send(InferenceEvent::ToolCall {
|
||
index: idx,
|
||
id,
|
||
name,
|
||
arguments,
|
||
})
|
||
.await
|
||
.is_err()
|
||
{
|
||
break 'work;
|
||
}
|
||
}
|
||
None => {
|
||
let open = tool_call_tokens
|
||
.as_ref()
|
||
.map(|p| p.open_text.as_str())
|
||
.unwrap_or("<tool_call>");
|
||
let close = tool_call_tokens
|
||
.as_ref()
|
||
.map(|p| p.close_text.as_str())
|
||
.unwrap_or("</tool_call>");
|
||
let raw = format!("{open}{buffer}{close}");
|
||
if !emit_delta(&raw, &tx, in_reasoning).await {
|
||
break 'work;
|
||
}
|
||
}
|
||
}
|
||
continue;
|
||
}
|
||
ToolCallMarker::None => {}
|
||
}
|
||
if in_tool_call {
|
||
match decode_stream.step(next_token) {
|
||
Ok(Some(s)) => tool_call_buf.push_str(&s),
|
||
Ok(None) => {}
|
||
Err(e) => tracing::warn!(
|
||
model = %model_id,
|
||
error = %e,
|
||
"TP stream: decode_stream step failed (in tool_call)"
|
||
),
|
||
}
|
||
continue;
|
||
}
|
||
if handle_reasoning_marker(
|
||
next_token,
|
||
reasoning_tokens.as_ref(),
|
||
&mut in_reasoning,
|
||
) {
|
||
continue;
|
||
}
|
||
match decode_stream.step(next_token) {
|
||
Ok(Some(delta)) => {
|
||
if !emit_delta(&delta, &tx, in_reasoning).await {
|
||
break 'work;
|
||
}
|
||
}
|
||
Ok(None) => {}
|
||
Err(e) => tracing::warn!(
|
||
model = %model_id,
|
||
error = %e,
|
||
"TP stream: decode_stream step failed"
|
||
),
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// One terminal line per request, success or failure. The
|
||
// success branch was previously implicit (the SSE final
|
||
// chunk went out and the spawned task just ended); now
|
||
// there's always a log line for the operator.
|
||
if let Some(err) = &failure {
|
||
if is_device_fault(err) {
|
||
tp_for_task.poisoned.store(true, Ordering::Release);
|
||
tracing::error!(
|
||
error = %err,
|
||
completion_tokens = all_tokens.len(),
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"TP chat_completion (stream): failed with device fault, model marked poisoned"
|
||
);
|
||
} else {
|
||
tracing::error!(
|
||
error = %err,
|
||
completion_tokens = all_tokens.len(),
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"TP chat_completion (stream): failed (non-device fault); model NOT marked poisoned"
|
||
);
|
||
}
|
||
} else {
|
||
tracing::info!(
|
||
prompt_tokens = prompt_len,
|
||
completion_tokens = all_tokens.len(),
|
||
finish_reason = finish_reason.as_openai_str(),
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"TP chat_completion (stream): done"
|
||
);
|
||
}
|
||
|
||
// Finish event — only on the success path. On
|
||
// failure we drop the channel so the client sees the
|
||
// SSE stream end abruptly (matches the pre-refactor
|
||
// behaviour when the failed-path early-returned
|
||
// without a final chunk).
|
||
if failure.is_none() {
|
||
let _ = tx
|
||
.send(InferenceEvent::Finish {
|
||
reason: finish_reason,
|
||
})
|
||
.await;
|
||
}
|
||
}
|
||
.instrument(span),
|
||
);
|
||
|
||
// Hand the raw event channel back to the public entry
|
||
// points; they pick the wire projection. Uses the clones
|
||
// we stashed before the spawn — the originals were moved
|
||
// into the orchestration task above.
|
||
let reasoning_markers = tp.reasoning_tokens.clone();
|
||
Ok(InferenceStream {
|
||
events: event_rx,
|
||
id: projector_id,
|
||
created,
|
||
model_id: projector_model_id,
|
||
reasoning_markers,
|
||
})
|
||
}
|
||
}
|
||
|
||
/// Body of the TP non-streaming chat completion, hoisted out of
|
||
/// `CandleHarness::chat_completion_tp` so it can run inside
|
||
/// `tokio::spawn` (which requires a `'static` future) and survive
|
||
/// HTTP-layer cancellation.
|
||
///
|
||
/// Tracing strategy: `info` for request entry/exit so journalctl
|
||
/// always shows when an inference started and finished; `debug` for
|
||
/// per-step timing so an operator running with `RUST_LOG=debug` sees
|
||
/// where the request actually spends its time without needing to
|
||
/// instrument the model code.
|
||
#[cfg(feature = "cuda")]
|
||
async fn chat_completion_tp_inner(
|
||
tp: Arc<TpLoadedModel>,
|
||
request: ChatCompletionRequest,
|
||
) -> Result<ChatCompletionResponse, InferenceError> {
|
||
let req_start = std::time::Instant::now();
|
||
let model_id = request.model.clone();
|
||
|
||
let prompt = build_prompt_for_request(tp.chat_template.as_deref(), &request);
|
||
let encoding = tp
|
||
.tokenizer
|
||
.encode(prompt.as_str(), true)
|
||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||
let mut prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||
|
||
// TP-vision: when the request carries images (and the model has a
|
||
// replicated tower — enforced by the caller's guard), expand each
|
||
// `<|image_pad|>` sentinel to the per-image patch count and carry
|
||
// the source data URIs through to the single-shot image prefill.
|
||
// Mirrors the single-GPU `chat_completion` vision_route block.
|
||
let vision_route: Option<(Vec<String>, u32)> = if request_has_images(&request) {
|
||
if !tp.has_vision {
|
||
return Err(InferenceError::VisionUnsupported {
|
||
model_id: request.model.clone(),
|
||
});
|
||
}
|
||
let image_token_id =
|
||
tp.image_token_id
|
||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||
model_id: request.model.clone(),
|
||
})?;
|
||
let factor = tp
|
||
.image_grid_factor
|
||
.ok_or_else(|| InferenceError::VisionUnsupported {
|
||
model_id: request.model.clone(),
|
||
})?;
|
||
let data_uris = extract_image_data_uris(&request);
|
||
if data_uris.is_empty() {
|
||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||
"request has image content but extractor produced zero data URIs"
|
||
)));
|
||
}
|
||
// Per-image LM token count from each image's resized grid (#14).
|
||
let profile = super::preprocess::PreprocessProfile::qwen3_6();
|
||
let per_image_counts: Vec<usize> = data_uris
|
||
.iter()
|
||
.enumerate()
|
||
.map(|(i, uri)| {
|
||
let (h, w) =
|
||
super::preprocess::resized_dims_for_uri(uri, &profile).map_err(|e| {
|
||
InferenceError::Other(anyhow::anyhow!("resized_dims image #{i}: {e}"))
|
||
})?;
|
||
Ok::<usize, InferenceError>((h as usize / factor) * (w as usize / factor))
|
||
})
|
||
.collect::<Result<Vec<_>, _>>()?;
|
||
prompt_tokens = expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
|
||
.map_err(InferenceError::Other)?;
|
||
Some((data_uris, image_token_id))
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let prompt_len = prompt_tokens.len();
|
||
|
||
let temperature = request.temperature.unwrap_or(0.7);
|
||
let top_p = request.top_p;
|
||
let max_new = request.max_tokens.unwrap_or(8192) as usize;
|
||
let seed = unix_subsec_nanos();
|
||
|
||
let eos_id = tp
|
||
.tokenizer
|
||
.token_to_id("<|im_end|>")
|
||
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
|
||
|
||
let (vram_free_mb, vram_total_mb) = tp.query_vram().await;
|
||
tracing::info!(
|
||
model = %model_id,
|
||
prompt_len,
|
||
max_new,
|
||
temperature,
|
||
?top_p,
|
||
?eos_id,
|
||
vram_free_mb,
|
||
vram_total_mb,
|
||
"TP chat_completion: starting"
|
||
);
|
||
|
||
validate_request(prompt_len, vram_free_mb)?;
|
||
if vision_route.is_some() {
|
||
validate_vision_prefill(prompt_len, vram_free_mb)?;
|
||
}
|
||
|
||
// Acquire the pool lock for the duration of the request. After
|
||
// Phase 3 the leader's TpLeaderModel lives in the device worker
|
||
// thread, so the pool lock now serialises only subprocess RPC
|
||
// traffic — but holding it for the whole request still keeps
|
||
// concurrent chat_completions against the same TP model from
|
||
// interleaving prefill/decode jobs.
|
||
let mut pool = acquire_pool_lock(&tp.pool, &model_id).await;
|
||
let leader_handle = tp.leader_handle;
|
||
|
||
// Reset every rank's KV cache so this request doesn't attend
|
||
// over the previous request's tokens.
|
||
let clear_start = std::time::Instant::now();
|
||
pool.clear_kv_cache(&model_id, leader_handle)
|
||
.await
|
||
.map_err(InferenceError::Other)?;
|
||
tracing::debug!(
|
||
model = %model_id,
|
||
elapsed_ms = clear_start.elapsed().as_millis(),
|
||
"TP chat_completion: kv cache cleared"
|
||
);
|
||
|
||
let mut logits_processor = {
|
||
let sampling = if temperature <= 0.0 {
|
||
Sampling::ArgMax
|
||
} else {
|
||
match top_p {
|
||
Some(p) => Sampling::TopP { p, temperature },
|
||
None => Sampling::All { temperature },
|
||
}
|
||
};
|
||
LogitsProcessor::from_sampling(seed, sampling)
|
||
};
|
||
|
||
let mut generated: Vec<u32> = Vec::new();
|
||
let mut finish_reason = "length".to_string();
|
||
|
||
// Prefill: chunk the prompt through `chunked_prefill_tp` so
|
||
// activation memory is bounded by chunk size rather than the full
|
||
// prompt length. Every rank still sees the prompt in order, just
|
||
// spread across multiple `generate_step` calls with monotonically
|
||
// growing offsets.
|
||
let prefill_start = std::time::Instant::now();
|
||
// Vision requests do a chunked image prefill (every rank encodes its
|
||
// replicated tower once, then splices per chunk); text requests
|
||
// chunk the prefill the same way.
|
||
let logits_vec = match &vision_route {
|
||
Some((data_uris, image_token_id)) => pool
|
||
.generate_step_with_images(
|
||
&model_id,
|
||
leader_handle,
|
||
prompt_tokens.clone(),
|
||
0,
|
||
*image_token_id,
|
||
data_uris.clone(),
|
||
prefill_chunk_tokens(),
|
||
)
|
||
.await
|
||
.map_err(InferenceError::Other)?,
|
||
None => chunked_prefill_tp(&mut pool, &model_id, leader_handle, &prompt_tokens)
|
||
.await
|
||
.map_err(InferenceError::Other)?,
|
||
};
|
||
let (post_prefill_vram_free_mb, _) = tp.query_vram().await;
|
||
tracing::info!(
|
||
model = %model_id,
|
||
prompt_len,
|
||
elapsed_ms = prefill_start.elapsed().as_millis(),
|
||
vram_free_mb = post_prefill_vram_free_mb,
|
||
"TP chat_completion: prefill complete"
|
||
);
|
||
// Wrap the CPU-side logits in a CPU candle Tensor for sampling.
|
||
// No device touch on the async caller's thread — sampling reads
|
||
// from CPU memory only.
|
||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)
|
||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("build cpu logits: {e}")))?;
|
||
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||
Ok(t) => t,
|
||
Err(e) => {
|
||
// Logits health snapshot — the surrounding wrapper logs
|
||
// "failed, model marked poisoned" with the error chain;
|
||
// this WARN sits just above that and carries the actual
|
||
// numerical state so an operator can tell at a glance
|
||
// whether it was a NaN cascade, an Inf, or something else.
|
||
let health = logits_health_slice(&logits_vec);
|
||
tracing::warn!(
|
||
model = %model_id,
|
||
?health,
|
||
"TP chat_completion: prefill sample failed; logits unhealthy"
|
||
);
|
||
return Err(InferenceError::Other(e));
|
||
}
|
||
};
|
||
|
||
if Some(next_token) == eos_id {
|
||
finish_reason = "stop".into();
|
||
} else {
|
||
generated.push(next_token);
|
||
let decode_start = std::time::Instant::now();
|
||
for index in 0..max_new.saturating_sub(1) {
|
||
let step_start = std::time::Instant::now();
|
||
let logits_vec = pool
|
||
.generate_step(
|
||
&model_id,
|
||
leader_handle,
|
||
vec![next_token],
|
||
prompt_len + index,
|
||
)
|
||
.await
|
||
.map_err(InferenceError::Other)?;
|
||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu).map_err(|e| {
|
||
InferenceError::Other(anyhow::anyhow!("build cpu logits step {index}: {e}"))
|
||
})?;
|
||
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||
Ok(t) => t,
|
||
Err(e) => {
|
||
let health = logits_health_slice(&logits_vec);
|
||
tracing::warn!(
|
||
model = %model_id,
|
||
step = index,
|
||
?health,
|
||
"TP chat_completion: decode sample failed; logits unhealthy"
|
||
);
|
||
return Err(InferenceError::Other(e));
|
||
}
|
||
};
|
||
let step_vram_free_mb = tp.query_vram().await.0;
|
||
tracing::trace!(
|
||
model = %model_id,
|
||
step = index,
|
||
next_token,
|
||
step_ms = step_start.elapsed().as_millis(),
|
||
vram_free_mb = step_vram_free_mb,
|
||
"TP chat_completion: decode step"
|
||
);
|
||
if Some(next_token) == eos_id {
|
||
finish_reason = "stop".into();
|
||
break;
|
||
}
|
||
generated.push(next_token);
|
||
}
|
||
tracing::info!(
|
||
model = %model_id,
|
||
generated = generated.len(),
|
||
elapsed_ms = decode_start.elapsed().as_millis(),
|
||
"TP chat_completion: decode complete"
|
||
);
|
||
}
|
||
drop(pool);
|
||
|
||
let completion_text = tp
|
||
.tokenizer
|
||
.decode(&generated, true)
|
||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
|
||
|
||
let usage = Usage {
|
||
prompt_tokens: prompt_len as u64,
|
||
completion_tokens: generated.len() as u64,
|
||
total_tokens: (prompt_len + generated.len()) as u64,
|
||
};
|
||
|
||
tracing::info!(
|
||
model = %model_id,
|
||
prompt_tokens = prompt_len,
|
||
completion_tokens = generated.len(),
|
||
finish_reason = %finish_reason,
|
||
total_ms = req_start.elapsed().as_millis(),
|
||
"TP chat_completion: done"
|
||
);
|
||
|
||
Ok(ChatCompletionResponse {
|
||
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
|
||
object: "chat.completion".into(),
|
||
created: unix_now_secs(),
|
||
model: model_id,
|
||
choices: vec![ChatCompletionChoice {
|
||
index: 0,
|
||
message: ChatMessage {
|
||
role: "assistant".into(),
|
||
content: MessageContent::Text(completion_text),
|
||
extra: serde_json::Value::Object(Default::default()),
|
||
},
|
||
finish_reason: Some(finish_reason),
|
||
extra: serde_json::Value::Object(Default::default()),
|
||
}],
|
||
usage: Some(usage),
|
||
extra: serde_json::Value::Object(Default::default()),
|
||
})
|
||
}
|
||
|
||
/// Send `delta` as an [`InferenceEvent::TextDelta`]. Returns `false`
|
||
/// if the receiver has hung up — the caller should bail. Empty
|
||
/// deltas (the DecodeStream is buffering an incomplete UTF-8
|
||
/// sequence) are a no-op return-true so the caller can treat "no
|
||
/// delta yet" and "tx still live" uniformly.
|
||
///
|
||
/// Wire-format-specific metadata (chunk id, created, model_id)
|
||
/// stays out of this function — the wire projector in
|
||
/// [`crate::wire::openai_chat`] stamps it onto every chunk
|
||
/// downstream.
|
||
#[cfg(feature = "cuda")]
|
||
async fn emit_delta(delta: &str, tx: &mpsc::Sender<InferenceEvent>, in_reasoning: bool) -> bool {
|
||
if delta.is_empty() {
|
||
return true;
|
||
}
|
||
let event = if in_reasoning {
|
||
InferenceEvent::ReasoningDelta(delta.into())
|
||
} else {
|
||
InferenceEvent::TextDelta(delta.into())
|
||
};
|
||
tx.send(event).await.is_ok()
|
||
}
|
||
|
||
/// Sync counterpart of [`emit_delta`] for the CPU path's
|
||
/// `spawn_blocking` closure. Same shape, `blocking_send` instead of
|
||
/// `send`. Kept as a separate fn so the async / blocking-send choice
|
||
/// is local to one place per path.
|
||
fn emit_delta_blocking(delta: &str, tx: &mpsc::Sender<InferenceEvent>, in_reasoning: bool) -> bool {
|
||
if delta.is_empty() {
|
||
return true;
|
||
}
|
||
let event = if in_reasoning {
|
||
InferenceEvent::ReasoningDelta(delta.into())
|
||
} else {
|
||
InferenceEvent::TextDelta(delta.into())
|
||
};
|
||
tx.blocking_send(event).is_ok()
|
||
}
|
||
|
||
/// If `next_token` is one of the loaded model's reasoning markers,
|
||
/// flip `in_reasoning` and return `true` to tell the caller to
|
||
/// skip detokenisation + emission for this token. The markers
|
||
/// themselves never appear in the streamed output — they exist
|
||
/// purely to transition state.
|
||
///
|
||
/// `pair = None` short-circuits to `false` (no reasoning markers
|
||
/// configured for this model → pass-through).
|
||
fn handle_reasoning_marker(
|
||
next_token: u32,
|
||
pair: Option<&ReasoningTokenPair>,
|
||
in_reasoning: &mut bool,
|
||
) -> bool {
|
||
let Some(pair) = pair else { return false };
|
||
if next_token == pair.open_id {
|
||
*in_reasoning = true;
|
||
return true;
|
||
}
|
||
if next_token == pair.close_id {
|
||
*in_reasoning = false;
|
||
return true;
|
||
}
|
||
false
|
||
}
|
||
|
||
/// Outcome of checking a sampled token against the model's
|
||
/// tool-call markers.
|
||
enum ToolCallMarker {
|
||
/// Not a tool-call marker — caller proceeds with the normal
|
||
/// detokenize-and-emit path.
|
||
None,
|
||
/// `<tool_call>` token — caller starts buffering subsequent
|
||
/// detokenized text into the tool-call buffer instead of
|
||
/// emitting it. The token itself produces no output.
|
||
Enter,
|
||
/// `</tool_call>` token — caller takes ownership of the
|
||
/// buffered JSON, parses it, and emits either a structured
|
||
/// `InferenceEvent::ToolCall` or (on parse failure) the
|
||
/// original `<tool_call>{buf}</tool_call>` as text. The
|
||
/// returned buffer is `std::mem::take`-d out of the inner
|
||
/// state.
|
||
Exit { buffer: String },
|
||
}
|
||
|
||
fn handle_tool_call_marker(
|
||
next_token: u32,
|
||
pair: Option<&ToolCallTokenPair>,
|
||
in_tool_call: &mut bool,
|
||
buffer: &mut String,
|
||
) -> ToolCallMarker {
|
||
let Some(pair) = pair else {
|
||
return ToolCallMarker::None;
|
||
};
|
||
if next_token == pair.open_id {
|
||
*in_tool_call = true;
|
||
buffer.clear();
|
||
return ToolCallMarker::Enter;
|
||
}
|
||
if next_token == pair.close_id {
|
||
*in_tool_call = false;
|
||
return ToolCallMarker::Exit {
|
||
buffer: std::mem::take(buffer),
|
||
};
|
||
}
|
||
ToolCallMarker::None
|
||
}
|
||
|
||
/// Parse a `<tool_call>{json}</tool_call>` body into the fields the
|
||
/// `InferenceEvent::ToolCall` variant carries. Returns `None` when
|
||
/// the body isn't valid JSON or doesn't carry a `name`. The caller
|
||
/// falls back to passing the original text through on `None` so
|
||
/// downstream consumers (helexa-acp's existing `ToolCallParser`,
|
||
/// which has its own repair passes) can take another swing.
|
||
///
|
||
/// Generates a fresh `call_<hex>` id per parsed call; the model
|
||
/// itself doesn't include ids in the wire convention we model.
|
||
fn parse_tool_call_body(body: &str, index: usize) -> Option<(String, String, String)> {
|
||
let value: serde_json::Value = serde_json::from_str(body.trim()).ok()?;
|
||
let name = value.get("name")?.as_str()?.to_string();
|
||
let arguments = value
|
||
.get("arguments")
|
||
.map(|v| v.to_string())
|
||
.unwrap_or_else(|| "{}".into());
|
||
let id = format!("call_{:x}_{}", unix_subsec_nanos(), index);
|
||
Some((id, name, arguments))
|
||
}
|
||
|
||
/// Errors returned by `CandleHarness::chat_completion`. The
|
||
/// `ModelNotLoaded`, `PromptTooLong`, and `InsufficientVram` variants
|
||
/// let the HTTP handler map cleanly to 404 / 400 / 503 without
|
||
/// string-matching on anyhow messages.
|
||
#[derive(Debug, thiserror::Error)]
|
||
pub enum InferenceError {
|
||
#[error("model '{0}' not loaded on this neuron")]
|
||
ModelNotLoaded(String),
|
||
#[error("prompt has {prompt_len} tokens but max is {max}")]
|
||
PromptTooLong { prompt_len: usize, max: usize },
|
||
#[error(
|
||
"insufficient free VRAM for prefill: {free_mb} MiB free, need at least {required_mb} MiB"
|
||
)]
|
||
InsufficientVram { free_mb: u64, required_mb: u64 },
|
||
/// Request carried `image_url` content but the loaded model has
|
||
/// no vision tower. Stage B6 — replaces the silent-drop pattern
|
||
/// from issue #3 with an explicit 400 + `vision_unsupported`
|
||
/// error body that clients (litellm, agent0, …) can act on.
|
||
#[error(
|
||
"model '{model_id}' does not support image input; \
|
||
load a vision-capable model (e.g. Qwen/Qwen3.6-27B) or \
|
||
remove the image_url content parts from the request"
|
||
)]
|
||
VisionUnsupported { model_id: String },
|
||
#[error(transparent)]
|
||
Other(#[from] anyhow::Error),
|
||
}
|
||
|
||
/// Build the model's prompt from a [`ChatCompletionRequest`].
|
||
///
|
||
/// Prefers the model's own `chat_template` when one was loaded
|
||
/// from `tokenizer_config.json` at startup and the
|
||
/// `NEURON_USE_CHAT_TEMPLATE` kill switch isn't tripped. The
|
||
/// request's `chat_template_kwargs` (e.g.
|
||
/// `{"enable_thinking": false}` on Qwen3) and `tools` array flow
|
||
/// into the template's Jinja context so model-specific behaviour
|
||
/// like reasoning-suppression-at-generation works.
|
||
///
|
||
/// Falls back to [`format_qwen3_prompt`] (the legacy hardcoded
|
||
/// ChatML glue) on any of:
|
||
///
|
||
/// - no `chat_template` loaded for this model (older quantised
|
||
/// variants, fallback-only models)
|
||
/// - the env kill switch is set to a falsy value
|
||
/// - the template rendered to an error (caller can flip the env
|
||
/// var to force fallback while debugging the template)
|
||
///
|
||
/// Failures are logged at `warn` so an operator running with
|
||
/// `RUST_LOG=neuron=debug` sees which path each request took.
|
||
fn build_prompt_for_request(
|
||
chat_template: Option<&str>,
|
||
request: &ChatCompletionRequest,
|
||
) -> String {
|
||
if !super::chat_template::chat_templates_enabled() {
|
||
return format_qwen3_prompt(&request.messages);
|
||
}
|
||
let Some(tmpl) = chat_template else {
|
||
return format_qwen3_prompt(&request.messages);
|
||
};
|
||
|
||
// Pull `chat_template_kwargs` and `tools` from the request's
|
||
// catch-all `extra` field. Both are optional; absent fields
|
||
// become `Value::Null`, which the renderer skips inserting
|
||
// into the Jinja context.
|
||
let kwargs = request
|
||
.extra
|
||
.get("chat_template_kwargs")
|
||
.cloned()
|
||
.unwrap_or(serde_json::Value::Null);
|
||
let tools = request
|
||
.extra
|
||
.get("tools")
|
||
.cloned()
|
||
.unwrap_or(serde_json::Value::Null);
|
||
|
||
match super::chat_template::render_chat_template(tmpl, &request.messages, &tools, &kwargs) {
|
||
Ok(prompt) => prompt,
|
||
Err(e) => {
|
||
tracing::warn!(
|
||
model = %request.model,
|
||
error = %format!("{e:#}"),
|
||
"chat_template render failed; falling back to format_qwen3_prompt"
|
||
);
|
||
format_qwen3_prompt(&request.messages)
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Vision metadata derived at model-load time. Stashed on
|
||
/// `LoadedModel` so the chat-completion hot path doesn't have to
|
||
/// re-parse `config.json` or reach across the worker thread to peek
|
||
/// at the loaded `ModelArch`.
|
||
#[derive(Debug, Default, Clone, Copy)]
|
||
struct VisionMeta {
|
||
has_vision: bool,
|
||
image_token_id: Option<u32>,
|
||
/// `patch_size × spatial_merge_size` — the divisor that turns a
|
||
/// resized pixel dimension into an LM-grid dimension. An image of
|
||
/// resized `(h, w)` emits `(h/factor) × (w/factor)` LM tokens (#14
|
||
/// dynamic resolution; was a constant 196 at the old fixed 448²).
|
||
/// `None` for text-only models.
|
||
image_grid_factor: Option<usize>,
|
||
}
|
||
|
||
impl VisionMeta {
|
||
/// Peek at `config.json` for vision-related fields. Returns the
|
||
/// default (no-vision) struct on any read/parse error — vision is
|
||
/// best-effort metadata; load can still succeed for text usage.
|
||
fn from_config_path(config_path: &std::path::Path) -> Self {
|
||
let Ok(text) = std::fs::read_to_string(config_path) else {
|
||
return Self::default();
|
||
};
|
||
let Ok(v) = serde_json::from_str::<serde_json::Value>(&text) else {
|
||
return Self::default();
|
||
};
|
||
let Some(vision_config) = v.get("vision_config") else {
|
||
return Self::default();
|
||
};
|
||
let patch_size = vision_config
|
||
.get("patch_size")
|
||
.and_then(|x| x.as_u64())
|
||
.unwrap_or(16) as usize;
|
||
let spatial_merge_size = vision_config
|
||
.get("spatial_merge_size")
|
||
.and_then(|x| x.as_u64())
|
||
.unwrap_or(2) as usize;
|
||
let image_token_id = v
|
||
.get("image_token_id")
|
||
.and_then(|x| x.as_u64())
|
||
.map(|n| n as u32);
|
||
// The pixel→LM-grid divisor. An image resized to (h, w) emits
|
||
// (h/factor) × (w/factor) LM tokens — computed per image at
|
||
// request time now that resolution is dynamic (#14).
|
||
let image_grid_factor = if patch_size > 0 && spatial_merge_size > 0 {
|
||
Some(patch_size * spatial_merge_size)
|
||
} else {
|
||
None
|
||
};
|
||
Self {
|
||
has_vision: true,
|
||
image_token_id,
|
||
image_grid_factor,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// True iff any message in the request carries an `image_url`
|
||
/// content part. The Stage B routing decision in `chat_completion`
|
||
/// dispatches to the vision-aware worker job when this is true.
|
||
fn request_has_images(request: &ChatCompletionRequest) -> bool {
|
||
request.messages.iter().any(|m| {
|
||
matches!(&m.content, MessageContent::Parts(parts)
|
||
if parts.iter().any(|p|
|
||
p.get("type").and_then(|v| v.as_str()) == Some("image_url")))
|
||
})
|
||
}
|
||
|
||
/// Extract `image_url` content parts from a chat request and turn
|
||
/// each one into a preprocessed `ImageInput` ready for the device
|
||
/// worker. Stage B4.
|
||
///
|
||
/// Walks `request.messages`, looking for `MessageContent::Parts` and
|
||
/// pulling out entries whose `type == "image_url"`. Each is run
|
||
/// through `harness::preprocess::decode_data_uri` + `preprocess` with
|
||
/// the supplied `profile` (Stage B always uses
|
||
/// `PreprocessProfile::qwen3_6()` — fixed 448×448 — so every image
|
||
/// produces the same patch count; dynamic resolution per issue #14
|
||
/// would parameterise this).
|
||
///
|
||
/// Returns images in the order they appear in the request, which
|
||
/// matches the order the chat template emits `<|image_pad|>` tokens.
|
||
fn extract_images_from_request(
|
||
request: &ChatCompletionRequest,
|
||
profile: &super::preprocess::PreprocessProfile,
|
||
) -> anyhow::Result<Vec<super::device_worker::jobs::ImageInput>> {
|
||
let mut out = Vec::new();
|
||
for msg in &request.messages {
|
||
if let MessageContent::Parts(parts) = &msg.content {
|
||
for part in parts {
|
||
let kind = part.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||
if kind != "image_url" {
|
||
continue;
|
||
}
|
||
let url = part
|
||
.get("image_url")
|
||
.and_then(|v| v.get("url"))
|
||
.and_then(|v| v.as_str())
|
||
.ok_or_else(|| anyhow::anyhow!("image_url part missing url field"))?;
|
||
let (pixels, h, w) = super::preprocess::preprocess_data_uri(url, profile)
|
||
.with_context(|| format!("preprocess image #{}", out.len()))?;
|
||
out.push(super::device_worker::jobs::ImageInput {
|
||
pixels,
|
||
c: 3,
|
||
h: h as usize,
|
||
w: w as usize,
|
||
});
|
||
}
|
||
}
|
||
}
|
||
Ok(out)
|
||
}
|
||
|
||
/// Collect the raw `image_url.url` strings (data URIs) from a chat
|
||
/// request, in prompt order. The TP vision path (Stage C / TP-vision)
|
||
/// ships these verbatim to every rank, which each preprocess + encode
|
||
/// identically — so unlike `extract_images_from_request` (which
|
||
/// preprocesses on the leader for the single-GPU worker job) this
|
||
/// keeps the source form for replicated per-rank encoding.
|
||
///
|
||
/// Cuda-gated: the only callers are the TP entry points, which compile
|
||
/// only under the `cuda` feature.
|
||
#[cfg(feature = "cuda")]
|
||
fn extract_image_data_uris(request: &ChatCompletionRequest) -> Vec<String> {
|
||
let mut out = Vec::new();
|
||
for msg in &request.messages {
|
||
if let MessageContent::Parts(parts) = &msg.content {
|
||
for part in parts {
|
||
if part.get("type").and_then(|v| v.as_str()) != Some("image_url") {
|
||
continue;
|
||
}
|
||
if let Some(url) = part
|
||
.get("image_url")
|
||
.and_then(|v| v.get("url"))
|
||
.and_then(|v| v.as_str())
|
||
{
|
||
out.push(url.to_string());
|
||
}
|
||
}
|
||
}
|
||
}
|
||
out
|
||
}
|
||
|
||
/// Expand each occurrence of `image_token_id` in `input_ids` into
|
||
/// `patches_per_image[i]` copies (one expansion per image, in order).
|
||
/// Stage B4 helper.
|
||
///
|
||
/// The chat template emits a single `<|image_pad|>` per image; this
|
||
/// is what fits Qwen3-VL's template-then-runtime-expansion convention.
|
||
/// The runtime (us) is responsible for replacing each one with N
|
||
/// copies based on the corresponding image's patch count.
|
||
///
|
||
/// For Stage B fixed resolution every entry of `patches_per_image`
|
||
/// is the same constant (196 at 448×448). For dynamic resolution
|
||
/// (issue #14) each image gets its own value.
|
||
///
|
||
/// Errors if the number of `image_token_id` occurrences in `input_ids`
|
||
/// doesn't equal `patches_per_image.len()` — a mismatch means the
|
||
/// template emitted the wrong number of pad tokens (operator-visible
|
||
/// template bug, not a runtime error).
|
||
fn expand_image_pad_tokens(
|
||
input_ids: &[u32],
|
||
image_token_id: u32,
|
||
patches_per_image: &[usize],
|
||
) -> anyhow::Result<Vec<u32>> {
|
||
let occurrences = input_ids.iter().filter(|&&t| t == image_token_id).count();
|
||
if occurrences != patches_per_image.len() {
|
||
anyhow::bail!(
|
||
"expand_image_pad_tokens: prompt has {occurrences} image_token_id occurrences but \
|
||
{} images were preprocessed — chat template emitted the wrong number of pad tokens",
|
||
patches_per_image.len()
|
||
);
|
||
}
|
||
let total_extra: usize = patches_per_image.iter().map(|n| n.saturating_sub(1)).sum();
|
||
let mut out = Vec::with_capacity(input_ids.len() + total_extra);
|
||
let mut img_idx = 0;
|
||
for &t in input_ids {
|
||
if t == image_token_id {
|
||
let n = patches_per_image[img_idx];
|
||
for _ in 0..n {
|
||
out.push(image_token_id);
|
||
}
|
||
img_idx += 1;
|
||
} else {
|
||
out.push(t);
|
||
}
|
||
}
|
||
Ok(out)
|
||
}
|
||
|
||
/// Apply the Qwen3 chat template:
|
||
///
|
||
/// ```text
|
||
/// <|im_start|>{role}\n{content}<|im_end|>\n
|
||
/// ...
|
||
/// <|im_start|>assistant\n
|
||
/// ```
|
||
///
|
||
/// The trailing `<|im_start|>assistant\n` cues the model to begin a turn.
|
||
/// Non-text content parts (vision blocks) are joined as text only; full
|
||
/// multimodal handling is out of scope for Stage 3.
|
||
fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
|
||
let mut prompt = String::new();
|
||
for msg in messages {
|
||
let content = match &msg.content {
|
||
MessageContent::Text(s) => s.clone(),
|
||
MessageContent::Parts(parts) => parts
|
||
.iter()
|
||
.filter_map(|p| p.get("text").and_then(|v| v.as_str()))
|
||
.collect::<Vec<_>>()
|
||
.join(""),
|
||
};
|
||
prompt.push_str("<|im_start|>");
|
||
prompt.push_str(&msg.role);
|
||
prompt.push('\n');
|
||
prompt.push_str(&content);
|
||
prompt.push_str("<|im_end|>\n");
|
||
}
|
||
prompt.push_str("<|im_start|>assistant\n");
|
||
prompt
|
||
}
|
||
|
||
#[allow(clippy::too_many_arguments)]
|
||
/// Run the full single-GPU inference loop via the device worker.
|
||
///
|
||
/// Mirrors `run_inference`'s logic but routes each forward step
|
||
/// through `worker.forward_logits()` (returns CPU-side `Vec<f32>`)
|
||
/// and runs `apply_repeat_penalty` + sampling on a CPU candle tensor.
|
||
/// The device-resident logits tensor never escapes the worker thread.
|
||
///
|
||
/// Used by the CUDA path of `chat_completion`. The CPU path keeps
|
||
/// `run_inference` (spawn_blocking against `Arc<Mutex<ModelArch>>`)
|
||
/// because there's no CUDA context to own and the worker indirection
|
||
/// would only add channel overhead with no diagnostic benefit.
|
||
#[cfg(feature = "cuda")]
|
||
#[allow(clippy::too_many_arguments)]
|
||
/// Vision-aware analogue of `run_inference_via_worker`. Stage B5.
|
||
///
|
||
/// Single-shot prefill carrying the pre-expanded prompt + the image
|
||
/// payloads. The worker encodes each image through the vision tower,
|
||
/// splices the resulting embeddings at `image_token_id` positions,
|
||
/// and returns the last-position logits. Decode steps thereafter
|
||
/// follow the existing text-only `forward_logits` path — the KV
|
||
/// cache holds the image-conditioned hidden states from prefill, so
|
||
/// no further splicing is needed.
|
||
///
|
||
/// Stage B skips chunked prefill for vision (the fixed-resolution
|
||
/// budget — 196 image tokens at 448×448 + typical text — stays well
|
||
/// under the activation-memory threshold). Long-prompt-with-images
|
||
/// chunking is a Stage D follow-up.
|
||
#[allow(clippy::too_many_arguments)]
|
||
async fn run_inference_with_images_via_worker(
|
||
worker: &super::device_worker::DeviceWorkerHandle,
|
||
handle: super::device_worker::ArchHandle,
|
||
prompt_tokens: &[u32],
|
||
images: Vec<super::device_worker::jobs::ImageInput>,
|
||
image_token_id: u32,
|
||
max_new: usize,
|
||
temperature: f64,
|
||
top_p: Option<f64>,
|
||
seed: u64,
|
||
eos_id: Option<u32>,
|
||
) -> Result<(Vec<u32>, String)> {
|
||
let mut logits_processor = {
|
||
let sampling = if temperature <= 0.0 {
|
||
Sampling::ArgMax
|
||
} else {
|
||
match top_p {
|
||
Some(p) => Sampling::TopP { p, temperature },
|
||
None => Sampling::All { temperature },
|
||
}
|
||
};
|
||
LogitsProcessor::from_sampling(seed, sampling)
|
||
};
|
||
|
||
let mut generated: Vec<u32> = Vec::new();
|
||
let prompt_len = prompt_tokens.len();
|
||
|
||
worker
|
||
.clear_kv_cache(handle)
|
||
.await
|
||
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
|
||
|
||
// Single-shot prefill with image splicing.
|
||
let logits_vec = worker
|
||
.forward_logits_with_images(handle, prompt_tokens.to_vec(), 0, images, image_token_id)
|
||
.await
|
||
.map_err(|e| anyhow::anyhow!("forward_logits_with_images: {e}"))?;
|
||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||
Ok(t) => t,
|
||
Err(e) => {
|
||
let health = logits_health_slice(&logits_vec);
|
||
tracing::warn!(
|
||
?health,
|
||
"chat_completion (worker, vision): prefill sample failed; logits unhealthy"
|
||
);
|
||
return Err(e);
|
||
}
|
||
};
|
||
|
||
if Some(next_token) == eos_id {
|
||
return Ok((generated, "stop".into()));
|
||
}
|
||
generated.push(next_token);
|
||
|
||
for index in 0..max_new.saturating_sub(1) {
|
||
let logits_vec = worker
|
||
.forward_logits(handle, vec![next_token], prompt_len + index)
|
||
.await
|
||
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
|
||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||
Ok(t) => t,
|
||
Err(e) => {
|
||
let health = logits_health_slice(&logits_vec);
|
||
tracing::warn!(
|
||
step = index,
|
||
?health,
|
||
"chat_completion (worker, vision): decode sample failed; logits unhealthy"
|
||
);
|
||
return Err(e);
|
||
}
|
||
};
|
||
if Some(next_token) == eos_id {
|
||
return Ok((generated, "stop".into()));
|
||
}
|
||
generated.push(next_token);
|
||
}
|
||
Ok((generated, "length".into()))
|
||
}
|
||
|
||
#[cfg(feature = "cuda")]
|
||
async fn run_inference_via_worker(
|
||
worker: &super::device_worker::DeviceWorkerHandle,
|
||
handle: super::device_worker::ArchHandle,
|
||
prompt_tokens: &[u32],
|
||
max_new: usize,
|
||
temperature: f64,
|
||
top_p: Option<f64>,
|
||
seed: u64,
|
||
eos_id: Option<u32>,
|
||
) -> Result<(Vec<u32>, String)> {
|
||
let mut logits_processor = {
|
||
let sampling = if temperature <= 0.0 {
|
||
Sampling::ArgMax
|
||
} else {
|
||
match top_p {
|
||
Some(p) => Sampling::TopP { p, temperature },
|
||
None => Sampling::All { temperature },
|
||
}
|
||
};
|
||
LogitsProcessor::from_sampling(seed, sampling)
|
||
};
|
||
|
||
let mut generated: Vec<u32> = Vec::new();
|
||
let prompt_len = prompt_tokens.len();
|
||
|
||
worker
|
||
.clear_kv_cache(handle)
|
||
.await
|
||
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
|
||
|
||
// Prefill the prompt in `prefill_chunk_tokens()`-sized chunks so
|
||
// activation memory is bounded per step rather than scaling with
|
||
// prompt length. The KV cache accumulates across chunks; we keep
|
||
// only the final chunk's logits for sampling the first generated
|
||
// token.
|
||
let logits_vec = chunked_prefill_via_worker(worker, handle, prompt_tokens).await?;
|
||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||
Ok(t) => t,
|
||
Err(e) => {
|
||
let health = logits_health_slice(&logits_vec);
|
||
tracing::warn!(
|
||
?health,
|
||
"chat_completion (worker): prefill sample failed; logits unhealthy"
|
||
);
|
||
return Err(e);
|
||
}
|
||
};
|
||
|
||
if Some(next_token) == eos_id {
|
||
return Ok((generated, "stop".into()));
|
||
}
|
||
generated.push(next_token);
|
||
|
||
for index in 0..max_new.saturating_sub(1) {
|
||
let logits_vec = worker
|
||
.forward_logits(handle, vec![next_token], prompt_len + index)
|
||
.await
|
||
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
|
||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||
Ok(t) => t,
|
||
Err(e) => {
|
||
let health = logits_health_slice(&logits_vec);
|
||
tracing::warn!(
|
||
step = index,
|
||
?health,
|
||
"chat_completion (worker): decode sample failed; logits unhealthy"
|
||
);
|
||
return Err(e);
|
||
}
|
||
};
|
||
if Some(next_token) == eos_id {
|
||
return Ok((generated, "stop".into()));
|
||
}
|
||
generated.push(next_token);
|
||
}
|
||
|
||
Ok((generated, "length".into()))
|
||
}
|
||
|
||
/// Streaming counterpart of [`run_inference_via_worker`]. Emits one
|
||
/// `ChatCompletionChunk` per generated token via `tx`; routes every
|
||
/// forward step through `worker.forward_logits()`. Same per-step
|
||
/// CPU-side sampling discipline — no device tensor escapes the
|
||
/// worker thread.
|
||
///
|
||
/// `images` carries the Stage C vision payload. When `Some`, prefill
|
||
/// is a single-shot `forward_logits_with_images` that splices image
|
||
/// embeddings at `image_token_id` positions (same contract as the
|
||
/// non-streaming [`run_inference_with_images_via_worker`]); image
|
||
/// embeddings are prefill-only, so every decode step below takes the
|
||
/// plain `forward_logits` path regardless. When `None`, prefill is
|
||
/// chunked (`chunked_prefill_via_worker`) to bound activation memory
|
||
/// — the original text-only behaviour, unchanged. The decode loop and
|
||
/// the `route_token!` reasoning/tool-call state machine are shared
|
||
/// across both prefill shapes, so there's exactly one copy to maintain.
|
||
#[cfg(feature = "cuda")]
|
||
#[allow(clippy::too_many_arguments)]
|
||
async fn stream_inference_via_worker(
|
||
worker: Arc<super::device_worker::DeviceWorkerHandle>,
|
||
handle: super::device_worker::ArchHandle,
|
||
tokenizer: Tokenizer,
|
||
prompt_tokens: Vec<u32>,
|
||
images: Option<(Vec<super::device_worker::jobs::ImageInput>, u32)>,
|
||
max_new: usize,
|
||
temperature: f64,
|
||
top_p: Option<f64>,
|
||
seed: u64,
|
||
eos_id: Option<u32>,
|
||
reasoning_tokens: Option<ReasoningTokenPair>,
|
||
tool_call_tokens: Option<ToolCallTokenPair>,
|
||
tx: mpsc::Sender<InferenceEvent>,
|
||
) -> Result<String> {
|
||
let mut logits_processor = {
|
||
let sampling = if temperature <= 0.0 {
|
||
Sampling::ArgMax
|
||
} else {
|
||
match top_p {
|
||
Some(p) => Sampling::TopP { p, temperature },
|
||
None => Sampling::All { temperature },
|
||
}
|
||
};
|
||
LogitsProcessor::from_sampling(seed, sampling)
|
||
};
|
||
|
||
let mut all_tokens: Vec<u32> = Vec::new();
|
||
// Incremental detokenizer. Replaces the old "decode cumulative
|
||
// tokens, byte-slice the delta against a stored prefix" pattern
|
||
// that panicked when BPE byte-fallback split a multi-byte UTF-8
|
||
// sequence (e.g. an emoji) across tokens. `step` returns
|
||
// `Ok(Some(delta))` only when the trailing bytes form a complete
|
||
// codepoint; `Ok(None)` while it's buffering an incomplete one.
|
||
let mut decode_stream = tokenizer.decode_stream(true);
|
||
let prompt_len = prompt_tokens.len();
|
||
let mut finish_reason = FinishReason::Length;
|
||
// Reasoning + tool-call state machines — see
|
||
// `run_inference_streaming` for the why. Markers never reach
|
||
// `decode_stream`; they toggle state. Tool-call content
|
||
// accumulates into `tool_call_buf` until the close marker.
|
||
let mut in_reasoning = false;
|
||
let mut in_tool_call = false;
|
||
let mut tool_call_buf = String::new();
|
||
let mut tool_call_idx: usize = 0;
|
||
|
||
worker
|
||
.clear_kv_cache(handle)
|
||
.await
|
||
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
|
||
|
||
// Prefill. Vision-bearing requests (`images = Some`) do a
|
||
// single-shot prefill that splices the image embeddings; text-only
|
||
// requests use chunked prefill (see `chunked_prefill_via_worker`)
|
||
// to bound activation memory. Either way the owning
|
||
// `prompt_tokens: Vec<u32>` outlives this step; we use `prompt_len`
|
||
// (already extracted above) for the decode-step offset arithmetic.
|
||
let logits_vec = match images {
|
||
Some((imgs, image_token_id)) => worker
|
||
.forward_logits_with_images(handle, prompt_tokens.clone(), 0, imgs, image_token_id)
|
||
.await
|
||
.map_err(|e| anyhow::anyhow!("forward_logits_with_images: {e}"))?,
|
||
None => chunked_prefill_via_worker(&*worker, handle, &prompt_tokens).await?,
|
||
};
|
||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||
let mut next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||
Ok(t) => t,
|
||
Err(e) => {
|
||
let health = logits_health_slice(&logits_vec);
|
||
tracing::warn!(
|
||
?health,
|
||
"chat_completion (stream/worker): prefill sample failed; logits unhealthy"
|
||
);
|
||
return Err(e);
|
||
}
|
||
};
|
||
|
||
// Per-token routing. `tokenizers::DecodeStream` carries five
|
||
// generic parameters (`M, N, PT, PP, D`) which makes naming
|
||
// its type from a helper signature painful. Use a macro
|
||
// instead — the body expands inline with `decode_stream`'s
|
||
// concrete type inferred from the call site. The macro
|
||
// contains `.await` calls, so it can only expand inside an
|
||
// `async` context (which both call sites below are).
|
||
//
|
||
// The macro takes a single `$next_token` expression and
|
||
// returns control to the enclosing scope via `break 'work_step`
|
||
// (success path) — labels are needed because Rust macros can't
|
||
// emit naked `return` from the caller when the caller's return
|
||
// type isn't `()`. Instead the macro `break`s out of a
|
||
// labelled block, and the surrounding `if !routed { ... }`
|
||
// checks whether the consumer hung up via a captured `routed`
|
||
// flag.
|
||
macro_rules! route_token {
|
||
($next_token:expr) => {{
|
||
let nt = $next_token;
|
||
all_tokens.push(nt);
|
||
let mut consumer_alive = true;
|
||
'route: {
|
||
match handle_tool_call_marker(
|
||
nt,
|
||
tool_call_tokens.as_ref(),
|
||
&mut in_tool_call,
|
||
&mut tool_call_buf,
|
||
) {
|
||
ToolCallMarker::Enter => break 'route,
|
||
ToolCallMarker::Exit { buffer } => {
|
||
let idx = tool_call_idx;
|
||
tool_call_idx += 1;
|
||
match parse_tool_call_body(&buffer, idx) {
|
||
Some((id, name, arguments)) => {
|
||
if tx
|
||
.send(InferenceEvent::ToolCall {
|
||
index: idx,
|
||
id,
|
||
name,
|
||
arguments,
|
||
})
|
||
.await
|
||
.is_err()
|
||
{
|
||
consumer_alive = false;
|
||
}
|
||
}
|
||
None => {
|
||
let open = tool_call_tokens
|
||
.as_ref()
|
||
.map(|p| p.open_text.as_str())
|
||
.unwrap_or("<tool_call>");
|
||
let close = tool_call_tokens
|
||
.as_ref()
|
||
.map(|p| p.close_text.as_str())
|
||
.unwrap_or("</tool_call>");
|
||
let raw = format!("{open}{buffer}{close}");
|
||
if !emit_delta(&raw, &tx, in_reasoning).await {
|
||
consumer_alive = false;
|
||
}
|
||
}
|
||
}
|
||
break 'route;
|
||
}
|
||
ToolCallMarker::None => {}
|
||
}
|
||
if in_tool_call {
|
||
match decode_stream.step(nt) {
|
||
Ok(Some(s)) => tool_call_buf.push_str(&s),
|
||
Ok(None) => {}
|
||
Err(e) => tracing::warn!(
|
||
error = %e,
|
||
"decode_stream step failed (in tool_call)"
|
||
),
|
||
}
|
||
break 'route;
|
||
}
|
||
if handle_reasoning_marker(nt, reasoning_tokens.as_ref(), &mut in_reasoning) {
|
||
break 'route;
|
||
}
|
||
match decode_stream.step(nt) {
|
||
Ok(Some(delta)) => {
|
||
if !emit_delta(&delta, &tx, in_reasoning).await {
|
||
consumer_alive = false;
|
||
}
|
||
}
|
||
Ok(None) => {}
|
||
Err(e) => tracing::warn!(error = %e, "decode_stream step failed"),
|
||
}
|
||
}
|
||
consumer_alive
|
||
}};
|
||
}
|
||
|
||
if Some(next_token) == eos_id {
|
||
finish_reason = FinishReason::Stop;
|
||
} else if !route_token!(next_token) {
|
||
return Ok(finish_reason.as_openai_str().to_string());
|
||
}
|
||
|
||
for index in 0..max_new.saturating_sub(1) {
|
||
let logits_vec = worker
|
||
.forward_logits(handle, vec![next_token], prompt_len + index)
|
||
.await
|
||
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
|
||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||
next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||
Ok(t) => t,
|
||
Err(e) => {
|
||
let health = logits_health_slice(&logits_vec);
|
||
tracing::warn!(
|
||
step = index,
|
||
?health,
|
||
"chat_completion (stream/worker): decode sample failed; logits unhealthy"
|
||
);
|
||
return Err(e);
|
||
}
|
||
};
|
||
if Some(next_token) == eos_id {
|
||
finish_reason = FinishReason::Stop;
|
||
break;
|
||
}
|
||
if !route_token!(next_token) {
|
||
return Ok(finish_reason.as_openai_str().to_string());
|
||
}
|
||
}
|
||
|
||
// Terminal Finish event. The wire projector turns this into a
|
||
// format-specific final chunk (`finish_reason: "stop"` on
|
||
// OpenAI chat, `response.completed` on Responses).
|
||
let _ = tx
|
||
.send(InferenceEvent::Finish {
|
||
reason: finish_reason,
|
||
})
|
||
.await;
|
||
|
||
Ok(finish_reason.as_openai_str().to_string())
|
||
}
|
||
|
||
#[allow(clippy::too_many_arguments)]
|
||
fn run_inference(
|
||
arch: &mut ModelArch,
|
||
device: &Device,
|
||
prompt_tokens: &[u32],
|
||
max_new: usize,
|
||
temperature: f64,
|
||
top_p: Option<f64>,
|
||
seed: u64,
|
||
eos_id: Option<u32>,
|
||
) -> Result<(Vec<u32>, String)> {
|
||
let mut logits_processor = {
|
||
let sampling = if temperature <= 0.0 {
|
||
Sampling::ArgMax
|
||
} else {
|
||
match top_p {
|
||
Some(p) => Sampling::TopP { p, temperature },
|
||
None => Sampling::All { temperature },
|
||
}
|
||
};
|
||
LogitsProcessor::from_sampling(seed, sampling)
|
||
};
|
||
|
||
let mut generated: Vec<u32> = Vec::new();
|
||
|
||
arch.clear_kv_cache()?;
|
||
let logits = chunked_prefill_local(arch, device, prompt_tokens)?;
|
||
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
|
||
|
||
if Some(next_token) == eos_id {
|
||
return Ok((generated, "stop".into()));
|
||
}
|
||
generated.push(next_token);
|
||
|
||
for index in 0..max_new.saturating_sub(1) {
|
||
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
|
||
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)?;
|
||
if Some(next_token) == eos_id {
|
||
return Ok((generated, "stop".into()));
|
||
}
|
||
generated.push(next_token);
|
||
}
|
||
|
||
Ok((generated, "length".into()))
|
||
}
|
||
|
||
/// Streaming counterpart to `run_inference`. Emits chunks via `tx` as
|
||
/// tokens are generated and exits on EOS, max_new, or receiver drop.
|
||
///
|
||
/// Detokenization tracks the cumulative decoded prefix so each chunk's
|
||
/// `content` delta is the substring appended since the last chunk —
|
||
/// safe across BPE byte-fallback boundaries.
|
||
#[allow(clippy::too_many_arguments)]
|
||
fn run_inference_streaming(
|
||
arch: &mut ModelArch,
|
||
device: &Device,
|
||
tokenizer: &Tokenizer,
|
||
prompt_tokens: &[u32],
|
||
max_new: usize,
|
||
temperature: f64,
|
||
top_p: Option<f64>,
|
||
seed: u64,
|
||
eos_id: Option<u32>,
|
||
reasoning_tokens: Option<&ReasoningTokenPair>,
|
||
tool_call_tokens: Option<&ToolCallTokenPair>,
|
||
tx: &mpsc::Sender<InferenceEvent>,
|
||
) -> Result<()> {
|
||
let mut logits_processor = {
|
||
let sampling = if temperature <= 0.0 {
|
||
Sampling::ArgMax
|
||
} else {
|
||
match top_p {
|
||
Some(p) => Sampling::TopP { p, temperature },
|
||
None => Sampling::All { temperature },
|
||
}
|
||
};
|
||
LogitsProcessor::from_sampling(seed, sampling)
|
||
};
|
||
|
||
let mut all_tokens: Vec<u32> = Vec::new();
|
||
// Incremental detokenizer. See `stream_inference_via_worker` for
|
||
// the same reasoning — `tokenizer.decode_stream(true).step(id)`
|
||
// buffers incomplete multi-byte UTF-8 sequences across token
|
||
// boundaries and only emits when a clean codepoint completes.
|
||
let mut decode_stream = tokenizer.decode_stream(true);
|
||
let mut finish_reason = FinishReason::Length;
|
||
// Reasoning marker state machine. Flips on
|
||
// `next_token == reasoning_tokens.open_id`, off on
|
||
// `.close_id`. The marker tokens themselves never feed into
|
||
// `decode_stream` — they aren't part of any visible output,
|
||
// they exist purely as state transitions.
|
||
let mut in_reasoning = false;
|
||
// Tool-call state. While `in_tool_call`, content tokens get
|
||
// accumulated into `tool_call_buf` instead of emitted; on the
|
||
// close marker we parse the buffer and emit a structured
|
||
// ToolCall event (or fall back to passing the raw text
|
||
// through if the buffer doesn't parse).
|
||
let mut in_tool_call = false;
|
||
let mut tool_call_buf = String::new();
|
||
let mut tool_call_idx: usize = 0;
|
||
|
||
arch.clear_kv_cache()?;
|
||
let logits = chunked_prefill_local(arch, device, prompt_tokens)?;
|
||
let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
|
||
|
||
// Per-token routing block, used at both the prefill-sample
|
||
// tail and the decode loop. Macros are ugly but Rust's
|
||
// closure inference fights `&mut DecodeStream<'_>` capture +
|
||
// mutable borrows of the surrounding `tool_call_buf` /
|
||
// `in_reasoning` / etc. Inline the body via a macro and live
|
||
// with the duplication of the call sites instead.
|
||
macro_rules! route_token {
|
||
($next_token:expr) => {{
|
||
let nt = $next_token;
|
||
all_tokens.push(nt);
|
||
match handle_tool_call_marker(nt, tool_call_tokens, &mut in_tool_call, &mut tool_call_buf) {
|
||
ToolCallMarker::Enter => {}
|
||
ToolCallMarker::Exit { buffer } => {
|
||
let idx = tool_call_idx;
|
||
tool_call_idx += 1;
|
||
match parse_tool_call_body(&buffer, idx) {
|
||
Some((id, name, arguments)) => {
|
||
if tx
|
||
.blocking_send(InferenceEvent::ToolCall {
|
||
index: idx,
|
||
id,
|
||
name,
|
||
arguments,
|
||
})
|
||
.is_err()
|
||
{
|
||
return Ok(());
|
||
}
|
||
}
|
||
None => {
|
||
// Malformed JSON — pass the block
|
||
// through as text so consumer parsers
|
||
// can try their own repair.
|
||
let open = tool_call_tokens
|
||
.map(|p| p.open_text.as_str())
|
||
.unwrap_or("<tool_call>");
|
||
let close = tool_call_tokens
|
||
.map(|p| p.close_text.as_str())
|
||
.unwrap_or("</tool_call>");
|
||
let raw = format!("{open}{buffer}{close}");
|
||
if !emit_delta_blocking(&raw, tx, in_reasoning) {
|
||
return Ok(());
|
||
}
|
||
}
|
||
}
|
||
}
|
||
ToolCallMarker::None => {
|
||
if in_tool_call {
|
||
// Buffer JSON content without emitting.
|
||
match decode_stream.step(nt) {
|
||
Ok(Some(s)) => tool_call_buf.push_str(&s),
|
||
Ok(None) => {}
|
||
Err(e) => tracing::warn!(
|
||
error = %e,
|
||
"stream: decode_stream step failed (in tool_call)"
|
||
),
|
||
}
|
||
} else if handle_reasoning_marker(nt, reasoning_tokens, &mut in_reasoning) {
|
||
// marker — nothing to emit
|
||
} else {
|
||
match decode_stream.step(nt) {
|
||
Ok(Some(delta)) => {
|
||
if !emit_delta_blocking(&delta, tx, in_reasoning) {
|
||
return Ok(());
|
||
}
|
||
}
|
||
Ok(None) => {}
|
||
Err(e) => tracing::warn!(
|
||
error = %e,
|
||
"stream: decode_stream step failed"
|
||
),
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}};
|
||
}
|
||
|
||
if Some(next_token) == eos_id {
|
||
finish_reason = FinishReason::Stop;
|
||
} else {
|
||
route_token!(next_token);
|
||
}
|
||
|
||
for index in 0..max_new.saturating_sub(1) {
|
||
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
|
||
next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
|
||
if Some(next_token) == eos_id {
|
||
finish_reason = FinishReason::Stop;
|
||
break;
|
||
}
|
||
route_token!(next_token);
|
||
}
|
||
|
||
let _ = tx.blocking_send(InferenceEvent::Finish {
|
||
reason: finish_reason,
|
||
});
|
||
Ok(())
|
||
}
|
||
|
||
fn unix_now_secs() -> u64 {
|
||
SystemTime::now()
|
||
.duration_since(UNIX_EPOCH)
|
||
.map(|d| d.as_secs())
|
||
.unwrap_or(0)
|
||
}
|
||
|
||
fn unix_subsec_nanos() -> u64 {
|
||
SystemTime::now()
|
||
.duration_since(UNIX_EPOCH)
|
||
.map(|d| d.as_nanos() as u64)
|
||
.unwrap_or(0)
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn check_dense_config_accepts_qwen3() {
|
||
let cfg = r#"{
|
||
"model_type": "qwen3",
|
||
"vocab_size": 151936,
|
||
"architectures": ["Qwen3ForCausalLM"]
|
||
}"#;
|
||
check_dense_config_supported(cfg, "Qwen/Qwen3-1.7B").expect("qwen3 should pass");
|
||
}
|
||
|
||
#[test]
|
||
fn check_dense_config_rejects_unsupported_arch_with_clear_message() {
|
||
// Use a deliberately-fake model_type so this test stays
|
||
// meaningful as the supported set grows. (qwen3_5 was the
|
||
// motivating real example but now lives in the supported set
|
||
// as a Stage 8c scaffold.)
|
||
let cfg = r#"{
|
||
"model_type": "fictional_arch_99",
|
||
"architectures": ["FictionalArch99ForCausalLM"]
|
||
}"#;
|
||
let err = check_dense_config_supported(cfg, "Fake/Model-99")
|
||
.expect_err("fictional_arch_99 should be rejected");
|
||
let msg = format!("{err}");
|
||
assert!(
|
||
msg.contains("unsupported model_type 'fictional_arch_99'"),
|
||
"message should name the rejected type: {msg}"
|
||
);
|
||
assert!(
|
||
msg.contains("Fake/Model-99"),
|
||
"message should echo the model id: {msg}"
|
||
);
|
||
assert!(
|
||
msg.contains("qwen3"),
|
||
"message should list the supported set: {msg}"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn check_dense_config_accepts_qwen3_5() {
|
||
// Sanity: Stage 8c scaffold means qwen3_5 deserialises into the
|
||
// supported set. Forward still bails (covered by tests on the
|
||
// architecture module itself), but the dispatch gate must let
|
||
// it through.
|
||
let cfg = r#"{
|
||
"model_type": "qwen3_5",
|
||
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
||
"text_config": {"hidden_size": 5120}
|
||
}"#;
|
||
check_dense_config_supported(cfg, "Qwen/Qwen3.6-27B")
|
||
.expect("qwen3_5 should be in the supported set as of Stage 8c scaffold");
|
||
}
|
||
|
||
#[test]
|
||
fn check_dense_config_rejects_missing_model_type() {
|
||
let cfg = r#"{ "vocab_size": 1234 }"#;
|
||
let err = check_dense_config_supported(cfg, "anon/no-type")
|
||
.expect_err("missing model_type should be rejected");
|
||
assert!(
|
||
format!("{err}").contains("missing `model_type`"),
|
||
"message should call out the missing field"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn check_dense_config_rejects_invalid_json() {
|
||
let err = check_dense_config_supported("not json", "anon/bad-json")
|
||
.expect_err("malformed JSON should be rejected");
|
||
assert!(
|
||
format!("{err:#}").contains("config.json"),
|
||
"message should mention config.json"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn is_device_fault_rejects_known_non_device_errors() {
|
||
// Shape mismatches happen pre-kernel; device is healthy.
|
||
assert!(!is_device_fault(
|
||
"prefill chunk 0/9: shape mismatch in broadcast_add, lhs: [1, 32, 512, 1024], rhs: [1, 1, 512, 512]"
|
||
));
|
||
// NaN logits are CPU-side numerical, not driver.
|
||
assert!(!is_device_fault(
|
||
"prefill sample failed; logits unhealthy nan: 248320/248320"
|
||
));
|
||
// Tokenizer/detokenizer errors are pure host.
|
||
assert!(!is_device_fault("tokenize: invalid utf-8 sequence"));
|
||
assert!(!is_device_fault("detokenize: byte fallback failed"));
|
||
// Missing handle is a dispatch-side bug, not a device fault.
|
||
assert!(!is_device_fault("ForwardLogits: no model for handle 42"));
|
||
// DecodeStream errors during SSE are not device faults.
|
||
assert!(!is_device_fault(
|
||
"decode_stream step failed: invalid prefix"
|
||
));
|
||
}
|
||
|
||
#[test]
|
||
fn is_device_fault_defaults_to_poisoning() {
|
||
// Unknown errors default to "poison" — better to over-reject
|
||
// than to keep serving from a corrupted context.
|
||
assert!(is_device_fault("some unrecognised candle error"));
|
||
// Real driver faults — these strings come from cudarc's
|
||
// DriverError Display impl and we want them to poison.
|
||
assert!(is_device_fault(
|
||
"leader forward failed: DriverError(CUDA_ERROR_OUT_OF_MEMORY, \"out of memory\")"
|
||
));
|
||
assert!(is_device_fault(
|
||
"DriverError(CUDA_ERROR_ILLEGAL_ADDRESS, \"an illegal memory access was encountered\")"
|
||
));
|
||
}
|
||
|
||
/// Phase 1 of plan-source-aware-loader: harness must resolve each
|
||
/// configured scheme to its own endpoint+cache, and reject schemes
|
||
/// the operator hasn't configured with a useful error.
|
||
#[test]
|
||
fn hf_api_for_routes_per_scheme() {
|
||
use crate::config::{CandleHarnessConfig, SourceConfig};
|
||
use std::collections::HashMap;
|
||
|
||
let mut sources = HashMap::new();
|
||
sources.insert(
|
||
"huggingface".to_string(),
|
||
SourceConfig {
|
||
endpoint: "https://huggingface.example.org".into(),
|
||
auth_env: None,
|
||
cache_dir: Some(std::path::PathBuf::from("/tmp/hf-cache")),
|
||
},
|
||
);
|
||
sources.insert(
|
||
"helexa".to_string(),
|
||
SourceConfig {
|
||
endpoint: "https://registry.helexa.example.ai".into(),
|
||
auth_env: None,
|
||
cache_dir: Some(std::path::PathBuf::from("/tmp/helexa-cache")),
|
||
},
|
||
);
|
||
let cfg = CandleHarnessConfig {
|
||
sources,
|
||
default_source: Some("huggingface".into()),
|
||
..Default::default()
|
||
};
|
||
let harness = CandleHarness::new("http://localhost:13131".into(), &cfg);
|
||
|
||
// Both configured schemes build cleanly.
|
||
harness
|
||
.hf_api_for("huggingface")
|
||
.expect("huggingface scheme should build");
|
||
harness
|
||
.hf_api_for("helexa")
|
||
.expect("helexa scheme should build");
|
||
|
||
// Unknown scheme errors with a message that names the configured
|
||
// set so the operator can act on it.
|
||
let err = harness
|
||
.hf_api_for("does-not-exist")
|
||
.expect_err("unknown scheme should error");
|
||
let msg = format!("{err:#}");
|
||
assert!(
|
||
msg.contains("does-not-exist") && msg.contains("huggingface") && msg.contains("helexa"),
|
||
"error must list configured schemes: {msg}"
|
||
);
|
||
|
||
assert_eq!(harness.default_source_scheme(), "huggingface");
|
||
}
|
||
|
||
/// Operator with only `hf_cache` set (no `sources` table) still
|
||
/// gets a working `huggingface` source pointed at HF.
|
||
#[test]
|
||
fn hf_api_for_synthesises_huggingface_from_legacy_hf_cache() {
|
||
use crate::config::CandleHarnessConfig;
|
||
|
||
let cfg = CandleHarnessConfig {
|
||
hf_cache: Some(std::path::PathBuf::from("/archive3/llm-cache")),
|
||
..Default::default()
|
||
};
|
||
let harness = CandleHarness::new("http://localhost:13131".into(), &cfg);
|
||
harness
|
||
.hf_api_for("huggingface")
|
||
.expect("synth huggingface source should build");
|
||
assert_eq!(harness.default_source_scheme(), "huggingface");
|
||
}
|
||
|
||
#[test]
|
||
fn expand_image_pad_replaces_single_token_with_n_copies() {
|
||
// Mimics the chat template's output: each image emits
|
||
// [vision_start, image_pad, vision_end]. After expansion
|
||
// with 3 patches/image we want
|
||
// [vision_start, pad×3, vision_end].
|
||
let pad = 248056_u32;
|
||
let vstart = 248053_u32;
|
||
let vend = 248054_u32;
|
||
let input = vec![1, vstart, pad, vend, 2];
|
||
let out = expand_image_pad_tokens(&input, pad, &[3]).unwrap();
|
||
assert_eq!(out, vec![1, vstart, pad, pad, pad, vend, 2]);
|
||
}
|
||
|
||
#[test]
|
||
fn expand_image_pad_handles_multiple_images() {
|
||
let pad = 248056_u32;
|
||
// Two images in one prompt; first gets 2 patches, second 3.
|
||
let input = vec![pad, 99, pad];
|
||
let out = expand_image_pad_tokens(&input, pad, &[2, 3]).unwrap();
|
||
assert_eq!(out, vec![pad, pad, 99, pad, pad, pad]);
|
||
}
|
||
|
||
#[test]
|
||
fn expand_image_pad_errors_on_count_mismatch() {
|
||
let pad = 248056_u32;
|
||
// Prompt has 2 pad tokens but caller supplied 3 images.
|
||
let input = vec![pad, 99, pad];
|
||
let err = expand_image_pad_tokens(&input, pad, &[2, 3, 4]).unwrap_err();
|
||
assert!(format!("{err:#}").contains("emitted the wrong number"));
|
||
}
|
||
|
||
#[test]
|
||
fn expand_image_pad_passes_through_when_no_images() {
|
||
let pad = 248056_u32;
|
||
let input = vec![1, 2, 3];
|
||
let out = expand_image_pad_tokens(&input, pad, &[]).unwrap();
|
||
assert_eq!(out, input);
|
||
}
|
||
|
||
/// `request_has_images` is the gate that routes both the
|
||
/// non-streaming (`chat_completion`) and streaming
|
||
/// (`inference_stream`, Stage C1) paths to the vision-aware
|
||
/// prefill. Exercise the three shapes it must distinguish: plain
|
||
/// text, a text-only content-parts array, and a parts array
|
||
/// carrying an `image_url`.
|
||
#[test]
|
||
fn request_has_images_detects_image_url_parts() {
|
||
let text_only: ChatCompletionRequest = serde_json::from_value(serde_json::json!({
|
||
"model": "m",
|
||
"messages": [{"role": "user", "content": "hello"}],
|
||
}))
|
||
.unwrap();
|
||
assert!(!request_has_images(&text_only));
|
||
|
||
let parts_text_only: ChatCompletionRequest = serde_json::from_value(serde_json::json!({
|
||
"model": "m",
|
||
"messages": [{"role": "user", "content": [
|
||
{"type": "text", "text": "hello"}
|
||
]}],
|
||
}))
|
||
.unwrap();
|
||
assert!(!request_has_images(&parts_text_only));
|
||
|
||
let with_image: ChatCompletionRequest = serde_json::from_value(serde_json::json!({
|
||
"model": "m",
|
||
"messages": [{"role": "user", "content": [
|
||
{"type": "text", "text": "what is this?"},
|
||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAA="}}
|
||
]}],
|
||
}))
|
||
.unwrap();
|
||
assert!(request_has_images(&with_image));
|
||
}
|
||
|
||
/// The vision pre-flight guard rejects a prefill whose estimated
|
||
/// footprint exceeds free VRAM (so a doomed request fails clean
|
||
/// instead of OOM-hanging the TP collective), passes one that fits,
|
||
/// and is skipped on the CPU sentinel.
|
||
#[test]
|
||
fn vision_prefill_guard_behaviour() {
|
||
// CPU sentinel (vram_free_mb == 0) is always allowed.
|
||
assert!(validate_vision_prefill(10_000_000, 0).is_ok());
|
||
|
||
// A clearly-oversized prompt against tiny free VRAM is rejected
|
||
// for any non-degenerate config (default: 2000 base + 500/1k).
|
||
assert!(matches!(
|
||
validate_vision_prefill(10_000_000, 50),
|
||
Err(InferenceError::InsufficientVram { .. })
|
||
));
|
||
|
||
// With defaults, the agent-0-sized 12,960-token prompt that
|
||
// OOM'd single-shot fits the estimate at ~12 GB free (2000 +
|
||
// 12960*500/1000 = 8480 MiB) — the chunked prefill handles it,
|
||
// so the guard must NOT reject it.
|
||
assert!(validate_vision_prefill(12_960, 12_445).is_ok());
|
||
}
|
||
}
|