feat(tp): cancellation-safe inference + structured tracing
All checks were successful
CI / Format (push) Successful in 30s
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Clippy (push) Successful in 2m14s
build-prerelease / Build neuron-blackwell (push) Successful in 3m44s
build-prerelease / Build cortex binary (push) Successful in 4m13s
CI / Test (push) Successful in 4m38s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 5m13s
build-prerelease / Build neuron-ada (push) Successful in 4m47s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m41s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
All checks were successful
CI / Format (push) Successful in 30s
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Clippy (push) Successful in 2m14s
build-prerelease / Build neuron-blackwell (push) Successful in 3m44s
build-prerelease / Build cortex binary (push) Successful in 4m13s
CI / Test (push) Successful in 4m38s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 5m13s
build-prerelease / Build neuron-ada (push) Successful in 4m47s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m41s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
Two changes addressing operator visibility into TP inference + the
HTTP-cancellation poisoning chain:
1. `chat_completion_tp` now runs its body inside `tokio::spawn`. When
the HTTP client disconnects (curl --max-time, browser nav, etc.)
the future returned from `chat_completion_tp` gets dropped, but
the spawned task keeps running to completion — finishing every
`pool.generate_step` / `pool.clear_kv_cache` to drain the worker
pipes. The next inference request then finds a clean pool.
Previously: dropped future left workers still processing the
in-flight request, the next call's `ClearKvCache` recv would
read the stale `GenerateStepOk` from the abandoned step ("rank N
expected KvCacheCleared, got GenerateStepOk"). The drain-on-
leader-error fix from d1a4aad covered Rust-side leader failures
but not HTTP-layer cancellation, which is what we actually hit
on the user's Qwen3.6 test.
2. Tracing throughout the TP path so journalctl shows where an
inference spends its time without needing to surface harness
internals via the HTTP error body:
- `chat_completion_tp_inner` (now a free fn so it can run inside
spawn): `info` at request start (prompt_len, max_new, temp,
top_p, eos_id), `info` per major phase (prefill complete with
elapsed_ms, decode complete with elapsed_ms + token count),
`info` at completion (total_ms, finish_reason). `debug` for
pool-lock acquisition + kv-cache clear timing. `trace` per
decode step (next_token, step_ms).
- `WorkerPool::generate_step` (leader side): `debug` at fan-out,
`debug` after leader forward returns with elapsed_ms + ok flag,
`debug` after drain with errors count + total_ms.
- `WorkerPool::clear_kv_cache`: matching `debug` at fan-out + drain.
- `worker::handle_generate_step`: `debug` at forward start + done
with elapsed_ms, `warn` on forward failure with the full error.
The default log filter is already `info,neuron=debug` so the
operator gets every `info` and `debug` line by default; `trace`
needs RUST_LOG=trace for per-step decode timing.
Stage 7c-ii crash-detection is still future work; this is the
minimum that makes the "where did the 120s go" question answerable
from the logs.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1167,127 +1167,32 @@ impl CandleHarness {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Non-streaming chat completion against a TP model. Pattern mirrors
|
/// Non-streaming chat completion against a TP model.
|
||||||
/// the single-GPU `run_inference`: tokenize, prefill, sample, decode
|
///
|
||||||
/// loop, detokenize. Each forward step fans out to every rank via
|
/// The actual work runs inside a `tokio::spawn`'d task so the HTTP
|
||||||
/// the WorkerPool and uses the leader's last-position logits to
|
/// client disconnecting (curl timeout, browser nav-away, etc.)
|
||||||
/// sample.
|
/// can't cancel the future mid-`pool.generate_step` and leave the
|
||||||
|
/// worker subprocesses mid-RPC. If the spawned task is dropped,
|
||||||
|
/// it still runs to completion and finishes draining the pool —
|
||||||
|
/// the next inference request finds a clean pool. The HTTP layer
|
||||||
|
/// just gives up on the response.
|
||||||
|
///
|
||||||
|
/// Every step also emits `info`/`debug` tracing so journalctl
|
||||||
|
/// shows where time went without needing to surface internals in
|
||||||
|
/// the HTTP error response.
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
async fn chat_completion_tp(
|
async fn chat_completion_tp(
|
||||||
&self,
|
&self,
|
||||||
tp: Arc<TpLoadedModel>,
|
tp: Arc<TpLoadedModel>,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<ChatCompletionResponse, InferenceError> {
|
) -> Result<ChatCompletionResponse, InferenceError> {
|
||||||
let prompt = format_qwen3_prompt(&request.messages);
|
let handle = tokio::spawn(chat_completion_tp_inner(tp, request));
|
||||||
let encoding = tp
|
match handle.await {
|
||||||
.tokenizer
|
Ok(result) => result,
|
||||||
.encode(prompt.as_str(), true)
|
Err(join_err) => Err(InferenceError::Other(anyhow::anyhow!(
|
||||||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
"TP inference task panicked or was cancelled: {join_err}"
|
||||||
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
))),
|
||||||
let prompt_len = prompt_tokens.len();
|
|
||||||
|
|
||||||
let temperature = request.temperature.unwrap_or(0.7);
|
|
||||||
let top_p = request.top_p;
|
|
||||||
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
|
||||||
let seed = unix_subsec_nanos();
|
|
||||||
|
|
||||||
let eos_id = tp
|
|
||||||
.tokenizer
|
|
||||||
.token_to_id("<|im_end|>")
|
|
||||||
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
|
|
||||||
|
|
||||||
let model_id = request.model.clone();
|
|
||||||
|
|
||||||
// Acquire the pool lock for the duration of the request. The
|
|
||||||
// leader_model's own Mutex is acquired step-by-step inside
|
|
||||||
// pool.generate_step (so spawn_blocking can grab it without
|
|
||||||
// holding the pool lock across the blocking_lock call).
|
|
||||||
let mut pool = tp.pool.lock().await;
|
|
||||||
let leader_arc = tp.leader_model.clone();
|
|
||||||
|
|
||||||
// Reset every rank's KV cache so this request doesn't attend
|
|
||||||
// over the previous request's tokens.
|
|
||||||
pool.clear_kv_cache(&model_id, leader_arc.clone())
|
|
||||||
.await
|
|
||||||
.map_err(InferenceError::Other)?;
|
|
||||||
|
|
||||||
let mut logits_processor = {
|
|
||||||
let sampling = if temperature <= 0.0 {
|
|
||||||
Sampling::ArgMax
|
|
||||||
} else {
|
|
||||||
match top_p {
|
|
||||||
Some(p) => Sampling::TopP { p, temperature },
|
|
||||||
None => Sampling::All { temperature },
|
|
||||||
}
|
|
||||||
};
|
|
||||||
LogitsProcessor::from_sampling(seed, sampling)
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut generated: Vec<u32> = Vec::new();
|
|
||||||
let mut finish_reason = "length".to_string();
|
|
||||||
|
|
||||||
// Prefill: every rank embeds the whole prompt, offset = 0.
|
|
||||||
let logits = pool
|
|
||||||
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
|
||||||
.await
|
|
||||||
.map_err(InferenceError::Other)?;
|
|
||||||
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
|
||||||
.map_err(InferenceError::Other)?;
|
|
||||||
|
|
||||||
if Some(next_token) == eos_id {
|
|
||||||
finish_reason = "stop".into();
|
|
||||||
} else {
|
|
||||||
generated.push(next_token);
|
|
||||||
for index in 0..max_new.saturating_sub(1) {
|
|
||||||
let logits = pool
|
|
||||||
.generate_step(
|
|
||||||
&model_id,
|
|
||||||
leader_arc.clone(),
|
|
||||||
vec![next_token],
|
|
||||||
prompt_len + index,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(InferenceError::Other)?;
|
|
||||||
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
|
||||||
.map_err(InferenceError::Other)?;
|
|
||||||
if Some(next_token) == eos_id {
|
|
||||||
finish_reason = "stop".into();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
generated.push(next_token);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
drop(pool);
|
|
||||||
|
|
||||||
let completion_text = tp
|
|
||||||
.tokenizer
|
|
||||||
.decode(&generated, true)
|
|
||||||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
|
|
||||||
|
|
||||||
let usage = Usage {
|
|
||||||
prompt_tokens: prompt_len as u64,
|
|
||||||
completion_tokens: generated.len() as u64,
|
|
||||||
total_tokens: (prompt_len + generated.len()) as u64,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(ChatCompletionResponse {
|
|
||||||
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
|
|
||||||
object: "chat.completion".into(),
|
|
||||||
created: unix_now_secs(),
|
|
||||||
model: model_id,
|
|
||||||
choices: vec![ChatCompletionChoice {
|
|
||||||
index: 0,
|
|
||||||
message: ChatMessage {
|
|
||||||
role: "assistant".into(),
|
|
||||||
content: MessageContent::Text(completion_text),
|
|
||||||
extra: serde_json::Value::Object(Default::default()),
|
|
||||||
},
|
|
||||||
finish_reason: Some(finish_reason),
|
|
||||||
extra: serde_json::Value::Object(Default::default()),
|
|
||||||
}],
|
|
||||||
usage: Some(usage),
|
|
||||||
extra: serde_json::Value::Object(Default::default()),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Streaming counterpart to `chat_completion_tp`. Same per-step
|
/// Streaming counterpart to `chat_completion_tp`. Same per-step
|
||||||
@@ -1499,6 +1404,187 @@ impl CandleHarness {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Body of the TP non-streaming chat completion, hoisted out of
|
||||||
|
/// `CandleHarness::chat_completion_tp` so it can run inside
|
||||||
|
/// `tokio::spawn` (which requires a `'static` future) and survive
|
||||||
|
/// HTTP-layer cancellation.
|
||||||
|
///
|
||||||
|
/// Tracing strategy: `info` for request entry/exit so journalctl
|
||||||
|
/// always shows when an inference started and finished; `debug` for
|
||||||
|
/// per-step timing so an operator running with `RUST_LOG=debug` sees
|
||||||
|
/// where the request actually spends its time without needing to
|
||||||
|
/// instrument the model code.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
async fn chat_completion_tp_inner(
|
||||||
|
tp: Arc<TpLoadedModel>,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Result<ChatCompletionResponse, InferenceError> {
|
||||||
|
let req_start = std::time::Instant::now();
|
||||||
|
let model_id = request.model.clone();
|
||||||
|
|
||||||
|
let prompt = format_qwen3_prompt(&request.messages);
|
||||||
|
let encoding = tp
|
||||||
|
.tokenizer
|
||||||
|
.encode(prompt.as_str(), true)
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||||||
|
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||||||
|
let prompt_len = prompt_tokens.len();
|
||||||
|
|
||||||
|
let temperature = request.temperature.unwrap_or(0.7);
|
||||||
|
let top_p = request.top_p;
|
||||||
|
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
||||||
|
let seed = unix_subsec_nanos();
|
||||||
|
|
||||||
|
let eos_id = tp
|
||||||
|
.tokenizer
|
||||||
|
.token_to_id("<|im_end|>")
|
||||||
|
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
model = %model_id,
|
||||||
|
prompt_len,
|
||||||
|
max_new,
|
||||||
|
temperature,
|
||||||
|
?top_p,
|
||||||
|
?eos_id,
|
||||||
|
"TP chat_completion: starting"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Acquire the pool lock for the duration of the request. The
|
||||||
|
// leader_model's own Mutex is acquired step-by-step inside
|
||||||
|
// pool.generate_step (so spawn_blocking can grab it without
|
||||||
|
// holding the pool lock across the blocking_lock call).
|
||||||
|
let lock_start = std::time::Instant::now();
|
||||||
|
let mut pool = tp.pool.lock().await;
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = lock_start.elapsed().as_millis(),
|
||||||
|
"TP chat_completion: pool lock acquired"
|
||||||
|
);
|
||||||
|
let leader_arc = tp.leader_model.clone();
|
||||||
|
|
||||||
|
// Reset every rank's KV cache so this request doesn't attend
|
||||||
|
// over the previous request's tokens.
|
||||||
|
let clear_start = std::time::Instant::now();
|
||||||
|
pool.clear_kv_cache(&model_id, leader_arc.clone())
|
||||||
|
.await
|
||||||
|
.map_err(InferenceError::Other)?;
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = clear_start.elapsed().as_millis(),
|
||||||
|
"TP chat_completion: kv cache cleared"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut logits_processor = {
|
||||||
|
let sampling = if temperature <= 0.0 {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match top_p {
|
||||||
|
Some(p) => Sampling::TopP { p, temperature },
|
||||||
|
None => Sampling::All { temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut generated: Vec<u32> = Vec::new();
|
||||||
|
let mut finish_reason = "length".to_string();
|
||||||
|
|
||||||
|
// Prefill: every rank embeds the whole prompt, offset = 0.
|
||||||
|
let prefill_start = std::time::Instant::now();
|
||||||
|
let logits = pool
|
||||||
|
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
||||||
|
.await
|
||||||
|
.map_err(InferenceError::Other)?;
|
||||||
|
tracing::info!(
|
||||||
|
model = %model_id,
|
||||||
|
prompt_len,
|
||||||
|
elapsed_ms = prefill_start.elapsed().as_millis(),
|
||||||
|
"TP chat_completion: prefill complete"
|
||||||
|
);
|
||||||
|
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
||||||
|
.map_err(InferenceError::Other)?;
|
||||||
|
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
finish_reason = "stop".into();
|
||||||
|
} else {
|
||||||
|
generated.push(next_token);
|
||||||
|
let decode_start = std::time::Instant::now();
|
||||||
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
|
let step_start = std::time::Instant::now();
|
||||||
|
let logits = pool
|
||||||
|
.generate_step(
|
||||||
|
&model_id,
|
||||||
|
leader_arc.clone(),
|
||||||
|
vec![next_token],
|
||||||
|
prompt_len + index,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(InferenceError::Other)?;
|
||||||
|
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
||||||
|
.map_err(InferenceError::Other)?;
|
||||||
|
tracing::trace!(
|
||||||
|
model = %model_id,
|
||||||
|
step = index,
|
||||||
|
next_token,
|
||||||
|
step_ms = step_start.elapsed().as_millis(),
|
||||||
|
"TP chat_completion: decode step"
|
||||||
|
);
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
finish_reason = "stop".into();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
generated.push(next_token);
|
||||||
|
}
|
||||||
|
tracing::info!(
|
||||||
|
model = %model_id,
|
||||||
|
generated = generated.len(),
|
||||||
|
elapsed_ms = decode_start.elapsed().as_millis(),
|
||||||
|
"TP chat_completion: decode complete"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
drop(pool);
|
||||||
|
|
||||||
|
let completion_text = tp
|
||||||
|
.tokenizer
|
||||||
|
.decode(&generated, true)
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
|
||||||
|
|
||||||
|
let usage = Usage {
|
||||||
|
prompt_tokens: prompt_len as u64,
|
||||||
|
completion_tokens: generated.len() as u64,
|
||||||
|
total_tokens: (prompt_len + generated.len()) as u64,
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
model = %model_id,
|
||||||
|
prompt_tokens = prompt_len,
|
||||||
|
completion_tokens = generated.len(),
|
||||||
|
finish_reason = %finish_reason,
|
||||||
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
|
"TP chat_completion: done"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(ChatCompletionResponse {
|
||||||
|
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
|
||||||
|
object: "chat.completion".into(),
|
||||||
|
created: unix_now_secs(),
|
||||||
|
model: model_id,
|
||||||
|
choices: vec![ChatCompletionChoice {
|
||||||
|
index: 0,
|
||||||
|
message: ChatMessage {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: MessageContent::Text(completion_text),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
},
|
||||||
|
finish_reason: Some(finish_reason),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}],
|
||||||
|
usage: Some(usage),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Decode the cumulative token list, emit the delta (substring appended
|
/// Decode the cumulative token list, emit the delta (substring appended
|
||||||
/// since the last chunk) as a `chat.completion.chunk`. Returns `false`
|
/// since the last chunk) as a `chat.completion.chunk`. Returns `false`
|
||||||
/// if the receiver has hung up — the caller should bail.
|
/// if the receiver has hung up — the caller should bail.
|
||||||
|
|||||||
@@ -604,6 +604,14 @@ impl WorkerPool {
|
|||||||
tokens: Vec<u32>,
|
tokens: Vec<u32>,
|
||||||
offset: usize,
|
offset: usize,
|
||||||
) -> Result<candle_core::Tensor> {
|
) -> Result<candle_core::Tensor> {
|
||||||
|
let step_start = std::time::Instant::now();
|
||||||
|
let tokens_len = tokens.len();
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens_len,
|
||||||
|
offset,
|
||||||
|
"WorkerPool::generate_step: fan-out"
|
||||||
|
);
|
||||||
// 1. Fan-out to workers.
|
// 1. Fan-out to workers.
|
||||||
for w in &mut self.workers {
|
for w in &mut self.workers {
|
||||||
w.send_only(&WorkerRequest::GenerateStep {
|
w.send_only(&WorkerRequest::GenerateStep {
|
||||||
@@ -617,6 +625,7 @@ impl WorkerPool {
|
|||||||
// 2. Leader's forward in spawn_blocking. The AllReduce CustomOps
|
// 2. Leader's forward in spawn_blocking. The AllReduce CustomOps
|
||||||
// inside the row-parallel layers block until every worker's
|
// inside the row-parallel layers block until every worker's
|
||||||
// forward issues the matching collective.
|
// forward issues the matching collective.
|
||||||
|
let leader_start = std::time::Instant::now();
|
||||||
let leader_result = tokio::task::spawn_blocking(move || -> Result<candle_core::Tensor> {
|
let leader_result = tokio::task::spawn_blocking(move || -> Result<candle_core::Tensor> {
|
||||||
let mut model = leader_model.blocking_lock();
|
let mut model = leader_model.blocking_lock();
|
||||||
let device = model.device().clone();
|
let device = model.device().clone();
|
||||||
@@ -628,6 +637,14 @@ impl WorkerPool {
|
|||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.context("leader forward task panicked");
|
.context("leader forward task panicked");
|
||||||
|
let leader_ok = matches!(leader_result, Ok(Ok(_)));
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens_len,
|
||||||
|
leader_ms = leader_start.elapsed().as_millis(),
|
||||||
|
leader_ok,
|
||||||
|
"WorkerPool::generate_step: leader forward returned"
|
||||||
|
);
|
||||||
|
|
||||||
// 3. ALWAYS drain worker responses, regardless of whether the
|
// 3. ALWAYS drain worker responses, regardless of whether the
|
||||||
// leader succeeded. Skipping this on the leader's error
|
// leader succeeded. Skipping this on the leader's error
|
||||||
@@ -635,12 +652,20 @@ impl WorkerPool {
|
|||||||
// pipes that poison the NEXT request's recv (was seeing
|
// pipes that poison the NEXT request's recv (was seeing
|
||||||
// "ClearKvCache: expected KvCacheCleared, got
|
// "ClearKvCache: expected KvCacheCleared, got
|
||||||
// GenerateStepOk" the call after any forward-time failure).
|
// GenerateStepOk" the call after any forward-time failure).
|
||||||
|
let drain_start = std::time::Instant::now();
|
||||||
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
||||||
WorkerResponse::GenerateStepOk => Ok(()),
|
WorkerResponse::GenerateStepOk => Ok(()),
|
||||||
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
||||||
other => Err(format!("expected GenerateStepOk, got {other:?}")),
|
other => Err(format!("expected GenerateStepOk, got {other:?}")),
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
drain_ms = drain_start.elapsed().as_millis(),
|
||||||
|
errors = worker_errors.len(),
|
||||||
|
total_ms = step_start.elapsed().as_millis(),
|
||||||
|
"WorkerPool::generate_step: workers drained"
|
||||||
|
);
|
||||||
|
|
||||||
combine_leader_workers(leader_result, worker_errors, "GenerateStep")
|
combine_leader_workers(leader_result, worker_errors, "GenerateStep")
|
||||||
}
|
}
|
||||||
@@ -653,6 +678,8 @@ impl WorkerPool {
|
|||||||
model_id: &str,
|
model_id: &str,
|
||||||
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
|
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out");
|
||||||
for w in &mut self.workers {
|
for w in &mut self.workers {
|
||||||
w.send_only(&WorkerRequest::ClearKvCache {
|
w.send_only(&WorkerRequest::ClearKvCache {
|
||||||
model_id: model_id.to_string(),
|
model_id: model_id.to_string(),
|
||||||
@@ -674,6 +701,12 @@ impl WorkerPool {
|
|||||||
other => Err(format!("expected KvCacheCleared, got {other:?}")),
|
other => Err(format!("expected KvCacheCleared, got {other:?}")),
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
errors = worker_errors.len(),
|
||||||
|
"WorkerPool::clear_kv_cache: workers drained"
|
||||||
|
);
|
||||||
if !worker_errors.is_empty() {
|
if !worker_errors.is_empty() {
|
||||||
anyhow::bail!("ClearKvCache: {}", worker_errors.join("; "));
|
anyhow::bail!("ClearKvCache: {}", worker_errors.join("; "));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -344,15 +344,36 @@ impl WorkerState {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
tracing::debug!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens.len(),
|
||||||
|
offset,
|
||||||
|
"worker GenerateStep: forward starting"
|
||||||
|
);
|
||||||
// Drop the resulting logits — the leader uses its own copy from
|
// Drop the resulting logits — the leader uses its own copy from
|
||||||
// rank 0. The forward's value here is the NCCL collectives it
|
// rank 0. The forward's value here is the NCCL collectives it
|
||||||
// issues, which let the leader's rank-0 forward make progress.
|
// issues, which let the leader's rank-0 forward make progress.
|
||||||
if let Err(e) = model.forward(&input, offset) {
|
if let Err(e) = model.forward(&input, offset) {
|
||||||
|
tracing::warn!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
error = %e,
|
||||||
|
"worker GenerateStep: forward failed"
|
||||||
|
);
|
||||||
return WorkerResponse::Error {
|
return WorkerResponse::Error {
|
||||||
kind: "forward_failed".into(),
|
kind: "forward_failed".into(),
|
||||||
message: format!("TP forward: {e}"),
|
message: format!("TP forward: {e}"),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
tracing::debug!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
"worker GenerateStep: forward done"
|
||||||
|
);
|
||||||
WorkerResponse::GenerateStepOk
|
WorkerResponse::GenerateStepOk
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user