fix(neuron): trim cudarc mempool after clear_kv_cache to release VRAM
cudarc's stream-ordered memory pool retains freed blocks (cuMemFreeAsync returns memory to the device's default mempool, not to the OS), so mem_get_info under-reports free VRAM between requests. With Qwen/Qwen3.6-27B TP=2, the second consecutive chat completion saw ~4.5 GB of "missing" free VRAM and either OOMed or tripped cuBLAS into CUBLAS_STATUS_INTERNAL_ERROR depending on quant. Add a cuda-gated trim_device_pool helper that, after each successful clear_kv_cache, synchronizes the context and calls cuMemPoolTrimTo(pool, 0) against the device's default mempool. Failures (no async-alloc support, transient driver errors) are non-fatal and log at debug. The before/after free-VRAM delta is logged so an operator can correlate the trim with the next request's prefill VRAM. ConcatKvCache::reset() in candle-nn 0.10.2 already drops its tensors correctly; the leak was strictly at the cudarc pool layer. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -144,6 +144,9 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
|||||||
Some(arch) => arch.clear_kv_cache(),
|
Some(arch) => arch.clear_kv_cache(),
|
||||||
None => Err(anyhow::anyhow!("ClearKv: no model for handle {}", handle.0)),
|
None => Err(anyhow::anyhow!("ClearKv: no model for handle {}", handle.0)),
|
||||||
};
|
};
|
||||||
|
if result.is_ok() {
|
||||||
|
trim_device_pool(&state);
|
||||||
|
}
|
||||||
let _ = reply.send(result);
|
let _ = reply.send(result);
|
||||||
}
|
}
|
||||||
Job::ForwardLogits {
|
Job::ForwardLogits {
|
||||||
@@ -214,6 +217,9 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
|||||||
handle.0
|
handle.0
|
||||||
)),
|
)),
|
||||||
};
|
};
|
||||||
|
if result.is_ok() {
|
||||||
|
trim_device_pool(&state);
|
||||||
|
}
|
||||||
let _ = reply.send(result);
|
let _ = reply.send(result);
|
||||||
}
|
}
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
@@ -338,6 +344,75 @@ fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
|
|||||||
Ok((0, 0))
|
Ok((0, 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Force cudarc's stream-ordered memory pool to release every block it
|
||||||
|
/// is holding back to the system. After `ConcatKvCache::reset()` drops
|
||||||
|
/// its tensors, the underlying `CudaSlice::drop` calls `cuMemFreeAsync`,
|
||||||
|
/// which returns the blocks to the device's default mempool but not to
|
||||||
|
/// the OS — `mem_get_info` still reports them as used. The next
|
||||||
|
/// request's prefill then sees a falsely-small free pool and either
|
||||||
|
/// OOMs or trips cuBLAS into `CUBLAS_STATUS_INTERNAL_ERROR`.
|
||||||
|
///
|
||||||
|
/// Calling `cuMemPoolTrimTo(pool, 0)` after each `clear_kv_cache`
|
||||||
|
/// returns those blocks. We synchronize first so any pending
|
||||||
|
/// `cuMemFreeAsync` operations have settled. Failures are non-fatal:
|
||||||
|
/// the pool may not exist on legacy drivers, or a transient driver
|
||||||
|
/// error may prevent the trim — neither breaks correctness, the next
|
||||||
|
/// request just sees a less-recovered free pool.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn trim_device_pool(state: &DeviceWorkerState) {
|
||||||
|
use candle_core::cuda::cudarc::driver::result::{device, mem_pool};
|
||||||
|
let Some(ctx) = state.ctx.as_ref() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let (before_free, _) = match query_vram(state) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => (0, 0),
|
||||||
|
};
|
||||||
|
if let Err(e) = ctx.synchronize() {
|
||||||
|
tracing::debug!(
|
||||||
|
device_index = state.device_index,
|
||||||
|
error = ?e,
|
||||||
|
"trim_device_pool: synchronize failed; skipping trim"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let dev = ctx.cu_device();
|
||||||
|
let pool = match unsafe { device::get_default_mem_pool(dev) } {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(
|
||||||
|
device_index = state.device_index,
|
||||||
|
error = ?e,
|
||||||
|
"trim_device_pool: get_default_mem_pool failed"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Err(e) = unsafe { mem_pool::trim_to(pool, 0) } {
|
||||||
|
tracing::debug!(
|
||||||
|
device_index = state.device_index,
|
||||||
|
error = ?e,
|
||||||
|
"trim_device_pool: cuMemPoolTrimTo failed"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let (after_free, _) = match query_vram(state) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => (0, 0),
|
||||||
|
};
|
||||||
|
let freed_mb = after_free.saturating_sub(before_free);
|
||||||
|
tracing::debug!(
|
||||||
|
device_index = state.device_index,
|
||||||
|
before_free_mb = before_free,
|
||||||
|
after_free_mb = after_free,
|
||||||
|
freed_mb,
|
||||||
|
"trim_device_pool: trimmed pool"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn trim_device_pool(_state: &DeviceWorkerState) {}
|
||||||
|
|
||||||
/// Insert a freshly-built `ModelArch` into the slab and mint a fresh
|
/// Insert a freshly-built `ModelArch` into the slab and mint a fresh
|
||||||
/// `ArchHandle`. Used by both `LoadGguf` and `LoadDense` dispatch
|
/// `ArchHandle`. Used by both `LoadGguf` and `LoadDense` dispatch
|
||||||
/// handlers — they differ only in *how* the arch is built; the
|
/// handlers — they differ only in *how* the arch is built; the
|
||||||
|
|||||||
Reference in New Issue
Block a user