fix(neuron): trim cudarc mempool after clear_kv_cache to release VRAM
cudarc's stream-ordered memory pool retains freed blocks (cuMemFreeAsync returns memory to the device's default mempool, not to the OS), so mem_get_info under-reports free VRAM between requests. With Qwen/Qwen3.6-27B TP=2, the second consecutive chat completion saw ~4.5 GB of "missing" free VRAM and either OOMed or tripped cuBLAS into CUBLAS_STATUS_INTERNAL_ERROR depending on quant. Add a cuda-gated trim_device_pool helper that, after each successful clear_kv_cache, synchronizes the context and calls cuMemPoolTrimTo(pool, 0) against the device's default mempool. Failures (no async-alloc support, transient driver errors) are non-fatal and log at debug. The before/after free-VRAM delta is logged so an operator can correlate the trim with the next request's prefill VRAM. ConcatKvCache::reset() in candle-nn 0.10.2 already drops its tensors correctly; the leak was strictly at the cudarc pool layer. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -144,6 +144,9 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
||||
Some(arch) => arch.clear_kv_cache(),
|
||||
None => Err(anyhow::anyhow!("ClearKv: no model for handle {}", handle.0)),
|
||||
};
|
||||
if result.is_ok() {
|
||||
trim_device_pool(&state);
|
||||
}
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
Job::ForwardLogits {
|
||||
@@ -214,6 +217,9 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
||||
handle.0
|
||||
)),
|
||||
};
|
||||
if result.is_ok() {
|
||||
trim_device_pool(&state);
|
||||
}
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
@@ -338,6 +344,75 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
|
||||
Ok((0, 0))
|
||||
}
|
||||
|
||||
/// Force cudarc's stream-ordered memory pool to release every block it
|
||||
/// is holding back to the system. After `ConcatKvCache::reset()` drops
|
||||
/// its tensors, the underlying `CudaSlice::drop` calls `cuMemFreeAsync`,
|
||||
/// which returns the blocks to the device's default mempool but not to
|
||||
/// the OS — `mem_get_info` still reports them as used. The next
|
||||
/// request's prefill then sees a falsely-small free pool and either
|
||||
/// OOMs or trips cuBLAS into `CUBLAS_STATUS_INTERNAL_ERROR`.
|
||||
///
|
||||
/// Calling `cuMemPoolTrimTo(pool, 0)` after each `clear_kv_cache`
|
||||
/// returns those blocks. We synchronize first so any pending
|
||||
/// `cuMemFreeAsync` operations have settled. Failures are non-fatal:
|
||||
/// the pool may not exist on legacy drivers, or a transient driver
|
||||
/// error may prevent the trim — neither breaks correctness, the next
|
||||
/// request just sees a less-recovered free pool.
|
||||
#[cfg(feature = "cuda")]
|
||||
fn trim_device_pool(state: &DeviceWorkerState) {
|
||||
use candle_core::cuda::cudarc::driver::result::{device, mem_pool};
|
||||
let Some(ctx) = state.ctx.as_ref() else {
|
||||
return;
|
||||
};
|
||||
let (before_free, _) = match query_vram(state) {
|
||||
Ok(v) => v,
|
||||
Err(_) => (0, 0),
|
||||
};
|
||||
if let Err(e) = ctx.synchronize() {
|
||||
tracing::debug!(
|
||||
device_index = state.device_index,
|
||||
error = ?e,
|
||||
"trim_device_pool: synchronize failed; skipping trim"
|
||||
);
|
||||
return;
|
||||
}
|
||||
let dev = ctx.cu_device();
|
||||
let pool = match unsafe { device::get_default_mem_pool(dev) } {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
device_index = state.device_index,
|
||||
error = ?e,
|
||||
"trim_device_pool: get_default_mem_pool failed"
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
if let Err(e) = unsafe { mem_pool::trim_to(pool, 0) } {
|
||||
tracing::debug!(
|
||||
device_index = state.device_index,
|
||||
error = ?e,
|
||||
"trim_device_pool: cuMemPoolTrimTo failed"
|
||||
);
|
||||
return;
|
||||
}
|
||||
let (after_free, _) = match query_vram(state) {
|
||||
Ok(v) => v,
|
||||
Err(_) => (0, 0),
|
||||
};
|
||||
let freed_mb = after_free.saturating_sub(before_free);
|
||||
tracing::debug!(
|
||||
device_index = state.device_index,
|
||||
before_free_mb = before_free,
|
||||
after_free_mb = after_free,
|
||||
freed_mb,
|
||||
"trim_device_pool: trimmed pool"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn trim_device_pool(_state: &DeviceWorkerState) {}
|
||||
|
||||
/// Insert a freshly-built `ModelArch` into the slab and mint a fresh
|
||||
/// `ArchHandle`. Used by both `LoadGguf` and `LoadDense` dispatch
|
||||
/// handlers — they differ only in *how* the arch is built; the
|
||||
|
||||
Reference in New Issue
Block a user