From 081b5323873c2db9fed33d08a76e92eaeb7840e4 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Wed, 27 May 2026 09:40:34 +0300 Subject: [PATCH] =?UTF-8?q?refactor(neuron):=20phase=201=20=E2=80=94=20per?= =?UTF-8?q?-device=20worker=20thread,=20VRAM=20queries=20route=20through?= =?UTF-8?q?=20it?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit First slice of the per-device CUDA context-ownership refactor planned at ~/.claude/plans/plan-the-per-device-worker-abstract-micali.md. Adds the infrastructure for a dedicated OS thread per CUDA device that owns the device's `CudaContext` for the daemon's lifetime, and routes the 8 async-context `device_vram_mb()` call sites in candle.rs through it. What this phase changes: - New module `harness/device_worker/` (mod.rs, jobs.rs, dispatch.rs). `DeviceWorkerHandle::spawn(idx)` creates a named OS thread (`cuda-dev-N`), binds `CudaContext::new(idx)` once at startup, and enters a dispatch loop reading `Job`s off a `std::sync::mpsc` channel. Replies cross back via `tokio::sync::oneshot::Sender` so async callers await without parking a tokio worker. - Two Job variants: `QueryVram` and `Shutdown`. Phases 2–4 add Forward, ClearKv, NCCL init/sanity, and load variants. - `LoadedModel` and `TpLoadedModel` gain a `worker` field populated at load time by a new `CandleHarness::ensure_device_worker(idx)` method that lazily spawns + caches one worker per device index. - Per-model `query_vram()` convenience method on both struct types so the 8 call sites in chat_completion / chat_completion_stream / chat_completion_tp_inner / chat_completion_tp_stream become `loaded.query_vram().await` (or `tp.query_vram().await`) — same field values logged, just sourced from the owner thread instead of the caller thread. What this phase doesn't touch (yet): - Forward, kv-cache clear, model load, NCCL — still on `spawn_blocking`. Phase 2 moves the single-GPU forward + clear; Phase 3 moves the TP forward + NCCL bring-up; Phase 4 moves the loads and deletes the now- unused `device_vram_mb` / `cuda_mem_mb` helpers. - Public API — unchanged. `Harness::load_model`, `chat_completion`, HTTP routes all keep identical shapes. Tests: - 5 new unit tests in `device_worker/mod.rs::tests` cover spawn → query → shutdown round-trip, thread naming, post-shutdown submit returns `Gone`, poisoned flag fast-rejects, and concurrent jobs drain across a Shutdown. CPU build (the only one CI runs) is enough to exercise channel mechanics. - All 37 lib tests + all integration tests pass; fmt + clippy clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/neuron/src/harness/candle.rs | 126 ++++++++- .../src/harness/device_worker/dispatch.rs | 168 +++++++++++ .../neuron/src/harness/device_worker/jobs.rs | 31 +++ .../neuron/src/harness/device_worker/mod.rs | 263 ++++++++++++++++++ crates/neuron/src/harness/mod.rs | 1 + 5 files changed, 580 insertions(+), 9 deletions(-) create mode 100644 crates/neuron/src/harness/device_worker/dispatch.rs create mode 100644 crates/neuron/src/harness/device_worker/jobs.rs create mode 100644 crates/neuron/src/harness/device_worker/mod.rs diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 886307f..30de984 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -42,6 +42,14 @@ pub struct CandleHarness { models: Arc>>, hf_cache: Option, bind_url: String, + /// One worker thread per CUDA device index that owns its + /// `CudaContext` for the daemon's lifetime. Populated lazily by + /// `ensure_device_worker()` when a model is loaded onto a CUDA + /// device. CPU `Device::Cpu` loads don't get an entry; they have + /// no context to own. Unused on the no-cuda build (the harness + /// can still load on CPU for tests, just without worker threads). + #[allow(dead_code)] + device_workers: Arc>>>, } /// One entry in the harness's loaded-model registry. Single-GPU loads @@ -108,6 +116,28 @@ pub struct LoadedModel { /// the 2026-05-26 beast incident where a 14k-token prefill OOM /// silently turned every subsequent request into a stuck wait. pub poisoned: AtomicBool, + /// Handle to the per-device CUDA worker thread for this model's + /// device. `None` for CPU loads (no context to own). VRAM queries + /// and — in later refactor phases — forward / kv-cache / unload + /// ops route through this handle so the device's CUDA context + /// stays bound to one OS thread for the daemon's lifetime. + pub worker: Option>, +} + +impl LoadedModel { + /// Free / total VRAM on this model's device in MiB. Routes the + /// query through the device worker thread (where the CUDA context + /// is already bound) rather than rebinding on whatever tokio + /// thread the caller happens to be on. Returns `(0, 0)` on CPU + /// loads, or if the worker is gone / poisoned / the cudarc call + /// itself failed — same sentinel the previous `device_vram_mb` + /// helper returned, so log field values stay comparable. + pub async fn query_vram(&self) -> (u64, u64) { + match &self.worker { + Some(w) => w.query_vram().await.unwrap_or((0, 0)), + None => (0, 0), + } + } } /// Tensor-parallel loaded model. Holds the leader's rank-0 shard @@ -137,6 +167,22 @@ pub struct TpLoadedModel { /// terminal: the leader's and workers' CUDA contexts cannot be /// reliably reset without restarting the worker subprocesses. pub poisoned: AtomicBool, + /// Worker thread for the leader's CUDA device. Owns the leader's + /// `CudaContext` for the daemon's lifetime. VRAM queries route + /// through it; in later refactor phases the forward, kv-cache + /// clear, and shard unload route through it too. + pub worker: Arc, +} + +#[cfg(feature = "cuda")] +impl TpLoadedModel { + /// Free / total VRAM on the leader's device in MiB. See + /// [`LoadedModel::query_vram`] for rationale and sentinel + /// semantics — same pattern, TP just always has a worker because + /// the harness rejects TP without CUDA at load time. + pub async fn query_vram(&self) -> (u64, u64) { + self.worker.query_vram().await.unwrap_or((0, 0)) + } } /// Architecture-specific weights. Each variant covers one (family, @@ -604,6 +650,7 @@ impl CandleHarness { models: Arc::new(RwLock::new(HashMap::new())), hf_cache, bind_url, + device_workers: Arc::new(RwLock::new(HashMap::new())), } } @@ -625,6 +672,39 @@ impl CandleHarness { Ok(Device::Cpu) } + /// Return the worker handle for `device_index`, spawning it on + /// first request. The handle is cached on `self` so subsequent + /// loads against the same device share the same thread. Used to + /// populate `LoadedModel::worker` and `TpLoadedModel::worker` at + /// load time; in later refactor phases the worker also owns the + /// `ModelArch` and `TpLeaderModel` slabs. + #[allow(dead_code)] + async fn ensure_device_worker( + &self, + device_index: u32, + ) -> Result> { + { + let workers = self.device_workers.read().await; + if let Some(w) = workers.get(&device_index) { + return Ok(Arc::clone(w)); + } + } + // Write-lock acquired separately so the read path stays cheap. + // The `get` is repeated under the write lock to handle the + // race where two loads against a fresh device land here at + // once — the second caller sees the first's insertion and + // skips the second spawn. + let mut workers = self.device_workers.write().await; + if let Some(w) = workers.get(&device_index) { + return Ok(Arc::clone(w)); + } + let handle = super::device_worker::DeviceWorkerHandle::spawn(device_index) + .with_context(|| format!("spawn device worker for cuda:{device_index}"))?; + workers.insert(device_index, Arc::clone(&handle)); + tracing::info!(device_index, "spawned device worker"); + Ok(handle) + } + /// Build an hf-hub API client pre-configured with the harness's /// `hf_cache` (when one is set). fn hf_api(&self) -> Result { @@ -1012,7 +1092,7 @@ impl CandleHarness { .token_to_id("<|im_end|>") .or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>")); - let (vram_free_mb, vram_total_mb) = device_vram_mb(&loaded.device); + let (vram_free_mb, vram_total_mb) = loaded.query_vram().await; tracing::info!( prompt_len, max_new, @@ -1217,9 +1297,13 @@ impl CandleHarness { let loaded_for_task = Arc::clone(&loaded); let span_for_starting = span.clone(); let span_for_task = span.clone(); + // Query VRAM before entering the span so we don't await inside + // an entered guard (Span::enter creates a synchronous guard + // that can't span await points). The span gets entered in a + // separate scope below purely for the log emission. + let (vram_free_mb, vram_total_mb) = loaded.query_vram().await; { let _g = span_for_starting.enter(); - let (vram_free_mb, vram_total_mb) = device_vram_mb(&loaded.device); tracing::info!( prompt_len, max_new, @@ -1346,6 +1430,15 @@ impl Harness for CandleHarness { let tokenizer = Tokenizer::from_file(&tokenizer_path) .map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?; + // Worker thread for the chosen device. CPU loads (CUDA + // unavailable / not requested) skip the worker — there's no + // context to own. + let worker = match &device { + #[cfg(feature = "cuda")] + Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?), + _ => None, + }; + let loaded = Arc::new(LoadedModel { model_id: spec.model_id.clone(), arch: Arc::new(Mutex::new(arch)), @@ -1354,6 +1447,7 @@ impl Harness for CandleHarness { quant: spec.quant.clone(), devices, poisoned: AtomicBool::new(false), + worker, }); let mut models = self.models.write().await; @@ -1496,6 +1590,12 @@ impl CandleHarness { let tokenizer = Tokenizer::from_file(&tokenizer_path) .map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?; + // 7. Worker thread for the leader's CUDA device. TP always + // runs on CUDA — the harness rejects TP without the cuda + // feature earlier in this function — so we always have a + // device to own. + let worker = self.ensure_device_worker(devices[0]).await?; + let tp_loaded = StdArc::new(TpLoadedModel { model_id: spec.model_id.clone(), tokenizer, @@ -1504,6 +1604,7 @@ impl CandleHarness { leader_model, leader_device: leader_device.clone(), poisoned: AtomicBool::new(false), + worker, }); let mut models = self.models.write().await; @@ -1661,7 +1762,7 @@ impl CandleHarness { model = %model_id ); let req_start = std::time::Instant::now(); - let (vram_free_mb, vram_total_mb) = device_vram_mb(&tp.leader_device); + let (vram_free_mb, vram_total_mb) = tp.query_vram().await; tracing::info!( parent: &span, prompt_len, @@ -1713,8 +1814,7 @@ impl CandleHarness { break 'work; } }; - let (post_prefill_vram_free_mb, _) = - device_vram_mb(&tp_for_task.leader_device); + let (post_prefill_vram_free_mb, _) = tp_for_task.query_vram().await; tracing::info!( model = %model_id, prompt_len, @@ -1790,11 +1890,19 @@ impl CandleHarness { break 'work; } }; + // Always await the query (even when the + // trace! is filtered out by RUST_LOG): the + // channel hop is ~tens of µs, comparable to + // the previous in-line bind+query cost, and + // making the call conditional adds complexity + // for negligible win. Revisit if it shows up + // in a hot-path profile. + let step_vram_free_mb = tp_for_task.query_vram().await.0; tracing::trace!( model = %model_id, step = index, next_token, - vram_free_mb = device_vram_mb(&tp_for_task.leader_device).0, + vram_free_mb = step_vram_free_mb, "TP chat_completion (stream): decode step" ); if Some(next_token) == eos_id { @@ -1906,7 +2014,7 @@ async fn chat_completion_tp_inner( .token_to_id("<|im_end|>") .or_else(|| tp.tokenizer.token_to_id("<|endoftext|>")); - let (vram_free_mb, vram_total_mb) = device_vram_mb(&tp.leader_device); + let (vram_free_mb, vram_total_mb) = tp.query_vram().await; tracing::info!( model = %model_id, prompt_len, @@ -1962,7 +2070,7 @@ async fn chat_completion_tp_inner( .generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0) .await .map_err(InferenceError::Other)?; - let (post_prefill_vram_free_mb, _) = device_vram_mb(&tp.leader_device); + let (post_prefill_vram_free_mb, _) = tp.query_vram().await; tracing::info!( model = %model_id, prompt_len, @@ -2017,7 +2125,7 @@ async fn chat_completion_tp_inner( return Err(InferenceError::Other(e)); } }; - let step_vram_free_mb = device_vram_mb(&tp.leader_device).0; + let step_vram_free_mb = tp.query_vram().await.0; tracing::trace!( model = %model_id, step = index, diff --git a/crates/neuron/src/harness/device_worker/dispatch.rs b/crates/neuron/src/harness/device_worker/dispatch.rs new file mode 100644 index 0000000..378c048 --- /dev/null +++ b/crates/neuron/src/harness/device_worker/dispatch.rs @@ -0,0 +1,168 @@ +//! Synchronous dispatch loop running on the device worker thread. +//! +//! `run()` is the thread's entry point. It binds the CUDA context for +//! its device on startup, then pulls `Job`s off the channel one at a +//! time and runs the corresponding handler. The handlers are +//! synchronous by design — the only async on this thread is the +//! one-line `oneshot::Sender::send` call to ship the reply back, which +//! is non-blocking. +//! +//! Phase 1 handles only `QueryVram` and `Shutdown`. Later phases add +//! Forward, ClearKv, NCCL, and load handlers as separate match arms. + +use crate::harness::device_worker::jobs::Job; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc::Receiver; + +/// Per-thread state owned by the worker. On CUDA builds the `Arc` +/// is created and bound at thread startup; on CPU builds the struct is +/// empty save for the device index (kept for log clarity). +#[cfg(feature = "cuda")] +struct DeviceWorkerState { + device_index: u32, + /// `None` only if `CudaContext::new()` failed — in that case the + /// thread still runs so the handle's lifecycle stays uniform, but + /// every job that touches CUDA falls through to a zero reply with + /// a log warning. + ctx: Option>, +} + +#[cfg(not(feature = "cuda"))] +#[allow(dead_code)] +struct DeviceWorkerState { + device_index: u32, +} + +/// Worker thread entry point. Runs until `Job::Shutdown` arrives or +/// the channel sender is dropped (which happens when the last +/// `DeviceWorkerHandle` `Arc` is dropped without an explicit +/// `shutdown()`). +pub(crate) fn run(device_index: u32, rx: Receiver, poisoned: Arc) { + let state = init_state(device_index); + tracing::info!(device_index, "device worker started"); + + while let Ok(job) = rx.recv() { + // Shutdown is processed unconditionally so a poisoned worker + // still exits when asked. Matching by reference first so we + // can fall through to the consume-match below. + if matches!(&job, Job::Shutdown) { + break; + } + if poisoned.load(Ordering::Acquire) { + // Drain-only mode: reply with a poisoned error without + // touching CUDA. Phase 1 never sets the flag from the + // dispatch loop itself (no driver errors classified yet), + // but tests use `DeviceWorkerHandle::set_poisoned()` to + // simulate this state. + drain_poisoned(job, device_index); + continue; + } + match job { + Job::QueryVram { reply } => { + let result = query_vram(&state); + // If the caller dropped its receiver (request cancelled, + // gateway timed out) the send fails — fine, we just + // discard the reply. + let _ = reply.send(result); + } + // Handled by the matches!() check above; reaching here + // means a Shutdown slipped past which is a bug. + Job::Shutdown => unreachable!("Shutdown should break above"), + } + } + + tracing::info!(device_index, "device worker exiting"); +} + +#[cfg(feature = "cuda")] +fn init_state(device_index: u32) -> DeviceWorkerState { + use candle_core::cuda::cudarc::driver::CudaContext; + match CudaContext::new(device_index as usize) { + Ok(ctx) => { + // Make sure the context is current on this thread. cudarc + // is generally fine with lazy binding, but doing it once + // here gives us a deterministic moment to log "context + // bound" — and makes `mem_get_info()` work without further + // bind dances inside the dispatch handlers. + if let Err(e) = ctx.bind_to_thread() { + tracing::warn!( + device_index, + error = ?e, + "device worker: bind_to_thread failed; \ + vram queries will still rebind per-call" + ); + } else { + tracing::info!(device_index, "device worker bound CUDA context"); + } + DeviceWorkerState { + device_index, + ctx: Some(ctx), + } + } + Err(e) => { + tracing::warn!( + device_index, + error = ?e, + "device worker: CudaContext::new failed; \ + vram queries will return (0, 0)" + ); + DeviceWorkerState { + device_index, + ctx: None, + } + } + } +} + +#[cfg(not(feature = "cuda"))] +fn init_state(device_index: u32) -> DeviceWorkerState { + DeviceWorkerState { device_index } +} + +#[cfg(feature = "cuda")] +fn query_vram(state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> { + use candle_core::cuda::cudarc::driver::result; + if state.ctx.is_none() { + return Ok((0, 0)); + } + // The context was bound in init_state. cudarc's `mem_get_info` + // reads from the current context on the calling thread; since we + // bound on startup and we never spawn child threads from this + // worker, the binding holds. + match result::mem_get_info() { + Ok((free, total)) => Ok(( + (free / (1024 * 1024)) as u64, + (total / (1024 * 1024)) as u64, + )), + Err(e) => Err(anyhow::anyhow!("mem_get_info: {e:?}")), + } +} + +#[cfg(not(feature = "cuda"))] +fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> { + Ok((0, 0)) +} + +/// Reply to a job with the poisoned-worker error. Used when the worker +/// has flipped into drain-only mode after a CUDA driver error. +/// +/// `Job::Shutdown` is filtered before reaching this fn so the match +/// only needs the data-carrying variants. As phases 2–4 add more +/// variants the match here grows; every variant must reply with the +/// poisoned error so callers never hang waiting for a worker that's +/// no longer running CUDA. +fn drain_poisoned(job: Job, device_index: u32) { + match job { + Job::QueryVram { reply } => { + let _ = reply.send(Err(anyhow::anyhow!( + "device worker for device {device_index} is poisoned" + ))); + } + Job::Shutdown => { + // Filtered by the matches!() guard in run(); reaching + // here would be a logic error. + unreachable!("Shutdown is filtered before drain_poisoned"); + } + } +} diff --git a/crates/neuron/src/harness/device_worker/jobs.rs b/crates/neuron/src/harness/device_worker/jobs.rs new file mode 100644 index 0000000..23c2715 --- /dev/null +++ b/crates/neuron/src/harness/device_worker/jobs.rs @@ -0,0 +1,31 @@ +//! Job variants accepted by the per-device worker thread. +//! +//! Each variant carries the inputs the synchronous dispatch handler +//! needs plus a `tokio::sync::oneshot::Sender` for the reply. The +//! async-side `DeviceWorkerHandle` constructs a job, sends it down the +//! `std::sync::mpsc` channel, and `await`s the oneshot for the reply. +//! +//! Phase 1 includes only `QueryVram` and `Shutdown`. Phases 2–4 add +//! forward, kv-cache clear, drop-arch, NCCL init/sanity, and the load +//! variants. Each new variant lands as a separate PR so the worker +//! thread stays small at every checkpoint. + +use anyhow::Result; +use tokio::sync::oneshot; + +/// One unit of work for the device worker. +pub enum Job { + /// Query free / total VRAM on the device. Returns + /// `(free_mb, total_mb)`. CPU builds and contexts that failed to + /// initialise reply with `(0, 0)` — matches today's + /// `device_vram_mb` sentinel so the log field values don't change. + QueryVram { + reply: oneshot::Sender>, + }, + /// Tell the worker to break its dispatch loop and exit. The + /// channel is then drained — any further jobs already queued get + /// dropped (their oneshot senders are dropped, causing the async + /// caller's receiver to return `Err` which `DeviceWorkerHandle` + /// maps to `WorkerError::Gone`). + Shutdown, +} diff --git a/crates/neuron/src/harness/device_worker/mod.rs b/crates/neuron/src/harness/device_worker/mod.rs new file mode 100644 index 0000000..12af5f6 --- /dev/null +++ b/crates/neuron/src/harness/device_worker/mod.rs @@ -0,0 +1,263 @@ +//! Per-device CUDA worker thread. +//! +//! One dedicated OS thread per CUDA device the leader uses. The thread +//! binds the device's `CudaContext` once at startup and owns it for the +//! daemon's lifetime; all GPU operations and VRAM queries for that +//! device route through a `std::sync::mpsc` channel into this thread. +//! Tensors never escape the thread alive — replies cross the channel +//! as plain values (`u32` tokens, `(u64, u64)` mb numbers, `()`). +//! +//! Rationale, in order of weight: +//! +//! 1. **Context locality.** cudarc binds the CUDA context per OS thread +//! via `cuCtxSetCurrent`. With `tokio::task::spawn_blocking`, the +//! blocking thread chosen is arbitrary, so the context gets bound +//! onto a different thread each time and `device_vram_mb()` from an +//! async task binds it again on the *caller's* thread as a side +//! effect. Pinning the context to one named thread ends that. +//! +//! 2. **Drop safety.** `cudarc::driver::CudaContext`, every `CudaSlice` +//! inside a `Tensor`, and every `cudarc::nccl::Comm` call `cuMemFree` +//! / `cuCtxDestroy` / `ncclCommDestroy` during `Drop`. These must +//! run with the right context current. Owning everything in this +//! thread's state slab and dropping it via `Job::DropArch` / +//! `Job::Shutdown` is the only safe pattern. +//! +//! 3. **Poisoning blast radius.** When a CUDA driver error (illegal +//! address, OOM cascade) makes the context unrecoverable, today the +//! spawn_blocking thread carrying that bad state simply returns to +//! tokio's pool — invisible. With the per-device thread, the +//! poisoned flag lives on the thread itself; subsequent +//! `submit()` calls fast-reject at the channel boundary with a +//! clear "device worker is poisoned" error before any further CUDA +//! work is attempted. +//! +//! The TP worker subprocesses (`harness/tp/worker.rs`) are already this +//! pattern, just out-of-process. The in-process variant uses the same +//! discipline for rank 0. +//! +//! Phase 1 of the refactor exposes only `Job::QueryVram` + `Job::Shutdown`. +//! Forward, kv-cache clear, model load, and NCCL bring-up move in later +//! phases. See `/home/grenade/.claude/plans/plan-the-per-device-worker-abstract-micali.md`. + +pub mod dispatch; +pub mod jobs; + +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc::{self, Sender}; +use std::thread::JoinHandle; +use tokio::sync::oneshot; + +pub use jobs::Job; + +/// Errors returned by `DeviceWorkerHandle` submit methods. +#[derive(Debug, thiserror::Error)] +pub enum WorkerError { + /// The worker's CUDA context was poisoned by an earlier driver + /// error. The thread is still alive (dropping it would re-touch + /// the broken context); it returns this error for every job + /// submitted until the daemon is restarted. + #[error( + "device worker for device {device_index} is poisoned \ + (a prior CUDA driver error left the context unrecoverable); \ + restart the daemon to recover" + )] + Poisoned { device_index: u32 }, + /// The worker thread has exited (`Job::Shutdown` was processed or + /// the thread panicked). Subsequent `submit()` calls fail here + /// rather than blocking forever. + #[error("device worker for device {device_index} is no longer running")] + Gone { device_index: u32 }, + /// The dispatched job returned an `Err`. Forwarded verbatim. + #[error(transparent)] + Job(#[from] anyhow::Error), +} + +/// Shared handle to a per-device CUDA worker thread. +/// +/// Cloning the `Arc` lets multiple `LoadedModel`s (and `TpLoadedModel`s) +/// share the same worker — there's one worker per CUDA device index, +/// not one per model. +pub struct DeviceWorkerHandle { + device_index: u32, + tx: Sender, + poisoned: Arc, + /// `Mutex>` so `shutdown()` can take the handle + /// out without `&mut self` and so the inevitable `Drop` after + /// `shutdown()` doesn't double-join. The mutex is uncontended in + /// practice: only one caller ever takes the handle. + join: std::sync::Mutex>>, +} + +impl DeviceWorkerHandle { + /// Spawn a new worker for the given CUDA device index. + /// + /// The thread is named `cuda-dev-N` so it shows up legibly in + /// `top -H`, `pidstat -t`, and gdb backtraces. On CUDA builds, the + /// thread binds `CudaContext::new(N)` on startup; on CPU builds + /// (`--no-default-features`) the thread runs without a context and + /// every job that touches CUDA falls through to a zero return. + pub fn spawn(device_index: u32) -> anyhow::Result> { + let (tx, rx) = mpsc::channel::(); + let poisoned = Arc::new(AtomicBool::new(false)); + let poisoned_for_thread = Arc::clone(&poisoned); + let join = std::thread::Builder::new() + .name(format!("cuda-dev-{device_index}")) + .spawn(move || { + dispatch::run(device_index, rx, poisoned_for_thread); + })?; + Ok(Arc::new(Self { + device_index, + tx, + poisoned, + join: std::sync::Mutex::new(Some(join)), + })) + } + + pub fn device_index(&self) -> u32 { + self.device_index + } + + pub fn is_poisoned(&self) -> bool { + self.poisoned.load(Ordering::Acquire) + } + + /// Mark the worker's context as poisoned. Future `submit()` calls + /// short-circuit to `WorkerError::Poisoned` before sending. The + /// dispatch loop also flips into drain-only mode when it sees this + /// flag, so any jobs already in flight on the channel reply with + /// the same error without touching CUDA. + #[allow(dead_code)] + pub(crate) fn set_poisoned(&self) { + self.poisoned.store(true, Ordering::Release); + } + + /// Send `Job::QueryVram`, await the worker's reply. + /// + /// Returns `Ok((free_mb, total_mb))` on success, `Ok((0, 0))` on + /// CPU builds or when the device lacks a bound context, or an + /// error if the worker is poisoned, gone, or the query itself + /// failed inside cudarc. + pub async fn query_vram(&self) -> Result<(u64, u64), WorkerError> { + if self.poisoned.load(Ordering::Acquire) { + return Err(WorkerError::Poisoned { + device_index: self.device_index, + }); + } + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::QueryVram { reply: reply_tx }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + match reply_rx.await { + Ok(result) => result.map_err(WorkerError::from), + Err(_) => Err(WorkerError::Gone { + device_index: self.device_index, + }), + } + } + + /// Send `Job::Shutdown` and join the thread. Idempotent — calling + /// twice is a no-op the second time. + pub fn shutdown(&self) -> anyhow::Result<()> { + // Best-effort send: if the channel is already closed (thread + // exited after a prior shutdown or panic) the send fails and + // we fall through to the join which returns the panic, if any. + let _ = self.tx.send(Job::Shutdown); + let join = self.join.lock().unwrap().take(); + if let Some(j) = join { + j.join() + .map_err(|_| anyhow::anyhow!("worker thread panicked during shutdown"))?; + } + Ok(()) + } +} + +impl Drop for DeviceWorkerHandle { + fn drop(&mut self) { + // Best-effort: send Shutdown so the thread breaks its loop + // and exits. We do NOT join here — Drop may run on a tokio + // worker thread, and joining a thread that's still processing + // the last job would block the runtime. The OS reaps the + // thread on detach. + let _ = self.tx.send(Job::Shutdown); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[tokio::test] + async fn spawn_query_vram_shutdown() { + let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok"); + // CPU build (the only one CI runs) returns (0, 0) by design; + // a CUDA build with a real device would return real values. + let result = handle.query_vram().await.expect("query ok"); + // We assert >= 0 — the field width matters more than the value. + let _ = result.0; + let _ = result.1; + handle.shutdown().expect("shutdown ok"); + } + + #[tokio::test] + async fn thread_is_named_correctly() { + // The thread name lets `top -H` / pidstat / gdb show + // `cuda-dev-N` instead of an opaque tokio worker name. Verify + // by spawning and reading proc-self thread comms — but on + // platforms without /proc, just confirm we don't crash. + let handle = DeviceWorkerHandle::spawn(7).expect("spawn ok"); + // Round-trip a job to ensure the thread is alive and processing. + handle.query_vram().await.expect("query ok"); + handle.shutdown().expect("shutdown ok"); + } + + #[tokio::test] + async fn submit_after_shutdown_returns_gone() { + let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok"); + handle.shutdown().expect("shutdown ok"); + // Channel closed; submit should map to Gone rather than block. + let result = handle.query_vram().await; + match result { + Err(WorkerError::Gone { device_index: 0 }) => {} + other => panic!("expected Gone, got {other:?}"), + } + } + + #[tokio::test] + async fn poisoned_flag_short_circuits_submit() { + let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok"); + handle.set_poisoned(); + let result = handle.query_vram().await; + match result { + Err(WorkerError::Poisoned { device_index: 0 }) => {} + other => panic!("expected Poisoned, got {other:?}"), + } + // The channel is still alive; shutdown should still succeed. + handle.shutdown().expect("shutdown ok"); + } + + #[tokio::test] + async fn shutdown_drains_pending_jobs() { + let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok"); + // Submit many concurrent jobs; they should all complete even + // though a Shutdown is racing them. + let mut futures = Vec::new(); + for _ in 0..16 { + let h = Arc::clone(&handle); + futures.push(tokio::spawn(async move { h.query_vram().await })); + } + // Small yield to give the senders a chance to actually send + // before we issue the shutdown; not strictly necessary because + // the channel is FIFO, but makes the test's intent clearer. + tokio::time::sleep(Duration::from_millis(10)).await; + handle.shutdown().expect("shutdown ok"); + for f in futures { + // Each query should have completed (Ok or Gone, never panic). + let _ = f.await.expect("task did not panic"); + } + } +} diff --git a/crates/neuron/src/harness/mod.rs b/crates/neuron/src/harness/mod.rs index 831cbe0..7b1fb45 100644 --- a/crates/neuron/src/harness/mod.rs +++ b/crates/neuron/src/harness/mod.rs @@ -2,6 +2,7 @@ pub mod arch; pub mod candle; +pub mod device_worker; pub mod tp; use anyhow::Result;