refactor(neuron): phase 2 — single-GPU forward + clear_kv route through device worker
Some checks failed
build-prerelease / Package helexa-neuron-ada RPM (push) Blocked by required conditions
CI / Format (push) Successful in 34s
CI / Clippy (push) Successful in 2m12s
build-prerelease / Resolve version stamps (push) Successful in 3m41s
CI / Test (push) Successful in 5m1s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m32s
build-prerelease / Build neuron-ampere (push) Successful in 5m20s
build-prerelease / Build cortex binary (push) Successful in 12m20s
build-prerelease / Build neuron-ada (push) Successful in 5m17s
build-prerelease / Package cortex RPM (push) Successful in 1m25s
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled

Second slice of the per-device CUDA context-ownership refactor planned at
~/.claude/plans/plan-the-per-device-worker-abstract-micali.md. The two
spawn_blocking sites in `chat_completion` and `chat_completion_stream`
now route through the device worker thread on CUDA loads. CPU loads
keep the existing spawn_blocking + `Arc<Mutex<ModelArch>>` path; there's
no context to own and the channel hop would only add latency.

What this phase changes:

- `Job` gains `TransferIn`, `DropArch`, `ClearKv`, `ForwardLogits`. The
  worker's dispatch state grows a `HashMap<ArchHandle, Box<ModelArch>>`
  slab and a `next_handle` counter for minting opaque handles.
- `LoadedModel.arch: Arc<Mutex<ModelArch>>` → `Option<Arc<Mutex<>>>`,
  plus a new `arch_handle: Option<ArchHandle>` field. The two are
  mutually exclusive: CUDA loads set `arch_handle = Some(_)` after
  transferring the boxed arch into the worker's slab; CPU loads keep
  `arch = Some(_)` for the legacy spawn_blocking path.
- New `run_inference_via_worker` and `stream_inference_via_worker`
  drive the prefill + decode loop by sending `Job::ForwardLogits` per
  step; the worker copies the resulting `[vocab]` logits to a
  CPU-side `Vec<f32>` before reply, so the async caller never holds a
  device-resident tensor. `apply_repeat_penalty` and
  `LogitsProcessor::sample` run on a CPU candle tensor; no context
  binding side-effects on tokio worker threads.
- `logits_health_slice(&[f32])` complements the existing
  `logits_health(&Tensor)` so the new worker paths can compute
  health stats directly from the CPU vec.
- `unload_model` for the single-GPU CUDA path now sends
  `Job::DropArch { handle }` to the worker so the `Box<ModelArch>`
  drops on the thread that allocated its CUDA tensors. The `Drop` runs
  with the bound context, freeing memory on the right context.

What this phase doesn't touch (yet):

- TP forward, TP load, NCCL bring-up — still on spawn_blocking. Phase 3.
- Single-GPU model load — still spawn_blocking, followed by a
  `Job::TransferIn` to move the freshly-built `ModelArch` into the
  worker slab. Phase 4 moves the load itself onto the worker thread
  and eliminates the bootstrap TransferIn.
- The `device_vram_mb` / `cuda_mem_mb` helpers — still present and
  used by the construction-time logs running inside spawn_blocking
  loads. Phase 4 cleanup folds them into `dispatch.rs`.

Public API unchanged. fmt + clippy clean; 37 lib tests + all
integration tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-27 09:55:08 +03:00
parent 081b532387
commit b179204fd3
4 changed files with 830 additions and 138 deletions

View File

@@ -102,7 +102,14 @@ impl LoadedHandle {
/// moved into `spawn_blocking` for synchronous candle forward passes. /// moved into `spawn_blocking` for synchronous candle forward passes.
pub struct LoadedModel { pub struct LoadedModel {
pub model_id: String, pub model_id: String,
pub arch: Arc<Mutex<ModelArch>>, /// Local (async-side) handle to the model architecture. `Some`
/// only when the model loaded onto the CPU device (no CUDA
/// available); the inference path then takes this mutex via
/// `spawn_blocking` and runs candle ops on the CPU backend.
/// `None` when the model loaded onto a CUDA device — in that case
/// the architecture lives in the worker thread's slab and is
/// addressed via [`Self::arch_handle`].
pub arch: Option<Arc<Mutex<ModelArch>>>,
pub tokenizer: Tokenizer, pub tokenizer: Tokenizer,
pub device: Device, pub device: Device,
pub quant: Option<String>, pub quant: Option<String>,
@@ -118,10 +125,15 @@ pub struct LoadedModel {
pub poisoned: AtomicBool, pub poisoned: AtomicBool,
/// Handle to the per-device CUDA worker thread for this model's /// Handle to the per-device CUDA worker thread for this model's
/// device. `None` for CPU loads (no context to own). VRAM queries /// device. `None` for CPU loads (no context to own). VRAM queries
/// and — in later refactor phases — forward / kv-cache / unload /// and — for CUDA loads — forward / kv-cache / drop ops route
/// ops route through this handle so the device's CUDA context /// through this handle so the device's CUDA context stays bound
/// stays bound to one OS thread for the daemon's lifetime. /// to one OS thread for the daemon's lifetime.
pub worker: Option<Arc<super::device_worker::DeviceWorkerHandle>>, pub worker: Option<Arc<super::device_worker::DeviceWorkerHandle>>,
/// Index into the worker's `ModelArch` slab. `Some` iff the model
/// loaded onto a CUDA device and was successfully transferred to
/// the worker; in that case [`Self::arch`] is `None`. The two
/// fields are mutually exclusive.
pub arch_handle: Option<super::device_worker::ArchHandle>,
} }
impl LoadedModel { impl LoadedModel {
@@ -475,6 +487,16 @@ fn logits_health(t: &Tensor) -> LogitsHealth {
}; };
} }
}; };
logits_health_slice(&values)
}
/// Same diagnostic as [`logits_health`] but operates directly on a
/// `[f32]` slice. Used by the worker-routed inference paths where the
/// device → host copy has already happened on the worker thread and
/// the async caller has the values in hand. Avoids the round-trip of
/// rebuilding a Tensor just to call to_vec1 again.
#[allow(dead_code)]
fn logits_health_slice(values: &[f32]) -> LogitsHealth {
let mut nan = 0usize; let mut nan = 0usize;
let mut pos_inf = 0usize; let mut pos_inf = 0usize;
let mut neg_inf = 0usize; let mut neg_inf = 0usize;
@@ -483,7 +505,7 @@ fn logits_health(t: &Tensor) -> LogitsHealth {
let mut finite_max = f32::NEG_INFINITY; let mut finite_max = f32::NEG_INFINITY;
let mut finite_sum = 0.0_f64; let mut finite_sum = 0.0_f64;
let mut finite_count = 0usize; let mut finite_count = 0usize;
for &v in &values { for &v in values {
if v.is_nan() { if v.is_nan() {
nan += 1; nan += 1;
} else if v == f32::INFINITY { } else if v == f32::INFINITY {
@@ -1104,9 +1126,50 @@ impl CandleHarness {
"chat_completion: starting" "chat_completion: starting"
); );
let arch_arc = Arc::clone(&loaded.arch); // Routing: CUDA loads go through the per-device worker
// thread (introduced in Phase 1; forward/clear added in
// Phase 2). CPU loads keep the existing spawn_blocking
// path because there's no context to own and the channel
// round-trip would only add latency. The two arms produce
// the same `(Vec<u32>, String)` shape so the rest of the
// path is shared.
let (generated_ids, finish_reason) = if let (Some(worker), Some(handle)) =
(loaded.worker.as_ref(), loaded.arch_handle)
{
// Worker path (CUDA).
#[cfg(feature = "cuda")]
{
match run_inference_via_worker(
worker,
handle,
&prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
)
.await
{
Ok(v) => v,
Err(e) => {
loaded.poisoned.store(true, Ordering::Release);
return Err(InferenceError::Other(e));
}
}
}
#[cfg(not(feature = "cuda"))]
{
// Can't happen: `loaded.worker` is only Some on
// CUDA builds. The dead branch keeps the no-cuda
// build well-typed.
let _ = (worker, handle);
unreachable!("worker handle present without cuda feature");
}
} else if let Some(arch_arc) = loaded.arch.clone() {
// CPU path: existing spawn_blocking on the local
// Arc<Mutex<ModelArch>>.
let device = loaded.device.clone(); let device = loaded.device.clone();
let inference_result = let inference_result =
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> { tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
let mut guard = arch_arc.blocking_lock(); let mut guard = arch_arc.blocking_lock();
@@ -1128,7 +1191,7 @@ impl CandleHarness {
// device-poisoning event. The terminal log at the bottom // device-poisoning event. The terminal log at the bottom
// of the wrapper reports the error; this flag stops the // of the wrapper reports the error; this flag stops the
// NEXT request from going down the same path. // NEXT request from going down the same path.
let (generated_ids, finish_reason) = match inference_result { match inference_result {
Ok(Ok(v)) => v, Ok(Ok(v)) => v,
Ok(Err(e)) => { Ok(Err(e)) => {
loaded.poisoned.store(true, Ordering::Release); loaded.poisoned.store(true, Ordering::Release);
@@ -1140,6 +1203,13 @@ impl CandleHarness {
"inference task panicked: {e}" "inference task panicked: {e}"
))); )));
} }
}
} else {
// LoadedModel invariant: exactly one of `worker` /
// `arch` is Some. Reaching here is a construction bug.
return Err(InferenceError::Other(anyhow::anyhow!(
"LoadedModel has neither worker handle nor local arch — load-path bug"
)));
}; };
let completion_text = loaded let completion_text = loaded
@@ -1244,7 +1314,6 @@ impl CandleHarness {
.token_to_id("<|im_end|>") .token_to_id("<|im_end|>")
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>")); .or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
let arch_arc = Arc::clone(&loaded.arch);
let device = loaded.device.clone(); let device = loaded.device.clone();
let tokenizer = loaded.tokenizer.clone(); let tokenizer = loaded.tokenizer.clone();
let model_id = request.model.clone(); let model_id = request.model.clone();
@@ -1315,6 +1384,57 @@ impl CandleHarness {
"chat_completion (stream): starting" "chat_completion (stream): starting"
); );
} }
// Routing parallel to the non-streaming chat_completion: CUDA
// goes through the worker (async task), CPU keeps the
// spawn_blocking + Arc<Mutex<ModelArch>> path.
if let (Some(worker), Some(handle)) = (loaded.worker.clone(), loaded.arch_handle) {
#[cfg(feature = "cuda")]
{
let prompt_tokens = prompt_tokens.clone();
tokio::spawn(
async move {
match stream_inference_via_worker(
worker,
handle,
tokenizer,
prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
id,
created,
model_id,
tx,
)
.await
{
Ok(_finish_reason) => tracing::info!(
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): done"
),
Err(e) => {
loaded_for_task.poisoned.store(true, Ordering::Release);
tracing::error!(
error = %format!("{e:#}"),
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): failed, model marked poisoned"
);
}
}
}
.instrument(span_for_task),
);
}
#[cfg(not(feature = "cuda"))]
{
let _ = (worker, handle, span_for_task);
unreachable!("worker handle present without cuda feature");
}
} else if let Some(arch_arc) = loaded.arch.clone() {
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
let _g = span_for_task.enter(); let _g = span_for_task.enter();
let mut guard = arch_arc.blocking_lock(); let mut guard = arch_arc.blocking_lock();
@@ -1349,6 +1469,11 @@ impl CandleHarness {
} }
} }
}); });
} else {
return Err(InferenceError::Other(anyhow::anyhow!(
"LoadedModel has neither worker handle nor local arch — load-path bug"
)));
}
Ok(rx) Ok(rx)
} }
@@ -1432,22 +1557,37 @@ impl Harness for CandleHarness {
// Worker thread for the chosen device. CPU loads (CUDA // Worker thread for the chosen device. CPU loads (CUDA
// unavailable / not requested) skip the worker — there's no // unavailable / not requested) skip the worker — there's no
// context to own. // context to own. For CUDA loads, the arch is transferred
let worker = match &device { // into the worker's slab now so the inference path can
// reference it via the returned `ArchHandle`. The explicit
// type annotation lets the no-cuda build resolve `None` to
// the right `Option<Arc<DeviceWorkerHandle>>` type.
let worker: Option<Arc<super::device_worker::DeviceWorkerHandle>> = match &device {
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?), Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?),
_ => None, _ => None,
}; };
let (arch_local, arch_handle) = match &worker {
Some(w) => {
let handle = w
.transfer_in(Box::new(arch))
.await
.map_err(|e| anyhow::anyhow!("transfer arch into device worker: {e}"))?;
(None, Some(handle))
}
None => (Some(Arc::new(Mutex::new(arch))), None),
};
let loaded = Arc::new(LoadedModel { let loaded = Arc::new(LoadedModel {
model_id: spec.model_id.clone(), model_id: spec.model_id.clone(),
arch: Arc::new(Mutex::new(arch)), arch: arch_local,
tokenizer, tokenizer,
device, device,
quant: spec.quant.clone(), quant: spec.quant.clone(),
devices, devices,
poisoned: AtomicBool::new(false), poisoned: AtomicBool::new(false),
worker, worker,
arch_handle,
}); });
let mut models = self.models.write().await; let mut models = self.models.write().await;
@@ -1465,13 +1605,26 @@ impl Harness for CandleHarness {
anyhow::bail!("model '{model_id}' not loaded"); anyhow::bail!("model '{model_id}' not loaded");
}; };
// Single-GPU drops are immediate — the LoadedModel goes out of // Single-GPU drops are immediate — the LoadedModel goes out of
// scope with the Arc and candle frees VRAM. TP unloads also // scope with the Arc and candle frees VRAM. CUDA loads also
// need to tell every worker to drop its shard before the pool // ship a `Job::DropArch` to the device worker so the boxed
// itself is dropped (otherwise the workers keep their shards // `ModelArch` releases its CUDA allocations on the right
// around until Shutdown, which is wasteful and would surface // thread (with the bound context); without that, the Drop
// as VRAM not freed promptly). // would run on whatever tokio thread happens to be holding
// the last `Arc<LoadedModel>` clone when this fn returns.
// TP unloads further coordinate the subprocess pool below.
match handle { match handle {
LoadedHandle::Single(_) => {} LoadedHandle::Single(single) => {
if let (Some(worker), Some(arch_handle)) =
(single.worker.as_ref(), single.arch_handle)
&& let Err(e) = worker.drop_arch(arch_handle).await
{
tracing::warn!(
model = %model_id,
error = %e,
"single-GPU unload: DropArch RPC failed (model state may leak in worker slab)"
);
}
}
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
LoadedHandle::Tp(tp) => { LoadedHandle::Tp(tp) => {
// Try to recover the inner TpLoadedModel so we can move // Try to recover the inner TpLoadedModel so we can move
@@ -2276,6 +2429,239 @@ fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
prompt prompt
} }
#[allow(clippy::too_many_arguments)]
/// Run the full single-GPU inference loop via the device worker.
///
/// Mirrors `run_inference`'s logic but routes each forward step
/// through `worker.forward_logits()` (returns CPU-side `Vec<f32>`)
/// and runs `apply_repeat_penalty` + sampling on a CPU candle tensor.
/// The device-resident logits tensor never escapes the worker thread.
///
/// Used by the CUDA path of `chat_completion`. The CPU path keeps
/// `run_inference` (spawn_blocking against `Arc<Mutex<ModelArch>>`)
/// because there's no CUDA context to own and the worker indirection
/// would only add channel overhead with no diagnostic benefit.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
async fn run_inference_via_worker(
worker: &super::device_worker::DeviceWorkerHandle,
handle: super::device_worker::ArchHandle,
prompt_tokens: &[u32],
max_new: usize,
temperature: f64,
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
) -> Result<(Vec<u32>, String)> {
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut generated: Vec<u32> = Vec::new();
let prompt_len = prompt_tokens.len();
worker
.clear_kv_cache(handle)
.await
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
// Prefill — every rank embeds the prompt with offset 0.
let logits_vec = worker
.forward_logits(handle, prompt_tokens.to_vec(), 0)
.await
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
?health,
"chat_completion (worker): prefill sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
for index in 0..max_new.saturating_sub(1) {
let logits_vec = worker
.forward_logits(handle, vec![next_token], prompt_len + index)
.await
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
step = index,
?health,
"chat_completion (worker): decode sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
}
Ok((generated, "length".into()))
}
/// Streaming counterpart of [`run_inference_via_worker`]. Emits one
/// `ChatCompletionChunk` per generated token via `tx`; routes every
/// forward step through `worker.forward_logits()`. Same per-step
/// CPU-side sampling discipline — no device tensor escapes the
/// worker thread.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
async fn stream_inference_via_worker(
worker: Arc<super::device_worker::DeviceWorkerHandle>,
handle: super::device_worker::ArchHandle,
tokenizer: Tokenizer,
prompt_tokens: Vec<u32>,
max_new: usize,
temperature: f64,
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
id: String,
created: u64,
model_id: String,
tx: mpsc::Sender<ChatCompletionChunk>,
) -> Result<String> {
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut all_tokens: Vec<u32> = Vec::new();
let mut decoded_prefix = String::new();
let prompt_len = prompt_tokens.len();
let mut finish_reason = "length".to_string();
worker
.clear_kv_cache(handle)
.await
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
let logits_vec = worker
.forward_logits(handle, prompt_tokens, 0)
.await
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
let mut next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
?health,
"chat_completion (stream/worker): prefill sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
} else {
all_tokens.push(next_token);
if !emit_chunk(
&all_tokens,
&mut decoded_prefix,
&tokenizer,
&tx,
&id,
created,
&model_id,
)
.await
{
return Ok(finish_reason); // Client gone — clean stream end.
}
for index in 0..max_new.saturating_sub(1) {
let logits_vec = worker
.forward_logits(handle, vec![next_token], prompt_len + index)
.await
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
step = index,
?health,
"chat_completion (stream/worker): decode sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
}
all_tokens.push(next_token);
if !emit_chunk(
&all_tokens,
&mut decoded_prefix,
&tokenizer,
&tx,
&id,
created,
&model_id,
)
.await
{
return Ok(finish_reason);
}
}
}
// Final chunk carrying finish_reason. Matches the run_inference_streaming
// shape so the SSE consumer sees an identical termination sequence.
let final_chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".into(),
created,
model: model_id.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: serde_json::Value::Object(Default::default()),
finish_reason: Some(finish_reason.clone()),
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
let _ = tx.send(final_chunk).await;
Ok(finish_reason)
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn run_inference( fn run_inference(
arch: &mut ModelArch, arch: &mut ModelArch,

View File

@@ -7,20 +7,41 @@
//! one-line `oneshot::Sender::send` call to ship the reply back, which //! one-line `oneshot::Sender::send` call to ship the reply back, which
//! is non-blocking. //! is non-blocking.
//! //!
//! Phase 1 handles only `QueryVram` and `Shutdown`. Later phases add //! Phase 2 handles QueryVram, TransferIn, DropArch, ClearKv,
//! Forward, ClearKv, NCCL, and load handlers as separate match arms. //! ForwardLogits, Shutdown. Phase 3 will add the TP variants
//! (NcclInit, NcclSanity, TpLoadShard, TpForward, TpClearKv) and the
//! ARCH model state in this state slab will gain a companion
//! `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>`.
use crate::harness::device_worker::jobs::Job; use crate::harness::candle::ModelArch;
use crate::harness::device_worker::jobs::{ArchHandle, Job};
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::Receiver; use std::sync::mpsc::Receiver;
/// Per-thread state owned by the worker. On CUDA builds the `Arc<CudaContext>` /// Per-thread state owned by the worker. On CUDA builds the `Arc<CudaContext>`
/// is created and bound at thread startup; on CPU builds the struct is /// is created and bound at thread startup; on CPU builds the struct
/// empty save for the device index (kept for log clarity). /// is mostly empty.
#[cfg(feature = "cuda")]
struct DeviceWorkerState { struct DeviceWorkerState {
#[allow(dead_code)]
device_index: u32, device_index: u32,
/// Candle `Device` constructed at startup. Used by handlers (e.g.
/// `ForwardLogits`) to build input tensors against the right
/// device. Falls back to `Device::Cpu` if CUDA init fails.
device: candle_core::Device,
/// Boxed `ModelArch` slab. Indexed by an opaque `ArchHandle` minted
/// by `TransferIn`. The Box means the entry's address is stable
/// across HashMap rehashes (relevant only when we later hand out
/// `&mut ModelArch` references — for Phase 2 every handler runs
/// `&mut` via `get_mut`, no long-lived borrows).
models: HashMap<ArchHandle, Box<ModelArch>>,
/// Counter for minting fresh `ArchHandle`s. Each `TransferIn`
/// increments and returns the new value. Wraps at u64::MAX after
/// ~10^19 model loads — not a practical concern.
next_handle: u64,
#[cfg(feature = "cuda")]
#[allow(dead_code)]
/// `None` only if `CudaContext::new()` failed — in that case the /// `None` only if `CudaContext::new()` failed — in that case the
/// thread still runs so the handle's lifecycle stays uniform, but /// thread still runs so the handle's lifecycle stays uniform, but
/// every job that touches CUDA falls through to a zero reply with /// every job that touches CUDA falls through to a zero reply with
@@ -28,18 +49,12 @@ struct DeviceWorkerState {
ctx: Option<Arc<candle_core::cuda::cudarc::driver::CudaContext>>, ctx: Option<Arc<candle_core::cuda::cudarc::driver::CudaContext>>,
} }
#[cfg(not(feature = "cuda"))]
#[allow(dead_code)]
struct DeviceWorkerState {
device_index: u32,
}
/// Worker thread entry point. Runs until `Job::Shutdown` arrives or /// Worker thread entry point. Runs until `Job::Shutdown` arrives or
/// the channel sender is dropped (which happens when the last /// the channel sender is dropped (which happens when the last
/// `DeviceWorkerHandle` `Arc` is dropped without an explicit /// `DeviceWorkerHandle` `Arc` is dropped without an explicit
/// `shutdown()`). /// `shutdown()`).
pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool>) { pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool>) {
let state = init_state(device_index); let mut state = init_state(device_index);
tracing::info!(device_index, "device worker started"); tracing::info!(device_index, "device worker started");
while let Ok(job) = rx.recv() { while let Ok(job) = rx.recv() {
@@ -51,7 +66,7 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
} }
if poisoned.load(Ordering::Acquire) { if poisoned.load(Ordering::Acquire) {
// Drain-only mode: reply with a poisoned error without // Drain-only mode: reply with a poisoned error without
// touching CUDA. Phase 1 never sets the flag from the // touching CUDA. Phase 1/2 never set the flag from the
// dispatch loop itself (no driver errors classified yet), // dispatch loop itself (no driver errors classified yet),
// but tests use `DeviceWorkerHandle::set_poisoned()` to // but tests use `DeviceWorkerHandle::set_poisoned()` to
// simulate this state. // simulate this state.
@@ -66,58 +81,130 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
// discard the reply. // discard the reply.
let _ = reply.send(result); let _ = reply.send(result);
} }
Job::TransferIn { arch, reply } => {
let handle = ArchHandle(state.next_handle);
state.next_handle = state.next_handle.wrapping_add(1);
state.models.insert(handle, arch);
tracing::debug!(
device_index,
handle = handle.0,
slab_size = state.models.len(),
"device worker: model transferred in"
);
let _ = reply.send(Ok(handle));
}
Job::DropArch { handle, reply } => {
let removed = state.models.remove(&handle);
let was_present = removed.is_some();
// Explicit drop on this thread — runs the Box<ModelArch>
// Drop with the CUDA context bound here, which frees
// all device tensors on the right context. The Drop is
// implicit on the `removed` value going out of scope at
// the end of the arm; calling drop() explicitly just
// makes the intent visible.
drop(removed);
tracing::debug!(
device_index,
handle = handle.0,
was_present,
slab_size = state.models.len(),
"device worker: model dropped"
);
let _ = reply.send(());
}
Job::ClearKv { handle, reply } => {
let result = match state.models.get_mut(&handle) {
Some(arch) => arch.clear_kv_cache(),
None => Err(anyhow::anyhow!("ClearKv: no model for handle {}", handle.0)),
};
let _ = reply.send(result);
}
Job::ForwardLogits {
handle,
tokens,
offset,
reply,
} => {
let result = forward_logits(&mut state, handle, &tokens, offset);
let _ = reply.send(result);
}
// Handled by the matches!() check above; reaching here // Handled by the matches!() check above; reaching here
// means a Shutdown slipped past which is a bug. // means a Shutdown slipped past which is a bug.
Job::Shutdown => unreachable!("Shutdown should break above"), Job::Shutdown => unreachable!("Shutdown should break above"),
} }
} }
tracing::info!(device_index, "device worker exiting"); tracing::info!(
device_index,
slab_size = state.models.len(),
"device worker exiting; dropping remaining models"
);
// Drops every model in the slab on this thread before the function
// returns. Critical for CUDA tensors: dropping on a thread that
// doesn't have the context bound is UB. Phase 2 still runs Drop
// via the slab going out of scope, which is correct as long as no
// pre-poisoned state lurks in here — see the poisoned-mode
// semantics in mod.rs for the Phase 3+ refinement.
} }
#[cfg(feature = "cuda")]
fn init_state(device_index: u32) -> DeviceWorkerState { fn init_state(device_index: u32) -> DeviceWorkerState {
#[cfg(feature = "cuda")]
{
use candle_core::cuda::cudarc::driver::CudaContext; use candle_core::cuda::cudarc::driver::CudaContext;
match CudaContext::new(device_index as usize) { // Construct a candle Device first — cudarc returns the
// primary context for this index on subsequent calls, so
// CudaContext::new and Device::new_cuda end up sharing state.
let (device, ctx) = match candle_core::Device::new_cuda(device_index as usize) {
Ok(device) => match CudaContext::new(device_index as usize) {
Ok(ctx) => { Ok(ctx) => {
// Make sure the context is current on this thread. cudarc
// is generally fine with lazy binding, but doing it once
// here gives us a deterministic moment to log "context
// bound" — and makes `mem_get_info()` work without further
// bind dances inside the dispatch handlers.
if let Err(e) = ctx.bind_to_thread() { if let Err(e) = ctx.bind_to_thread() {
tracing::warn!( tracing::warn!(
device_index, device_index,
error = ?e, error = ?e,
"device worker: bind_to_thread failed; \ "device worker: bind_to_thread failed; \
vram queries will still rebind per-call" operations will still rebind per-call"
); );
} else { } else {
tracing::info!(device_index, "device worker bound CUDA context"); tracing::info!(device_index, "device worker bound CUDA context");
} }
DeviceWorkerState { (device, Some(ctx))
device_index,
ctx: Some(ctx),
}
} }
Err(e) => { Err(e) => {
tracing::warn!( tracing::warn!(
device_index, device_index,
error = ?e, error = ?e,
"device worker: CudaContext::new failed; \ "device worker: CudaContext::new failed; \
vram queries will return (0, 0)" vram queries will return (0, 0), forward will error"
); );
(device, None)
}
},
Err(e) => {
tracing::warn!(
device_index,
error = %e,
"device worker: Device::new_cuda failed; falling back to CPU device"
);
(candle_core::Device::Cpu, None)
}
};
DeviceWorkerState { DeviceWorkerState {
device_index, device_index,
ctx: None, device,
models: HashMap::new(),
next_handle: 1,
ctx,
} }
} }
#[cfg(not(feature = "cuda"))]
{
DeviceWorkerState {
device_index,
device: candle_core::Device::Cpu,
models: HashMap::new(),
next_handle: 1,
}
} }
}
#[cfg(not(feature = "cuda"))]
fn init_state(device_index: u32) -> DeviceWorkerState {
DeviceWorkerState { device_index }
} }
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
@@ -144,6 +231,42 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
Ok((0, 0)) Ok((0, 0))
} }
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
/// for sampling on the async caller. The model's `device()` (CUDA or
/// CPU) determines where the kernel runs; this fn doesn't care.
///
/// On CUDA, the `to_dtype(F32).flatten_all().to_vec1::<f32>()` chain
/// triggers the device → host copy. The copy runs synchronously on
/// this worker thread; the bound context owns the source allocation
/// so the transfer is straightforward.
fn forward_logits(
state: &mut DeviceWorkerState,
handle: ArchHandle,
tokens: &[u32],
offset: usize,
) -> anyhow::Result<Vec<f32>> {
use candle_core::{DType, Tensor};
// Build the input tensor on the worker's own device. cudarc's
// primary-context model means `Device::new_cuda(idx)` shares state
// with the `CudaContext` we bound at startup, so this is the same
// device the ModelArch was loaded against.
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
let arch = state
.models
.get_mut(&handle)
.ok_or_else(|| anyhow::anyhow!("ForwardLogits: no model for handle {}", handle.0))?;
let logits = arch.forward(&input, offset)?;
// Copy to CPU f32. logits is already `[vocab]` (squeeze_to_vocab
// inside ModelArch::forward). The to_dtype handles bf16/f16 →
// f32 promotion for the sampler.
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
let values = logits.to_vec1::<f32>()?;
Ok(values)
}
/// Reply to a job with the poisoned-worker error. Used when the worker /// Reply to a job with the poisoned-worker error. Used when the worker
/// has flipped into drain-only mode after a CUDA driver error. /// has flipped into drain-only mode after a CUDA driver error.
/// ///
@@ -153,11 +276,26 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
/// poisoned error so callers never hang waiting for a worker that's /// poisoned error so callers never hang waiting for a worker that's
/// no longer running CUDA. /// no longer running CUDA.
fn drain_poisoned(job: Job, device_index: u32) { fn drain_poisoned(job: Job, device_index: u32) {
let err = || anyhow::anyhow!("device worker for device {device_index} is poisoned");
match job { match job {
Job::QueryVram { reply } => { Job::QueryVram { reply } => {
let _ = reply.send(Err(anyhow::anyhow!( let _ = reply.send(Err(err()));
"device worker for device {device_index} is poisoned" }
))); Job::TransferIn { reply, .. } => {
let _ = reply.send(Err(err()));
}
Job::DropArch { reply, .. } => {
// Drop reply is `()` — no error path. Send the unit so the
// caller's await resolves; the model handle is leaked in
// the worker's slab, but the whole slab gets `mem::forget`
// on shutdown anyway per the poisoned-thread design.
let _ = reply.send(());
}
Job::ClearKv { reply, .. } => {
let _ = reply.send(Err(err()));
}
Job::ForwardLogits { reply, .. } => {
let _ = reply.send(Err(err()));
} }
Job::Shutdown => { Job::Shutdown => {
// Filtered by the matches!() guard in run(); reaching // Filtered by the matches!() guard in run(); reaching

View File

@@ -4,16 +4,33 @@
//! needs plus a `tokio::sync::oneshot::Sender` for the reply. The //! needs plus a `tokio::sync::oneshot::Sender` for the reply. The
//! async-side `DeviceWorkerHandle` constructs a job, sends it down the //! async-side `DeviceWorkerHandle` constructs a job, sends it down the
//! `std::sync::mpsc` channel, and `await`s the oneshot for the reply. //! `std::sync::mpsc` channel, and `await`s the oneshot for the reply.
//!
//! Phase 1 includes only `QueryVram` and `Shutdown`. Phases 24 add
//! forward, kv-cache clear, drop-arch, NCCL init/sanity, and the load
//! variants. Each new variant lands as a separate PR so the worker
//! thread stays small at every checkpoint.
use crate::harness::candle::ModelArch;
use anyhow::Result; use anyhow::Result;
use tokio::sync::oneshot; use tokio::sync::oneshot;
/// Opaque handle to a `ModelArch` stored in the worker thread's state
/// slab. Cheap to copy; `Send + Sync` so it crosses task boundaries
/// freely. The actual `Box<ModelArch>` it points to is owned by the
/// worker thread for the duration of the handle's lifetime — the only
/// way to drop the model is to send `Job::DropArch { handle }` so the
/// `Drop` impl runs on the thread with the bound CUDA context (the
/// invariant the whole refactor exists to guarantee).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ArchHandle(pub u64);
/// One unit of work for the device worker. /// One unit of work for the device worker.
///
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
/// single-GPU inference primitives: transfer-in a freshly-loaded
/// `ModelArch`, drop it, clear its KV cache, and run one forward step
/// returning CPU-side logits ready for sampling on the async caller.
///
/// Sampling stays on the async side intentionally. The worker copies
/// logits to CPU (`Vec<f32>`) before reply, so the device-resident
/// tensor never escapes the worker thread and the async caller's
/// `LogitsProcessor::sample` runs entirely on the CPU candle backend
/// — no incidental context binding on a tokio worker thread.
pub enum Job { pub enum Job {
/// Query free / total VRAM on the device. Returns /// Query free / total VRAM on the device. Returns
/// `(free_mb, total_mb)`. CPU builds and contexts that failed to /// `(free_mb, total_mb)`. CPU builds and contexts that failed to
@@ -22,10 +39,43 @@ pub enum Job {
QueryVram { QueryVram {
reply: oneshot::Sender<Result<(u64, u64)>>, reply: oneshot::Sender<Result<(u64, u64)>>,
}, },
/// Tell the worker to break its dispatch loop and exit. The /// Move a freshly-loaded `ModelArch` into the worker's state slab.
/// channel is then drained — any further jobs already queued get /// Returns an `ArchHandle` the caller stores on `LoadedModel` and
/// dropped (their oneshot senders are dropped, causing the async /// passes back in subsequent `ClearKv` / `ForwardLogits` /
/// caller's receiver to return `Err` which `DeviceWorkerHandle` /// `DropArch` jobs.
/// maps to `WorkerError::Gone`). TransferIn {
arch: Box<ModelArch>,
reply: oneshot::Sender<Result<ArchHandle>>,
},
/// Remove the model from the slab and drop it. The `Drop` runs on
/// the worker thread so CUDA tensors release their memory on the
/// same context that allocated them.
DropArch {
handle: ArchHandle,
reply: oneshot::Sender<()>,
},
/// Reset the KV cache for this model. Called at the start of every
/// chat completion so a new request doesn't attend over the
/// previous one's tokens.
ClearKv {
handle: ArchHandle,
reply: oneshot::Sender<Result<()>>,
},
/// Run one forward step and copy the resulting `[vocab]` logits to
/// CPU. The caller takes the returned `Vec<f32>`, wraps it in a
/// CPU `Tensor`, and runs `apply_repeat_penalty` + sampling
/// without touching the device context. `offset` is the KV-cache
/// position before this step (0 for prefill, `prompt_len + i` for
/// the i-th decode step).
ForwardLogits {
handle: ArchHandle,
tokens: Vec<u32>,
offset: usize,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Tell the worker to break its dispatch loop and exit. Any jobs
/// queued after this in the channel reply `Err` to their oneshot
/// senders (the senders are dropped on the worker's exit, which
/// the async-side `Receiver::await` maps to `WorkerError::Gone`).
Shutdown, Shutdown,
} }

View File

@@ -49,7 +49,7 @@ use std::sync::mpsc::{self, Sender};
use std::thread::JoinHandle; use std::thread::JoinHandle;
use tokio::sync::oneshot; use tokio::sync::oneshot;
pub use jobs::Job; pub use jobs::{ArchHandle, Job};
/// Errors returned by `DeviceWorkerHandle` submit methods. /// Errors returned by `DeviceWorkerHandle` submit methods.
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@@ -159,6 +159,124 @@ impl DeviceWorkerHandle {
} }
} }
/// Move a freshly-loaded `ModelArch` into the worker's state slab.
/// Returns the `ArchHandle` the caller stores on `LoadedModel`.
/// The `Box<ModelArch>` crosses the channel; the worker thread
/// owns it from here on.
pub async fn transfer_in(
&self,
arch: Box<crate::harness::candle::ModelArch>,
) -> Result<ArchHandle, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TransferIn {
arch,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Tell the worker to drop the `ModelArch` for `handle` on the
/// worker thread (so CUDA tensors release on the right context).
/// Returns `Ok(())` even if the handle wasn't in the slab — Drop
/// is idempotent. Reports `Gone` if the worker isn't running.
pub async fn drop_arch(&self, handle: ArchHandle) -> Result<(), WorkerError> {
// Poisoning doesn't block DropArch — even on a poisoned
// context we want callers to unblock and proceed with the
// unload bookkeeping. The dispatch handler under poison just
// replies `()` without touching the model (the actual Drop
// happens via mem::forget at thread exit per the poison
// protocol).
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::DropArch {
handle,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(()) => Ok(()),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Reset the KV cache for the model at `handle`. Called at the
/// start of every chat completion so the new prompt doesn't
/// attend over the previous request's tokens.
pub async fn clear_kv_cache(&self, handle: ArchHandle) -> Result<(), WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::ClearKv {
handle,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Run one forward step and return the resulting `[vocab]` logits
/// as a CPU-side `Vec<f32>`. The caller then samples on a CPU
/// candle Tensor without ever binding the device context on its
/// tokio thread.
pub async fn forward_logits(
&self,
handle: ArchHandle,
tokens: Vec<u32>,
offset: usize,
) -> Result<Vec<f32>, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::ForwardLogits {
handle,
tokens,
offset,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Send `Job::Shutdown` and join the thread. Idempotent — calling /// Send `Job::Shutdown` and join the thread. Idempotent — calling
/// twice is a no-op the second time. /// twice is a no-op the second time.
pub fn shutdown(&self) -> anyhow::Result<()> { pub fn shutdown(&self) -> anyhow::Result<()> {