feat(tp): cancellation-safe inference + structured tracing
All checks were successful
CI / Format (push) Successful in 30s
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Clippy (push) Successful in 2m14s
build-prerelease / Build neuron-blackwell (push) Successful in 3m44s
build-prerelease / Build cortex binary (push) Successful in 4m13s
CI / Test (push) Successful in 4m38s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 5m13s
build-prerelease / Build neuron-ada (push) Successful in 4m47s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m41s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s

Two changes addressing operator visibility into TP inference + the
HTTP-cancellation poisoning chain:

1. `chat_completion_tp` now runs its body inside `tokio::spawn`. When
   the HTTP client disconnects (curl --max-time, browser nav, etc.)
   the future returned from `chat_completion_tp` gets dropped, but
   the spawned task keeps running to completion — finishing every
   `pool.generate_step` / `pool.clear_kv_cache` to drain the worker
   pipes. The next inference request then finds a clean pool.

   Previously: dropped future left workers still processing the
   in-flight request, the next call's `ClearKvCache` recv would
   read the stale `GenerateStepOk` from the abandoned step ("rank N
   expected KvCacheCleared, got GenerateStepOk"). The drain-on-
   leader-error fix from d1a4aad covered Rust-side leader failures
   but not HTTP-layer cancellation, which is what we actually hit
   on the user's Qwen3.6 test.

2. Tracing throughout the TP path so journalctl shows where an
   inference spends its time without needing to surface harness
   internals via the HTTP error body:

   - `chat_completion_tp_inner` (now a free fn so it can run inside
     spawn): `info` at request start (prompt_len, max_new, temp,
     top_p, eos_id), `info` per major phase (prefill complete with
     elapsed_ms, decode complete with elapsed_ms + token count),
     `info` at completion (total_ms, finish_reason). `debug` for
     pool-lock acquisition + kv-cache clear timing. `trace` per
     decode step (next_token, step_ms).

   - `WorkerPool::generate_step` (leader side): `debug` at fan-out,
     `debug` after leader forward returns with elapsed_ms + ok flag,
     `debug` after drain with errors count + total_ms.

   - `WorkerPool::clear_kv_cache`: matching `debug` at fan-out + drain.

   - `worker::handle_generate_step`: `debug` at forward start + done
     with elapsed_ms, `warn` on forward failure with the full error.

The default log filter is already `info,neuron=debug` so the
operator gets every `info` and `debug` line by default; `trace`
needs RUST_LOG=trace for per-step decode timing.

Stage 7c-ii crash-detection is still future work; this is the
minimum that makes the "where did the 120s go" question answerable
from the logs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 08:22:00 +03:00
parent d1a4aad91d
commit 70eb6af42b
3 changed files with 254 additions and 114 deletions

View File

@@ -604,6 +604,14 @@ impl WorkerPool {
tokens: Vec<u32>,
offset: usize,
) -> Result<candle_core::Tensor> {
let step_start = std::time::Instant::now();
let tokens_len = tokens.len();
tracing::debug!(
model = %model_id,
tokens = tokens_len,
offset,
"WorkerPool::generate_step: fan-out"
);
// 1. Fan-out to workers.
for w in &mut self.workers {
w.send_only(&WorkerRequest::GenerateStep {
@@ -617,6 +625,7 @@ impl WorkerPool {
// 2. Leader's forward in spawn_blocking. The AllReduce CustomOps
// inside the row-parallel layers block until every worker's
// forward issues the matching collective.
let leader_start = std::time::Instant::now();
let leader_result = tokio::task::spawn_blocking(move || -> Result<candle_core::Tensor> {
let mut model = leader_model.blocking_lock();
let device = model.device().clone();
@@ -628,6 +637,14 @@ impl WorkerPool {
})
.await
.context("leader forward task panicked");
let leader_ok = matches!(leader_result, Ok(Ok(_)));
tracing::debug!(
model = %model_id,
tokens = tokens_len,
leader_ms = leader_start.elapsed().as_millis(),
leader_ok,
"WorkerPool::generate_step: leader forward returned"
);
// 3. ALWAYS drain worker responses, regardless of whether the
// leader succeeded. Skipping this on the leader's error
@@ -635,12 +652,20 @@ impl WorkerPool {
// pipes that poison the NEXT request's recv (was seeing
// "ClearKvCache: expected KvCacheCleared, got
// GenerateStepOk" the call after any forward-time failure).
let drain_start = std::time::Instant::now();
let worker_errors = drain_workers(&mut self.workers, |r| match r {
WorkerResponse::GenerateStepOk => Ok(()),
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
other => Err(format!("expected GenerateStepOk, got {other:?}")),
})
.await;
tracing::debug!(
model = %model_id,
drain_ms = drain_start.elapsed().as_millis(),
errors = worker_errors.len(),
total_ms = step_start.elapsed().as_millis(),
"WorkerPool::generate_step: workers drained"
);
combine_leader_workers(leader_result, worker_errors, "GenerateStep")
}
@@ -653,6 +678,8 @@ impl WorkerPool {
model_id: &str,
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
) -> Result<()> {
let start = std::time::Instant::now();
tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out");
for w in &mut self.workers {
w.send_only(&WorkerRequest::ClearKvCache {
model_id: model_id.to_string(),
@@ -674,6 +701,12 @@ impl WorkerPool {
other => Err(format!("expected KvCacheCleared, got {other:?}")),
})
.await;
tracing::debug!(
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
errors = worker_errors.len(),
"WorkerPool::clear_kv_cache: workers drained"
);
if !worker_errors.is_empty() {
anyhow::bail!("ClearKvCache: {}", worker_errors.join("; "));
}