feat(tp): cancellation-safe inference + structured tracing
All checks were successful
CI / Format (push) Successful in 30s
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Clippy (push) Successful in 2m14s
build-prerelease / Build neuron-blackwell (push) Successful in 3m44s
build-prerelease / Build cortex binary (push) Successful in 4m13s
CI / Test (push) Successful in 4m38s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 5m13s
build-prerelease / Build neuron-ada (push) Successful in 4m47s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m41s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s

Two changes addressing operator visibility into TP inference + the
HTTP-cancellation poisoning chain:

1. `chat_completion_tp` now runs its body inside `tokio::spawn`. When
   the HTTP client disconnects (curl --max-time, browser nav, etc.)
   the future returned from `chat_completion_tp` gets dropped, but
   the spawned task keeps running to completion — finishing every
   `pool.generate_step` / `pool.clear_kv_cache` to drain the worker
   pipes. The next inference request then finds a clean pool.

   Previously: dropped future left workers still processing the
   in-flight request, the next call's `ClearKvCache` recv would
   read the stale `GenerateStepOk` from the abandoned step ("rank N
   expected KvCacheCleared, got GenerateStepOk"). The drain-on-
   leader-error fix from d1a4aad covered Rust-side leader failures
   but not HTTP-layer cancellation, which is what we actually hit
   on the user's Qwen3.6 test.

2. Tracing throughout the TP path so journalctl shows where an
   inference spends its time without needing to surface harness
   internals via the HTTP error body:

   - `chat_completion_tp_inner` (now a free fn so it can run inside
     spawn): `info` at request start (prompt_len, max_new, temp,
     top_p, eos_id), `info` per major phase (prefill complete with
     elapsed_ms, decode complete with elapsed_ms + token count),
     `info` at completion (total_ms, finish_reason). `debug` for
     pool-lock acquisition + kv-cache clear timing. `trace` per
     decode step (next_token, step_ms).

   - `WorkerPool::generate_step` (leader side): `debug` at fan-out,
     `debug` after leader forward returns with elapsed_ms + ok flag,
     `debug` after drain with errors count + total_ms.

   - `WorkerPool::clear_kv_cache`: matching `debug` at fan-out + drain.

   - `worker::handle_generate_step`: `debug` at forward start + done
     with elapsed_ms, `warn` on forward failure with the full error.

The default log filter is already `info,neuron=debug` so the
operator gets every `info` and `debug` line by default; `trace`
needs RUST_LOG=trace for per-step decode timing.

Stage 7c-ii crash-detection is still future work; this is the
minimum that makes the "where did the 120s go" question answerable
from the logs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 08:22:00 +03:00
parent d1a4aad91d
commit 70eb6af42b
3 changed files with 254 additions and 114 deletions

View File

@@ -1167,127 +1167,32 @@ impl CandleHarness {
Ok(())
}
/// Non-streaming chat completion against a TP model. Pattern mirrors
/// the single-GPU `run_inference`: tokenize, prefill, sample, decode
/// loop, detokenize. Each forward step fans out to every rank via
/// the WorkerPool and uses the leader's last-position logits to
/// sample.
/// Non-streaming chat completion against a TP model.
///
/// The actual work runs inside a `tokio::spawn`'d task so the HTTP
/// client disconnecting (curl timeout, browser nav-away, etc.)
/// can't cancel the future mid-`pool.generate_step` and leave the
/// worker subprocesses mid-RPC. If the spawned task is dropped,
/// it still runs to completion and finishes draining the pool —
/// the next inference request finds a clean pool. The HTTP layer
/// just gives up on the response.
///
/// Every step also emits `info`/`debug` tracing so journalctl
/// shows where time went without needing to surface internals in
/// the HTTP error response.
#[cfg(feature = "cuda")]
async fn chat_completion_tp(
&self,
tp: Arc<TpLoadedModel>,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, InferenceError> {
let prompt = format_qwen3_prompt(&request.messages);
let encoding = tp
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let prompt_len = prompt_tokens.len();
let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(512) as usize;
let seed = unix_subsec_nanos();
let eos_id = tp
.tokenizer
.token_to_id("<|im_end|>")
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
let model_id = request.model.clone();
// Acquire the pool lock for the duration of the request. The
// leader_model's own Mutex is acquired step-by-step inside
// pool.generate_step (so spawn_blocking can grab it without
// holding the pool lock across the blocking_lock call).
let mut pool = tp.pool.lock().await;
let leader_arc = tp.leader_model.clone();
// Reset every rank's KV cache so this request doesn't attend
// over the previous request's tokens.
pool.clear_kv_cache(&model_id, leader_arc.clone())
.await
.map_err(InferenceError::Other)?;
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut generated: Vec<u32> = Vec::new();
let mut finish_reason = "length".to_string();
// Prefill: every rank embeds the whole prompt, offset = 0.
let logits = pool
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
.await
.map_err(InferenceError::Other)?;
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
.map_err(InferenceError::Other)?;
if Some(next_token) == eos_id {
finish_reason = "stop".into();
} else {
generated.push(next_token);
for index in 0..max_new.saturating_sub(1) {
let logits = pool
.generate_step(
&model_id,
leader_arc.clone(),
vec![next_token],
prompt_len + index,
)
.await
.map_err(InferenceError::Other)?;
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
.map_err(InferenceError::Other)?;
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
}
generated.push(next_token);
}
let handle = tokio::spawn(chat_completion_tp_inner(tp, request));
match handle.await {
Ok(result) => result,
Err(join_err) => Err(InferenceError::Other(anyhow::anyhow!(
"TP inference task panicked or was cancelled: {join_err}"
))),
}
drop(pool);
let completion_text = tp
.tokenizer
.decode(&generated, true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
let usage = Usage {
prompt_tokens: prompt_len as u64,
completion_tokens: generated.len() as u64,
total_tokens: (prompt_len + generated.len()) as u64,
};
Ok(ChatCompletionResponse {
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
object: "chat.completion".into(),
created: unix_now_secs(),
model: model_id,
choices: vec![ChatCompletionChoice {
index: 0,
message: ChatMessage {
role: "assistant".into(),
content: MessageContent::Text(completion_text),
extra: serde_json::Value::Object(Default::default()),
},
finish_reason: Some(finish_reason),
extra: serde_json::Value::Object(Default::default()),
}],
usage: Some(usage),
extra: serde_json::Value::Object(Default::default()),
})
}
/// Streaming counterpart to `chat_completion_tp`. Same per-step
@@ -1499,6 +1404,187 @@ impl CandleHarness {
}
}
/// Body of the TP non-streaming chat completion, hoisted out of
/// `CandleHarness::chat_completion_tp` so it can run inside
/// `tokio::spawn` (which requires a `'static` future) and survive
/// HTTP-layer cancellation.
///
/// Tracing strategy: `info` for request entry/exit so journalctl
/// always shows when an inference started and finished; `debug` for
/// per-step timing so an operator running with `RUST_LOG=debug` sees
/// where the request actually spends its time without needing to
/// instrument the model code.
#[cfg(feature = "cuda")]
async fn chat_completion_tp_inner(
tp: Arc<TpLoadedModel>,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, InferenceError> {
let req_start = std::time::Instant::now();
let model_id = request.model.clone();
let prompt = format_qwen3_prompt(&request.messages);
let encoding = tp
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
let prompt_len = prompt_tokens.len();
let temperature = request.temperature.unwrap_or(0.7);
let top_p = request.top_p;
let max_new = request.max_tokens.unwrap_or(512) as usize;
let seed = unix_subsec_nanos();
let eos_id = tp
.tokenizer
.token_to_id("<|im_end|>")
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
tracing::info!(
model = %model_id,
prompt_len,
max_new,
temperature,
?top_p,
?eos_id,
"TP chat_completion: starting"
);
// Acquire the pool lock for the duration of the request. The
// leader_model's own Mutex is acquired step-by-step inside
// pool.generate_step (so spawn_blocking can grab it without
// holding the pool lock across the blocking_lock call).
let lock_start = std::time::Instant::now();
let mut pool = tp.pool.lock().await;
tracing::debug!(
model = %model_id,
elapsed_ms = lock_start.elapsed().as_millis(),
"TP chat_completion: pool lock acquired"
);
let leader_arc = tp.leader_model.clone();
// Reset every rank's KV cache so this request doesn't attend
// over the previous request's tokens.
let clear_start = std::time::Instant::now();
pool.clear_kv_cache(&model_id, leader_arc.clone())
.await
.map_err(InferenceError::Other)?;
tracing::debug!(
model = %model_id,
elapsed_ms = clear_start.elapsed().as_millis(),
"TP chat_completion: kv cache cleared"
);
let mut logits_processor = {
let sampling = if temperature <= 0.0 {
Sampling::ArgMax
} else {
match top_p {
Some(p) => Sampling::TopP { p, temperature },
None => Sampling::All { temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let mut generated: Vec<u32> = Vec::new();
let mut finish_reason = "length".to_string();
// Prefill: every rank embeds the whole prompt, offset = 0.
let prefill_start = std::time::Instant::now();
let logits = pool
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
.await
.map_err(InferenceError::Other)?;
tracing::info!(
model = %model_id,
prompt_len,
elapsed_ms = prefill_start.elapsed().as_millis(),
"TP chat_completion: prefill complete"
);
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
.map_err(InferenceError::Other)?;
if Some(next_token) == eos_id {
finish_reason = "stop".into();
} else {
generated.push(next_token);
let decode_start = std::time::Instant::now();
for index in 0..max_new.saturating_sub(1) {
let step_start = std::time::Instant::now();
let logits = pool
.generate_step(
&model_id,
leader_arc.clone(),
vec![next_token],
prompt_len + index,
)
.await
.map_err(InferenceError::Other)?;
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
.map_err(InferenceError::Other)?;
tracing::trace!(
model = %model_id,
step = index,
next_token,
step_ms = step_start.elapsed().as_millis(),
"TP chat_completion: decode step"
);
if Some(next_token) == eos_id {
finish_reason = "stop".into();
break;
}
generated.push(next_token);
}
tracing::info!(
model = %model_id,
generated = generated.len(),
elapsed_ms = decode_start.elapsed().as_millis(),
"TP chat_completion: decode complete"
);
}
drop(pool);
let completion_text = tp
.tokenizer
.decode(&generated, true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
let usage = Usage {
prompt_tokens: prompt_len as u64,
completion_tokens: generated.len() as u64,
total_tokens: (prompt_len + generated.len()) as u64,
};
tracing::info!(
model = %model_id,
prompt_tokens = prompt_len,
completion_tokens = generated.len(),
finish_reason = %finish_reason,
total_ms = req_start.elapsed().as_millis(),
"TP chat_completion: done"
);
Ok(ChatCompletionResponse {
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
object: "chat.completion".into(),
created: unix_now_secs(),
model: model_id,
choices: vec![ChatCompletionChoice {
index: 0,
message: ChatMessage {
role: "assistant".into(),
content: MessageContent::Text(completion_text),
extra: serde_json::Value::Object(Default::default()),
},
finish_reason: Some(finish_reason),
extra: serde_json::Value::Object(Default::default()),
}],
usage: Some(usage),
extra: serde_json::Value::Object(Default::default()),
})
}
/// Decode the cumulative token list, emit the delta (substring appended
/// since the last chunk) as a `chat.completion.chunk`. Returns `false`
/// if the receiver has hung up — the caller should bail.

View File

@@ -604,6 +604,14 @@ impl WorkerPool {
tokens: Vec<u32>,
offset: usize,
) -> Result<candle_core::Tensor> {
let step_start = std::time::Instant::now();
let tokens_len = tokens.len();
tracing::debug!(
model = %model_id,
tokens = tokens_len,
offset,
"WorkerPool::generate_step: fan-out"
);
// 1. Fan-out to workers.
for w in &mut self.workers {
w.send_only(&WorkerRequest::GenerateStep {
@@ -617,6 +625,7 @@ impl WorkerPool {
// 2. Leader's forward in spawn_blocking. The AllReduce CustomOps
// inside the row-parallel layers block until every worker's
// forward issues the matching collective.
let leader_start = std::time::Instant::now();
let leader_result = tokio::task::spawn_blocking(move || -> Result<candle_core::Tensor> {
let mut model = leader_model.blocking_lock();
let device = model.device().clone();
@@ -628,6 +637,14 @@ impl WorkerPool {
})
.await
.context("leader forward task panicked");
let leader_ok = matches!(leader_result, Ok(Ok(_)));
tracing::debug!(
model = %model_id,
tokens = tokens_len,
leader_ms = leader_start.elapsed().as_millis(),
leader_ok,
"WorkerPool::generate_step: leader forward returned"
);
// 3. ALWAYS drain worker responses, regardless of whether the
// leader succeeded. Skipping this on the leader's error
@@ -635,12 +652,20 @@ impl WorkerPool {
// pipes that poison the NEXT request's recv (was seeing
// "ClearKvCache: expected KvCacheCleared, got
// GenerateStepOk" the call after any forward-time failure).
let drain_start = std::time::Instant::now();
let worker_errors = drain_workers(&mut self.workers, |r| match r {
WorkerResponse::GenerateStepOk => Ok(()),
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
other => Err(format!("expected GenerateStepOk, got {other:?}")),
})
.await;
tracing::debug!(
model = %model_id,
drain_ms = drain_start.elapsed().as_millis(),
errors = worker_errors.len(),
total_ms = step_start.elapsed().as_millis(),
"WorkerPool::generate_step: workers drained"
);
combine_leader_workers(leader_result, worker_errors, "GenerateStep")
}
@@ -653,6 +678,8 @@ impl WorkerPool {
model_id: &str,
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
) -> Result<()> {
let start = std::time::Instant::now();
tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out");
for w in &mut self.workers {
w.send_only(&WorkerRequest::ClearKvCache {
model_id: model_id.to_string(),
@@ -674,6 +701,12 @@ impl WorkerPool {
other => Err(format!("expected KvCacheCleared, got {other:?}")),
})
.await;
tracing::debug!(
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
errors = worker_errors.len(),
"WorkerPool::clear_kv_cache: workers drained"
);
if !worker_errors.is_empty() {
anyhow::bail!("ClearKvCache: {}", worker_errors.join("; "));
}

View File

@@ -344,15 +344,36 @@ impl WorkerState {
};
}
};
let start = std::time::Instant::now();
tracing::debug!(
rank = self.config.rank,
model = %model_id,
tokens = tokens.len(),
offset,
"worker GenerateStep: forward starting"
);
// Drop the resulting logits — the leader uses its own copy from
// rank 0. The forward's value here is the NCCL collectives it
// issues, which let the leader's rank-0 forward make progress.
if let Err(e) = model.forward(&input, offset) {
tracing::warn!(
rank = self.config.rank,
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
error = %e,
"worker GenerateStep: forward failed"
);
return WorkerResponse::Error {
kind: "forward_failed".into(),
message: format!("TP forward: {e}"),
};
}
tracing::debug!(
rank = self.config.rank,
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
"worker GenerateStep: forward done"
);
WorkerResponse::GenerateStepOk
}