refactor(neuron): phase 1 — per-device worker thread, VRAM queries route through it
Some checks failed
CI / Format (push) Successful in 31s
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Clippy (push) Failing after 59s
build-prerelease / Build neuron-blackwell (push) Successful in 3m30s
CI / Test (push) Successful in 4m47s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m17s
build-prerelease / Package cortex RPM (push) Successful in 1m32s
build-prerelease / Build neuron-ampere (push) Successful in 5m16s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled

First slice of the per-device CUDA context-ownership refactor planned at
~/.claude/plans/plan-the-per-device-worker-abstract-micali.md. Adds the
infrastructure for a dedicated OS thread per CUDA device that owns the
device's `CudaContext` for the daemon's lifetime, and routes the 8
async-context `device_vram_mb()` call sites in candle.rs through it.

What this phase changes:

- New module `harness/device_worker/` (mod.rs, jobs.rs, dispatch.rs).
  `DeviceWorkerHandle::spawn(idx)` creates a named OS thread
  (`cuda-dev-N`), binds `CudaContext::new(idx)` once at startup, and
  enters a dispatch loop reading `Job`s off a `std::sync::mpsc` channel.
  Replies cross back via `tokio::sync::oneshot::Sender` so async callers
  await without parking a tokio worker.
- Two Job variants: `QueryVram` and `Shutdown`. Phases 2–4 add Forward,
  ClearKv, NCCL init/sanity, and load variants.
- `LoadedModel` and `TpLoadedModel` gain a `worker` field populated at
  load time by a new `CandleHarness::ensure_device_worker(idx)` method
  that lazily spawns + caches one worker per device index.
- Per-model `query_vram()` convenience method on both struct types so
  the 8 call sites in chat_completion / chat_completion_stream /
  chat_completion_tp_inner / chat_completion_tp_stream become
  `loaded.query_vram().await` (or `tp.query_vram().await`) — same field
  values logged, just sourced from the owner thread instead of the
  caller thread.

What this phase doesn't touch (yet):

- Forward, kv-cache clear, model load, NCCL — still on `spawn_blocking`.
  Phase 2 moves the single-GPU forward + clear; Phase 3 moves the TP
  forward + NCCL bring-up; Phase 4 moves the loads and deletes the now-
  unused `device_vram_mb` / `cuda_mem_mb` helpers.
- Public API — unchanged. `Harness::load_model`, `chat_completion`,
  HTTP routes all keep identical shapes.

Tests:

- 5 new unit tests in `device_worker/mod.rs::tests` cover spawn → query
  → shutdown round-trip, thread naming, post-shutdown submit returns
  `Gone`, poisoned flag fast-rejects, and concurrent jobs drain across
  a Shutdown. CPU build (the only one CI runs) is enough to exercise
  channel mechanics.
- All 37 lib tests + all integration tests pass; fmt + clippy clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-27 09:40:34 +03:00
parent 7c19da9361
commit 081b532387
5 changed files with 580 additions and 9 deletions

View File

@@ -42,6 +42,14 @@ pub struct CandleHarness {
models: Arc<RwLock<HashMap<String, LoadedHandle>>>,
hf_cache: Option<PathBuf>,
bind_url: String,
/// One worker thread per CUDA device index that owns its
/// `CudaContext` for the daemon's lifetime. Populated lazily by
/// `ensure_device_worker()` when a model is loaded onto a CUDA
/// device. CPU `Device::Cpu` loads don't get an entry; they have
/// no context to own. Unused on the no-cuda build (the harness
/// can still load on CPU for tests, just without worker threads).
#[allow(dead_code)]
device_workers: Arc<RwLock<HashMap<u32, Arc<super::device_worker::DeviceWorkerHandle>>>>,
}
/// One entry in the harness's loaded-model registry. Single-GPU loads
@@ -108,6 +116,28 @@ pub struct LoadedModel {
/// the 2026-05-26 beast incident where a 14k-token prefill OOM
/// silently turned every subsequent request into a stuck wait.
pub poisoned: AtomicBool,
/// Handle to the per-device CUDA worker thread for this model's
/// device. `None` for CPU loads (no context to own). VRAM queries
/// and — in later refactor phases — forward / kv-cache / unload
/// ops route through this handle so the device's CUDA context
/// stays bound to one OS thread for the daemon's lifetime.
pub worker: Option<Arc<super::device_worker::DeviceWorkerHandle>>,
}
impl LoadedModel {
/// Free / total VRAM on this model's device in MiB. Routes the
/// query through the device worker thread (where the CUDA context
/// is already bound) rather than rebinding on whatever tokio
/// thread the caller happens to be on. Returns `(0, 0)` on CPU
/// loads, or if the worker is gone / poisoned / the cudarc call
/// itself failed — same sentinel the previous `device_vram_mb`
/// helper returned, so log field values stay comparable.
pub async fn query_vram(&self) -> (u64, u64) {
match &self.worker {
Some(w) => w.query_vram().await.unwrap_or((0, 0)),
None => (0, 0),
}
}
}
/// Tensor-parallel loaded model. Holds the leader's rank-0 shard
@@ -137,6 +167,22 @@ pub struct TpLoadedModel {
/// terminal: the leader's and workers' CUDA contexts cannot be
/// reliably reset without restarting the worker subprocesses.
pub poisoned: AtomicBool,
/// Worker thread for the leader's CUDA device. Owns the leader's
/// `CudaContext` for the daemon's lifetime. VRAM queries route
/// through it; in later refactor phases the forward, kv-cache
/// clear, and shard unload route through it too.
pub worker: Arc<super::device_worker::DeviceWorkerHandle>,
}
#[cfg(feature = "cuda")]
impl TpLoadedModel {
/// Free / total VRAM on the leader's device in MiB. See
/// [`LoadedModel::query_vram`] for rationale and sentinel
/// semantics — same pattern, TP just always has a worker because
/// the harness rejects TP without CUDA at load time.
pub async fn query_vram(&self) -> (u64, u64) {
self.worker.query_vram().await.unwrap_or((0, 0))
}
}
/// Architecture-specific weights. Each variant covers one (family,
@@ -604,6 +650,7 @@ impl CandleHarness {
models: Arc::new(RwLock::new(HashMap::new())),
hf_cache,
bind_url,
device_workers: Arc::new(RwLock::new(HashMap::new())),
}
}
@@ -625,6 +672,39 @@ impl CandleHarness {
Ok(Device::Cpu)
}
/// Return the worker handle for `device_index`, spawning it on
/// first request. The handle is cached on `self` so subsequent
/// loads against the same device share the same thread. Used to
/// populate `LoadedModel::worker` and `TpLoadedModel::worker` at
/// load time; in later refactor phases the worker also owns the
/// `ModelArch` and `TpLeaderModel` slabs.
#[allow(dead_code)]
async fn ensure_device_worker(
&self,
device_index: u32,
) -> Result<Arc<super::device_worker::DeviceWorkerHandle>> {
{
let workers = self.device_workers.read().await;
if let Some(w) = workers.get(&device_index) {
return Ok(Arc::clone(w));
}
}
// Write-lock acquired separately so the read path stays cheap.
// The `get` is repeated under the write lock to handle the
// race where two loads against a fresh device land here at
// once — the second caller sees the first's insertion and
// skips the second spawn.
let mut workers = self.device_workers.write().await;
if let Some(w) = workers.get(&device_index) {
return Ok(Arc::clone(w));
}
let handle = super::device_worker::DeviceWorkerHandle::spawn(device_index)
.with_context(|| format!("spawn device worker for cuda:{device_index}"))?;
workers.insert(device_index, Arc::clone(&handle));
tracing::info!(device_index, "spawned device worker");
Ok(handle)
}
/// Build an hf-hub API client pre-configured with the harness's
/// `hf_cache` (when one is set).
fn hf_api(&self) -> Result<hf_hub::api::tokio::Api> {
@@ -1012,7 +1092,7 @@ impl CandleHarness {
.token_to_id("<|im_end|>")
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
let (vram_free_mb, vram_total_mb) = device_vram_mb(&loaded.device);
let (vram_free_mb, vram_total_mb) = loaded.query_vram().await;
tracing::info!(
prompt_len,
max_new,
@@ -1217,9 +1297,13 @@ impl CandleHarness {
let loaded_for_task = Arc::clone(&loaded);
let span_for_starting = span.clone();
let span_for_task = span.clone();
// Query VRAM before entering the span so we don't await inside
// an entered guard (Span::enter creates a synchronous guard
// that can't span await points). The span gets entered in a
// separate scope below purely for the log emission.
let (vram_free_mb, vram_total_mb) = loaded.query_vram().await;
{
let _g = span_for_starting.enter();
let (vram_free_mb, vram_total_mb) = device_vram_mb(&loaded.device);
tracing::info!(
prompt_len,
max_new,
@@ -1346,6 +1430,15 @@ impl Harness for CandleHarness {
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
// Worker thread for the chosen device. CPU loads (CUDA
// unavailable / not requested) skip the worker — there's no
// context to own.
let worker = match &device {
#[cfg(feature = "cuda")]
Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?),
_ => None,
};
let loaded = Arc::new(LoadedModel {
model_id: spec.model_id.clone(),
arch: Arc::new(Mutex::new(arch)),
@@ -1354,6 +1447,7 @@ impl Harness for CandleHarness {
quant: spec.quant.clone(),
devices,
poisoned: AtomicBool::new(false),
worker,
});
let mut models = self.models.write().await;
@@ -1496,6 +1590,12 @@ impl CandleHarness {
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
// 7. Worker thread for the leader's CUDA device. TP always
// runs on CUDA — the harness rejects TP without the cuda
// feature earlier in this function — so we always have a
// device to own.
let worker = self.ensure_device_worker(devices[0]).await?;
let tp_loaded = StdArc::new(TpLoadedModel {
model_id: spec.model_id.clone(),
tokenizer,
@@ -1504,6 +1604,7 @@ impl CandleHarness {
leader_model,
leader_device: leader_device.clone(),
poisoned: AtomicBool::new(false),
worker,
});
let mut models = self.models.write().await;
@@ -1661,7 +1762,7 @@ impl CandleHarness {
model = %model_id
);
let req_start = std::time::Instant::now();
let (vram_free_mb, vram_total_mb) = device_vram_mb(&tp.leader_device);
let (vram_free_mb, vram_total_mb) = tp.query_vram().await;
tracing::info!(
parent: &span,
prompt_len,
@@ -1713,8 +1814,7 @@ impl CandleHarness {
break 'work;
}
};
let (post_prefill_vram_free_mb, _) =
device_vram_mb(&tp_for_task.leader_device);
let (post_prefill_vram_free_mb, _) = tp_for_task.query_vram().await;
tracing::info!(
model = %model_id,
prompt_len,
@@ -1790,11 +1890,19 @@ impl CandleHarness {
break 'work;
}
};
// Always await the query (even when the
// trace! is filtered out by RUST_LOG): the
// channel hop is ~tens of µs, comparable to
// the previous in-line bind+query cost, and
// making the call conditional adds complexity
// for negligible win. Revisit if it shows up
// in a hot-path profile.
let step_vram_free_mb = tp_for_task.query_vram().await.0;
tracing::trace!(
model = %model_id,
step = index,
next_token,
vram_free_mb = device_vram_mb(&tp_for_task.leader_device).0,
vram_free_mb = step_vram_free_mb,
"TP chat_completion (stream): decode step"
);
if Some(next_token) == eos_id {
@@ -1906,7 +2014,7 @@ async fn chat_completion_tp_inner(
.token_to_id("<|im_end|>")
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
let (vram_free_mb, vram_total_mb) = device_vram_mb(&tp.leader_device);
let (vram_free_mb, vram_total_mb) = tp.query_vram().await;
tracing::info!(
model = %model_id,
prompt_len,
@@ -1962,7 +2070,7 @@ async fn chat_completion_tp_inner(
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
.await
.map_err(InferenceError::Other)?;
let (post_prefill_vram_free_mb, _) = device_vram_mb(&tp.leader_device);
let (post_prefill_vram_free_mb, _) = tp.query_vram().await;
tracing::info!(
model = %model_id,
prompt_len,
@@ -2017,7 +2125,7 @@ async fn chat_completion_tp_inner(
return Err(InferenceError::Other(e));
}
};
let step_vram_free_mb = device_vram_mb(&tp.leader_device).0;
let step_vram_free_mb = tp.query_vram().await.0;
tracing::trace!(
model = %model_id,
step = index,