refactor(neuron): phase 2 — single-GPU forward + clear_kv route through device worker
Some checks failed
build-prerelease / Package helexa-neuron-ada RPM (push) Blocked by required conditions
CI / Format (push) Successful in 34s
CI / Clippy (push) Successful in 2m12s
build-prerelease / Resolve version stamps (push) Successful in 3m41s
CI / Test (push) Successful in 5m1s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m32s
build-prerelease / Build neuron-ampere (push) Successful in 5m20s
build-prerelease / Build cortex binary (push) Successful in 12m20s
build-prerelease / Build neuron-ada (push) Successful in 5m17s
build-prerelease / Package cortex RPM (push) Successful in 1m25s
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
Some checks failed
build-prerelease / Package helexa-neuron-ada RPM (push) Blocked by required conditions
CI / Format (push) Successful in 34s
CI / Clippy (push) Successful in 2m12s
build-prerelease / Resolve version stamps (push) Successful in 3m41s
CI / Test (push) Successful in 5m1s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m32s
build-prerelease / Build neuron-ampere (push) Successful in 5m20s
build-prerelease / Build cortex binary (push) Successful in 12m20s
build-prerelease / Build neuron-ada (push) Successful in 5m17s
build-prerelease / Package cortex RPM (push) Successful in 1m25s
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
Second slice of the per-device CUDA context-ownership refactor planned at
~/.claude/plans/plan-the-per-device-worker-abstract-micali.md. The two
spawn_blocking sites in `chat_completion` and `chat_completion_stream`
now route through the device worker thread on CUDA loads. CPU loads
keep the existing spawn_blocking + `Arc<Mutex<ModelArch>>` path; there's
no context to own and the channel hop would only add latency.
What this phase changes:
- `Job` gains `TransferIn`, `DropArch`, `ClearKv`, `ForwardLogits`. The
worker's dispatch state grows a `HashMap<ArchHandle, Box<ModelArch>>`
slab and a `next_handle` counter for minting opaque handles.
- `LoadedModel.arch: Arc<Mutex<ModelArch>>` → `Option<Arc<Mutex<>>>`,
plus a new `arch_handle: Option<ArchHandle>` field. The two are
mutually exclusive: CUDA loads set `arch_handle = Some(_)` after
transferring the boxed arch into the worker's slab; CPU loads keep
`arch = Some(_)` for the legacy spawn_blocking path.
- New `run_inference_via_worker` and `stream_inference_via_worker`
drive the prefill + decode loop by sending `Job::ForwardLogits` per
step; the worker copies the resulting `[vocab]` logits to a
CPU-side `Vec<f32>` before reply, so the async caller never holds a
device-resident tensor. `apply_repeat_penalty` and
`LogitsProcessor::sample` run on a CPU candle tensor; no context
binding side-effects on tokio worker threads.
- `logits_health_slice(&[f32])` complements the existing
`logits_health(&Tensor)` so the new worker paths can compute
health stats directly from the CPU vec.
- `unload_model` for the single-GPU CUDA path now sends
`Job::DropArch { handle }` to the worker so the `Box<ModelArch>`
drops on the thread that allocated its CUDA tensors. The `Drop` runs
with the bound context, freeing memory on the right context.
What this phase doesn't touch (yet):
- TP forward, TP load, NCCL bring-up — still on spawn_blocking. Phase 3.
- Single-GPU model load — still spawn_blocking, followed by a
`Job::TransferIn` to move the freshly-built `ModelArch` into the
worker slab. Phase 4 moves the load itself onto the worker thread
and eliminates the bootstrap TransferIn.
- The `device_vram_mb` / `cuda_mem_mb` helpers — still present and
used by the construction-time logs running inside spawn_blocking
loads. Phase 4 cleanup folds them into `dispatch.rs`.
Public API unchanged. fmt + clippy clean; 37 lib tests + all
integration tests pass.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -102,7 +102,14 @@ impl LoadedHandle {
|
|||||||
/// moved into `spawn_blocking` for synchronous candle forward passes.
|
/// moved into `spawn_blocking` for synchronous candle forward passes.
|
||||||
pub struct LoadedModel {
|
pub struct LoadedModel {
|
||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
pub arch: Arc<Mutex<ModelArch>>,
|
/// Local (async-side) handle to the model architecture. `Some`
|
||||||
|
/// only when the model loaded onto the CPU device (no CUDA
|
||||||
|
/// available); the inference path then takes this mutex via
|
||||||
|
/// `spawn_blocking` and runs candle ops on the CPU backend.
|
||||||
|
/// `None` when the model loaded onto a CUDA device — in that case
|
||||||
|
/// the architecture lives in the worker thread's slab and is
|
||||||
|
/// addressed via [`Self::arch_handle`].
|
||||||
|
pub arch: Option<Arc<Mutex<ModelArch>>>,
|
||||||
pub tokenizer: Tokenizer,
|
pub tokenizer: Tokenizer,
|
||||||
pub device: Device,
|
pub device: Device,
|
||||||
pub quant: Option<String>,
|
pub quant: Option<String>,
|
||||||
@@ -118,10 +125,15 @@ pub struct LoadedModel {
|
|||||||
pub poisoned: AtomicBool,
|
pub poisoned: AtomicBool,
|
||||||
/// Handle to the per-device CUDA worker thread for this model's
|
/// Handle to the per-device CUDA worker thread for this model's
|
||||||
/// device. `None` for CPU loads (no context to own). VRAM queries
|
/// device. `None` for CPU loads (no context to own). VRAM queries
|
||||||
/// and — in later refactor phases — forward / kv-cache / unload
|
/// and — for CUDA loads — forward / kv-cache / drop ops route
|
||||||
/// ops route through this handle so the device's CUDA context
|
/// through this handle so the device's CUDA context stays bound
|
||||||
/// stays bound to one OS thread for the daemon's lifetime.
|
/// to one OS thread for the daemon's lifetime.
|
||||||
pub worker: Option<Arc<super::device_worker::DeviceWorkerHandle>>,
|
pub worker: Option<Arc<super::device_worker::DeviceWorkerHandle>>,
|
||||||
|
/// Index into the worker's `ModelArch` slab. `Some` iff the model
|
||||||
|
/// loaded onto a CUDA device and was successfully transferred to
|
||||||
|
/// the worker; in that case [`Self::arch`] is `None`. The two
|
||||||
|
/// fields are mutually exclusive.
|
||||||
|
pub arch_handle: Option<super::device_worker::ArchHandle>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LoadedModel {
|
impl LoadedModel {
|
||||||
@@ -475,6 +487,16 @@ fn logits_health(t: &Tensor) -> LogitsHealth {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
logits_health_slice(&values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Same diagnostic as [`logits_health`] but operates directly on a
|
||||||
|
/// `[f32]` slice. Used by the worker-routed inference paths where the
|
||||||
|
/// device → host copy has already happened on the worker thread and
|
||||||
|
/// the async caller has the values in hand. Avoids the round-trip of
|
||||||
|
/// rebuilding a Tensor just to call to_vec1 again.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn logits_health_slice(values: &[f32]) -> LogitsHealth {
|
||||||
let mut nan = 0usize;
|
let mut nan = 0usize;
|
||||||
let mut pos_inf = 0usize;
|
let mut pos_inf = 0usize;
|
||||||
let mut neg_inf = 0usize;
|
let mut neg_inf = 0usize;
|
||||||
@@ -483,7 +505,7 @@ fn logits_health(t: &Tensor) -> LogitsHealth {
|
|||||||
let mut finite_max = f32::NEG_INFINITY;
|
let mut finite_max = f32::NEG_INFINITY;
|
||||||
let mut finite_sum = 0.0_f64;
|
let mut finite_sum = 0.0_f64;
|
||||||
let mut finite_count = 0usize;
|
let mut finite_count = 0usize;
|
||||||
for &v in &values {
|
for &v in values {
|
||||||
if v.is_nan() {
|
if v.is_nan() {
|
||||||
nan += 1;
|
nan += 1;
|
||||||
} else if v == f32::INFINITY {
|
} else if v == f32::INFINITY {
|
||||||
@@ -1104,15 +1126,22 @@ impl CandleHarness {
|
|||||||
"chat_completion: starting"
|
"chat_completion: starting"
|
||||||
);
|
);
|
||||||
|
|
||||||
let arch_arc = Arc::clone(&loaded.arch);
|
// Routing: CUDA loads go through the per-device worker
|
||||||
let device = loaded.device.clone();
|
// thread (introduced in Phase 1; forward/clear added in
|
||||||
|
// Phase 2). CPU loads keep the existing spawn_blocking
|
||||||
let inference_result =
|
// path because there's no context to own and the channel
|
||||||
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
// round-trip would only add latency. The two arms produce
|
||||||
let mut guard = arch_arc.blocking_lock();
|
// the same `(Vec<u32>, String)` shape so the rest of the
|
||||||
run_inference(
|
// path is shared.
|
||||||
&mut guard,
|
let (generated_ids, finish_reason) = if let (Some(worker), Some(handle)) =
|
||||||
&device,
|
(loaded.worker.as_ref(), loaded.arch_handle)
|
||||||
|
{
|
||||||
|
// Worker path (CUDA).
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
match run_inference_via_worker(
|
||||||
|
worker,
|
||||||
|
handle,
|
||||||
&prompt_tokens,
|
&prompt_tokens,
|
||||||
max_new,
|
max_new,
|
||||||
temperature,
|
temperature,
|
||||||
@@ -1120,26 +1149,67 @@ impl CandleHarness {
|
|||||||
seed,
|
seed,
|
||||||
eos_id,
|
eos_id,
|
||||||
)
|
)
|
||||||
})
|
.await
|
||||||
.await;
|
{
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
loaded.poisoned.store(true, Ordering::Release);
|
||||||
|
return Err(InferenceError::Other(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
{
|
||||||
|
// Can't happen: `loaded.worker` is only Some on
|
||||||
|
// CUDA builds. The dead branch keeps the no-cuda
|
||||||
|
// build well-typed.
|
||||||
|
let _ = (worker, handle);
|
||||||
|
unreachable!("worker handle present without cuda feature");
|
||||||
|
}
|
||||||
|
} else if let Some(arch_arc) = loaded.arch.clone() {
|
||||||
|
// CPU path: existing spawn_blocking on the local
|
||||||
|
// Arc<Mutex<ModelArch>>.
|
||||||
|
let device = loaded.device.clone();
|
||||||
|
let inference_result =
|
||||||
|
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
||||||
|
let mut guard = arch_arc.blocking_lock();
|
||||||
|
run_inference(
|
||||||
|
&mut guard,
|
||||||
|
&device,
|
||||||
|
&prompt_tokens,
|
||||||
|
max_new,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
seed,
|
||||||
|
eos_id,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
// Any failure inside the spawn_blocking touched CUDA via
|
// Any failure inside the spawn_blocking touched CUDA via
|
||||||
// candle's forward / cache code, so we treat it as a
|
// candle's forward / cache code, so we treat it as a
|
||||||
// device-poisoning event. The terminal log at the bottom
|
// device-poisoning event. The terminal log at the bottom
|
||||||
// of the wrapper reports the error; this flag stops the
|
// of the wrapper reports the error; this flag stops the
|
||||||
// NEXT request from going down the same path.
|
// NEXT request from going down the same path.
|
||||||
let (generated_ids, finish_reason) = match inference_result {
|
match inference_result {
|
||||||
Ok(Ok(v)) => v,
|
Ok(Ok(v)) => v,
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
loaded.poisoned.store(true, Ordering::Release);
|
loaded.poisoned.store(true, Ordering::Release);
|
||||||
return Err(InferenceError::Other(e));
|
return Err(InferenceError::Other(e));
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
loaded.poisoned.store(true, Ordering::Release);
|
loaded.poisoned.store(true, Ordering::Release);
|
||||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||||
"inference task panicked: {e}"
|
"inference task panicked: {e}"
|
||||||
)));
|
)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// LoadedModel invariant: exactly one of `worker` /
|
||||||
|
// `arch` is Some. Reaching here is a construction bug.
|
||||||
|
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||||
|
"LoadedModel has neither worker handle nor local arch — load-path bug"
|
||||||
|
)));
|
||||||
};
|
};
|
||||||
|
|
||||||
let completion_text = loaded
|
let completion_text = loaded
|
||||||
@@ -1244,7 +1314,6 @@ impl CandleHarness {
|
|||||||
.token_to_id("<|im_end|>")
|
.token_to_id("<|im_end|>")
|
||||||
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
|
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
|
||||||
|
|
||||||
let arch_arc = Arc::clone(&loaded.arch);
|
|
||||||
let device = loaded.device.clone();
|
let device = loaded.device.clone();
|
||||||
let tokenizer = loaded.tokenizer.clone();
|
let tokenizer = loaded.tokenizer.clone();
|
||||||
let model_id = request.model.clone();
|
let model_id = request.model.clone();
|
||||||
@@ -1315,40 +1384,96 @@ impl CandleHarness {
|
|||||||
"chat_completion (stream): starting"
|
"chat_completion (stream): starting"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
tokio::task::spawn_blocking(move || {
|
// Routing parallel to the non-streaming chat_completion: CUDA
|
||||||
let _g = span_for_task.enter();
|
// goes through the worker (async task), CPU keeps the
|
||||||
let mut guard = arch_arc.blocking_lock();
|
// spawn_blocking + Arc<Mutex<ModelArch>> path.
|
||||||
match run_inference_streaming(
|
if let (Some(worker), Some(handle)) = (loaded.worker.clone(), loaded.arch_handle) {
|
||||||
&mut guard,
|
#[cfg(feature = "cuda")]
|
||||||
&device,
|
{
|
||||||
&tokenizer,
|
let prompt_tokens = prompt_tokens.clone();
|
||||||
&prompt_tokens,
|
tokio::spawn(
|
||||||
max_new,
|
async move {
|
||||||
temperature,
|
match stream_inference_via_worker(
|
||||||
top_p,
|
worker,
|
||||||
seed,
|
handle,
|
||||||
eos_id,
|
tokenizer,
|
||||||
&id,
|
prompt_tokens,
|
||||||
created,
|
max_new,
|
||||||
&model_id,
|
temperature,
|
||||||
&tx,
|
top_p,
|
||||||
) {
|
seed,
|
||||||
Ok(()) => tracing::info!(
|
eos_id,
|
||||||
prompt_tokens = prompt_len,
|
id,
|
||||||
total_ms = req_start.elapsed().as_millis(),
|
created,
|
||||||
"chat_completion (stream): done"
|
model_id,
|
||||||
),
|
tx,
|
||||||
Err(e) => {
|
)
|
||||||
loaded_for_task.poisoned.store(true, Ordering::Release);
|
.await
|
||||||
tracing::error!(
|
{
|
||||||
error = %format!("{e:#}"),
|
Ok(_finish_reason) => tracing::info!(
|
||||||
|
prompt_tokens = prompt_len,
|
||||||
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
|
"chat_completion (stream): done"
|
||||||
|
),
|
||||||
|
Err(e) => {
|
||||||
|
loaded_for_task.poisoned.store(true, Ordering::Release);
|
||||||
|
tracing::error!(
|
||||||
|
error = %format!("{e:#}"),
|
||||||
|
prompt_tokens = prompt_len,
|
||||||
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
|
"chat_completion (stream): failed, model marked poisoned"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.instrument(span_for_task),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
{
|
||||||
|
let _ = (worker, handle, span_for_task);
|
||||||
|
unreachable!("worker handle present without cuda feature");
|
||||||
|
}
|
||||||
|
} else if let Some(arch_arc) = loaded.arch.clone() {
|
||||||
|
tokio::task::spawn_blocking(move || {
|
||||||
|
let _g = span_for_task.enter();
|
||||||
|
let mut guard = arch_arc.blocking_lock();
|
||||||
|
match run_inference_streaming(
|
||||||
|
&mut guard,
|
||||||
|
&device,
|
||||||
|
&tokenizer,
|
||||||
|
&prompt_tokens,
|
||||||
|
max_new,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
seed,
|
||||||
|
eos_id,
|
||||||
|
&id,
|
||||||
|
created,
|
||||||
|
&model_id,
|
||||||
|
&tx,
|
||||||
|
) {
|
||||||
|
Ok(()) => tracing::info!(
|
||||||
prompt_tokens = prompt_len,
|
prompt_tokens = prompt_len,
|
||||||
total_ms = req_start.elapsed().as_millis(),
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
"chat_completion (stream): failed, model marked poisoned"
|
"chat_completion (stream): done"
|
||||||
);
|
),
|
||||||
|
Err(e) => {
|
||||||
|
loaded_for_task.poisoned.store(true, Ordering::Release);
|
||||||
|
tracing::error!(
|
||||||
|
error = %format!("{e:#}"),
|
||||||
|
prompt_tokens = prompt_len,
|
||||||
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
|
"chat_completion (stream): failed, model marked poisoned"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
});
|
} else {
|
||||||
|
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||||
|
"LoadedModel has neither worker handle nor local arch — load-path bug"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(rx)
|
Ok(rx)
|
||||||
}
|
}
|
||||||
@@ -1432,22 +1557,37 @@ impl Harness for CandleHarness {
|
|||||||
|
|
||||||
// Worker thread for the chosen device. CPU loads (CUDA
|
// Worker thread for the chosen device. CPU loads (CUDA
|
||||||
// unavailable / not requested) skip the worker — there's no
|
// unavailable / not requested) skip the worker — there's no
|
||||||
// context to own.
|
// context to own. For CUDA loads, the arch is transferred
|
||||||
let worker = match &device {
|
// into the worker's slab now so the inference path can
|
||||||
|
// reference it via the returned `ArchHandle`. The explicit
|
||||||
|
// type annotation lets the no-cuda build resolve `None` to
|
||||||
|
// the right `Option<Arc<DeviceWorkerHandle>>` type.
|
||||||
|
let worker: Option<Arc<super::device_worker::DeviceWorkerHandle>> = match &device {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?),
|
Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?),
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
|
let (arch_local, arch_handle) = match &worker {
|
||||||
|
Some(w) => {
|
||||||
|
let handle = w
|
||||||
|
.transfer_in(Box::new(arch))
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("transfer arch into device worker: {e}"))?;
|
||||||
|
(None, Some(handle))
|
||||||
|
}
|
||||||
|
None => (Some(Arc::new(Mutex::new(arch))), None),
|
||||||
|
};
|
||||||
|
|
||||||
let loaded = Arc::new(LoadedModel {
|
let loaded = Arc::new(LoadedModel {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: spec.model_id.clone(),
|
||||||
arch: Arc::new(Mutex::new(arch)),
|
arch: arch_local,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
device,
|
device,
|
||||||
quant: spec.quant.clone(),
|
quant: spec.quant.clone(),
|
||||||
devices,
|
devices,
|
||||||
poisoned: AtomicBool::new(false),
|
poisoned: AtomicBool::new(false),
|
||||||
worker,
|
worker,
|
||||||
|
arch_handle,
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut models = self.models.write().await;
|
let mut models = self.models.write().await;
|
||||||
@@ -1465,13 +1605,26 @@ impl Harness for CandleHarness {
|
|||||||
anyhow::bail!("model '{model_id}' not loaded");
|
anyhow::bail!("model '{model_id}' not loaded");
|
||||||
};
|
};
|
||||||
// Single-GPU drops are immediate — the LoadedModel goes out of
|
// Single-GPU drops are immediate — the LoadedModel goes out of
|
||||||
// scope with the Arc and candle frees VRAM. TP unloads also
|
// scope with the Arc and candle frees VRAM. CUDA loads also
|
||||||
// need to tell every worker to drop its shard before the pool
|
// ship a `Job::DropArch` to the device worker so the boxed
|
||||||
// itself is dropped (otherwise the workers keep their shards
|
// `ModelArch` releases its CUDA allocations on the right
|
||||||
// around until Shutdown, which is wasteful and would surface
|
// thread (with the bound context); without that, the Drop
|
||||||
// as VRAM not freed promptly).
|
// would run on whatever tokio thread happens to be holding
|
||||||
|
// the last `Arc<LoadedModel>` clone when this fn returns.
|
||||||
|
// TP unloads further coordinate the subprocess pool below.
|
||||||
match handle {
|
match handle {
|
||||||
LoadedHandle::Single(_) => {}
|
LoadedHandle::Single(single) => {
|
||||||
|
if let (Some(worker), Some(arch_handle)) =
|
||||||
|
(single.worker.as_ref(), single.arch_handle)
|
||||||
|
&& let Err(e) = worker.drop_arch(arch_handle).await
|
||||||
|
{
|
||||||
|
tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"single-GPU unload: DropArch RPC failed (model state may leak in worker slab)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
LoadedHandle::Tp(tp) => {
|
LoadedHandle::Tp(tp) => {
|
||||||
// Try to recover the inner TpLoadedModel so we can move
|
// Try to recover the inner TpLoadedModel so we can move
|
||||||
@@ -2276,6 +2429,239 @@ fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
|
|||||||
prompt
|
prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
/// Run the full single-GPU inference loop via the device worker.
|
||||||
|
///
|
||||||
|
/// Mirrors `run_inference`'s logic but routes each forward step
|
||||||
|
/// through `worker.forward_logits()` (returns CPU-side `Vec<f32>`)
|
||||||
|
/// and runs `apply_repeat_penalty` + sampling on a CPU candle tensor.
|
||||||
|
/// The device-resident logits tensor never escapes the worker thread.
|
||||||
|
///
|
||||||
|
/// Used by the CUDA path of `chat_completion`. The CPU path keeps
|
||||||
|
/// `run_inference` (spawn_blocking against `Arc<Mutex<ModelArch>>`)
|
||||||
|
/// because there's no CUDA context to own and the worker indirection
|
||||||
|
/// would only add channel overhead with no diagnostic benefit.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
async fn run_inference_via_worker(
|
||||||
|
worker: &super::device_worker::DeviceWorkerHandle,
|
||||||
|
handle: super::device_worker::ArchHandle,
|
||||||
|
prompt_tokens: &[u32],
|
||||||
|
max_new: usize,
|
||||||
|
temperature: f64,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
seed: u64,
|
||||||
|
eos_id: Option<u32>,
|
||||||
|
) -> Result<(Vec<u32>, String)> {
|
||||||
|
let mut logits_processor = {
|
||||||
|
let sampling = if temperature <= 0.0 {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match top_p {
|
||||||
|
Some(p) => Sampling::TopP { p, temperature },
|
||||||
|
None => Sampling::All { temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut generated: Vec<u32> = Vec::new();
|
||||||
|
let prompt_len = prompt_tokens.len();
|
||||||
|
|
||||||
|
worker
|
||||||
|
.clear_kv_cache(handle)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
|
||||||
|
|
||||||
|
// Prefill — every rank embeds the prompt with offset 0.
|
||||||
|
let logits_vec = worker
|
||||||
|
.forward_logits(handle, prompt_tokens.to_vec(), 0)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
|
||||||
|
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||||
|
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
let health = logits_health_slice(&logits_vec);
|
||||||
|
tracing::warn!(
|
||||||
|
?health,
|
||||||
|
"chat_completion (worker): prefill sample failed; logits unhealthy"
|
||||||
|
);
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
return Ok((generated, "stop".into()));
|
||||||
|
}
|
||||||
|
generated.push(next_token);
|
||||||
|
|
||||||
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
|
let logits_vec = worker
|
||||||
|
.forward_logits(handle, vec![next_token], prompt_len + index)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
|
||||||
|
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||||
|
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
let health = logits_health_slice(&logits_vec);
|
||||||
|
tracing::warn!(
|
||||||
|
step = index,
|
||||||
|
?health,
|
||||||
|
"chat_completion (worker): decode sample failed; logits unhealthy"
|
||||||
|
);
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
return Ok((generated, "stop".into()));
|
||||||
|
}
|
||||||
|
generated.push(next_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((generated, "length".into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Streaming counterpart of [`run_inference_via_worker`]. Emits one
|
||||||
|
/// `ChatCompletionChunk` per generated token via `tx`; routes every
|
||||||
|
/// forward step through `worker.forward_logits()`. Same per-step
|
||||||
|
/// CPU-side sampling discipline — no device tensor escapes the
|
||||||
|
/// worker thread.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
async fn stream_inference_via_worker(
|
||||||
|
worker: Arc<super::device_worker::DeviceWorkerHandle>,
|
||||||
|
handle: super::device_worker::ArchHandle,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
prompt_tokens: Vec<u32>,
|
||||||
|
max_new: usize,
|
||||||
|
temperature: f64,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
seed: u64,
|
||||||
|
eos_id: Option<u32>,
|
||||||
|
id: String,
|
||||||
|
created: u64,
|
||||||
|
model_id: String,
|
||||||
|
tx: mpsc::Sender<ChatCompletionChunk>,
|
||||||
|
) -> Result<String> {
|
||||||
|
let mut logits_processor = {
|
||||||
|
let sampling = if temperature <= 0.0 {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match top_p {
|
||||||
|
Some(p) => Sampling::TopP { p, temperature },
|
||||||
|
None => Sampling::All { temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut all_tokens: Vec<u32> = Vec::new();
|
||||||
|
let mut decoded_prefix = String::new();
|
||||||
|
let prompt_len = prompt_tokens.len();
|
||||||
|
let mut finish_reason = "length".to_string();
|
||||||
|
|
||||||
|
worker
|
||||||
|
.clear_kv_cache(handle)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
|
||||||
|
|
||||||
|
let logits_vec = worker
|
||||||
|
.forward_logits(handle, prompt_tokens, 0)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
|
||||||
|
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||||
|
let mut next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
let health = logits_health_slice(&logits_vec);
|
||||||
|
tracing::warn!(
|
||||||
|
?health,
|
||||||
|
"chat_completion (stream/worker): prefill sample failed; logits unhealthy"
|
||||||
|
);
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
finish_reason = "stop".into();
|
||||||
|
} else {
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if !emit_chunk(
|
||||||
|
&all_tokens,
|
||||||
|
&mut decoded_prefix,
|
||||||
|
&tokenizer,
|
||||||
|
&tx,
|
||||||
|
&id,
|
||||||
|
created,
|
||||||
|
&model_id,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
return Ok(finish_reason); // Client gone — clean stream end.
|
||||||
|
}
|
||||||
|
|
||||||
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
|
let logits_vec = worker
|
||||||
|
.forward_logits(handle, vec![next_token], prompt_len + index)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
|
||||||
|
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||||
|
next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
let health = logits_health_slice(&logits_vec);
|
||||||
|
tracing::warn!(
|
||||||
|
step = index,
|
||||||
|
?health,
|
||||||
|
"chat_completion (stream/worker): decode sample failed; logits unhealthy"
|
||||||
|
);
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
finish_reason = "stop".into();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if !emit_chunk(
|
||||||
|
&all_tokens,
|
||||||
|
&mut decoded_prefix,
|
||||||
|
&tokenizer,
|
||||||
|
&tx,
|
||||||
|
&id,
|
||||||
|
created,
|
||||||
|
&model_id,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
return Ok(finish_reason);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final chunk carrying finish_reason. Matches the run_inference_streaming
|
||||||
|
// shape so the SSE consumer sees an identical termination sequence.
|
||||||
|
let final_chunk = ChatCompletionChunk {
|
||||||
|
id: id.clone(),
|
||||||
|
object: "chat.completion.chunk".into(),
|
||||||
|
created,
|
||||||
|
model: model_id.clone(),
|
||||||
|
choices: vec![ChunkChoice {
|
||||||
|
index: 0,
|
||||||
|
delta: serde_json::Value::Object(Default::default()),
|
||||||
|
finish_reason: Some(finish_reason.clone()),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
};
|
||||||
|
let _ = tx.send(final_chunk).await;
|
||||||
|
|
||||||
|
Ok(finish_reason)
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn run_inference(
|
fn run_inference(
|
||||||
arch: &mut ModelArch,
|
arch: &mut ModelArch,
|
||||||
|
|||||||
@@ -7,20 +7,41 @@
|
|||||||
//! one-line `oneshot::Sender::send` call to ship the reply back, which
|
//! one-line `oneshot::Sender::send` call to ship the reply back, which
|
||||||
//! is non-blocking.
|
//! is non-blocking.
|
||||||
//!
|
//!
|
||||||
//! Phase 1 handles only `QueryVram` and `Shutdown`. Later phases add
|
//! Phase 2 handles QueryVram, TransferIn, DropArch, ClearKv,
|
||||||
//! Forward, ClearKv, NCCL, and load handlers as separate match arms.
|
//! ForwardLogits, Shutdown. Phase 3 will add the TP variants
|
||||||
|
//! (NcclInit, NcclSanity, TpLoadShard, TpForward, TpClearKv) and the
|
||||||
|
//! ARCH model state in this state slab will gain a companion
|
||||||
|
//! `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>`.
|
||||||
|
|
||||||
use crate::harness::device_worker::jobs::Job;
|
use crate::harness::candle::ModelArch;
|
||||||
|
use crate::harness::device_worker::jobs::{ArchHandle, Job};
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::mpsc::Receiver;
|
use std::sync::mpsc::Receiver;
|
||||||
|
|
||||||
/// Per-thread state owned by the worker. On CUDA builds the `Arc<CudaContext>`
|
/// Per-thread state owned by the worker. On CUDA builds the `Arc<CudaContext>`
|
||||||
/// is created and bound at thread startup; on CPU builds the struct is
|
/// is created and bound at thread startup; on CPU builds the struct
|
||||||
/// empty save for the device index (kept for log clarity).
|
/// is mostly empty.
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
struct DeviceWorkerState {
|
struct DeviceWorkerState {
|
||||||
|
#[allow(dead_code)]
|
||||||
device_index: u32,
|
device_index: u32,
|
||||||
|
/// Candle `Device` constructed at startup. Used by handlers (e.g.
|
||||||
|
/// `ForwardLogits`) to build input tensors against the right
|
||||||
|
/// device. Falls back to `Device::Cpu` if CUDA init fails.
|
||||||
|
device: candle_core::Device,
|
||||||
|
/// Boxed `ModelArch` slab. Indexed by an opaque `ArchHandle` minted
|
||||||
|
/// by `TransferIn`. The Box means the entry's address is stable
|
||||||
|
/// across HashMap rehashes (relevant only when we later hand out
|
||||||
|
/// `&mut ModelArch` references — for Phase 2 every handler runs
|
||||||
|
/// `&mut` via `get_mut`, no long-lived borrows).
|
||||||
|
models: HashMap<ArchHandle, Box<ModelArch>>,
|
||||||
|
/// Counter for minting fresh `ArchHandle`s. Each `TransferIn`
|
||||||
|
/// increments and returns the new value. Wraps at u64::MAX after
|
||||||
|
/// ~10^19 model loads — not a practical concern.
|
||||||
|
next_handle: u64,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(dead_code)]
|
||||||
/// `None` only if `CudaContext::new()` failed — in that case the
|
/// `None` only if `CudaContext::new()` failed — in that case the
|
||||||
/// thread still runs so the handle's lifecycle stays uniform, but
|
/// thread still runs so the handle's lifecycle stays uniform, but
|
||||||
/// every job that touches CUDA falls through to a zero reply with
|
/// every job that touches CUDA falls through to a zero reply with
|
||||||
@@ -28,18 +49,12 @@ struct DeviceWorkerState {
|
|||||||
ctx: Option<Arc<candle_core::cuda::cudarc::driver::CudaContext>>,
|
ctx: Option<Arc<candle_core::cuda::cudarc::driver::CudaContext>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
|
||||||
#[allow(dead_code)]
|
|
||||||
struct DeviceWorkerState {
|
|
||||||
device_index: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Worker thread entry point. Runs until `Job::Shutdown` arrives or
|
/// Worker thread entry point. Runs until `Job::Shutdown` arrives or
|
||||||
/// the channel sender is dropped (which happens when the last
|
/// the channel sender is dropped (which happens when the last
|
||||||
/// `DeviceWorkerHandle` `Arc` is dropped without an explicit
|
/// `DeviceWorkerHandle` `Arc` is dropped without an explicit
|
||||||
/// `shutdown()`).
|
/// `shutdown()`).
|
||||||
pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool>) {
|
pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool>) {
|
||||||
let state = init_state(device_index);
|
let mut state = init_state(device_index);
|
||||||
tracing::info!(device_index, "device worker started");
|
tracing::info!(device_index, "device worker started");
|
||||||
|
|
||||||
while let Ok(job) = rx.recv() {
|
while let Ok(job) = rx.recv() {
|
||||||
@@ -51,7 +66,7 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
|||||||
}
|
}
|
||||||
if poisoned.load(Ordering::Acquire) {
|
if poisoned.load(Ordering::Acquire) {
|
||||||
// Drain-only mode: reply with a poisoned error without
|
// Drain-only mode: reply with a poisoned error without
|
||||||
// touching CUDA. Phase 1 never sets the flag from the
|
// touching CUDA. Phase 1/2 never set the flag from the
|
||||||
// dispatch loop itself (no driver errors classified yet),
|
// dispatch loop itself (no driver errors classified yet),
|
||||||
// but tests use `DeviceWorkerHandle::set_poisoned()` to
|
// but tests use `DeviceWorkerHandle::set_poisoned()` to
|
||||||
// simulate this state.
|
// simulate this state.
|
||||||
@@ -66,58 +81,130 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
|||||||
// discard the reply.
|
// discard the reply.
|
||||||
let _ = reply.send(result);
|
let _ = reply.send(result);
|
||||||
}
|
}
|
||||||
|
Job::TransferIn { arch, reply } => {
|
||||||
|
let handle = ArchHandle(state.next_handle);
|
||||||
|
state.next_handle = state.next_handle.wrapping_add(1);
|
||||||
|
state.models.insert(handle, arch);
|
||||||
|
tracing::debug!(
|
||||||
|
device_index,
|
||||||
|
handle = handle.0,
|
||||||
|
slab_size = state.models.len(),
|
||||||
|
"device worker: model transferred in"
|
||||||
|
);
|
||||||
|
let _ = reply.send(Ok(handle));
|
||||||
|
}
|
||||||
|
Job::DropArch { handle, reply } => {
|
||||||
|
let removed = state.models.remove(&handle);
|
||||||
|
let was_present = removed.is_some();
|
||||||
|
// Explicit drop on this thread — runs the Box<ModelArch>
|
||||||
|
// Drop with the CUDA context bound here, which frees
|
||||||
|
// all device tensors on the right context. The Drop is
|
||||||
|
// implicit on the `removed` value going out of scope at
|
||||||
|
// the end of the arm; calling drop() explicitly just
|
||||||
|
// makes the intent visible.
|
||||||
|
drop(removed);
|
||||||
|
tracing::debug!(
|
||||||
|
device_index,
|
||||||
|
handle = handle.0,
|
||||||
|
was_present,
|
||||||
|
slab_size = state.models.len(),
|
||||||
|
"device worker: model dropped"
|
||||||
|
);
|
||||||
|
let _ = reply.send(());
|
||||||
|
}
|
||||||
|
Job::ClearKv { handle, reply } => {
|
||||||
|
let result = match state.models.get_mut(&handle) {
|
||||||
|
Some(arch) => arch.clear_kv_cache(),
|
||||||
|
None => Err(anyhow::anyhow!("ClearKv: no model for handle {}", handle.0)),
|
||||||
|
};
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
|
Job::ForwardLogits {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let result = forward_logits(&mut state, handle, &tokens, offset);
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
// Handled by the matches!() check above; reaching here
|
// Handled by the matches!() check above; reaching here
|
||||||
// means a Shutdown slipped past which is a bug.
|
// means a Shutdown slipped past which is a bug.
|
||||||
Job::Shutdown => unreachable!("Shutdown should break above"),
|
Job::Shutdown => unreachable!("Shutdown should break above"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::info!(device_index, "device worker exiting");
|
tracing::info!(
|
||||||
|
device_index,
|
||||||
|
slab_size = state.models.len(),
|
||||||
|
"device worker exiting; dropping remaining models"
|
||||||
|
);
|
||||||
|
// Drops every model in the slab on this thread before the function
|
||||||
|
// returns. Critical for CUDA tensors: dropping on a thread that
|
||||||
|
// doesn't have the context bound is UB. Phase 2 still runs Drop
|
||||||
|
// via the slab going out of scope, which is correct as long as no
|
||||||
|
// pre-poisoned state lurks in here — see the poisoned-mode
|
||||||
|
// semantics in mod.rs for the Phase 3+ refinement.
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
fn init_state(device_index: u32) -> DeviceWorkerState {
|
fn init_state(device_index: u32) -> DeviceWorkerState {
|
||||||
use candle_core::cuda::cudarc::driver::CudaContext;
|
#[cfg(feature = "cuda")]
|
||||||
match CudaContext::new(device_index as usize) {
|
{
|
||||||
Ok(ctx) => {
|
use candle_core::cuda::cudarc::driver::CudaContext;
|
||||||
// Make sure the context is current on this thread. cudarc
|
// Construct a candle Device first — cudarc returns the
|
||||||
// is generally fine with lazy binding, but doing it once
|
// primary context for this index on subsequent calls, so
|
||||||
// here gives us a deterministic moment to log "context
|
// CudaContext::new and Device::new_cuda end up sharing state.
|
||||||
// bound" — and makes `mem_get_info()` work without further
|
let (device, ctx) = match candle_core::Device::new_cuda(device_index as usize) {
|
||||||
// bind dances inside the dispatch handlers.
|
Ok(device) => match CudaContext::new(device_index as usize) {
|
||||||
if let Err(e) = ctx.bind_to_thread() {
|
Ok(ctx) => {
|
||||||
|
if let Err(e) = ctx.bind_to_thread() {
|
||||||
|
tracing::warn!(
|
||||||
|
device_index,
|
||||||
|
error = ?e,
|
||||||
|
"device worker: bind_to_thread failed; \
|
||||||
|
operations will still rebind per-call"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
tracing::info!(device_index, "device worker bound CUDA context");
|
||||||
|
}
|
||||||
|
(device, Some(ctx))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
device_index,
|
||||||
|
error = ?e,
|
||||||
|
"device worker: CudaContext::new failed; \
|
||||||
|
vram queries will return (0, 0), forward will error"
|
||||||
|
);
|
||||||
|
(device, None)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
device_index,
|
device_index,
|
||||||
error = ?e,
|
error = %e,
|
||||||
"device worker: bind_to_thread failed; \
|
"device worker: Device::new_cuda failed; falling back to CPU device"
|
||||||
vram queries will still rebind per-call"
|
|
||||||
);
|
);
|
||||||
} else {
|
(candle_core::Device::Cpu, None)
|
||||||
tracing::info!(device_index, "device worker bound CUDA context");
|
|
||||||
}
|
|
||||||
DeviceWorkerState {
|
|
||||||
device_index,
|
|
||||||
ctx: Some(ctx),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::warn!(
|
|
||||||
device_index,
|
|
||||||
error = ?e,
|
|
||||||
"device worker: CudaContext::new failed; \
|
|
||||||
vram queries will return (0, 0)"
|
|
||||||
);
|
|
||||||
DeviceWorkerState {
|
|
||||||
device_index,
|
|
||||||
ctx: None,
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
DeviceWorkerState {
|
||||||
|
device_index,
|
||||||
|
device,
|
||||||
|
models: HashMap::new(),
|
||||||
|
next_handle: 1,
|
||||||
|
ctx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
{
|
||||||
|
DeviceWorkerState {
|
||||||
|
device_index,
|
||||||
|
device: candle_core::Device::Cpu,
|
||||||
|
models: HashMap::new(),
|
||||||
|
next_handle: 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
|
||||||
fn init_state(device_index: u32) -> DeviceWorkerState {
|
|
||||||
DeviceWorkerState { device_index }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
@@ -144,6 +231,42 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
|
|||||||
Ok((0, 0))
|
Ok((0, 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
|
||||||
|
/// for sampling on the async caller. The model's `device()` (CUDA or
|
||||||
|
/// CPU) determines where the kernel runs; this fn doesn't care.
|
||||||
|
///
|
||||||
|
/// On CUDA, the `to_dtype(F32).flatten_all().to_vec1::<f32>()` chain
|
||||||
|
/// triggers the device → host copy. The copy runs synchronously on
|
||||||
|
/// this worker thread; the bound context owns the source allocation
|
||||||
|
/// so the transfer is straightforward.
|
||||||
|
fn forward_logits(
|
||||||
|
state: &mut DeviceWorkerState,
|
||||||
|
handle: ArchHandle,
|
||||||
|
tokens: &[u32],
|
||||||
|
offset: usize,
|
||||||
|
) -> anyhow::Result<Vec<f32>> {
|
||||||
|
use candle_core::{DType, Tensor};
|
||||||
|
|
||||||
|
// Build the input tensor on the worker's own device. cudarc's
|
||||||
|
// primary-context model means `Device::new_cuda(idx)` shares state
|
||||||
|
// with the `CudaContext` we bound at startup, so this is the same
|
||||||
|
// device the ModelArch was loaded against.
|
||||||
|
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
|
let arch = state
|
||||||
|
.models
|
||||||
|
.get_mut(&handle)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("ForwardLogits: no model for handle {}", handle.0))?;
|
||||||
|
|
||||||
|
let logits = arch.forward(&input, offset)?;
|
||||||
|
// Copy to CPU f32. logits is already `[vocab]` (squeeze_to_vocab
|
||||||
|
// inside ModelArch::forward). The to_dtype handles bf16/f16 →
|
||||||
|
// f32 promotion for the sampler.
|
||||||
|
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
|
||||||
|
let values = logits.to_vec1::<f32>()?;
|
||||||
|
Ok(values)
|
||||||
|
}
|
||||||
|
|
||||||
/// Reply to a job with the poisoned-worker error. Used when the worker
|
/// Reply to a job with the poisoned-worker error. Used when the worker
|
||||||
/// has flipped into drain-only mode after a CUDA driver error.
|
/// has flipped into drain-only mode after a CUDA driver error.
|
||||||
///
|
///
|
||||||
@@ -153,11 +276,26 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
|
|||||||
/// poisoned error so callers never hang waiting for a worker that's
|
/// poisoned error so callers never hang waiting for a worker that's
|
||||||
/// no longer running CUDA.
|
/// no longer running CUDA.
|
||||||
fn drain_poisoned(job: Job, device_index: u32) {
|
fn drain_poisoned(job: Job, device_index: u32) {
|
||||||
|
let err = || anyhow::anyhow!("device worker for device {device_index} is poisoned");
|
||||||
match job {
|
match job {
|
||||||
Job::QueryVram { reply } => {
|
Job::QueryVram { reply } => {
|
||||||
let _ = reply.send(Err(anyhow::anyhow!(
|
let _ = reply.send(Err(err()));
|
||||||
"device worker for device {device_index} is poisoned"
|
}
|
||||||
)));
|
Job::TransferIn { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
Job::DropArch { reply, .. } => {
|
||||||
|
// Drop reply is `()` — no error path. Send the unit so the
|
||||||
|
// caller's await resolves; the model handle is leaked in
|
||||||
|
// the worker's slab, but the whole slab gets `mem::forget`
|
||||||
|
// on shutdown anyway per the poisoned-thread design.
|
||||||
|
let _ = reply.send(());
|
||||||
|
}
|
||||||
|
Job::ClearKv { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
Job::ForwardLogits { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
}
|
}
|
||||||
Job::Shutdown => {
|
Job::Shutdown => {
|
||||||
// Filtered by the matches!() guard in run(); reaching
|
// Filtered by the matches!() guard in run(); reaching
|
||||||
|
|||||||
@@ -4,16 +4,33 @@
|
|||||||
//! needs plus a `tokio::sync::oneshot::Sender` for the reply. The
|
//! needs plus a `tokio::sync::oneshot::Sender` for the reply. The
|
||||||
//! async-side `DeviceWorkerHandle` constructs a job, sends it down the
|
//! async-side `DeviceWorkerHandle` constructs a job, sends it down the
|
||||||
//! `std::sync::mpsc` channel, and `await`s the oneshot for the reply.
|
//! `std::sync::mpsc` channel, and `await`s the oneshot for the reply.
|
||||||
//!
|
|
||||||
//! Phase 1 includes only `QueryVram` and `Shutdown`. Phases 2–4 add
|
|
||||||
//! forward, kv-cache clear, drop-arch, NCCL init/sanity, and the load
|
|
||||||
//! variants. Each new variant lands as a separate PR so the worker
|
|
||||||
//! thread stays small at every checkpoint.
|
|
||||||
|
|
||||||
|
use crate::harness::candle::ModelArch;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
|
|
||||||
|
/// Opaque handle to a `ModelArch` stored in the worker thread's state
|
||||||
|
/// slab. Cheap to copy; `Send + Sync` so it crosses task boundaries
|
||||||
|
/// freely. The actual `Box<ModelArch>` it points to is owned by the
|
||||||
|
/// worker thread for the duration of the handle's lifetime — the only
|
||||||
|
/// way to drop the model is to send `Job::DropArch { handle }` so the
|
||||||
|
/// `Drop` impl runs on the thread with the bound CUDA context (the
|
||||||
|
/// invariant the whole refactor exists to guarantee).
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub struct ArchHandle(pub u64);
|
||||||
|
|
||||||
/// One unit of work for the device worker.
|
/// One unit of work for the device worker.
|
||||||
|
///
|
||||||
|
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
|
||||||
|
/// single-GPU inference primitives: transfer-in a freshly-loaded
|
||||||
|
/// `ModelArch`, drop it, clear its KV cache, and run one forward step
|
||||||
|
/// returning CPU-side logits ready for sampling on the async caller.
|
||||||
|
///
|
||||||
|
/// Sampling stays on the async side intentionally. The worker copies
|
||||||
|
/// logits to CPU (`Vec<f32>`) before reply, so the device-resident
|
||||||
|
/// tensor never escapes the worker thread and the async caller's
|
||||||
|
/// `LogitsProcessor::sample` runs entirely on the CPU candle backend
|
||||||
|
/// — no incidental context binding on a tokio worker thread.
|
||||||
pub enum Job {
|
pub enum Job {
|
||||||
/// Query free / total VRAM on the device. Returns
|
/// Query free / total VRAM on the device. Returns
|
||||||
/// `(free_mb, total_mb)`. CPU builds and contexts that failed to
|
/// `(free_mb, total_mb)`. CPU builds and contexts that failed to
|
||||||
@@ -22,10 +39,43 @@ pub enum Job {
|
|||||||
QueryVram {
|
QueryVram {
|
||||||
reply: oneshot::Sender<Result<(u64, u64)>>,
|
reply: oneshot::Sender<Result<(u64, u64)>>,
|
||||||
},
|
},
|
||||||
/// Tell the worker to break its dispatch loop and exit. The
|
/// Move a freshly-loaded `ModelArch` into the worker's state slab.
|
||||||
/// channel is then drained — any further jobs already queued get
|
/// Returns an `ArchHandle` the caller stores on `LoadedModel` and
|
||||||
/// dropped (their oneshot senders are dropped, causing the async
|
/// passes back in subsequent `ClearKv` / `ForwardLogits` /
|
||||||
/// caller's receiver to return `Err` which `DeviceWorkerHandle`
|
/// `DropArch` jobs.
|
||||||
/// maps to `WorkerError::Gone`).
|
TransferIn {
|
||||||
|
arch: Box<ModelArch>,
|
||||||
|
reply: oneshot::Sender<Result<ArchHandle>>,
|
||||||
|
},
|
||||||
|
/// Remove the model from the slab and drop it. The `Drop` runs on
|
||||||
|
/// the worker thread so CUDA tensors release their memory on the
|
||||||
|
/// same context that allocated them.
|
||||||
|
DropArch {
|
||||||
|
handle: ArchHandle,
|
||||||
|
reply: oneshot::Sender<()>,
|
||||||
|
},
|
||||||
|
/// Reset the KV cache for this model. Called at the start of every
|
||||||
|
/// chat completion so a new request doesn't attend over the
|
||||||
|
/// previous one's tokens.
|
||||||
|
ClearKv {
|
||||||
|
handle: ArchHandle,
|
||||||
|
reply: oneshot::Sender<Result<()>>,
|
||||||
|
},
|
||||||
|
/// Run one forward step and copy the resulting `[vocab]` logits to
|
||||||
|
/// CPU. The caller takes the returned `Vec<f32>`, wraps it in a
|
||||||
|
/// CPU `Tensor`, and runs `apply_repeat_penalty` + sampling
|
||||||
|
/// without touching the device context. `offset` is the KV-cache
|
||||||
|
/// position before this step (0 for prefill, `prompt_len + i` for
|
||||||
|
/// the i-th decode step).
|
||||||
|
ForwardLogits {
|
||||||
|
handle: ArchHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
|
},
|
||||||
|
/// Tell the worker to break its dispatch loop and exit. Any jobs
|
||||||
|
/// queued after this in the channel reply `Err` to their oneshot
|
||||||
|
/// senders (the senders are dropped on the worker's exit, which
|
||||||
|
/// the async-side `Receiver::await` maps to `WorkerError::Gone`).
|
||||||
Shutdown,
|
Shutdown,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ use std::sync::mpsc::{self, Sender};
|
|||||||
use std::thread::JoinHandle;
|
use std::thread::JoinHandle;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
|
|
||||||
pub use jobs::Job;
|
pub use jobs::{ArchHandle, Job};
|
||||||
|
|
||||||
/// Errors returned by `DeviceWorkerHandle` submit methods.
|
/// Errors returned by `DeviceWorkerHandle` submit methods.
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
@@ -159,6 +159,124 @@ impl DeviceWorkerHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Move a freshly-loaded `ModelArch` into the worker's state slab.
|
||||||
|
/// Returns the `ArchHandle` the caller stores on `LoadedModel`.
|
||||||
|
/// The `Box<ModelArch>` crosses the channel; the worker thread
|
||||||
|
/// owns it from here on.
|
||||||
|
pub async fn transfer_in(
|
||||||
|
&self,
|
||||||
|
arch: Box<crate::harness::candle::ModelArch>,
|
||||||
|
) -> Result<ArchHandle, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::TransferIn {
|
||||||
|
arch,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tell the worker to drop the `ModelArch` for `handle` on the
|
||||||
|
/// worker thread (so CUDA tensors release on the right context).
|
||||||
|
/// Returns `Ok(())` even if the handle wasn't in the slab — Drop
|
||||||
|
/// is idempotent. Reports `Gone` if the worker isn't running.
|
||||||
|
pub async fn drop_arch(&self, handle: ArchHandle) -> Result<(), WorkerError> {
|
||||||
|
// Poisoning doesn't block DropArch — even on a poisoned
|
||||||
|
// context we want callers to unblock and proceed with the
|
||||||
|
// unload bookkeeping. The dispatch handler under poison just
|
||||||
|
// replies `()` without touching the model (the actual Drop
|
||||||
|
// happens via mem::forget at thread exit per the poison
|
||||||
|
// protocol).
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::DropArch {
|
||||||
|
handle,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(()) => Ok(()),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset the KV cache for the model at `handle`. Called at the
|
||||||
|
/// start of every chat completion so the new prompt doesn't
|
||||||
|
/// attend over the previous request's tokens.
|
||||||
|
pub async fn clear_kv_cache(&self, handle: ArchHandle) -> Result<(), WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::ClearKv {
|
||||||
|
handle,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run one forward step and return the resulting `[vocab]` logits
|
||||||
|
/// as a CPU-side `Vec<f32>`. The caller then samples on a CPU
|
||||||
|
/// candle Tensor without ever binding the device context on its
|
||||||
|
/// tokio thread.
|
||||||
|
pub async fn forward_logits(
|
||||||
|
&self,
|
||||||
|
handle: ArchHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<Vec<f32>, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::ForwardLogits {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
|
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
|
||||||
/// twice is a no-op the second time.
|
/// twice is a no-op the second time.
|
||||||
pub fn shutdown(&self) -> anyhow::Result<()> {
|
pub fn shutdown(&self) -> anyhow::Result<()> {
|
||||||
|
|||||||
Reference in New Issue
Block a user