diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 30de984..23d5a71 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -102,7 +102,14 @@ impl LoadedHandle { /// moved into `spawn_blocking` for synchronous candle forward passes. pub struct LoadedModel { pub model_id: String, - pub arch: Arc>, + /// Local (async-side) handle to the model architecture. `Some` + /// only when the model loaded onto the CPU device (no CUDA + /// available); the inference path then takes this mutex via + /// `spawn_blocking` and runs candle ops on the CPU backend. + /// `None` when the model loaded onto a CUDA device — in that case + /// the architecture lives in the worker thread's slab and is + /// addressed via [`Self::arch_handle`]. + pub arch: Option>>, pub tokenizer: Tokenizer, pub device: Device, pub quant: Option, @@ -118,10 +125,15 @@ pub struct LoadedModel { pub poisoned: AtomicBool, /// Handle to the per-device CUDA worker thread for this model's /// device. `None` for CPU loads (no context to own). VRAM queries - /// and — in later refactor phases — forward / kv-cache / unload - /// ops route through this handle so the device's CUDA context - /// stays bound to one OS thread for the daemon's lifetime. + /// and — for CUDA loads — forward / kv-cache / drop ops route + /// through this handle so the device's CUDA context stays bound + /// to one OS thread for the daemon's lifetime. pub worker: Option>, + /// Index into the worker's `ModelArch` slab. `Some` iff the model + /// loaded onto a CUDA device and was successfully transferred to + /// the worker; in that case [`Self::arch`] is `None`. The two + /// fields are mutually exclusive. + pub arch_handle: Option, } impl LoadedModel { @@ -475,6 +487,16 @@ fn logits_health(t: &Tensor) -> LogitsHealth { }; } }; + logits_health_slice(&values) +} + +/// Same diagnostic as [`logits_health`] but operates directly on a +/// `[f32]` slice. Used by the worker-routed inference paths where the +/// device → host copy has already happened on the worker thread and +/// the async caller has the values in hand. Avoids the round-trip of +/// rebuilding a Tensor just to call to_vec1 again. +#[allow(dead_code)] +fn logits_health_slice(values: &[f32]) -> LogitsHealth { let mut nan = 0usize; let mut pos_inf = 0usize; let mut neg_inf = 0usize; @@ -483,7 +505,7 @@ fn logits_health(t: &Tensor) -> LogitsHealth { let mut finite_max = f32::NEG_INFINITY; let mut finite_sum = 0.0_f64; let mut finite_count = 0usize; - for &v in &values { + for &v in values { if v.is_nan() { nan += 1; } else if v == f32::INFINITY { @@ -1104,15 +1126,22 @@ impl CandleHarness { "chat_completion: starting" ); - let arch_arc = Arc::clone(&loaded.arch); - let device = loaded.device.clone(); - - let inference_result = - tokio::task::spawn_blocking(move || -> Result<(Vec, String)> { - let mut guard = arch_arc.blocking_lock(); - run_inference( - &mut guard, - &device, + // Routing: CUDA loads go through the per-device worker + // thread (introduced in Phase 1; forward/clear added in + // Phase 2). CPU loads keep the existing spawn_blocking + // path because there's no context to own and the channel + // round-trip would only add latency. The two arms produce + // the same `(Vec, String)` shape so the rest of the + // path is shared. + let (generated_ids, finish_reason) = if let (Some(worker), Some(handle)) = + (loaded.worker.as_ref(), loaded.arch_handle) + { + // Worker path (CUDA). + #[cfg(feature = "cuda")] + { + match run_inference_via_worker( + worker, + handle, &prompt_tokens, max_new, temperature, @@ -1120,26 +1149,67 @@ impl CandleHarness { seed, eos_id, ) - }) - .await; + .await + { + Ok(v) => v, + Err(e) => { + loaded.poisoned.store(true, Ordering::Release); + return Err(InferenceError::Other(e)); + } + } + } + #[cfg(not(feature = "cuda"))] + { + // Can't happen: `loaded.worker` is only Some on + // CUDA builds. The dead branch keeps the no-cuda + // build well-typed. + let _ = (worker, handle); + unreachable!("worker handle present without cuda feature"); + } + } else if let Some(arch_arc) = loaded.arch.clone() { + // CPU path: existing spawn_blocking on the local + // Arc>. + let device = loaded.device.clone(); + let inference_result = + tokio::task::spawn_blocking(move || -> Result<(Vec, String)> { + let mut guard = arch_arc.blocking_lock(); + run_inference( + &mut guard, + &device, + &prompt_tokens, + max_new, + temperature, + top_p, + seed, + eos_id, + ) + }) + .await; - // Any failure inside the spawn_blocking touched CUDA via - // candle's forward / cache code, so we treat it as a - // device-poisoning event. The terminal log at the bottom - // of the wrapper reports the error; this flag stops the - // NEXT request from going down the same path. - let (generated_ids, finish_reason) = match inference_result { - Ok(Ok(v)) => v, - Ok(Err(e)) => { - loaded.poisoned.store(true, Ordering::Release); - return Err(InferenceError::Other(e)); - } - Err(e) => { - loaded.poisoned.store(true, Ordering::Release); - return Err(InferenceError::Other(anyhow::anyhow!( - "inference task panicked: {e}" - ))); + // Any failure inside the spawn_blocking touched CUDA via + // candle's forward / cache code, so we treat it as a + // device-poisoning event. The terminal log at the bottom + // of the wrapper reports the error; this flag stops the + // NEXT request from going down the same path. + match inference_result { + Ok(Ok(v)) => v, + Ok(Err(e)) => { + loaded.poisoned.store(true, Ordering::Release); + return Err(InferenceError::Other(e)); + } + Err(e) => { + loaded.poisoned.store(true, Ordering::Release); + return Err(InferenceError::Other(anyhow::anyhow!( + "inference task panicked: {e}" + ))); + } } + } else { + // LoadedModel invariant: exactly one of `worker` / + // `arch` is Some. Reaching here is a construction bug. + return Err(InferenceError::Other(anyhow::anyhow!( + "LoadedModel has neither worker handle nor local arch — load-path bug" + ))); }; let completion_text = loaded @@ -1244,7 +1314,6 @@ impl CandleHarness { .token_to_id("<|im_end|>") .or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>")); - let arch_arc = Arc::clone(&loaded.arch); let device = loaded.device.clone(); let tokenizer = loaded.tokenizer.clone(); let model_id = request.model.clone(); @@ -1315,40 +1384,96 @@ impl CandleHarness { "chat_completion (stream): starting" ); } - tokio::task::spawn_blocking(move || { - let _g = span_for_task.enter(); - let mut guard = arch_arc.blocking_lock(); - match run_inference_streaming( - &mut guard, - &device, - &tokenizer, - &prompt_tokens, - max_new, - temperature, - top_p, - seed, - eos_id, - &id, - created, - &model_id, - &tx, - ) { - Ok(()) => tracing::info!( - prompt_tokens = prompt_len, - total_ms = req_start.elapsed().as_millis(), - "chat_completion (stream): done" - ), - Err(e) => { - loaded_for_task.poisoned.store(true, Ordering::Release); - tracing::error!( - error = %format!("{e:#}"), + // Routing parallel to the non-streaming chat_completion: CUDA + // goes through the worker (async task), CPU keeps the + // spawn_blocking + Arc> path. + if let (Some(worker), Some(handle)) = (loaded.worker.clone(), loaded.arch_handle) { + #[cfg(feature = "cuda")] + { + let prompt_tokens = prompt_tokens.clone(); + tokio::spawn( + async move { + match stream_inference_via_worker( + worker, + handle, + tokenizer, + prompt_tokens, + max_new, + temperature, + top_p, + seed, + eos_id, + id, + created, + model_id, + tx, + ) + .await + { + Ok(_finish_reason) => tracing::info!( + prompt_tokens = prompt_len, + total_ms = req_start.elapsed().as_millis(), + "chat_completion (stream): done" + ), + Err(e) => { + loaded_for_task.poisoned.store(true, Ordering::Release); + tracing::error!( + error = %format!("{e:#}"), + prompt_tokens = prompt_len, + total_ms = req_start.elapsed().as_millis(), + "chat_completion (stream): failed, model marked poisoned" + ); + } + } + } + .instrument(span_for_task), + ); + } + #[cfg(not(feature = "cuda"))] + { + let _ = (worker, handle, span_for_task); + unreachable!("worker handle present without cuda feature"); + } + } else if let Some(arch_arc) = loaded.arch.clone() { + tokio::task::spawn_blocking(move || { + let _g = span_for_task.enter(); + let mut guard = arch_arc.blocking_lock(); + match run_inference_streaming( + &mut guard, + &device, + &tokenizer, + &prompt_tokens, + max_new, + temperature, + top_p, + seed, + eos_id, + &id, + created, + &model_id, + &tx, + ) { + Ok(()) => tracing::info!( prompt_tokens = prompt_len, total_ms = req_start.elapsed().as_millis(), - "chat_completion (stream): failed, model marked poisoned" - ); + "chat_completion (stream): done" + ), + Err(e) => { + loaded_for_task.poisoned.store(true, Ordering::Release); + tracing::error!( + error = %format!("{e:#}"), + prompt_tokens = prompt_len, + total_ms = req_start.elapsed().as_millis(), + "chat_completion (stream): failed, model marked poisoned" + ); + } } - } - }); + }); + } else { + return Err(InferenceError::Other(anyhow::anyhow!( + "LoadedModel has neither worker handle nor local arch — load-path bug" + ))); + } Ok(rx) } @@ -1432,22 +1557,37 @@ impl Harness for CandleHarness { // Worker thread for the chosen device. CPU loads (CUDA // unavailable / not requested) skip the worker — there's no - // context to own. - let worker = match &device { + // context to own. For CUDA loads, the arch is transferred + // into the worker's slab now so the inference path can + // reference it via the returned `ArchHandle`. The explicit + // type annotation lets the no-cuda build resolve `None` to + // the right `Option>` type. + let worker: Option> = match &device { #[cfg(feature = "cuda")] Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?), _ => None, }; + let (arch_local, arch_handle) = match &worker { + Some(w) => { + let handle = w + .transfer_in(Box::new(arch)) + .await + .map_err(|e| anyhow::anyhow!("transfer arch into device worker: {e}"))?; + (None, Some(handle)) + } + None => (Some(Arc::new(Mutex::new(arch))), None), + }; let loaded = Arc::new(LoadedModel { model_id: spec.model_id.clone(), - arch: Arc::new(Mutex::new(arch)), + arch: arch_local, tokenizer, device, quant: spec.quant.clone(), devices, poisoned: AtomicBool::new(false), worker, + arch_handle, }); let mut models = self.models.write().await; @@ -1465,13 +1605,26 @@ impl Harness for CandleHarness { anyhow::bail!("model '{model_id}' not loaded"); }; // Single-GPU drops are immediate — the LoadedModel goes out of - // scope with the Arc and candle frees VRAM. TP unloads also - // need to tell every worker to drop its shard before the pool - // itself is dropped (otherwise the workers keep their shards - // around until Shutdown, which is wasteful and would surface - // as VRAM not freed promptly). + // scope with the Arc and candle frees VRAM. CUDA loads also + // ship a `Job::DropArch` to the device worker so the boxed + // `ModelArch` releases its CUDA allocations on the right + // thread (with the bound context); without that, the Drop + // would run on whatever tokio thread happens to be holding + // the last `Arc` clone when this fn returns. + // TP unloads further coordinate the subprocess pool below. match handle { - LoadedHandle::Single(_) => {} + LoadedHandle::Single(single) => { + if let (Some(worker), Some(arch_handle)) = + (single.worker.as_ref(), single.arch_handle) + && let Err(e) = worker.drop_arch(arch_handle).await + { + tracing::warn!( + model = %model_id, + error = %e, + "single-GPU unload: DropArch RPC failed (model state may leak in worker slab)" + ); + } + } #[cfg(feature = "cuda")] LoadedHandle::Tp(tp) => { // Try to recover the inner TpLoadedModel so we can move @@ -2276,6 +2429,239 @@ fn format_qwen3_prompt(messages: &[ChatMessage]) -> String { prompt } +#[allow(clippy::too_many_arguments)] +/// Run the full single-GPU inference loop via the device worker. +/// +/// Mirrors `run_inference`'s logic but routes each forward step +/// through `worker.forward_logits()` (returns CPU-side `Vec`) +/// and runs `apply_repeat_penalty` + sampling on a CPU candle tensor. +/// The device-resident logits tensor never escapes the worker thread. +/// +/// Used by the CUDA path of `chat_completion`. The CPU path keeps +/// `run_inference` (spawn_blocking against `Arc>`) +/// because there's no CUDA context to own and the worker indirection +/// would only add channel overhead with no diagnostic benefit. +#[cfg(feature = "cuda")] +#[allow(clippy::too_many_arguments)] +async fn run_inference_via_worker( + worker: &super::device_worker::DeviceWorkerHandle, + handle: super::device_worker::ArchHandle, + prompt_tokens: &[u32], + max_new: usize, + temperature: f64, + top_p: Option, + seed: u64, + eos_id: Option, +) -> Result<(Vec, String)> { + let mut logits_processor = { + let sampling = if temperature <= 0.0 { + Sampling::ArgMax + } else { + match top_p { + Some(p) => Sampling::TopP { p, temperature }, + None => Sampling::All { temperature }, + } + }; + LogitsProcessor::from_sampling(seed, sampling) + }; + + let mut generated: Vec = Vec::new(); + let prompt_len = prompt_tokens.len(); + + worker + .clear_kv_cache(handle) + .await + .map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?; + + // Prefill — every rank embeds the prompt with offset 0. + let logits_vec = worker + .forward_logits(handle, prompt_tokens.to_vec(), 0) + .await + .map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?; + let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?; + let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) { + Ok(t) => t, + Err(e) => { + let health = logits_health_slice(&logits_vec); + tracing::warn!( + ?health, + "chat_completion (worker): prefill sample failed; logits unhealthy" + ); + return Err(e); + } + }; + + if Some(next_token) == eos_id { + return Ok((generated, "stop".into())); + } + generated.push(next_token); + + for index in 0..max_new.saturating_sub(1) { + let logits_vec = worker + .forward_logits(handle, vec![next_token], prompt_len + index) + .await + .map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?; + let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?; + next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) { + Ok(t) => t, + Err(e) => { + let health = logits_health_slice(&logits_vec); + tracing::warn!( + step = index, + ?health, + "chat_completion (worker): decode sample failed; logits unhealthy" + ); + return Err(e); + } + }; + if Some(next_token) == eos_id { + return Ok((generated, "stop".into())); + } + generated.push(next_token); + } + + Ok((generated, "length".into())) +} + +/// Streaming counterpart of [`run_inference_via_worker`]. Emits one +/// `ChatCompletionChunk` per generated token via `tx`; routes every +/// forward step through `worker.forward_logits()`. Same per-step +/// CPU-side sampling discipline — no device tensor escapes the +/// worker thread. +#[cfg(feature = "cuda")] +#[allow(clippy::too_many_arguments)] +async fn stream_inference_via_worker( + worker: Arc, + handle: super::device_worker::ArchHandle, + tokenizer: Tokenizer, + prompt_tokens: Vec, + max_new: usize, + temperature: f64, + top_p: Option, + seed: u64, + eos_id: Option, + id: String, + created: u64, + model_id: String, + tx: mpsc::Sender, +) -> Result { + let mut logits_processor = { + let sampling = if temperature <= 0.0 { + Sampling::ArgMax + } else { + match top_p { + Some(p) => Sampling::TopP { p, temperature }, + None => Sampling::All { temperature }, + } + }; + LogitsProcessor::from_sampling(seed, sampling) + }; + + let mut all_tokens: Vec = Vec::new(); + let mut decoded_prefix = String::new(); + let prompt_len = prompt_tokens.len(); + let mut finish_reason = "length".to_string(); + + worker + .clear_kv_cache(handle) + .await + .map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?; + + let logits_vec = worker + .forward_logits(handle, prompt_tokens, 0) + .await + .map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?; + let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?; + let mut next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) { + Ok(t) => t, + Err(e) => { + let health = logits_health_slice(&logits_vec); + tracing::warn!( + ?health, + "chat_completion (stream/worker): prefill sample failed; logits unhealthy" + ); + return Err(e); + } + }; + + if Some(next_token) == eos_id { + finish_reason = "stop".into(); + } else { + all_tokens.push(next_token); + if !emit_chunk( + &all_tokens, + &mut decoded_prefix, + &tokenizer, + &tx, + &id, + created, + &model_id, + ) + .await + { + return Ok(finish_reason); // Client gone — clean stream end. + } + + for index in 0..max_new.saturating_sub(1) { + let logits_vec = worker + .forward_logits(handle, vec![next_token], prompt_len + index) + .await + .map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?; + let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?; + next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) { + Ok(t) => t, + Err(e) => { + let health = logits_health_slice(&logits_vec); + tracing::warn!( + step = index, + ?health, + "chat_completion (stream/worker): decode sample failed; logits unhealthy" + ); + return Err(e); + } + }; + if Some(next_token) == eos_id { + finish_reason = "stop".into(); + break; + } + all_tokens.push(next_token); + if !emit_chunk( + &all_tokens, + &mut decoded_prefix, + &tokenizer, + &tx, + &id, + created, + &model_id, + ) + .await + { + return Ok(finish_reason); + } + } + } + + // Final chunk carrying finish_reason. Matches the run_inference_streaming + // shape so the SSE consumer sees an identical termination sequence. + let final_chunk = ChatCompletionChunk { + id: id.clone(), + object: "chat.completion.chunk".into(), + created, + model: model_id.clone(), + choices: vec![ChunkChoice { + index: 0, + delta: serde_json::Value::Object(Default::default()), + finish_reason: Some(finish_reason.clone()), + extra: serde_json::Value::Object(Default::default()), + }], + usage: None, + extra: serde_json::Value::Object(Default::default()), + }; + let _ = tx.send(final_chunk).await; + + Ok(finish_reason) +} + #[allow(clippy::too_many_arguments)] fn run_inference( arch: &mut ModelArch, diff --git a/crates/neuron/src/harness/device_worker/dispatch.rs b/crates/neuron/src/harness/device_worker/dispatch.rs index 378c048..a6e2e3a 100644 --- a/crates/neuron/src/harness/device_worker/dispatch.rs +++ b/crates/neuron/src/harness/device_worker/dispatch.rs @@ -7,20 +7,41 @@ //! one-line `oneshot::Sender::send` call to ship the reply back, which //! is non-blocking. //! -//! Phase 1 handles only `QueryVram` and `Shutdown`. Later phases add -//! Forward, ClearKv, NCCL, and load handlers as separate match arms. +//! Phase 2 handles QueryVram, TransferIn, DropArch, ClearKv, +//! ForwardLogits, Shutdown. Phase 3 will add the TP variants +//! (NcclInit, NcclSanity, TpLoadShard, TpForward, TpClearKv) and the +//! ARCH model state in this state slab will gain a companion +//! `tp_models: HashMap>`. -use crate::harness::device_worker::jobs::Job; +use crate::harness::candle::ModelArch; +use crate::harness::device_worker::jobs::{ArchHandle, Job}; +use std::collections::HashMap; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::Receiver; /// Per-thread state owned by the worker. On CUDA builds the `Arc` -/// is created and bound at thread startup; on CPU builds the struct is -/// empty save for the device index (kept for log clarity). -#[cfg(feature = "cuda")] +/// is created and bound at thread startup; on CPU builds the struct +/// is mostly empty. struct DeviceWorkerState { + #[allow(dead_code)] device_index: u32, + /// Candle `Device` constructed at startup. Used by handlers (e.g. + /// `ForwardLogits`) to build input tensors against the right + /// device. Falls back to `Device::Cpu` if CUDA init fails. + device: candle_core::Device, + /// Boxed `ModelArch` slab. Indexed by an opaque `ArchHandle` minted + /// by `TransferIn`. The Box means the entry's address is stable + /// across HashMap rehashes (relevant only when we later hand out + /// `&mut ModelArch` references — for Phase 2 every handler runs + /// `&mut` via `get_mut`, no long-lived borrows). + models: HashMap>, + /// Counter for minting fresh `ArchHandle`s. Each `TransferIn` + /// increments and returns the new value. Wraps at u64::MAX after + /// ~10^19 model loads — not a practical concern. + next_handle: u64, + #[cfg(feature = "cuda")] + #[allow(dead_code)] /// `None` only if `CudaContext::new()` failed — in that case the /// thread still runs so the handle's lifecycle stays uniform, but /// every job that touches CUDA falls through to a zero reply with @@ -28,18 +49,12 @@ struct DeviceWorkerState { ctx: Option>, } -#[cfg(not(feature = "cuda"))] -#[allow(dead_code)] -struct DeviceWorkerState { - device_index: u32, -} - /// Worker thread entry point. Runs until `Job::Shutdown` arrives or /// the channel sender is dropped (which happens when the last /// `DeviceWorkerHandle` `Arc` is dropped without an explicit /// `shutdown()`). pub(crate) fn run(device_index: u32, rx: Receiver, poisoned: Arc) { - let state = init_state(device_index); + let mut state = init_state(device_index); tracing::info!(device_index, "device worker started"); while let Ok(job) = rx.recv() { @@ -51,7 +66,7 @@ pub(crate) fn run(device_index: u32, rx: Receiver, poisoned: Arc, poisoned: Arc { + let handle = ArchHandle(state.next_handle); + state.next_handle = state.next_handle.wrapping_add(1); + state.models.insert(handle, arch); + tracing::debug!( + device_index, + handle = handle.0, + slab_size = state.models.len(), + "device worker: model transferred in" + ); + let _ = reply.send(Ok(handle)); + } + Job::DropArch { handle, reply } => { + let removed = state.models.remove(&handle); + let was_present = removed.is_some(); + // Explicit drop on this thread — runs the Box + // Drop with the CUDA context bound here, which frees + // all device tensors on the right context. The Drop is + // implicit on the `removed` value going out of scope at + // the end of the arm; calling drop() explicitly just + // makes the intent visible. + drop(removed); + tracing::debug!( + device_index, + handle = handle.0, + was_present, + slab_size = state.models.len(), + "device worker: model dropped" + ); + let _ = reply.send(()); + } + Job::ClearKv { handle, reply } => { + let result = match state.models.get_mut(&handle) { + Some(arch) => arch.clear_kv_cache(), + None => Err(anyhow::anyhow!("ClearKv: no model for handle {}", handle.0)), + }; + let _ = reply.send(result); + } + Job::ForwardLogits { + handle, + tokens, + offset, + reply, + } => { + let result = forward_logits(&mut state, handle, &tokens, offset); + let _ = reply.send(result); + } // Handled by the matches!() check above; reaching here // means a Shutdown slipped past which is a bug. Job::Shutdown => unreachable!("Shutdown should break above"), } } - tracing::info!(device_index, "device worker exiting"); + tracing::info!( + device_index, + slab_size = state.models.len(), + "device worker exiting; dropping remaining models" + ); + // Drops every model in the slab on this thread before the function + // returns. Critical for CUDA tensors: dropping on a thread that + // doesn't have the context bound is UB. Phase 2 still runs Drop + // via the slab going out of scope, which is correct as long as no + // pre-poisoned state lurks in here — see the poisoned-mode + // semantics in mod.rs for the Phase 3+ refinement. } -#[cfg(feature = "cuda")] fn init_state(device_index: u32) -> DeviceWorkerState { - use candle_core::cuda::cudarc::driver::CudaContext; - match CudaContext::new(device_index as usize) { - Ok(ctx) => { - // Make sure the context is current on this thread. cudarc - // is generally fine with lazy binding, but doing it once - // here gives us a deterministic moment to log "context - // bound" — and makes `mem_get_info()` work without further - // bind dances inside the dispatch handlers. - if let Err(e) = ctx.bind_to_thread() { + #[cfg(feature = "cuda")] + { + use candle_core::cuda::cudarc::driver::CudaContext; + // Construct a candle Device first — cudarc returns the + // primary context for this index on subsequent calls, so + // CudaContext::new and Device::new_cuda end up sharing state. + let (device, ctx) = match candle_core::Device::new_cuda(device_index as usize) { + Ok(device) => match CudaContext::new(device_index as usize) { + Ok(ctx) => { + if let Err(e) = ctx.bind_to_thread() { + tracing::warn!( + device_index, + error = ?e, + "device worker: bind_to_thread failed; \ + operations will still rebind per-call" + ); + } else { + tracing::info!(device_index, "device worker bound CUDA context"); + } + (device, Some(ctx)) + } + Err(e) => { + tracing::warn!( + device_index, + error = ?e, + "device worker: CudaContext::new failed; \ + vram queries will return (0, 0), forward will error" + ); + (device, None) + } + }, + Err(e) => { tracing::warn!( device_index, - error = ?e, - "device worker: bind_to_thread failed; \ - vram queries will still rebind per-call" + error = %e, + "device worker: Device::new_cuda failed; falling back to CPU device" ); - } else { - tracing::info!(device_index, "device worker bound CUDA context"); - } - DeviceWorkerState { - device_index, - ctx: Some(ctx), - } - } - Err(e) => { - tracing::warn!( - device_index, - error = ?e, - "device worker: CudaContext::new failed; \ - vram queries will return (0, 0)" - ); - DeviceWorkerState { - device_index, - ctx: None, + (candle_core::Device::Cpu, None) } + }; + DeviceWorkerState { + device_index, + device, + models: HashMap::new(), + next_handle: 1, + ctx, + } + } + #[cfg(not(feature = "cuda"))] + { + DeviceWorkerState { + device_index, + device: candle_core::Device::Cpu, + models: HashMap::new(), + next_handle: 1, } } -} - -#[cfg(not(feature = "cuda"))] -fn init_state(device_index: u32) -> DeviceWorkerState { - DeviceWorkerState { device_index } } #[cfg(feature = "cuda")] @@ -144,6 +231,42 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> { Ok((0, 0)) } +/// Forward step + copy the `[vocab]` logits to a CPU `Vec` ready +/// for sampling on the async caller. The model's `device()` (CUDA or +/// CPU) determines where the kernel runs; this fn doesn't care. +/// +/// On CUDA, the `to_dtype(F32).flatten_all().to_vec1::()` chain +/// triggers the device → host copy. The copy runs synchronously on +/// this worker thread; the bound context owns the source allocation +/// so the transfer is straightforward. +fn forward_logits( + state: &mut DeviceWorkerState, + handle: ArchHandle, + tokens: &[u32], + offset: usize, +) -> anyhow::Result> { + use candle_core::{DType, Tensor}; + + // Build the input tensor on the worker's own device. cudarc's + // primary-context model means `Device::new_cuda(idx)` shares state + // with the `CudaContext` we bound at startup, so this is the same + // device the ModelArch was loaded against. + let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?; + + let arch = state + .models + .get_mut(&handle) + .ok_or_else(|| anyhow::anyhow!("ForwardLogits: no model for handle {}", handle.0))?; + + let logits = arch.forward(&input, offset)?; + // Copy to CPU f32. logits is already `[vocab]` (squeeze_to_vocab + // inside ModelArch::forward). The to_dtype handles bf16/f16 → + // f32 promotion for the sampler. + let logits = logits.to_dtype(DType::F32)?.flatten_all()?; + let values = logits.to_vec1::()?; + Ok(values) +} + /// Reply to a job with the poisoned-worker error. Used when the worker /// has flipped into drain-only mode after a CUDA driver error. /// @@ -153,11 +276,26 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> { /// poisoned error so callers never hang waiting for a worker that's /// no longer running CUDA. fn drain_poisoned(job: Job, device_index: u32) { + let err = || anyhow::anyhow!("device worker for device {device_index} is poisoned"); match job { Job::QueryVram { reply } => { - let _ = reply.send(Err(anyhow::anyhow!( - "device worker for device {device_index} is poisoned" - ))); + let _ = reply.send(Err(err())); + } + Job::TransferIn { reply, .. } => { + let _ = reply.send(Err(err())); + } + Job::DropArch { reply, .. } => { + // Drop reply is `()` — no error path. Send the unit so the + // caller's await resolves; the model handle is leaked in + // the worker's slab, but the whole slab gets `mem::forget` + // on shutdown anyway per the poisoned-thread design. + let _ = reply.send(()); + } + Job::ClearKv { reply, .. } => { + let _ = reply.send(Err(err())); + } + Job::ForwardLogits { reply, .. } => { + let _ = reply.send(Err(err())); } Job::Shutdown => { // Filtered by the matches!() guard in run(); reaching diff --git a/crates/neuron/src/harness/device_worker/jobs.rs b/crates/neuron/src/harness/device_worker/jobs.rs index 23c2715..2cfbb70 100644 --- a/crates/neuron/src/harness/device_worker/jobs.rs +++ b/crates/neuron/src/harness/device_worker/jobs.rs @@ -4,16 +4,33 @@ //! needs plus a `tokio::sync::oneshot::Sender` for the reply. The //! async-side `DeviceWorkerHandle` constructs a job, sends it down the //! `std::sync::mpsc` channel, and `await`s the oneshot for the reply. -//! -//! Phase 1 includes only `QueryVram` and `Shutdown`. Phases 2–4 add -//! forward, kv-cache clear, drop-arch, NCCL init/sanity, and the load -//! variants. Each new variant lands as a separate PR so the worker -//! thread stays small at every checkpoint. +use crate::harness::candle::ModelArch; use anyhow::Result; use tokio::sync::oneshot; +/// Opaque handle to a `ModelArch` stored in the worker thread's state +/// slab. Cheap to copy; `Send + Sync` so it crosses task boundaries +/// freely. The actual `Box` it points to is owned by the +/// worker thread for the duration of the handle's lifetime — the only +/// way to drop the model is to send `Job::DropArch { handle }` so the +/// `Drop` impl runs on the thread with the bound CUDA context (the +/// invariant the whole refactor exists to guarantee). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ArchHandle(pub u64); + /// One unit of work for the device worker. +/// +/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the +/// single-GPU inference primitives: transfer-in a freshly-loaded +/// `ModelArch`, drop it, clear its KV cache, and run one forward step +/// returning CPU-side logits ready for sampling on the async caller. +/// +/// Sampling stays on the async side intentionally. The worker copies +/// logits to CPU (`Vec`) before reply, so the device-resident +/// tensor never escapes the worker thread and the async caller's +/// `LogitsProcessor::sample` runs entirely on the CPU candle backend +/// — no incidental context binding on a tokio worker thread. pub enum Job { /// Query free / total VRAM on the device. Returns /// `(free_mb, total_mb)`. CPU builds and contexts that failed to @@ -22,10 +39,43 @@ pub enum Job { QueryVram { reply: oneshot::Sender>, }, - /// Tell the worker to break its dispatch loop and exit. The - /// channel is then drained — any further jobs already queued get - /// dropped (their oneshot senders are dropped, causing the async - /// caller's receiver to return `Err` which `DeviceWorkerHandle` - /// maps to `WorkerError::Gone`). + /// Move a freshly-loaded `ModelArch` into the worker's state slab. + /// Returns an `ArchHandle` the caller stores on `LoadedModel` and + /// passes back in subsequent `ClearKv` / `ForwardLogits` / + /// `DropArch` jobs. + TransferIn { + arch: Box, + reply: oneshot::Sender>, + }, + /// Remove the model from the slab and drop it. The `Drop` runs on + /// the worker thread so CUDA tensors release their memory on the + /// same context that allocated them. + DropArch { + handle: ArchHandle, + reply: oneshot::Sender<()>, + }, + /// Reset the KV cache for this model. Called at the start of every + /// chat completion so a new request doesn't attend over the + /// previous one's tokens. + ClearKv { + handle: ArchHandle, + reply: oneshot::Sender>, + }, + /// Run one forward step and copy the resulting `[vocab]` logits to + /// CPU. The caller takes the returned `Vec`, wraps it in a + /// CPU `Tensor`, and runs `apply_repeat_penalty` + sampling + /// without touching the device context. `offset` is the KV-cache + /// position before this step (0 for prefill, `prompt_len + i` for + /// the i-th decode step). + ForwardLogits { + handle: ArchHandle, + tokens: Vec, + offset: usize, + reply: oneshot::Sender>>, + }, + /// Tell the worker to break its dispatch loop and exit. Any jobs + /// queued after this in the channel reply `Err` to their oneshot + /// senders (the senders are dropped on the worker's exit, which + /// the async-side `Receiver::await` maps to `WorkerError::Gone`). Shutdown, } diff --git a/crates/neuron/src/harness/device_worker/mod.rs b/crates/neuron/src/harness/device_worker/mod.rs index 12af5f6..497eeb3 100644 --- a/crates/neuron/src/harness/device_worker/mod.rs +++ b/crates/neuron/src/harness/device_worker/mod.rs @@ -49,7 +49,7 @@ use std::sync::mpsc::{self, Sender}; use std::thread::JoinHandle; use tokio::sync::oneshot; -pub use jobs::Job; +pub use jobs::{ArchHandle, Job}; /// Errors returned by `DeviceWorkerHandle` submit methods. #[derive(Debug, thiserror::Error)] @@ -159,6 +159,124 @@ impl DeviceWorkerHandle { } } + /// Move a freshly-loaded `ModelArch` into the worker's state slab. + /// Returns the `ArchHandle` the caller stores on `LoadedModel`. + /// The `Box` crosses the channel; the worker thread + /// owns it from here on. + pub async fn transfer_in( + &self, + arch: Box, + ) -> Result { + if self.poisoned.load(Ordering::Acquire) { + return Err(WorkerError::Poisoned { + device_index: self.device_index, + }); + } + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::TransferIn { + arch, + reply: reply_tx, + }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + match reply_rx.await { + Ok(result) => result.map_err(WorkerError::from), + Err(_) => Err(WorkerError::Gone { + device_index: self.device_index, + }), + } + } + + /// Tell the worker to drop the `ModelArch` for `handle` on the + /// worker thread (so CUDA tensors release on the right context). + /// Returns `Ok(())` even if the handle wasn't in the slab — Drop + /// is idempotent. Reports `Gone` if the worker isn't running. + pub async fn drop_arch(&self, handle: ArchHandle) -> Result<(), WorkerError> { + // Poisoning doesn't block DropArch — even on a poisoned + // context we want callers to unblock and proceed with the + // unload bookkeeping. The dispatch handler under poison just + // replies `()` without touching the model (the actual Drop + // happens via mem::forget at thread exit per the poison + // protocol). + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::DropArch { + handle, + reply: reply_tx, + }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + match reply_rx.await { + Ok(()) => Ok(()), + Err(_) => Err(WorkerError::Gone { + device_index: self.device_index, + }), + } + } + + /// Reset the KV cache for the model at `handle`. Called at the + /// start of every chat completion so the new prompt doesn't + /// attend over the previous request's tokens. + pub async fn clear_kv_cache(&self, handle: ArchHandle) -> Result<(), WorkerError> { + if self.poisoned.load(Ordering::Acquire) { + return Err(WorkerError::Poisoned { + device_index: self.device_index, + }); + } + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::ClearKv { + handle, + reply: reply_tx, + }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + match reply_rx.await { + Ok(result) => result.map_err(WorkerError::from), + Err(_) => Err(WorkerError::Gone { + device_index: self.device_index, + }), + } + } + + /// Run one forward step and return the resulting `[vocab]` logits + /// as a CPU-side `Vec`. The caller then samples on a CPU + /// candle Tensor without ever binding the device context on its + /// tokio thread. + pub async fn forward_logits( + &self, + handle: ArchHandle, + tokens: Vec, + offset: usize, + ) -> Result, WorkerError> { + if self.poisoned.load(Ordering::Acquire) { + return Err(WorkerError::Poisoned { + device_index: self.device_index, + }); + } + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::ForwardLogits { + handle, + tokens, + offset, + reply: reply_tx, + }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + match reply_rx.await { + Ok(result) => result.map_err(WorkerError::from), + Err(_) => Err(WorkerError::Gone { + device_index: self.device_index, + }), + } + } + /// Send `Job::Shutdown` and join the thread. Idempotent — calling /// twice is a no-op the second time. pub fn shutdown(&self) -> anyhow::Result<()> {