refactor(neuron): phase 2 — single-GPU forward + clear_kv route through device worker
Some checks failed
build-prerelease / Package helexa-neuron-ada RPM (push) Blocked by required conditions
CI / Format (push) Successful in 34s
CI / Clippy (push) Successful in 2m12s
build-prerelease / Resolve version stamps (push) Successful in 3m41s
CI / Test (push) Successful in 5m1s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m32s
build-prerelease / Build neuron-ampere (push) Successful in 5m20s
build-prerelease / Build cortex binary (push) Successful in 12m20s
build-prerelease / Build neuron-ada (push) Successful in 5m17s
build-prerelease / Package cortex RPM (push) Successful in 1m25s
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
Some checks failed
build-prerelease / Package helexa-neuron-ada RPM (push) Blocked by required conditions
CI / Format (push) Successful in 34s
CI / Clippy (push) Successful in 2m12s
build-prerelease / Resolve version stamps (push) Successful in 3m41s
CI / Test (push) Successful in 5m1s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m32s
build-prerelease / Build neuron-ampere (push) Successful in 5m20s
build-prerelease / Build cortex binary (push) Successful in 12m20s
build-prerelease / Build neuron-ada (push) Successful in 5m17s
build-prerelease / Package cortex RPM (push) Successful in 1m25s
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
Second slice of the per-device CUDA context-ownership refactor planned at
~/.claude/plans/plan-the-per-device-worker-abstract-micali.md. The two
spawn_blocking sites in `chat_completion` and `chat_completion_stream`
now route through the device worker thread on CUDA loads. CPU loads
keep the existing spawn_blocking + `Arc<Mutex<ModelArch>>` path; there's
no context to own and the channel hop would only add latency.
What this phase changes:
- `Job` gains `TransferIn`, `DropArch`, `ClearKv`, `ForwardLogits`. The
worker's dispatch state grows a `HashMap<ArchHandle, Box<ModelArch>>`
slab and a `next_handle` counter for minting opaque handles.
- `LoadedModel.arch: Arc<Mutex<ModelArch>>` → `Option<Arc<Mutex<>>>`,
plus a new `arch_handle: Option<ArchHandle>` field. The two are
mutually exclusive: CUDA loads set `arch_handle = Some(_)` after
transferring the boxed arch into the worker's slab; CPU loads keep
`arch = Some(_)` for the legacy spawn_blocking path.
- New `run_inference_via_worker` and `stream_inference_via_worker`
drive the prefill + decode loop by sending `Job::ForwardLogits` per
step; the worker copies the resulting `[vocab]` logits to a
CPU-side `Vec<f32>` before reply, so the async caller never holds a
device-resident tensor. `apply_repeat_penalty` and
`LogitsProcessor::sample` run on a CPU candle tensor; no context
binding side-effects on tokio worker threads.
- `logits_health_slice(&[f32])` complements the existing
`logits_health(&Tensor)` so the new worker paths can compute
health stats directly from the CPU vec.
- `unload_model` for the single-GPU CUDA path now sends
`Job::DropArch { handle }` to the worker so the `Box<ModelArch>`
drops on the thread that allocated its CUDA tensors. The `Drop` runs
with the bound context, freeing memory on the right context.
What this phase doesn't touch (yet):
- TP forward, TP load, NCCL bring-up — still on spawn_blocking. Phase 3.
- Single-GPU model load — still spawn_blocking, followed by a
`Job::TransferIn` to move the freshly-built `ModelArch` into the
worker slab. Phase 4 moves the load itself onto the worker thread
and eliminates the bootstrap TransferIn.
- The `device_vram_mb` / `cuda_mem_mb` helpers — still present and
used by the construction-time logs running inside spawn_blocking
loads. Phase 4 cleanup folds them into `dispatch.rs`.
Public API unchanged. fmt + clippy clean; 37 lib tests + all
integration tests pass.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -102,7 +102,14 @@ impl LoadedHandle {
|
||||
/// moved into `spawn_blocking` for synchronous candle forward passes.
|
||||
pub struct LoadedModel {
|
||||
pub model_id: String,
|
||||
pub arch: Arc<Mutex<ModelArch>>,
|
||||
/// Local (async-side) handle to the model architecture. `Some`
|
||||
/// only when the model loaded onto the CPU device (no CUDA
|
||||
/// available); the inference path then takes this mutex via
|
||||
/// `spawn_blocking` and runs candle ops on the CPU backend.
|
||||
/// `None` when the model loaded onto a CUDA device — in that case
|
||||
/// the architecture lives in the worker thread's slab and is
|
||||
/// addressed via [`Self::arch_handle`].
|
||||
pub arch: Option<Arc<Mutex<ModelArch>>>,
|
||||
pub tokenizer: Tokenizer,
|
||||
pub device: Device,
|
||||
pub quant: Option<String>,
|
||||
@@ -118,10 +125,15 @@ pub struct LoadedModel {
|
||||
pub poisoned: AtomicBool,
|
||||
/// Handle to the per-device CUDA worker thread for this model's
|
||||
/// device. `None` for CPU loads (no context to own). VRAM queries
|
||||
/// and — in later refactor phases — forward / kv-cache / unload
|
||||
/// ops route through this handle so the device's CUDA context
|
||||
/// stays bound to one OS thread for the daemon's lifetime.
|
||||
/// and — for CUDA loads — forward / kv-cache / drop ops route
|
||||
/// through this handle so the device's CUDA context stays bound
|
||||
/// to one OS thread for the daemon's lifetime.
|
||||
pub worker: Option<Arc<super::device_worker::DeviceWorkerHandle>>,
|
||||
/// Index into the worker's `ModelArch` slab. `Some` iff the model
|
||||
/// loaded onto a CUDA device and was successfully transferred to
|
||||
/// the worker; in that case [`Self::arch`] is `None`. The two
|
||||
/// fields are mutually exclusive.
|
||||
pub arch_handle: Option<super::device_worker::ArchHandle>,
|
||||
}
|
||||
|
||||
impl LoadedModel {
|
||||
@@ -475,6 +487,16 @@ fn logits_health(t: &Tensor) -> LogitsHealth {
|
||||
};
|
||||
}
|
||||
};
|
||||
logits_health_slice(&values)
|
||||
}
|
||||
|
||||
/// Same diagnostic as [`logits_health`] but operates directly on a
|
||||
/// `[f32]` slice. Used by the worker-routed inference paths where the
|
||||
/// device → host copy has already happened on the worker thread and
|
||||
/// the async caller has the values in hand. Avoids the round-trip of
|
||||
/// rebuilding a Tensor just to call to_vec1 again.
|
||||
#[allow(dead_code)]
|
||||
fn logits_health_slice(values: &[f32]) -> LogitsHealth {
|
||||
let mut nan = 0usize;
|
||||
let mut pos_inf = 0usize;
|
||||
let mut neg_inf = 0usize;
|
||||
@@ -483,7 +505,7 @@ fn logits_health(t: &Tensor) -> LogitsHealth {
|
||||
let mut finite_max = f32::NEG_INFINITY;
|
||||
let mut finite_sum = 0.0_f64;
|
||||
let mut finite_count = 0usize;
|
||||
for &v in &values {
|
||||
for &v in values {
|
||||
if v.is_nan() {
|
||||
nan += 1;
|
||||
} else if v == f32::INFINITY {
|
||||
@@ -1104,15 +1126,22 @@ impl CandleHarness {
|
||||
"chat_completion: starting"
|
||||
);
|
||||
|
||||
let arch_arc = Arc::clone(&loaded.arch);
|
||||
let device = loaded.device.clone();
|
||||
|
||||
let inference_result =
|
||||
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
||||
let mut guard = arch_arc.blocking_lock();
|
||||
run_inference(
|
||||
&mut guard,
|
||||
&device,
|
||||
// Routing: CUDA loads go through the per-device worker
|
||||
// thread (introduced in Phase 1; forward/clear added in
|
||||
// Phase 2). CPU loads keep the existing spawn_blocking
|
||||
// path because there's no context to own and the channel
|
||||
// round-trip would only add latency. The two arms produce
|
||||
// the same `(Vec<u32>, String)` shape so the rest of the
|
||||
// path is shared.
|
||||
let (generated_ids, finish_reason) = if let (Some(worker), Some(handle)) =
|
||||
(loaded.worker.as_ref(), loaded.arch_handle)
|
||||
{
|
||||
// Worker path (CUDA).
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
match run_inference_via_worker(
|
||||
worker,
|
||||
handle,
|
||||
&prompt_tokens,
|
||||
max_new,
|
||||
temperature,
|
||||
@@ -1120,26 +1149,67 @@ impl CandleHarness {
|
||||
seed,
|
||||
eos_id,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
.await
|
||||
{
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
loaded.poisoned.store(true, Ordering::Release);
|
||||
return Err(InferenceError::Other(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
// Can't happen: `loaded.worker` is only Some on
|
||||
// CUDA builds. The dead branch keeps the no-cuda
|
||||
// build well-typed.
|
||||
let _ = (worker, handle);
|
||||
unreachable!("worker handle present without cuda feature");
|
||||
}
|
||||
} else if let Some(arch_arc) = loaded.arch.clone() {
|
||||
// CPU path: existing spawn_blocking on the local
|
||||
// Arc<Mutex<ModelArch>>.
|
||||
let device = loaded.device.clone();
|
||||
let inference_result =
|
||||
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
||||
let mut guard = arch_arc.blocking_lock();
|
||||
run_inference(
|
||||
&mut guard,
|
||||
&device,
|
||||
&prompt_tokens,
|
||||
max_new,
|
||||
temperature,
|
||||
top_p,
|
||||
seed,
|
||||
eos_id,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
|
||||
// Any failure inside the spawn_blocking touched CUDA via
|
||||
// candle's forward / cache code, so we treat it as a
|
||||
// device-poisoning event. The terminal log at the bottom
|
||||
// of the wrapper reports the error; this flag stops the
|
||||
// NEXT request from going down the same path.
|
||||
let (generated_ids, finish_reason) = match inference_result {
|
||||
Ok(Ok(v)) => v,
|
||||
Ok(Err(e)) => {
|
||||
loaded.poisoned.store(true, Ordering::Release);
|
||||
return Err(InferenceError::Other(e));
|
||||
}
|
||||
Err(e) => {
|
||||
loaded.poisoned.store(true, Ordering::Release);
|
||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||
"inference task panicked: {e}"
|
||||
)));
|
||||
// Any failure inside the spawn_blocking touched CUDA via
|
||||
// candle's forward / cache code, so we treat it as a
|
||||
// device-poisoning event. The terminal log at the bottom
|
||||
// of the wrapper reports the error; this flag stops the
|
||||
// NEXT request from going down the same path.
|
||||
match inference_result {
|
||||
Ok(Ok(v)) => v,
|
||||
Ok(Err(e)) => {
|
||||
loaded.poisoned.store(true, Ordering::Release);
|
||||
return Err(InferenceError::Other(e));
|
||||
}
|
||||
Err(e) => {
|
||||
loaded.poisoned.store(true, Ordering::Release);
|
||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||
"inference task panicked: {e}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// LoadedModel invariant: exactly one of `worker` /
|
||||
// `arch` is Some. Reaching here is a construction bug.
|
||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||
"LoadedModel has neither worker handle nor local arch — load-path bug"
|
||||
)));
|
||||
};
|
||||
|
||||
let completion_text = loaded
|
||||
@@ -1244,7 +1314,6 @@ impl CandleHarness {
|
||||
.token_to_id("<|im_end|>")
|
||||
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
|
||||
|
||||
let arch_arc = Arc::clone(&loaded.arch);
|
||||
let device = loaded.device.clone();
|
||||
let tokenizer = loaded.tokenizer.clone();
|
||||
let model_id = request.model.clone();
|
||||
@@ -1315,40 +1384,96 @@ impl CandleHarness {
|
||||
"chat_completion (stream): starting"
|
||||
);
|
||||
}
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let _g = span_for_task.enter();
|
||||
let mut guard = arch_arc.blocking_lock();
|
||||
match run_inference_streaming(
|
||||
&mut guard,
|
||||
&device,
|
||||
&tokenizer,
|
||||
&prompt_tokens,
|
||||
max_new,
|
||||
temperature,
|
||||
top_p,
|
||||
seed,
|
||||
eos_id,
|
||||
&id,
|
||||
created,
|
||||
&model_id,
|
||||
&tx,
|
||||
) {
|
||||
Ok(()) => tracing::info!(
|
||||
prompt_tokens = prompt_len,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"chat_completion (stream): done"
|
||||
),
|
||||
Err(e) => {
|
||||
loaded_for_task.poisoned.store(true, Ordering::Release);
|
||||
tracing::error!(
|
||||
error = %format!("{e:#}"),
|
||||
// Routing parallel to the non-streaming chat_completion: CUDA
|
||||
// goes through the worker (async task), CPU keeps the
|
||||
// spawn_blocking + Arc<Mutex<ModelArch>> path.
|
||||
if let (Some(worker), Some(handle)) = (loaded.worker.clone(), loaded.arch_handle) {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
let prompt_tokens = prompt_tokens.clone();
|
||||
tokio::spawn(
|
||||
async move {
|
||||
match stream_inference_via_worker(
|
||||
worker,
|
||||
handle,
|
||||
tokenizer,
|
||||
prompt_tokens,
|
||||
max_new,
|
||||
temperature,
|
||||
top_p,
|
||||
seed,
|
||||
eos_id,
|
||||
id,
|
||||
created,
|
||||
model_id,
|
||||
tx,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_finish_reason) => tracing::info!(
|
||||
prompt_tokens = prompt_len,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"chat_completion (stream): done"
|
||||
),
|
||||
Err(e) => {
|
||||
loaded_for_task.poisoned.store(true, Ordering::Release);
|
||||
tracing::error!(
|
||||
error = %format!("{e:#}"),
|
||||
prompt_tokens = prompt_len,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"chat_completion (stream): failed, model marked poisoned"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(span_for_task),
|
||||
);
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
let _ = (worker, handle, span_for_task);
|
||||
unreachable!("worker handle present without cuda feature");
|
||||
}
|
||||
} else if let Some(arch_arc) = loaded.arch.clone() {
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let _g = span_for_task.enter();
|
||||
let mut guard = arch_arc.blocking_lock();
|
||||
match run_inference_streaming(
|
||||
&mut guard,
|
||||
&device,
|
||||
&tokenizer,
|
||||
&prompt_tokens,
|
||||
max_new,
|
||||
temperature,
|
||||
top_p,
|
||||
seed,
|
||||
eos_id,
|
||||
&id,
|
||||
created,
|
||||
&model_id,
|
||||
&tx,
|
||||
) {
|
||||
Ok(()) => tracing::info!(
|
||||
prompt_tokens = prompt_len,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"chat_completion (stream): failed, model marked poisoned"
|
||||
);
|
||||
"chat_completion (stream): done"
|
||||
),
|
||||
Err(e) => {
|
||||
loaded_for_task.poisoned.store(true, Ordering::Release);
|
||||
tracing::error!(
|
||||
error = %format!("{e:#}"),
|
||||
prompt_tokens = prompt_len,
|
||||
total_ms = req_start.elapsed().as_millis(),
|
||||
"chat_completion (stream): failed, model marked poisoned"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
} else {
|
||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||
"LoadedModel has neither worker handle nor local arch — load-path bug"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
@@ -1432,22 +1557,37 @@ impl Harness for CandleHarness {
|
||||
|
||||
// Worker thread for the chosen device. CPU loads (CUDA
|
||||
// unavailable / not requested) skip the worker — there's no
|
||||
// context to own.
|
||||
let worker = match &device {
|
||||
// context to own. For CUDA loads, the arch is transferred
|
||||
// into the worker's slab now so the inference path can
|
||||
// reference it via the returned `ArchHandle`. The explicit
|
||||
// type annotation lets the no-cuda build resolve `None` to
|
||||
// the right `Option<Arc<DeviceWorkerHandle>>` type.
|
||||
let worker: Option<Arc<super::device_worker::DeviceWorkerHandle>> = match &device {
|
||||
#[cfg(feature = "cuda")]
|
||||
Device::Cuda(_) => Some(self.ensure_device_worker(devices[0]).await?),
|
||||
_ => None,
|
||||
};
|
||||
let (arch_local, arch_handle) = match &worker {
|
||||
Some(w) => {
|
||||
let handle = w
|
||||
.transfer_in(Box::new(arch))
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("transfer arch into device worker: {e}"))?;
|
||||
(None, Some(handle))
|
||||
}
|
||||
None => (Some(Arc::new(Mutex::new(arch))), None),
|
||||
};
|
||||
|
||||
let loaded = Arc::new(LoadedModel {
|
||||
model_id: spec.model_id.clone(),
|
||||
arch: Arc::new(Mutex::new(arch)),
|
||||
arch: arch_local,
|
||||
tokenizer,
|
||||
device,
|
||||
quant: spec.quant.clone(),
|
||||
devices,
|
||||
poisoned: AtomicBool::new(false),
|
||||
worker,
|
||||
arch_handle,
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
@@ -1465,13 +1605,26 @@ impl Harness for CandleHarness {
|
||||
anyhow::bail!("model '{model_id}' not loaded");
|
||||
};
|
||||
// Single-GPU drops are immediate — the LoadedModel goes out of
|
||||
// scope with the Arc and candle frees VRAM. TP unloads also
|
||||
// need to tell every worker to drop its shard before the pool
|
||||
// itself is dropped (otherwise the workers keep their shards
|
||||
// around until Shutdown, which is wasteful and would surface
|
||||
// as VRAM not freed promptly).
|
||||
// scope with the Arc and candle frees VRAM. CUDA loads also
|
||||
// ship a `Job::DropArch` to the device worker so the boxed
|
||||
// `ModelArch` releases its CUDA allocations on the right
|
||||
// thread (with the bound context); without that, the Drop
|
||||
// would run on whatever tokio thread happens to be holding
|
||||
// the last `Arc<LoadedModel>` clone when this fn returns.
|
||||
// TP unloads further coordinate the subprocess pool below.
|
||||
match handle {
|
||||
LoadedHandle::Single(_) => {}
|
||||
LoadedHandle::Single(single) => {
|
||||
if let (Some(worker), Some(arch_handle)) =
|
||||
(single.worker.as_ref(), single.arch_handle)
|
||||
&& let Err(e) = worker.drop_arch(arch_handle).await
|
||||
{
|
||||
tracing::warn!(
|
||||
model = %model_id,
|
||||
error = %e,
|
||||
"single-GPU unload: DropArch RPC failed (model state may leak in worker slab)"
|
||||
);
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(tp) => {
|
||||
// Try to recover the inner TpLoadedModel so we can move
|
||||
@@ -2276,6 +2429,239 @@ fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
|
||||
prompt
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Run the full single-GPU inference loop via the device worker.
|
||||
///
|
||||
/// Mirrors `run_inference`'s logic but routes each forward step
|
||||
/// through `worker.forward_logits()` (returns CPU-side `Vec<f32>`)
|
||||
/// and runs `apply_repeat_penalty` + sampling on a CPU candle tensor.
|
||||
/// The device-resident logits tensor never escapes the worker thread.
|
||||
///
|
||||
/// Used by the CUDA path of `chat_completion`. The CPU path keeps
|
||||
/// `run_inference` (spawn_blocking against `Arc<Mutex<ModelArch>>`)
|
||||
/// because there's no CUDA context to own and the worker indirection
|
||||
/// would only add channel overhead with no diagnostic benefit.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn run_inference_via_worker(
|
||||
worker: &super::device_worker::DeviceWorkerHandle,
|
||||
handle: super::device_worker::ArchHandle,
|
||||
prompt_tokens: &[u32],
|
||||
max_new: usize,
|
||||
temperature: f64,
|
||||
top_p: Option<f64>,
|
||||
seed: u64,
|
||||
eos_id: Option<u32>,
|
||||
) -> Result<(Vec<u32>, String)> {
|
||||
let mut logits_processor = {
|
||||
let sampling = if temperature <= 0.0 {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match top_p {
|
||||
Some(p) => Sampling::TopP { p, temperature },
|
||||
None => Sampling::All { temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(seed, sampling)
|
||||
};
|
||||
|
||||
let mut generated: Vec<u32> = Vec::new();
|
||||
let prompt_len = prompt_tokens.len();
|
||||
|
||||
worker
|
||||
.clear_kv_cache(handle)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
|
||||
|
||||
// Prefill — every rank embeds the prompt with offset 0.
|
||||
let logits_vec = worker
|
||||
.forward_logits(handle, prompt_tokens.to_vec(), 0)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
let mut next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
?health,
|
||||
"chat_completion (worker): prefill sample failed; logits unhealthy"
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
if Some(next_token) == eos_id {
|
||||
return Ok((generated, "stop".into()));
|
||||
}
|
||||
generated.push(next_token);
|
||||
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
let logits_vec = worker
|
||||
.forward_logits(handle, vec![next_token], prompt_len + index)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
next_token = match sample_with_penalty(&logits, &generated, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
step = index,
|
||||
?health,
|
||||
"chat_completion (worker): decode sample failed; logits unhealthy"
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
if Some(next_token) == eos_id {
|
||||
return Ok((generated, "stop".into()));
|
||||
}
|
||||
generated.push(next_token);
|
||||
}
|
||||
|
||||
Ok((generated, "length".into()))
|
||||
}
|
||||
|
||||
/// Streaming counterpart of [`run_inference_via_worker`]. Emits one
|
||||
/// `ChatCompletionChunk` per generated token via `tx`; routes every
|
||||
/// forward step through `worker.forward_logits()`. Same per-step
|
||||
/// CPU-side sampling discipline — no device tensor escapes the
|
||||
/// worker thread.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn stream_inference_via_worker(
|
||||
worker: Arc<super::device_worker::DeviceWorkerHandle>,
|
||||
handle: super::device_worker::ArchHandle,
|
||||
tokenizer: Tokenizer,
|
||||
prompt_tokens: Vec<u32>,
|
||||
max_new: usize,
|
||||
temperature: f64,
|
||||
top_p: Option<f64>,
|
||||
seed: u64,
|
||||
eos_id: Option<u32>,
|
||||
id: String,
|
||||
created: u64,
|
||||
model_id: String,
|
||||
tx: mpsc::Sender<ChatCompletionChunk>,
|
||||
) -> Result<String> {
|
||||
let mut logits_processor = {
|
||||
let sampling = if temperature <= 0.0 {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match top_p {
|
||||
Some(p) => Sampling::TopP { p, temperature },
|
||||
None => Sampling::All { temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(seed, sampling)
|
||||
};
|
||||
|
||||
let mut all_tokens: Vec<u32> = Vec::new();
|
||||
let mut decoded_prefix = String::new();
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut finish_reason = "length".to_string();
|
||||
|
||||
worker
|
||||
.clear_kv_cache(handle)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?;
|
||||
|
||||
let logits_vec = worker
|
||||
.forward_logits(handle, prompt_tokens, 0)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("prefill forward: {e}"))?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
let mut next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
?health,
|
||||
"chat_completion (stream/worker): prefill sample failed; logits unhealthy"
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = "stop".into();
|
||||
} else {
|
||||
all_tokens.push(next_token);
|
||||
if !emit_chunk(
|
||||
&all_tokens,
|
||||
&mut decoded_prefix,
|
||||
&tokenizer,
|
||||
&tx,
|
||||
&id,
|
||||
created,
|
||||
&model_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
return Ok(finish_reason); // Client gone — clean stream end.
|
||||
}
|
||||
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
let logits_vec = worker
|
||||
.forward_logits(handle, vec![next_token], prompt_len + index)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
step = index,
|
||||
?health,
|
||||
"chat_completion (stream/worker): decode sample failed; logits unhealthy"
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = "stop".into();
|
||||
break;
|
||||
}
|
||||
all_tokens.push(next_token);
|
||||
if !emit_chunk(
|
||||
&all_tokens,
|
||||
&mut decoded_prefix,
|
||||
&tokenizer,
|
||||
&tx,
|
||||
&id,
|
||||
created,
|
||||
&model_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
return Ok(finish_reason);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final chunk carrying finish_reason. Matches the run_inference_streaming
|
||||
// shape so the SSE consumer sees an identical termination sequence.
|
||||
let final_chunk = ChatCompletionChunk {
|
||||
id: id.clone(),
|
||||
object: "chat.completion.chunk".into(),
|
||||
created,
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChunkChoice {
|
||||
index: 0,
|
||||
delta: serde_json::Value::Object(Default::default()),
|
||||
finish_reason: Some(finish_reason.clone()),
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
}],
|
||||
usage: None,
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
};
|
||||
let _ = tx.send(final_chunk).await;
|
||||
|
||||
Ok(finish_reason)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_inference(
|
||||
arch: &mut ModelArch,
|
||||
|
||||
Reference in New Issue
Block a user