refactor(neuron): phase 2 — single-GPU forward + clear_kv route through device worker
Some checks failed
build-prerelease / Package helexa-neuron-ada RPM (push) Blocked by required conditions
CI / Format (push) Successful in 34s
CI / Clippy (push) Successful in 2m12s
build-prerelease / Resolve version stamps (push) Successful in 3m41s
CI / Test (push) Successful in 5m1s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m32s
build-prerelease / Build neuron-ampere (push) Successful in 5m20s
build-prerelease / Build cortex binary (push) Successful in 12m20s
build-prerelease / Build neuron-ada (push) Successful in 5m17s
build-prerelease / Package cortex RPM (push) Successful in 1m25s
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
Some checks failed
build-prerelease / Package helexa-neuron-ada RPM (push) Blocked by required conditions
CI / Format (push) Successful in 34s
CI / Clippy (push) Successful in 2m12s
build-prerelease / Resolve version stamps (push) Successful in 3m41s
CI / Test (push) Successful in 5m1s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m32s
build-prerelease / Build neuron-ampere (push) Successful in 5m20s
build-prerelease / Build cortex binary (push) Successful in 12m20s
build-prerelease / Build neuron-ada (push) Successful in 5m17s
build-prerelease / Package cortex RPM (push) Successful in 1m25s
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
Second slice of the per-device CUDA context-ownership refactor planned at
~/.claude/plans/plan-the-per-device-worker-abstract-micali.md. The two
spawn_blocking sites in `chat_completion` and `chat_completion_stream`
now route through the device worker thread on CUDA loads. CPU loads
keep the existing spawn_blocking + `Arc<Mutex<ModelArch>>` path; there's
no context to own and the channel hop would only add latency.
What this phase changes:
- `Job` gains `TransferIn`, `DropArch`, `ClearKv`, `ForwardLogits`. The
worker's dispatch state grows a `HashMap<ArchHandle, Box<ModelArch>>`
slab and a `next_handle` counter for minting opaque handles.
- `LoadedModel.arch: Arc<Mutex<ModelArch>>` → `Option<Arc<Mutex<>>>`,
plus a new `arch_handle: Option<ArchHandle>` field. The two are
mutually exclusive: CUDA loads set `arch_handle = Some(_)` after
transferring the boxed arch into the worker's slab; CPU loads keep
`arch = Some(_)` for the legacy spawn_blocking path.
- New `run_inference_via_worker` and `stream_inference_via_worker`
drive the prefill + decode loop by sending `Job::ForwardLogits` per
step; the worker copies the resulting `[vocab]` logits to a
CPU-side `Vec<f32>` before reply, so the async caller never holds a
device-resident tensor. `apply_repeat_penalty` and
`LogitsProcessor::sample` run on a CPU candle tensor; no context
binding side-effects on tokio worker threads.
- `logits_health_slice(&[f32])` complements the existing
`logits_health(&Tensor)` so the new worker paths can compute
health stats directly from the CPU vec.
- `unload_model` for the single-GPU CUDA path now sends
`Job::DropArch { handle }` to the worker so the `Box<ModelArch>`
drops on the thread that allocated its CUDA tensors. The `Drop` runs
with the bound context, freeing memory on the right context.
What this phase doesn't touch (yet):
- TP forward, TP load, NCCL bring-up — still on spawn_blocking. Phase 3.
- Single-GPU model load — still spawn_blocking, followed by a
`Job::TransferIn` to move the freshly-built `ModelArch` into the
worker slab. Phase 4 moves the load itself onto the worker thread
and eliminates the bootstrap TransferIn.
- The `device_vram_mb` / `cuda_mem_mb` helpers — still present and
used by the construction-time logs running inside spawn_blocking
loads. Phase 4 cleanup folds them into `dispatch.rs`.
Public API unchanged. fmt + clippy clean; 37 lib tests + all
integration tests pass.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -102,7 +102,14 @@ impl LoadedHandle {
|
||||
/// moved into `spawn_blocking` for synchronous candle forward passes.
|
||||
pub struct LoadedModel {
|
||||
pub model_id: String,
|
||||
pub arch: Arc<Mutex<ModelArch>>,
|
||||
/// Local (async-side) handle to the model architecture. `Some`
|
||||
/// only when the model loaded onto the CPU device (no CUDA
|
||||
/// available); the inference path then takes this mutex via
|
||||
/// `spawn_blocking` and runs candle ops on the CPU backend.
|
||||
/// `None` when the model loaded onto a CUDA device — in that case
|
||||
/// the architecture lives in the worker thread's slab and is
|
||||
/// addressed via [`Self::arch_handle`].
|
||||
pub arch: Option<Arc<Mutex<ModelArch>>>,
|
||||
pub tokenizer: Tokenizer,
|
||||
pub device: Device,
|
||||
pub quant: Option<String>,
|
||||
@@ -118,10 +125,15 @@ pub struct LoadedModel {
|
||||
pub poisoned: AtomicBool,
|
||||
/// Handle to the per-device CUDA worker thread for this model's
|
||||
/// device. `None` for CPU loads (no context to own). VRAM queries
|
||||
/// and — in later refactor phases — forward / kv-cache / unload
|
||||
/// ops route through this handle so the device's CUDA context
|
||||
/// stays bound to one OS thread for the daemon's lifetime.
|
||||
/// and — for CUDA loads — forward / kv-cache / drop ops route
|
||||
/// through this handle so the device's CUDA context stays bound
|
||||
/// to one OS thread for the daemon's lifetime.
|
||||
pub worker: Option<Arc<super::device_worker::DeviceWorkerHandle>>,
|
||||
/// Index into the worker's `ModelArch` slab. `Some` iff the model
|
||||
/// loaded onto a CUDA device and was successfully transferred to
|
||||
/// the worker; in that case [`Self::arch`] is `None`. The two
|
||||
/// fields are mutually exclusive.
|
||||
pub arch_handle: Option<super::device_worker::ArchHandle>,
|
||||
}
|
||||
|
||||
impl LoadedModel {
|
||||
@@ -475,6 +487,16 @@ fn logits_health(t: &Tensor) -> LogitsHealth {
|
||||
};
|
||||
}
|
||||
};
|
||||
logits_health_slice(&values)
|
||||
}
|
||||
|
||||
/// Same diagnostic as [`logits_health`] but operates directly on a
|
||||
/// `[f32]` slice. Used by the worker-routed inference paths where the
|
||||
/// device → host copy has already happened on the worker thread and
|
||||
/// the async caller has the values in hand. Avoids the round-trip of
|
||||
/// rebuilding a Tensor just to call to_vec1 again.
|
||||
#[allow(dead_code)]
|
||||
fn logits_health_slice(values: &[f32]) -> LogitsHealth {
|
||||
let mut nan = 0usize;
|
||||
let mut pos_inf = 0usize;
|
||||
let mut neg_inf = 0usize;
|
||||
@@ -483,7 +505,7 @@ fn logits_health(t: &Tensor) -> LogitsHealth {
|
||||
let mut finite_max = f32::NEG_INFINITY;
|
||||
let mut finite_sum = 0.0_f64;
|
||||
let mut finite_count = 0usize;
|
||||
for &v in &values {
|
||||
for &v in values {
|
||||
if v.is_nan() {
|
||||
nan += 1;
|
||||
} else if v == f32::INFINITY {
|
||||
@@ -1104,9 +1126,50 @@ impl CandleHarness {
|
||||
"chat_completion: starting"
|
||||
);
|
||||
|
||||
let arch_arc = Arc::clone(&loaded.arch);
|
||||
// Routing: CUDA loads go through the per-device worker
|
||||
// thread (introduced in Phase 1; forward/clear added in
|
||||
// Phase 2). CPU loads keep the existing spawn_blocking
|
||||
// path because there's no context to own and the channel
|
||||
// round-trip would only add latency. The two arms produce
|
||||
// the same `(Vec<u32>, String)` shape so the rest of the
|
||||
// path is shared.
|
||||
let (generated_ids, finish_reason) = if let (Some(worker), Some(handle)) =
|
||||
(loaded.worker.as_ref(), loaded.arch_handle)
|
||||
{
|
||||
// Worker path (CUDA).
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
match run_inference_via_worker(
|
||||
worker,
|
||||
handle,
|
||||
&prompt_tokens,
|
||||
max_new,
|
||||
temperature,
|
||||
top_p,
|
||||
seed,
|
||||
eos_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
loaded.poisoned.store(true, Ordering::Release);
|
||||
return Err(InferenceError::Other(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
// Can't happen: `loaded.worker` is only Some on
|
||||
// CUDA builds. The dead branch keeps the no-cuda
|
||||
// build well-typed.
|
||||
let _ = (worker, handle);
|
||||
unreachable!("worker handle present without cuda feature");
|
||||
}
|
||||
} else if let Some(arch_arc) = loaded.arch.clone() {
|
||||
// CPU path: existing spawn_blocking on the local
|
||||
// Arc<Mutex<ModelArch>>.
|
||||
let device = loaded.device.clone();
|
||||
|
||||
let inference_result =
|
||||
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
||||
let mut guard = arch_arc.blocking_lock();
|
||||
@@ -1128,7 +1191,7 @@ impl CandleHarness {
|
||||
// device-poisoning event. The terminal log at the bottom
|
||||
// of the wrapper reports the error; this flag stops the
|
||||
// NEXT request from going down the same path.
|
||||
let (generated_ids, finish_reason) = match inference_result {
|
||||
match inference_result {
|
||||
Ok(Ok(v)) => v,
|
||||
Ok(Err(e)) => {
|
||||
loaded.poisoned.store(true, Ordering::Release);
|
||||
@@ -1140,6 +1203,13 @@ impl CandleHarness {
|
||||
"inference task panicked: {e}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// LoadedModel invariant: exactly one of `worker` /
|
||||
// `arch` is Some. Reaching here is a construction bug.
|
||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||
"LoadedModel has neither worker handle nor local arch — load-path bug"
|
||||
)));
|
||||
};
|
||||
|
||||
let completion_text = loaded
|
||||
@@ -1244,7 +1314,6 @@ impl CandleHarness {
|
||||
.token_to_id("<|im_end|>")
|
||||
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
|
||||
|
||||
let arch_arc = Arc::clone(&loaded.arch);
|
||||
let device = loaded.device.clone();
|
||||
let tokenizer = loaded.tokenizer.clone();
|
||||
let model_id = request.model.clone();
|
||||
@@ -1315,6 +1384,57 @@ impl CandleHarness {
|
||||
"chat_completion (stream): starting"
|
||||
);
|
||||
}
|
||||
// Routing parallel to the non-streaming chat_completion: CUDA
|
||||
// goes through the worker (async task), CPU keeps the
|
||||
// spawn_blocking + Arc<Mutex<ModelArch>> path.
|
||||
if let (Some(worker), Some(handle)) = (loaded.worker.clone(), loaded.arch_handle) {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
let prompt_tokens = prompt_tokens.clone();
|
||||
tokio::spawn(
|
||||
async move {
|
||||
match stream_inference_via_worker(
|
||||
worker,
|
||||
handle,
|
||||
tokenizer,
|
||||
prompt_tokens,
|
||||
max_new,
|
||||
temperature,
|
||||
top_p,
|
||||
seed,
|
||||
eos_id,
|
||||
id,
|
||||
created,
|
||||
model_id,
|
||||
tx,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_finish_reason) => tracing::info!(
|
||||
prompt_tokens = prompt_len,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"chat_completion (stream): done"
|
||||
),
|
||||
Err(e) => {
|
||||
loaded_for_task.poisoned.store(true, Ordering::Release);
|
||||
tracing::error!(
|
||||
error = %format!("{e:#}"),
|
||||
prompt_tokens = prompt_len,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"chat_completion (stream): failed, model marked poisoned"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(span_for_task),
|
||||
);
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
let _ = (worker, handle, span_for_task);
|
||||
unreachable!("worker handle present without cuda feature");
|
||||
}
|
||||
} else if let Some(arch_arc) = loaded.arch.clone() {
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let _g = span_for_task.enter();
|
||||
let mut guard = arch_arc.blocking_lock();
|
||||
@@ -1349,6 +1469,11 @@ impl CandleHarness {
|
||||
}
|
||||
}
|
||||
});
|
||||
} else {
|
||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||
"LoadedModel has neither worker handle nor local arch — load-path bug"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
@@ -1432,22 +1557,37 @@ impl Harness for CandleHarness {
|
||||
|
||||
// Worker thread for the chosen device. CPU loads (CUDA
|
||||
// unavailable / not requested) skip the worker — there's no
|
||||
// context to own.
|
||||
let worker = match &device {
|
||||
// context to own. For CUDA loads, the arch is transferred
|
||||
// into the worker's slab now so the inference path can
|
||||
// reference it via the returned `ArchHandle`. The explicit
|
||||
// type annotation lets the no-cuda build resolve `None` to
|
||||
// the right `Option<Arc<DeviceWorkerHandle>>` type.
|
||||
let worker: Option<Arc<super::device_worker::DeviceWorkerHandle>> = match &device {
|
||||
#[cfg(feature = "cuda")]
|
||||
Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?),
|
||||
_ => None,
|
||||
};
|
||||
let (arch_local, arch_handle) = match &worker {
|
||||
Some(w) => {
|
||||
let handle = w
|
||||
.transfer_in(Box::new(arch))
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("transfer arch into device worker: {e}"))?;
|
||||
(None, Some(handle))
|
||||
}
|
||||
None => (Some(Arc::new(Mutex::new(arch))), None),
|
||||
};
|
||||
|
||||
let loaded = Arc::new(LoadedModel {
|
||||
model_id: spec.model_id.clone(),
|
||||
arch: Arc::new(Mutex::new(arch)),
|
||||
arch: arch_local,
|
||||
tokenizer,
|
||||
device,
|
||||
quant: spec.quant.clone(),
|
||||
devices,
|
||||
poisoned: AtomicBool::new(false),
|
||||
worker,
|
||||
arch_handle,
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
@@ -1465,13 +1605,26 @@ impl Harness for CandleHarness {
|
||||
anyhow::bail!("model '{model_id}' not loaded");
|
||||
};
|
||||
// Single-GPU drops are immediate — the LoadedModel goes out of
|
||||
// scope with the Arc and candle frees VRAM. TP unloads also
|
||||
// need to tell every worker to drop its shard before the pool
|
||||
// itself is dropped (otherwise the workers keep their shards
|
||||
// around until Shutdown, which is wasteful and would surface
|
||||
// as VRAM not freed promptly).
|
||||
// scope with the Arc and candle frees VRAM. CUDA loads also
|
||||
// ship a `Job::DropArch` to the device worker so the boxed
|
||||
// `ModelArch` releases its CUDA allocations on the right
|
||||
// thread (with the bound context); without that, the Drop
|
||||
// would run on whatever tokio thread happens to be holding
|
||||
// the last `Arc<LoadedModel>` clone when this fn returns.
|
||||
// TP unloads further coordinate the subprocess pool below.
|
||||
match handle {
|
||||
LoadedHandle::Single(_) => {}
|
||||
LoadedHandle::Single(single) => {
|
||||
if let (Some(worker), Some(arch_handle)) =
|
||||
(single.worker.as_ref(), single.arch_handle)
|
||||
&& let Err(e) = worker.drop_arch(arch_handle).await
|
||||
{
|
||||
tracing::warn!(
|
||||
model = %model_id,
|
||||
error = %e,
|
||||
"single-GPU unload: DropArch RPC failed (model state may leak in worker slab)"
|
||||
);
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(tp) => {
|
||||
// Try to recover the inner TpLoadedModel so we can move
|
||||
@@ -2276,6 +2429,239 @@ fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
|
||||
prompt
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Run the full single-GPU inference loop via the device worker.
|
||||
///
|
||||
/// Mirrors `run_inference`'s logic but routes each forward step
|
||||
/// through `worker.forward_logits()` (returns CPU-side `Vec<f32>`)
|
||||
/// and runs `apply_repeat_penalty` + sampling on a CPU candle tensor.
|
||||
/// The device-resident logits tensor never escapes the worker thread.
|
||||
///
|
||||
/// Used by the CUDA path of `chat_completion`. The CPU path keeps
|
||||
/// `run_inference` (spawn_blocking against `Arc<Mutex<ModelArch>>`)
|
||||
/// because there's no CUDA context to own and the worker indirection
|
||||
/// would only add channel overhead with no diagnostic benefit.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn run_inference_via_worker(
|
||||
worker: &super::device_worker::DeviceWorkerHandle,
|
||||
handle: super::device_worker::ArchHandle,
|
||||
prompt_tokens: &[u32],
|
||||
max_new: usize,
|
||||
temperature: f64,
|
||||
top_p: Option<f64>,
|
||||
seed: u64,
|
||||
eos_id: Option<u32>,
|
||||
) -> Result<(Vec<u32>, String)> {
|
||||
let mut logits_processor = {
|
||||
let sampling = if temperature <= 0.0 {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match top_p {
|
||||
Some(p) => Sampling::TopP { p, temperature },
|
||||
None => Sampling::All { temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(seed, sampling)
|
||||
};
|
||||
|
||||
let mut generated: Vec<u32> = Vec::new();
|
||||
let prompt_len = prompt_tokens.len();
|
||||
|
||||
worker
|
||||
.clear_kv_cache(handle)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
|
||||
|
||||
// Prefill — every rank embeds the prompt with offset 0.
|
||||
let logits_vec = worker
|
||||
.forward_logits(handle, prompt_tokens.to_vec(), 0)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
?health,
|
||||
"chat_completion (worker): prefill sample failed; logits unhealthy"
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
if Some(next_token) == eos_id {
|
||||
return Ok((generated, "stop".into()));
|
||||
}
|
||||
generated.push(next_token);
|
||||
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
let logits_vec = worker
|
||||
.forward_logits(handle, vec![next_token], prompt_len + index)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
step = index,
|
||||
?health,
|
||||
"chat_completion (worker): decode sample failed; logits unhealthy"
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
if Some(next_token) == eos_id {
|
||||
return Ok((generated, "stop".into()));
|
||||
}
|
||||
generated.push(next_token);
|
||||
}
|
||||
|
||||
Ok((generated, "length".into()))
|
||||
}
|
||||
|
||||
/// Streaming counterpart of [`run_inference_via_worker`]. Emits one
|
||||
/// `ChatCompletionChunk` per generated token via `tx`; routes every
|
||||
/// forward step through `worker.forward_logits()`. Same per-step
|
||||
/// CPU-side sampling discipline — no device tensor escapes the
|
||||
/// worker thread.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn stream_inference_via_worker(
|
||||
worker: Arc<super::device_worker::DeviceWorkerHandle>,
|
||||
handle: super::device_worker::ArchHandle,
|
||||
tokenizer: Tokenizer,
|
||||
prompt_tokens: Vec<u32>,
|
||||
max_new: usize,
|
||||
temperature: f64,
|
||||
top_p: Option<f64>,
|
||||
seed: u64,
|
||||
eos_id: Option<u32>,
|
||||
id: String,
|
||||
created: u64,
|
||||
model_id: String,
|
||||
tx: mpsc::Sender<ChatCompletionChunk>,
|
||||
) -> Result<String> {
|
||||
let mut logits_processor = {
|
||||
let sampling = if temperature <= 0.0 {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match top_p {
|
||||
Some(p) => Sampling::TopP { p, temperature },
|
||||
None => Sampling::All { temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(seed, sampling)
|
||||
};
|
||||
|
||||
let mut all_tokens: Vec<u32> = Vec::new();
|
||||
let mut decoded_prefix = String::new();
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut finish_reason = "length".to_string();
|
||||
|
||||
worker
|
||||
.clear_kv_cache(handle)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
|
||||
|
||||
let logits_vec = worker
|
||||
.forward_logits(handle, prompt_tokens, 0)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
let mut next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
?health,
|
||||
"chat_completion (stream/worker): prefill sample failed; logits unhealthy"
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = "stop".into();
|
||||
} else {
|
||||
all_tokens.push(next_token);
|
||||
if !emit_chunk(
|
||||
&all_tokens,
|
||||
&mut decoded_prefix,
|
||||
&tokenizer,
|
||||
&tx,
|
||||
&id,
|
||||
created,
|
||||
&model_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
return Ok(finish_reason); // Client gone — clean stream end.
|
||||
}
|
||||
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
let logits_vec = worker
|
||||
.forward_logits(handle, vec![next_token], prompt_len + index)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
step = index,
|
||||
?health,
|
||||
"chat_completion (stream/worker): decode sample failed; logits unhealthy"
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = "stop".into();
|
||||
break;
|
||||
}
|
||||
all_tokens.push(next_token);
|
||||
if !emit_chunk(
|
||||
&all_tokens,
|
||||
&mut decoded_prefix,
|
||||
&tokenizer,
|
||||
&tx,
|
||||
&id,
|
||||
created,
|
||||
&model_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
return Ok(finish_reason);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final chunk carrying finish_reason. Matches the run_inference_streaming
|
||||
// shape so the SSE consumer sees an identical termination sequence.
|
||||
let final_chunk = ChatCompletionChunk {
|
||||
id: id.clone(),
|
||||
object: "chat.completion.chunk".into(),
|
||||
created,
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChunkChoice {
|
||||
index: 0,
|
||||
delta: serde_json::Value::Object(Default::default()),
|
||||
finish_reason: Some(finish_reason.clone()),
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
}],
|
||||
usage: None,
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
};
|
||||
let _ = tx.send(final_chunk).await;
|
||||
|
||||
Ok(finish_reason)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_inference(
|
||||
arch: &mut ModelArch,
|
||||
|
||||
@@ -7,20 +7,41 @@
|
||||
//! one-line `oneshot::Sender::send` call to ship the reply back, which
|
||||
//! is non-blocking.
|
||||
//!
|
||||
//! Phase 1 handles only `QueryVram` and `Shutdown`. Later phases add
|
||||
//! Forward, ClearKv, NCCL, and load handlers as separate match arms.
|
||||
//! Phase 2 handles QueryVram, TransferIn, DropArch, ClearKv,
|
||||
//! ForwardLogits, Shutdown. Phase 3 will add the TP variants
|
||||
//! (NcclInit, NcclSanity, TpLoadShard, TpForward, TpClearKv) and the
|
||||
//! ARCH model state in this state slab will gain a companion
|
||||
//! `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>`.
|
||||
|
||||
use crate::harness::device_worker::jobs::Job;
|
||||
use crate::harness::candle::ModelArch;
|
||||
use crate::harness::device_worker::jobs::{ArchHandle, Job};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::mpsc::Receiver;
|
||||
|
||||
/// Per-thread state owned by the worker. On CUDA builds the `Arc<CudaContext>`
|
||||
/// is created and bound at thread startup; on CPU builds the struct is
|
||||
/// empty save for the device index (kept for log clarity).
|
||||
#[cfg(feature = "cuda")]
|
||||
/// is created and bound at thread startup; on CPU builds the struct
|
||||
/// is mostly empty.
|
||||
struct DeviceWorkerState {
|
||||
#[allow(dead_code)]
|
||||
device_index: u32,
|
||||
/// Candle `Device` constructed at startup. Used by handlers (e.g.
|
||||
/// `ForwardLogits`) to build input tensors against the right
|
||||
/// device. Falls back to `Device::Cpu` if CUDA init fails.
|
||||
device: candle_core::Device,
|
||||
/// Boxed `ModelArch` slab. Indexed by an opaque `ArchHandle` minted
|
||||
/// by `TransferIn`. The Box means the entry's address is stable
|
||||
/// across HashMap rehashes (relevant only when we later hand out
|
||||
/// `&mut ModelArch` references — for Phase 2 every handler runs
|
||||
/// `&mut` via `get_mut`, no long-lived borrows).
|
||||
models: HashMap<ArchHandle, Box<ModelArch>>,
|
||||
/// Counter for minting fresh `ArchHandle`s. Each `TransferIn`
|
||||
/// increments and returns the new value. Wraps at u64::MAX after
|
||||
/// ~10^19 model loads — not a practical concern.
|
||||
next_handle: u64,
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(dead_code)]
|
||||
/// `None` only if `CudaContext::new()` failed — in that case the
|
||||
/// thread still runs so the handle's lifecycle stays uniform, but
|
||||
/// every job that touches CUDA falls through to a zero reply with
|
||||
@@ -28,18 +49,12 @@ struct DeviceWorkerState {
|
||||
ctx: Option<Arc<candle_core::cuda::cudarc::driver::CudaContext>>,
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
#[allow(dead_code)]
|
||||
struct DeviceWorkerState {
|
||||
device_index: u32,
|
||||
}
|
||||
|
||||
/// Worker thread entry point. Runs until `Job::Shutdown` arrives or
|
||||
/// the channel sender is dropped (which happens when the last
|
||||
/// `DeviceWorkerHandle` `Arc` is dropped without an explicit
|
||||
/// `shutdown()`).
|
||||
pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool>) {
|
||||
let state = init_state(device_index);
|
||||
let mut state = init_state(device_index);
|
||||
tracing::info!(device_index, "device worker started");
|
||||
|
||||
while let Ok(job) = rx.recv() {
|
||||
@@ -51,7 +66,7 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
||||
}
|
||||
if poisoned.load(Ordering::Acquire) {
|
||||
// Drain-only mode: reply with a poisoned error without
|
||||
// touching CUDA. Phase 1 never sets the flag from the
|
||||
// touching CUDA. Phase 1/2 never set the flag from the
|
||||
// dispatch loop itself (no driver errors classified yet),
|
||||
// but tests use `DeviceWorkerHandle::set_poisoned()` to
|
||||
// simulate this state.
|
||||
@@ -66,58 +81,130 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
||||
// discard the reply.
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
Job::TransferIn { arch, reply } => {
|
||||
let handle = ArchHandle(state.next_handle);
|
||||
state.next_handle = state.next_handle.wrapping_add(1);
|
||||
state.models.insert(handle, arch);
|
||||
tracing::debug!(
|
||||
device_index,
|
||||
handle = handle.0,
|
||||
slab_size = state.models.len(),
|
||||
"device worker: model transferred in"
|
||||
);
|
||||
let _ = reply.send(Ok(handle));
|
||||
}
|
||||
Job::DropArch { handle, reply } => {
|
||||
let removed = state.models.remove(&handle);
|
||||
let was_present = removed.is_some();
|
||||
// Explicit drop on this thread — runs the Box<ModelArch>
|
||||
// Drop with the CUDA context bound here, which frees
|
||||
// all device tensors on the right context. The Drop is
|
||||
// implicit on the `removed` value going out of scope at
|
||||
// the end of the arm; calling drop() explicitly just
|
||||
// makes the intent visible.
|
||||
drop(removed);
|
||||
tracing::debug!(
|
||||
device_index,
|
||||
handle = handle.0,
|
||||
was_present,
|
||||
slab_size = state.models.len(),
|
||||
"device worker: model dropped"
|
||||
);
|
||||
let _ = reply.send(());
|
||||
}
|
||||
Job::ClearKv { handle, reply } => {
|
||||
let result = match state.models.get_mut(&handle) {
|
||||
Some(arch) => arch.clear_kv_cache(),
|
||||
None => Err(anyhow::anyhow!("ClearKv: no model for handle {}", handle.0)),
|
||||
};
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
Job::ForwardLogits {
|
||||
handle,
|
||||
tokens,
|
||||
offset,
|
||||
reply,
|
||||
} => {
|
||||
let result = forward_logits(&mut state, handle, &tokens, offset);
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
// Handled by the matches!() check above; reaching here
|
||||
// means a Shutdown slipped past which is a bug.
|
||||
Job::Shutdown => unreachable!("Shutdown should break above"),
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(device_index, "device worker exiting");
|
||||
tracing::info!(
|
||||
device_index,
|
||||
slab_size = state.models.len(),
|
||||
"device worker exiting; dropping remaining models"
|
||||
);
|
||||
// Drops every model in the slab on this thread before the function
|
||||
// returns. Critical for CUDA tensors: dropping on a thread that
|
||||
// doesn't have the context bound is UB. Phase 2 still runs Drop
|
||||
// via the slab going out of scope, which is correct as long as no
|
||||
// pre-poisoned state lurks in here — see the poisoned-mode
|
||||
// semantics in mod.rs for the Phase 3+ refinement.
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn init_state(device_index: u32) -> DeviceWorkerState {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
use candle_core::cuda::cudarc::driver::CudaContext;
|
||||
match CudaContext::new(device_index as usize) {
|
||||
// Construct a candle Device first — cudarc returns the
|
||||
// primary context for this index on subsequent calls, so
|
||||
// CudaContext::new and Device::new_cuda end up sharing state.
|
||||
let (device, ctx) = match candle_core::Device::new_cuda(device_index as usize) {
|
||||
Ok(device) => match CudaContext::new(device_index as usize) {
|
||||
Ok(ctx) => {
|
||||
// Make sure the context is current on this thread. cudarc
|
||||
// is generally fine with lazy binding, but doing it once
|
||||
// here gives us a deterministic moment to log "context
|
||||
// bound" — and makes `mem_get_info()` work without further
|
||||
// bind dances inside the dispatch handlers.
|
||||
if let Err(e) = ctx.bind_to_thread() {
|
||||
tracing::warn!(
|
||||
device_index,
|
||||
error = ?e,
|
||||
"device worker: bind_to_thread failed; \
|
||||
vram queries will still rebind per-call"
|
||||
operations will still rebind per-call"
|
||||
);
|
||||
} else {
|
||||
tracing::info!(device_index, "device worker bound CUDA context");
|
||||
}
|
||||
DeviceWorkerState {
|
||||
device_index,
|
||||
ctx: Some(ctx),
|
||||
}
|
||||
(device, Some(ctx))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
device_index,
|
||||
error = ?e,
|
||||
"device worker: CudaContext::new failed; \
|
||||
vram queries will return (0, 0)"
|
||||
vram queries will return (0, 0), forward will error"
|
||||
);
|
||||
(device, None)
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
device_index,
|
||||
error = %e,
|
||||
"device worker: Device::new_cuda failed; falling back to CPU device"
|
||||
);
|
||||
(candle_core::Device::Cpu, None)
|
||||
}
|
||||
};
|
||||
DeviceWorkerState {
|
||||
device_index,
|
||||
ctx: None,
|
||||
device,
|
||||
models: HashMap::new(),
|
||||
next_handle: 1,
|
||||
ctx,
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
DeviceWorkerState {
|
||||
device_index,
|
||||
device: candle_core::Device::Cpu,
|
||||
models: HashMap::new(),
|
||||
next_handle: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn init_state(device_index: u32) -> DeviceWorkerState {
|
||||
DeviceWorkerState { device_index }
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
@@ -144,6 +231,42 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
|
||||
Ok((0, 0))
|
||||
}
|
||||
|
||||
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
|
||||
/// for sampling on the async caller. The model's `device()` (CUDA or
|
||||
/// CPU) determines where the kernel runs; this fn doesn't care.
|
||||
///
|
||||
/// On CUDA, the `to_dtype(F32).flatten_all().to_vec1::<f32>()` chain
|
||||
/// triggers the device → host copy. The copy runs synchronously on
|
||||
/// this worker thread; the bound context owns the source allocation
|
||||
/// so the transfer is straightforward.
|
||||
fn forward_logits(
|
||||
state: &mut DeviceWorkerState,
|
||||
handle: ArchHandle,
|
||||
tokens: &[u32],
|
||||
offset: usize,
|
||||
) -> anyhow::Result<Vec<f32>> {
|
||||
use candle_core::{DType, Tensor};
|
||||
|
||||
// Build the input tensor on the worker's own device. cudarc's
|
||||
// primary-context model means `Device::new_cuda(idx)` shares state
|
||||
// with the `CudaContext` we bound at startup, so this is the same
|
||||
// device the ModelArch was loaded against.
|
||||
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
|
||||
|
||||
let arch = state
|
||||
.models
|
||||
.get_mut(&handle)
|
||||
.ok_or_else(|| anyhow::anyhow!("ForwardLogits: no model for handle {}", handle.0))?;
|
||||
|
||||
let logits = arch.forward(&input, offset)?;
|
||||
// Copy to CPU f32. logits is already `[vocab]` (squeeze_to_vocab
|
||||
// inside ModelArch::forward). The to_dtype handles bf16/f16 →
|
||||
// f32 promotion for the sampler.
|
||||
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
|
||||
let values = logits.to_vec1::<f32>()?;
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
/// Reply to a job with the poisoned-worker error. Used when the worker
|
||||
/// has flipped into drain-only mode after a CUDA driver error.
|
||||
///
|
||||
@@ -153,11 +276,26 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
|
||||
/// poisoned error so callers never hang waiting for a worker that's
|
||||
/// no longer running CUDA.
|
||||
fn drain_poisoned(job: Job, device_index: u32) {
|
||||
let err = || anyhow::anyhow!("device worker for device {device_index} is poisoned");
|
||||
match job {
|
||||
Job::QueryVram { reply } => {
|
||||
let _ = reply.send(Err(anyhow::anyhow!(
|
||||
"device worker for device {device_index} is poisoned"
|
||||
)));
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::TransferIn { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::DropArch { reply, .. } => {
|
||||
// Drop reply is `()` — no error path. Send the unit so the
|
||||
// caller's await resolves; the model handle is leaked in
|
||||
// the worker's slab, but the whole slab gets `mem::forget`
|
||||
// on shutdown anyway per the poisoned-thread design.
|
||||
let _ = reply.send(());
|
||||
}
|
||||
Job::ClearKv { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::ForwardLogits { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::Shutdown => {
|
||||
// Filtered by the matches!() guard in run(); reaching
|
||||
|
||||
@@ -4,16 +4,33 @@
|
||||
//! needs plus a `tokio::sync::oneshot::Sender` for the reply. The
|
||||
//! async-side `DeviceWorkerHandle` constructs a job, sends it down the
|
||||
//! `std::sync::mpsc` channel, and `await`s the oneshot for the reply.
|
||||
//!
|
||||
//! Phase 1 includes only `QueryVram` and `Shutdown`. Phases 2–4 add
|
||||
//! forward, kv-cache clear, drop-arch, NCCL init/sanity, and the load
|
||||
//! variants. Each new variant lands as a separate PR so the worker
|
||||
//! thread stays small at every checkpoint.
|
||||
|
||||
use crate::harness::candle::ModelArch;
|
||||
use anyhow::Result;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
/// Opaque handle to a `ModelArch` stored in the worker thread's state
|
||||
/// slab. Cheap to copy; `Send + Sync` so it crosses task boundaries
|
||||
/// freely. The actual `Box<ModelArch>` it points to is owned by the
|
||||
/// worker thread for the duration of the handle's lifetime — the only
|
||||
/// way to drop the model is to send `Job::DropArch { handle }` so the
|
||||
/// `Drop` impl runs on the thread with the bound CUDA context (the
|
||||
/// invariant the whole refactor exists to guarantee).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct ArchHandle(pub u64);
|
||||
|
||||
/// One unit of work for the device worker.
|
||||
///
|
||||
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
|
||||
/// single-GPU inference primitives: transfer-in a freshly-loaded
|
||||
/// `ModelArch`, drop it, clear its KV cache, and run one forward step
|
||||
/// returning CPU-side logits ready for sampling on the async caller.
|
||||
///
|
||||
/// Sampling stays on the async side intentionally. The worker copies
|
||||
/// logits to CPU (`Vec<f32>`) before reply, so the device-resident
|
||||
/// tensor never escapes the worker thread and the async caller's
|
||||
/// `LogitsProcessor::sample` runs entirely on the CPU candle backend
|
||||
/// — no incidental context binding on a tokio worker thread.
|
||||
pub enum Job {
|
||||
/// Query free / total VRAM on the device. Returns
|
||||
/// `(free_mb, total_mb)`. CPU builds and contexts that failed to
|
||||
@@ -22,10 +39,43 @@ pub enum Job {
|
||||
QueryVram {
|
||||
reply: oneshot::Sender<Result<(u64, u64)>>,
|
||||
},
|
||||
/// Tell the worker to break its dispatch loop and exit. The
|
||||
/// channel is then drained — any further jobs already queued get
|
||||
/// dropped (their oneshot senders are dropped, causing the async
|
||||
/// caller's receiver to return `Err` which `DeviceWorkerHandle`
|
||||
/// maps to `WorkerError::Gone`).
|
||||
/// Move a freshly-loaded `ModelArch` into the worker's state slab.
|
||||
/// Returns an `ArchHandle` the caller stores on `LoadedModel` and
|
||||
/// passes back in subsequent `ClearKv` / `ForwardLogits` /
|
||||
/// `DropArch` jobs.
|
||||
TransferIn {
|
||||
arch: Box<ModelArch>,
|
||||
reply: oneshot::Sender<Result<ArchHandle>>,
|
||||
},
|
||||
/// Remove the model from the slab and drop it. The `Drop` runs on
|
||||
/// the worker thread so CUDA tensors release their memory on the
|
||||
/// same context that allocated them.
|
||||
DropArch {
|
||||
handle: ArchHandle,
|
||||
reply: oneshot::Sender<()>,
|
||||
},
|
||||
/// Reset the KV cache for this model. Called at the start of every
|
||||
/// chat completion so a new request doesn't attend over the
|
||||
/// previous one's tokens.
|
||||
ClearKv {
|
||||
handle: ArchHandle,
|
||||
reply: oneshot::Sender<Result<()>>,
|
||||
},
|
||||
/// Run one forward step and copy the resulting `[vocab]` logits to
|
||||
/// CPU. The caller takes the returned `Vec<f32>`, wraps it in a
|
||||
/// CPU `Tensor`, and runs `apply_repeat_penalty` + sampling
|
||||
/// without touching the device context. `offset` is the KV-cache
|
||||
/// position before this step (0 for prefill, `prompt_len + i` for
|
||||
/// the i-th decode step).
|
||||
ForwardLogits {
|
||||
handle: ArchHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||
},
|
||||
/// Tell the worker to break its dispatch loop and exit. Any jobs
|
||||
/// queued after this in the channel reply `Err` to their oneshot
|
||||
/// senders (the senders are dropped on the worker's exit, which
|
||||
/// the async-side `Receiver::await` maps to `WorkerError::Gone`).
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ use std::sync::mpsc::{self, Sender};
|
||||
use std::thread::JoinHandle;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
pub use jobs::Job;
|
||||
pub use jobs::{ArchHandle, Job};
|
||||
|
||||
/// Errors returned by `DeviceWorkerHandle` submit methods.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -159,6 +159,124 @@ impl DeviceWorkerHandle {
|
||||
}
|
||||
}
|
||||
|
||||
/// Move a freshly-loaded `ModelArch` into the worker's state slab.
|
||||
/// Returns the `ArchHandle` the caller stores on `LoadedModel`.
|
||||
/// The `Box<ModelArch>` crosses the channel; the worker thread
|
||||
/// owns it from here on.
|
||||
pub async fn transfer_in(
|
||||
&self,
|
||||
arch: Box<crate::harness::candle::ModelArch>,
|
||||
) -> Result<ArchHandle, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::TransferIn {
|
||||
arch,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Tell the worker to drop the `ModelArch` for `handle` on the
|
||||
/// worker thread (so CUDA tensors release on the right context).
|
||||
/// Returns `Ok(())` even if the handle wasn't in the slab — Drop
|
||||
/// is idempotent. Reports `Gone` if the worker isn't running.
|
||||
pub async fn drop_arch(&self, handle: ArchHandle) -> Result<(), WorkerError> {
|
||||
// Poisoning doesn't block DropArch — even on a poisoned
|
||||
// context we want callers to unblock and proceed with the
|
||||
// unload bookkeeping. The dispatch handler under poison just
|
||||
// replies `()` without touching the model (the actual Drop
|
||||
// happens via mem::forget at thread exit per the poison
|
||||
// protocol).
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::DropArch {
|
||||
handle,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the KV cache for the model at `handle`. Called at the
|
||||
/// start of every chat completion so the new prompt doesn't
|
||||
/// attend over the previous request's tokens.
|
||||
pub async fn clear_kv_cache(&self, handle: ArchHandle) -> Result<(), WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::ClearKv {
|
||||
handle,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run one forward step and return the resulting `[vocab]` logits
|
||||
/// as a CPU-side `Vec<f32>`. The caller then samples on a CPU
|
||||
/// candle Tensor without ever binding the device context on its
|
||||
/// tokio thread.
|
||||
pub async fn forward_logits(
|
||||
&self,
|
||||
handle: ArchHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
) -> Result<Vec<f32>, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::ForwardLogits {
|
||||
handle,
|
||||
tokens,
|
||||
offset,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
|
||||
/// twice is a no-op the second time.
|
||||
pub fn shutdown(&self) -> anyhow::Result<()> {
|
||||
|
||||
Reference in New Issue
Block a user