refactor(neuron): phase 2 — single-GPU forward + clear_kv route through device worker
Some checks failed
build-prerelease / Package helexa-neuron-ada RPM (push) Blocked by required conditions
CI / Format (push) Successful in 34s
CI / Clippy (push) Successful in 2m12s
build-prerelease / Resolve version stamps (push) Successful in 3m41s
CI / Test (push) Successful in 5m1s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m32s
build-prerelease / Build neuron-ampere (push) Successful in 5m20s
build-prerelease / Build cortex binary (push) Successful in 12m20s
build-prerelease / Build neuron-ada (push) Successful in 5m17s
build-prerelease / Package cortex RPM (push) Successful in 1m25s
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled

Second slice of the per-device CUDA context-ownership refactor planned at
~/.claude/plans/plan-the-per-device-worker-abstract-micali.md. The two
spawn_blocking sites in `chat_completion` and `chat_completion_stream`
now route through the device worker thread on CUDA loads. CPU loads
keep the existing spawn_blocking + `Arc<Mutex<ModelArch>>` path; there's
no context to own and the channel hop would only add latency.

What this phase changes:

- `Job` gains `TransferIn`, `DropArch`, `ClearKv`, `ForwardLogits`. The
  worker's dispatch state grows a `HashMap<ArchHandle, Box<ModelArch>>`
  slab and a `next_handle` counter for minting opaque handles.
- `LoadedModel.arch: Arc<Mutex<ModelArch>>` → `Option<Arc<Mutex<>>>`,
  plus a new `arch_handle: Option<ArchHandle>` field. The two are
  mutually exclusive: CUDA loads set `arch_handle = Some(_)` after
  transferring the boxed arch into the worker's slab; CPU loads keep
  `arch = Some(_)` for the legacy spawn_blocking path.
- New `run_inference_via_worker` and `stream_inference_via_worker`
  drive the prefill + decode loop by sending `Job::ForwardLogits` per
  step; the worker copies the resulting `[vocab]` logits to a
  CPU-side `Vec<f32>` before reply, so the async caller never holds a
  device-resident tensor. `apply_repeat_penalty` and
  `LogitsProcessor::sample` run on a CPU candle tensor; no context
  binding side-effects on tokio worker threads.
- `logits_health_slice(&[f32])` complements the existing
  `logits_health(&Tensor)` so the new worker paths can compute
  health stats directly from the CPU vec.
- `unload_model` for the single-GPU CUDA path now sends
  `Job::DropArch { handle }` to the worker so the `Box<ModelArch>`
  drops on the thread that allocated its CUDA tensors. The `Drop` runs
  with the bound context, freeing memory on the right context.

What this phase doesn't touch (yet):

- TP forward, TP load, NCCL bring-up — still on spawn_blocking. Phase 3.
- Single-GPU model load — still spawn_blocking, followed by a
  `Job::TransferIn` to move the freshly-built `ModelArch` into the
  worker slab. Phase 4 moves the load itself onto the worker thread
  and eliminates the bootstrap TransferIn.
- The `device_vram_mb` / `cuda_mem_mb` helpers — still present and
  used by the construction-time logs running inside spawn_blocking
  loads. Phase 4 cleanup folds them into `dispatch.rs`.

Public API unchanged. fmt + clippy clean; 37 lib tests + all
integration tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-27 09:55:08 +03:00
parent 081b532387
commit b179204fd3
4 changed files with 830 additions and 138 deletions

View File

@@ -102,7 +102,14 @@ impl LoadedHandle {
/// moved into `spawn_blocking` for synchronous candle forward passes.
pub struct LoadedModel {
pub model_id: String,
pub arch: Arc<Mutex<ModelArch>>,
/// Local (async-side) handle to the model architecture. `Some`
/// only when the model loaded onto the CPU device (no CUDA
/// available); the inference path then takes this mutex via
/// `spawn_blocking` and runs candle ops on the CPU backend.
/// `None` when the model loaded onto a CUDA device — in that case
/// the architecture lives in the worker thread's slab and is
/// addressed via [`Self::arch_handle`].
pub arch: Option<Arc<Mutex<ModelArch>>>,
pub tokenizer: Tokenizer,
pub device: Device,
pub quant: Option<String>,
@@ -118,10 +125,15 @@ pub struct LoadedModel {
pub poisoned: AtomicBool,
/// Handle to the per-device CUDA worker thread for this model's
/// device. `None` for CPU loads (no context to own). VRAM queries
/// and — in later refactor phases — forward / kv-cache / unload
/// ops route through this handle so the device's CUDA context
/// stays bound to one OS thread for the daemon's lifetime.
/// and — for CUDA loads — forward / kv-cache / drop ops route
/// through this handle so the device's CUDA context stays bound
/// to one OS thread for the daemon's lifetime.
pub worker: Option<Arc<super::device_worker::DeviceWorkerHandle>>,
/// Index into the worker's `ModelArch` slab. `Some` iff the model
/// loaded onto a CUDA device and was successfully transferred to
/// the worker; in that case [`Self::arch`] is `None`. The two
/// fields are mutually exclusive.
pub arch_handle: Option<super::device_worker::ArchHandle>,
}
impl LoadedModel {
@@ -475,6 +487,16 @@ fn logits_health(t: &Tensor) -> LogitsHealth {
};
}
};
logits_health_slice(&values)
}
/// Same diagnostic as [`logits_health`] but operates directly on a
/// `[f32]` slice. Used by the worker-routed inference paths where the
/// device → host copy has already happened on the worker thread and
/// the async caller has the values in hand. Avoids the round-trip of
/// rebuilding a Tensor just to call to_vec1 again.
#[allow(dead_code)]
fn logits_health_slice(values: &[f32]) -> LogitsHealth {
let mut nan = 0usize;
let mut pos_inf = 0usize;
let mut neg_inf = 0usize;
@@ -483,7 +505,7 @@ fn logits_health(t: &Tensor) -> LogitsHealth {
let mut finite_max = f32::NEG_INFINITY;
let mut finite_sum = 0.0_f64;
let mut finite_count = 0usize;
for &v in &values {
for &v in values {
if v.is_nan() {
nan += 1;
} else if v == f32::INFINITY {
@@ -1104,15 +1126,22 @@ impl CandleHarness {
"chat_completion: starting"
);
let arch_arc = Arc::clone(&loaded.arch);
let device = loaded.device.clone();
let inference_result =
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
let mut guard = arch_arc.blocking_lock();
run_inference(
&mut guard,
&device,
// Routing: CUDA loads go through the per-device worker
// thread (introduced in Phase 1; forward/clear added in
// Phase 2). CPU loads keep the existing spawn_blocking
// path because there's no context to own and the channel
// round-trip would only add latency. The two arms produce
// the same `(Vec<u32>, String)` shape so the rest of the
// path is shared.
let (generated_ids, finish_reason) = if let (Some(worker), Some(handle)) =
(loaded.worker.as_ref(), loaded.arch_handle)
{
// Worker path (CUDA).
#[cfg(feature = "cuda")]
{
match run_inference_via_worker(
worker,
handle,
&prompt_tokens,
max_new,
temperature,
@@ -1120,26 +1149,67 @@ impl CandleHarness {
seed,
eos_id,
)
})
.await;
.await
{
Ok(v) => v,
Err(e) => {
loaded.poisoned.store(true, Ordering::Release);
return Err(InferenceError::Other(e));
}
}
}
#[cfg(not(feature = "cuda"))]
{
// Can't happen: `loaded.worker` is only Some on
// CUDA builds. The dead branch keeps the no-cuda
// build well-typed.
let _ = (worker, handle);
unreachable!("worker handle present without cuda feature");
}
} else if let Some(arch_arc) = loaded.arch.clone() {
// CPU path: existing spawn_blocking on the local
// Arc<Mutex<ModelArch>>.
let device = loaded.device.clone();
let inference_result =
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
let mut guard = arch_arc.blocking_lock();
run_inference(
&mut guard,
&device,
&prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
)
})
.await;
// Any failure inside the spawn_blocking touched CUDA via
// candle's forward / cache code, so we treat it as a
// device-poisoning event. The terminal log at the bottom
// of the wrapper reports the error; this flag stops the
// NEXT request from going down the same path.
let (generated_ids, finish_reason) = match inference_result {
Ok(Ok(v)) => v,
Ok(Err(e)) => {
loaded.poisoned.store(true, Ordering::Release);
return Err(InferenceError::Other(e));
}
Err(e) => {
loaded.poisoned.store(true, Ordering::Release);
return Err(InferenceError::Other(anyhow::anyhow!(
"inference task panicked: {e}"
)));
// Any failure inside the spawn_blocking touched CUDA via
// candle's forward / cache code, so we treat it as a
// device-poisoning event. The terminal log at the bottom
// of the wrapper reports the error; this flag stops the
// NEXT request from going down the same path.
match inference_result {
Ok(Ok(v)) => v,
Ok(Err(e)) => {
loaded.poisoned.store(true, Ordering::Release);
return Err(InferenceError::Other(e));
}
Err(e) => {
loaded.poisoned.store(true, Ordering::Release);
return Err(InferenceError::Other(anyhow::anyhow!(
"inference task panicked: {e}"
)));
}
}
} else {
// LoadedModel invariant: exactly one of `worker` /
// `arch` is Some. Reaching here is a construction bug.
return Err(InferenceError::Other(anyhow::anyhow!(
"LoadedModel has neither worker handle nor local arch — load-path bug"
)));
};
let completion_text = loaded
@@ -1244,7 +1314,6 @@ impl CandleHarness {
.token_to_id("<|im_end|>")
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
let arch_arc = Arc::clone(&loaded.arch);
let device = loaded.device.clone();
let tokenizer = loaded.tokenizer.clone();
let model_id = request.model.clone();
@@ -1315,40 +1384,96 @@ impl CandleHarness {
"chat_completion (stream): starting"
);
}
tokio::task::spawn_blocking(move || {
let _g = span_for_task.enter();
let mut guard = arch_arc.blocking_lock();
match run_inference_streaming(
&mut guard,
&device,
&tokenizer,
&prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
&id,
created,
&model_id,
&tx,
) {
Ok(()) => tracing::info!(
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): done"
),
Err(e) => {
loaded_for_task.poisoned.store(true, Ordering::Release);
tracing::error!(
error = %format!("{e:#}"),
// Routing parallel to the non-streaming chat_completion: CUDA
// goes through the worker (async task), CPU keeps the
// spawn_blocking + Arc<Mutex<ModelArch>> path.
if let (Some(worker), Some(handle)) = (loaded.worker.clone(), loaded.arch_handle) {
#[cfg(feature = "cuda")]
{
let prompt_tokens = prompt_tokens.clone();
tokio::spawn(
async move {
match stream_inference_via_worker(
worker,
handle,
tokenizer,
prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
id,
created,
model_id,
tx,
)
.await
{
Ok(_finish_reason) => tracing::info!(
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): done"
),
Err(e) => {
loaded_for_task.poisoned.store(true, Ordering::Release);
tracing::error!(
error = %format!("{e:#}"),
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): failed, model marked poisoned"
);
}
}
}
.instrument(span_for_task),
);
}
#[cfg(not(feature = "cuda"))]
{
let _ = (worker, handle, span_for_task);
unreachable!("worker handle present without cuda feature");
}
} else if let Some(arch_arc) = loaded.arch.clone() {
tokio::task::spawn_blocking(move || {
let _g = span_for_task.enter();
let mut guard = arch_arc.blocking_lock();
match run_inference_streaming(
&mut guard,
&device,
&tokenizer,
&prompt_tokens,
max_new,
temperature,
top_p,
seed,
eos_id,
&id,
created,
&model_id,
&tx,
) {
Ok(()) => tracing::info!(
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): failed, model marked poisoned"
);
"chat_completion (stream): done"
),
Err(e) => {
loaded_for_task.poisoned.store(true, Ordering::Release);
tracing::error!(
error = %format!("{e:#}"),
prompt_tokens = prompt_len,
total_ms = req_start.elapsed().as_millis(),
"chat_completion (stream): failed, model marked poisoned"
);
}
}
}
});
});
} else {
return Err(InferenceError::Other(anyhow::anyhow!(
"LoadedModel has neither worker handle nor local arch — load-path bug"
)));
}
Ok(rx)
}
@@ -1432,22 +1557,37 @@ impl Harness for CandleHarness {
// Worker thread for the chosen device. CPU loads (CUDA
// unavailable / not requested) skip the worker — there's no
// context to own.
let worker = match &device {
// context to own. For CUDA loads, the arch is transferred
// into the worker's slab now so the inference path can
// reference it via the returned `ArchHandle`. The explicit
// type annotation lets the no-cuda build resolve `None` to
// the right `Option<Arc<DeviceWorkerHandle>>` type.
let worker: Option<Arc<super::device_worker::DeviceWorkerHandle>> = match &device {
#[cfg(feature = "cuda")]
Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?),
_ => None,
};
let (arch_local, arch_handle) = match &worker {
Some(w) => {
let handle = w
.transfer_in(Box::new(arch))
.await
.map_err(|e| anyhow::anyhow!("transfer arch into device worker: {e}"))?;
(None, Some(handle))
}
None => (Some(Arc::new(Mutex::new(arch))), None),
};
let loaded = Arc::new(LoadedModel {
model_id: spec.model_id.clone(),
arch: Arc::new(Mutex::new(arch)),
arch: arch_local,
tokenizer,
device,
quant: spec.quant.clone(),
devices,
poisoned: AtomicBool::new(false),
worker,
arch_handle,
});
let mut models = self.models.write().await;
@@ -1465,13 +1605,26 @@ impl Harness for CandleHarness {
anyhow::bail!("model '{model_id}' not loaded");
};
// Single-GPU drops are immediate — the LoadedModel goes out of
// scope with the Arc and candle frees VRAM. TP unloads also
// need to tell every worker to drop its shard before the pool
// itself is dropped (otherwise the workers keep their shards
// around until Shutdown, which is wasteful and would surface
// as VRAM not freed promptly).
// scope with the Arc and candle frees VRAM. CUDA loads also
// ship a `Job::DropArch` to the device worker so the boxed
// `ModelArch` releases its CUDA allocations on the right
// thread (with the bound context); without that, the Drop
// would run on whatever tokio thread happens to be holding
// the last `Arc<LoadedModel>` clone when this fn returns.
// TP unloads further coordinate the subprocess pool below.
match handle {
LoadedHandle::Single(_) => {}
LoadedHandle::Single(single) => {
if let (Some(worker), Some(arch_handle)) =
(single.worker.as_ref(), single.arch_handle)
&& let Err(e) = worker.drop_arch(arch_handle).await
{
tracing::warn!(
model = %model_id,
error = %e,
"single-GPU unload: DropArch RPC failed (model state may leak in worker slab)"
);
}
}
#[cfg(feature = "cuda")]
LoadedHandle::Tp(tp) => {
// Try to recover the inner TpLoadedModel so we can move
@@ -2276,6 +2429,239 @@ fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
prompt
}
#[allow(clippy::too_many_arguments)]
/// Run the full single-GPU inference loop via the device worker.
///
/// Mirrors `run_inference`'s logic but routes each forward step
/// through `worker.forward_logits()` (returns CPU-side `Vec<f32>`)
/// and runs `apply_repeat_penalty` + sampling on a CPU candle tensor.
/// The device-resident logits tensor never escapes the worker thread.
///
/// Used by the CUDA path of `chat_completion`. The CPU path keeps
/// `run_inference` (spawn_blocking against `Arc<Mutex<ModelArch>>`)
/// because there's no CUDA context to own and the worker indirection
/// would only add channel overhead with no diagnostic benefit.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
async fn run_inference_via_worker(
worker: &super::device_worker::DeviceWorkerHandle,
handle: super::device_worker::ArchHandle,
prompt_tokens: &[u32],
max_new: usize,
temperature: f64,
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
) -> Result<(Vec<u32>, String)> {
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut generated: Vec<u32> = Vec::new();
let prompt_len = prompt_tokens.len();
worker
.clear_kv_cache(handle)
.await
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
// Prefill — every rank embeds the prompt with offset 0.
let logits_vec = worker
.forward_logits(handle, prompt_tokens.to_vec(), 0)
.await
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
?health,
"chat_completion (worker): prefill sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
for index in 0..max_new.saturating_sub(1) {
let logits_vec = worker
.forward_logits(handle, vec![next_token], prompt_len + index)
.await
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
step = index,
?health,
"chat_completion (worker): decode sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
return Ok((generated, "stop".into()));
}
generated.push(next_token);
}
Ok((generated, "length".into()))
}
/// Streaming counterpart of [`run_inference_via_worker`]. Emits one
/// `ChatCompletionChunk` per generated token via `tx`; routes every
/// forward step through `worker.forward_logits()`. Same per-step
/// CPU-side sampling discipline — no device tensor escapes the
/// worker thread.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
async fn stream_inference_via_worker(
worker: Arc<super::device_worker::DeviceWorkerHandle>,
handle: super::device_worker::ArchHandle,
tokenizer: Tokenizer,
prompt_tokens: Vec<u32>,
max_new: usize,
temperature: f64,
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
id: String,
created: u64,
model_id: String,
tx: mpsc::Sender<ChatCompletionChunk>,
) -> Result<String> {
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut all_tokens: Vec<u32> = Vec::new();
let mut decoded_prefix = String::new();
let prompt_len = prompt_tokens.len();
let mut finish_reason = "length".to_string();
worker
.clear_kv_cache(handle)
.await
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
let logits_vec = worker
.forward_logits(handle, prompt_tokens, 0)
.await
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
let mut next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
?health,
"chat_completion (stream/worker): prefill sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
} else {
all_tokens.push(next_token);
if !emit_chunk(
&all_tokens,
&mut decoded_prefix,
&tokenizer,
&tx,
&id,
created,
&model_id,
)
.await
{
return Ok(finish_reason); // Client gone — clean stream end.
}
for index in 0..max_new.saturating_sub(1) {
let logits_vec = worker
.forward_logits(handle, vec![next_token], prompt_len + index)
.await
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
step = index,
?health,
"chat_completion (stream/worker): decode sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
}
all_tokens.push(next_token);
if !emit_chunk(
&all_tokens,
&mut decoded_prefix,
&tokenizer,
&tx,
&id,
created,
&model_id,
)
.await
{
return Ok(finish_reason);
}
}
}
// Final chunk carrying finish_reason. Matches the run_inference_streaming
// shape so the SSE consumer sees an identical termination sequence.
let final_chunk = ChatCompletionChunk {
id: id.clone(),
object: "chat.completion.chunk".into(),
created,
model: model_id.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: serde_json::Value::Object(Default::default()),
finish_reason: Some(finish_reason.clone()),
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
};
let _ = tx.send(final_chunk).await;
Ok(finish_reason)
}
#[allow(clippy::too_many_arguments)]
fn run_inference(
arch: &mut ModelArch,