fix(neuron): serialise single-GPU inference per loaded model
Two concurrent chat_completion requests against the same single-GPU model could interleave their `clear_kv_cache → forward(chunk0) → forward(chunk1) → ...` sequences. The device-worker channel serialises individual jobs but not the sequence boundary, so the cache could end up holding tokens from one request while another's mask was sized for its own prompt — producing a shape mismatch mid-prefill. Observed on benjy 2026-05-27 18:41:05: agent-zero's `memorize memories` and `memorize solutions` extensions fired 4ms apart against Qwen/Qwen3-8B (a0's utility model). Both prefilled into the same KV cache, and request a08b4a's chunk 0 forward produced scores of shape [1, 32, 512, 1024] against a mask of [1, 1, 512, 512] — broadcast_add failed, both requests bubbled the error up, both flipped the model to poisoned. Add `LoadedModel.inference_lock: tokio::sync::Mutex<()>`, mirroring the TpLoadedModel.pool lock that the TP path already held. Acquire it at the start of `chat_completion` and inside the spawned task of `chat_completion_stream` (so the role chunk goes out immediately and only the inference work queues behind the lock). The CPU branch uses `blocking_lock` from inside spawn_blocking; the CUDA branch uses async `.lock().await` inside tokio::spawn. Throughput impact: zero. The GPU was already serialised at the device-worker channel — multiple requests just produced corrupt KV cache state instead of clean serial throughput. The lock makes the existing serialisation honest. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -134,6 +134,19 @@ pub struct LoadedModel {
|
|||||||
/// the worker; in that case [`Self::arch`] is `None`. The two
|
/// the worker; in that case [`Self::arch`] is `None`. The two
|
||||||
/// fields are mutually exclusive.
|
/// fields are mutually exclusive.
|
||||||
pub arch_handle: Option<super::device_worker::ArchHandle>,
|
pub arch_handle: Option<super::device_worker::ArchHandle>,
|
||||||
|
/// Serialises chat-completion requests against this model. Held
|
||||||
|
/// from the start of `clear_kv_cache` through the last decode
|
||||||
|
/// step, so concurrent requests can't interleave their KV-cache
|
||||||
|
/// mutations. Without this, two requests' chunked-prefill
|
||||||
|
/// `clear → forward(chunk0) → forward(chunk1) → ...` sequences
|
||||||
|
/// could end up sharing a cache between them — the device worker
|
||||||
|
/// channel serialises individual jobs, but not the sequence
|
||||||
|
/// boundary. Observed on benjy 2026-05-27 18:41 when agent-zero's
|
||||||
|
/// memorize extensions fired in parallel and produced a
|
||||||
|
/// shape-mismatch failure mid-prefill. Mirrors TpLoadedModel.pool
|
||||||
|
/// for the TP path (which already had this invariant by accident
|
||||||
|
/// because the pool lock covered the same window).
|
||||||
|
pub inference_lock: tokio::sync::Mutex<()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LoadedModel {
|
impl LoadedModel {
|
||||||
@@ -1314,6 +1327,13 @@ impl CandleHarness {
|
|||||||
return Err(poisoned_error(&model_id));
|
return Err(poisoned_error(&model_id));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Serialise concurrent requests against this model. Holds for
|
||||||
|
// the duration of clear_kv_cache → prefill → decode so two
|
||||||
|
// requests' chunked-prefill sequences can't interleave on the
|
||||||
|
// shared KV cache (see `LoadedModel.inference_lock` for the
|
||||||
|
// observed failure mode).
|
||||||
|
let _inference_guard = loaded.inference_lock.lock().await;
|
||||||
|
|
||||||
let result = async {
|
let result = async {
|
||||||
let prompt = format_qwen3_prompt(&request.messages);
|
let prompt = format_qwen3_prompt(&request.messages);
|
||||||
|
|
||||||
@@ -1624,13 +1644,21 @@ impl CandleHarness {
|
|||||||
|
|
||||||
// Routing parallel to the non-streaming chat_completion: CUDA
|
// Routing parallel to the non-streaming chat_completion: CUDA
|
||||||
// goes through the worker (async task), CPU keeps the
|
// goes through the worker (async task), CPU keeps the
|
||||||
// spawn_blocking + Arc<Mutex<ModelArch>> path.
|
// spawn_blocking + Arc<Mutex<ModelArch>> path. Both branches
|
||||||
|
// acquire `loaded.inference_lock` from inside the spawned
|
||||||
|
// task so concurrent stream requests against the same model
|
||||||
|
// serialise at the request boundary (preventing the
|
||||||
|
// chunked-prefill KV-cache interleave failure mode). The
|
||||||
|
// role chunk was already sent above, so the client sees
|
||||||
|
// immediate "stream open" feedback even when this request
|
||||||
|
// queues behind another for the lock.
|
||||||
if let (Some(worker), Some(handle)) = (loaded.worker.clone(), loaded.arch_handle) {
|
if let (Some(worker), Some(handle)) = (loaded.worker.clone(), loaded.arch_handle) {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
{
|
{
|
||||||
let prompt_tokens = prompt_tokens.clone();
|
let prompt_tokens = prompt_tokens.clone();
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
async move {
|
async move {
|
||||||
|
let _inference_guard = loaded_for_task.inference_lock.lock().await;
|
||||||
match stream_inference_via_worker(
|
match stream_inference_via_worker(
|
||||||
worker,
|
worker,
|
||||||
handle,
|
handle,
|
||||||
@@ -1675,6 +1703,10 @@ impl CandleHarness {
|
|||||||
} else if let Some(arch_arc) = loaded.arch.clone() {
|
} else if let Some(arch_arc) = loaded.arch.clone() {
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
let _g = span_for_task.enter();
|
let _g = span_for_task.enter();
|
||||||
|
// `blocking_lock` is safe here: spawn_blocking runs on
|
||||||
|
// a dedicated thread, not on the async runtime, so
|
||||||
|
// there's no executor to stall.
|
||||||
|
let _inference_guard = loaded_for_task.inference_lock.blocking_lock();
|
||||||
let mut guard = arch_arc.blocking_lock();
|
let mut guard = arch_arc.blocking_lock();
|
||||||
match run_inference_streaming(
|
match run_inference_streaming(
|
||||||
&mut guard,
|
&mut guard,
|
||||||
@@ -1831,6 +1863,7 @@ impl Harness for CandleHarness {
|
|||||||
poisoned: AtomicBool::new(false),
|
poisoned: AtomicBool::new(false),
|
||||||
worker,
|
worker,
|
||||||
arch_handle,
|
arch_handle,
|
||||||
|
inference_lock: tokio::sync::Mutex::new(()),
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut models = self.models.write().await;
|
let mut models = self.models.write().await;
|
||||||
|
|||||||
Reference in New Issue
Block a user