feat(neuron,candle): req_id spans, terminal failure logs, pool-lock warnings
Every chat completion path (single-GPU + TP, streaming + non-streaming)
now opens an `info_span!("chat", req_id=…, model=…)`. The fmt subscriber
prefixes every event with that span so `grep req_id=…` over journalctl
reconstructs one request even when dozens overlap.
Every path also emits a terminal log line on both success ("done", with
prompt_tokens/completion_tokens/finish_reason/total_ms) and failure
("failed", with full anyhow chain + total_ms). Failures used to vanish
silently — a request that hit a CUDA OOM left "starting" in the journal
and no further trace.
New `acquire_pool_lock` helper replaces the bare `tp.pool.lock().await`
in both TP paths. It warns at 2s ("still waiting on pool lock") and
re-warns every 2s thereafter, so queued requests stuck behind a
deadlocked holder are visible immediately instead of looking like idle
silence.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -29,9 +29,12 @@ use serde_json::json;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use std::time::Duration;
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::{Mutex, RwLock, mpsc};
|
use tokio::sync::{Mutex, RwLock, mpsc};
|
||||||
|
use tracing::Instrument;
|
||||||
|
|
||||||
/// In-process candle harness. Owns the loaded model registry.
|
/// In-process candle harness. Owns the loaded model registry.
|
||||||
pub struct CandleHarness {
|
pub struct CandleHarness {
|
||||||
@@ -351,6 +354,64 @@ fn resolve_hf_cache(explicit: Option<PathBuf>) -> Option<PathBuf> {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A short hex tag used to group every log line emitted on behalf of
|
||||||
|
/// one chat-completion request. Six hex digits is unique enough across
|
||||||
|
/// a 4-hour journal window (24 bits ≈ 16M values, while a busy neuron
|
||||||
|
/// sees ~10³ requests/hour) and fits cleanly inside `req_id=…` in the
|
||||||
|
/// fmt subscriber's span-prefix output.
|
||||||
|
fn new_req_id() -> String {
|
||||||
|
format!("{:06x}", unix_subsec_nanos() & 0xFFFFFF)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Threshold above which `pool.lock().await` blocking is interesting
|
||||||
|
/// enough to warn about. Healthy concurrent requests serialise behind
|
||||||
|
/// the pool in single-digit ms — anything past 2 seconds is either a
|
||||||
|
/// huge in-flight prompt or, more often, a stuck request holding the
|
||||||
|
/// lock against a poisoned CUDA context. See the 2026-05-26 4-hour
|
||||||
|
/// silence on beast where dozens of requests piled up invisibly here.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
const POOL_LOCK_WARN_THRESHOLD: Duration = Duration::from_secs(2);
|
||||||
|
|
||||||
|
/// Acquire the TP pool lock, emitting a warn-level breadcrumb if the
|
||||||
|
/// wait exceeds [`POOL_LOCK_WARN_THRESHOLD`]. Wrapped in a helper so
|
||||||
|
/// the warn happens at the call site — the request whose lock-wait is
|
||||||
|
/// slow is the one that knows its prompt_len and other context.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
async fn acquire_pool_lock(
|
||||||
|
pool: &tokio::sync::Mutex<super::tp::WorkerPool>,
|
||||||
|
model_id: &str,
|
||||||
|
) -> tokio::sync::MutexGuard<'_, super::tp::WorkerPool> {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
// Tick once at the threshold so a stuck request shows up in
|
||||||
|
// journalctl even while it's still waiting. Without this the wait
|
||||||
|
// looks like silence in the log right up until the lock is freed.
|
||||||
|
tokio::pin! {
|
||||||
|
let lock = pool.lock();
|
||||||
|
}
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
guard = &mut lock => {
|
||||||
|
let elapsed = start.elapsed();
|
||||||
|
if elapsed >= POOL_LOCK_WARN_THRESHOLD {
|
||||||
|
tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
waited_ms = elapsed.as_millis(),
|
||||||
|
"TP chat_completion: pool lock acquired after long wait"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return guard;
|
||||||
|
}
|
||||||
|
_ = tokio::time::sleep(POOL_LOCK_WARN_THRESHOLD) => {
|
||||||
|
tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
waited_ms = start.elapsed().as_millis(),
|
||||||
|
"TP chat_completion: still waiting on pool lock"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Apply the repetition penalty (if any) to the prediction logits and
|
/// Apply the repetition penalty (if any) to the prediction logits and
|
||||||
/// then sample. Centralises the prefill / generation-loop call sites
|
/// then sample. Centralises the prefill / generation-loop call sites
|
||||||
/// so they share identical sampling behaviour.
|
/// so they share identical sampling behaviour.
|
||||||
@@ -746,76 +807,119 @@ impl CandleHarness {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let prompt = format_qwen3_prompt(&request.messages);
|
// Span every line of this request with a short req_id +
|
||||||
|
// model so `grep req_id=…` over the journal can reconstruct
|
||||||
let encoding = loaded
|
// one request even when dozens overlap. Add a terminal log
|
||||||
.tokenizer
|
// line on both success and failure — the single-GPU path
|
||||||
.encode(prompt.as_str(), true)
|
// used to log nothing on either side, so a failing request
|
||||||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
// looked exactly like an idle neuron.
|
||||||
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
let req_id = new_req_id();
|
||||||
let prompt_len = prompt_tokens.len();
|
|
||||||
|
|
||||||
let temperature = request.temperature.unwrap_or(0.7);
|
|
||||||
let top_p = request.top_p;
|
|
||||||
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
|
||||||
let seed = unix_subsec_nanos();
|
|
||||||
|
|
||||||
let eos_id = loaded
|
|
||||||
.tokenizer
|
|
||||||
.token_to_id("<|im_end|>")
|
|
||||||
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
|
|
||||||
|
|
||||||
let arch_arc = Arc::clone(&loaded.arch);
|
|
||||||
let device = loaded.device.clone();
|
|
||||||
let model_id = request.model.clone();
|
let model_id = request.model.clone();
|
||||||
|
let span = tracing::info_span!("chat", req_id = %req_id, model = %model_id);
|
||||||
|
let req_start = std::time::Instant::now();
|
||||||
|
|
||||||
let (generated_ids, finish_reason) =
|
let result = async {
|
||||||
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
let prompt = format_qwen3_prompt(&request.messages);
|
||||||
let mut guard = arch_arc.blocking_lock();
|
|
||||||
run_inference(
|
|
||||||
&mut guard,
|
|
||||||
&device,
|
|
||||||
&prompt_tokens,
|
|
||||||
max_new,
|
|
||||||
temperature,
|
|
||||||
top_p,
|
|
||||||
seed,
|
|
||||||
eos_id,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("inference task panicked: {e}")))?
|
|
||||||
.map_err(InferenceError::Other)?;
|
|
||||||
|
|
||||||
let completion_text = loaded
|
let encoding = loaded
|
||||||
.tokenizer
|
.tokenizer
|
||||||
.decode(&generated_ids, true)
|
.encode(prompt.as_str(), true)
|
||||||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||||||
|
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||||||
|
let prompt_len = prompt_tokens.len();
|
||||||
|
|
||||||
let usage = Usage {
|
let temperature = request.temperature.unwrap_or(0.7);
|
||||||
prompt_tokens: prompt_len as u64,
|
let top_p = request.top_p;
|
||||||
completion_tokens: generated_ids.len() as u64,
|
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
||||||
total_tokens: (prompt_len + generated_ids.len()) as u64,
|
let seed = unix_subsec_nanos();
|
||||||
};
|
|
||||||
|
|
||||||
Ok(ChatCompletionResponse {
|
let eos_id = loaded
|
||||||
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
|
.tokenizer
|
||||||
object: "chat.completion".into(),
|
.token_to_id("<|im_end|>")
|
||||||
created: unix_now_secs(),
|
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
|
||||||
model: model_id,
|
|
||||||
choices: vec![ChatCompletionChoice {
|
tracing::info!(
|
||||||
index: 0,
|
prompt_len,
|
||||||
message: ChatMessage {
|
max_new,
|
||||||
role: "assistant".into(),
|
temperature,
|
||||||
content: MessageContent::Text(completion_text),
|
?top_p,
|
||||||
|
?eos_id,
|
||||||
|
"chat_completion: starting"
|
||||||
|
);
|
||||||
|
|
||||||
|
let arch_arc = Arc::clone(&loaded.arch);
|
||||||
|
let device = loaded.device.clone();
|
||||||
|
|
||||||
|
let (generated_ids, finish_reason) =
|
||||||
|
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
||||||
|
let mut guard = arch_arc.blocking_lock();
|
||||||
|
run_inference(
|
||||||
|
&mut guard,
|
||||||
|
&device,
|
||||||
|
&prompt_tokens,
|
||||||
|
max_new,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
seed,
|
||||||
|
eos_id,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
InferenceError::Other(anyhow::anyhow!("inference task panicked: {e}"))
|
||||||
|
})?
|
||||||
|
.map_err(InferenceError::Other)?;
|
||||||
|
|
||||||
|
let completion_text = loaded
|
||||||
|
.tokenizer
|
||||||
|
.decode(&generated_ids, true)
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
|
||||||
|
|
||||||
|
let usage = Usage {
|
||||||
|
prompt_tokens: prompt_len as u64,
|
||||||
|
completion_tokens: generated_ids.len() as u64,
|
||||||
|
total_tokens: (prompt_len + generated_ids.len()) as u64,
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
prompt_tokens = prompt_len,
|
||||||
|
completion_tokens = generated_ids.len(),
|
||||||
|
finish_reason = %finish_reason,
|
||||||
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
|
"chat_completion: done"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok::<_, InferenceError>(ChatCompletionResponse {
|
||||||
|
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
|
||||||
|
object: "chat.completion".into(),
|
||||||
|
created: unix_now_secs(),
|
||||||
|
model: request.model.clone(),
|
||||||
|
choices: vec![ChatCompletionChoice {
|
||||||
|
index: 0,
|
||||||
|
message: ChatMessage {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: MessageContent::Text(completion_text),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
},
|
||||||
|
finish_reason: Some(finish_reason),
|
||||||
extra: serde_json::Value::Object(Default::default()),
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
},
|
}],
|
||||||
finish_reason: Some(finish_reason),
|
usage: Some(usage),
|
||||||
extra: serde_json::Value::Object(Default::default()),
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
}],
|
})
|
||||||
usage: Some(usage),
|
}
|
||||||
extra: serde_json::Value::Object(Default::default()),
|
.instrument(span.clone())
|
||||||
})
|
.await;
|
||||||
|
|
||||||
|
if let Err(ref e) = result {
|
||||||
|
let _g = span.enter();
|
||||||
|
tracing::error!(
|
||||||
|
error = %format!("{e:#}"),
|
||||||
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
|
"chat_completion: failed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run a streaming chat completion against a loaded model.
|
/// Run a streaming chat completion against a loaded model.
|
||||||
@@ -903,9 +1007,30 @@ impl CandleHarness {
|
|||||||
.await
|
.await
|
||||||
.map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?;
|
.map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?;
|
||||||
|
|
||||||
|
// Span context — spawn_blocking detaches from the async
|
||||||
|
// executor so we capture the span explicitly and re-enter it
|
||||||
|
// inside the closure to keep the req_id on every emitted line.
|
||||||
|
let req_id = new_req_id();
|
||||||
|
let span = tracing::info_span!("chat_stream", req_id = %req_id, model = %model_id);
|
||||||
|
let prompt_len = prompt_tokens.len();
|
||||||
|
let req_start = std::time::Instant::now();
|
||||||
|
let span_for_starting = span.clone();
|
||||||
|
let span_for_task = span.clone();
|
||||||
|
{
|
||||||
|
let _g = span_for_starting.enter();
|
||||||
|
tracing::info!(
|
||||||
|
prompt_len,
|
||||||
|
max_new,
|
||||||
|
temperature,
|
||||||
|
?top_p,
|
||||||
|
?eos_id,
|
||||||
|
"chat_completion (stream): starting"
|
||||||
|
);
|
||||||
|
}
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
|
let _g = span_for_task.enter();
|
||||||
let mut guard = arch_arc.blocking_lock();
|
let mut guard = arch_arc.blocking_lock();
|
||||||
if let Err(e) = run_inference_streaming(
|
match run_inference_streaming(
|
||||||
&mut guard,
|
&mut guard,
|
||||||
&device,
|
&device,
|
||||||
&tokenizer,
|
&tokenizer,
|
||||||
@@ -920,7 +1045,17 @@ impl CandleHarness {
|
|||||||
&model_id,
|
&model_id,
|
||||||
&tx,
|
&tx,
|
||||||
) {
|
) {
|
||||||
tracing::warn!(model = %model_id, error = %e, "streaming inference failed");
|
Ok(()) => tracing::info!(
|
||||||
|
prompt_tokens = prompt_len,
|
||||||
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
|
"chat_completion (stream): done"
|
||||||
|
),
|
||||||
|
Err(e) => tracing::error!(
|
||||||
|
error = %format!("{e:#}"),
|
||||||
|
prompt_tokens = prompt_len,
|
||||||
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
|
"chat_completion (stream): failed"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -1187,13 +1322,32 @@ impl CandleHarness {
|
|||||||
tp: Arc<TpLoadedModel>,
|
tp: Arc<TpLoadedModel>,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<ChatCompletionResponse, InferenceError> {
|
) -> Result<ChatCompletionResponse, InferenceError> {
|
||||||
let handle = tokio::spawn(chat_completion_tp_inner(tp, request));
|
// Tag every line of this request with a short req_id so a
|
||||||
match handle.await {
|
// grep over journalctl reconstructs one request even when
|
||||||
Ok(result) => result,
|
// dozens are queued and interleaved. The span prefix is added
|
||||||
|
// by the fmt subscriber to every event emitted within the
|
||||||
|
// instrumented future, including events from `WorkerPool::*`
|
||||||
|
// since those run on the leader's task.
|
||||||
|
let req_id = new_req_id();
|
||||||
|
let model_id = request.model.clone();
|
||||||
|
let span = tracing::info_span!("tp_chat", req_id = %req_id, model = %model_id);
|
||||||
|
let req_start = std::time::Instant::now();
|
||||||
|
let handle = tokio::spawn(chat_completion_tp_inner(tp, request).instrument(span.clone()));
|
||||||
|
let result = match handle.await {
|
||||||
|
Ok(r) => r,
|
||||||
Err(join_err) => Err(InferenceError::Other(anyhow::anyhow!(
|
Err(join_err) => Err(InferenceError::Other(anyhow::anyhow!(
|
||||||
"TP inference task panicked or was cancelled: {join_err}"
|
"TP inference task panicked or was cancelled: {join_err}"
|
||||||
))),
|
))),
|
||||||
|
};
|
||||||
|
if let Err(ref e) = result {
|
||||||
|
let _g = span.enter();
|
||||||
|
tracing::error!(
|
||||||
|
error = %format!("{e:#}"),
|
||||||
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
|
"TP chat_completion: failed"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Streaming counterpart to `chat_completion_tp`. Same per-step
|
/// Streaming counterpart to `chat_completion_tp`. Same per-step
|
||||||
@@ -1263,143 +1417,189 @@ impl CandleHarness {
|
|||||||
// The orchestration task. Holds the pool lock for the lifetime
|
// The orchestration task. Holds the pool lock for the lifetime
|
||||||
// of this inference; concurrent requests against the same TP
|
// of this inference; concurrent requests against the same TP
|
||||||
// model serialise behind it.
|
// model serialise behind it.
|
||||||
|
//
|
||||||
|
// Tagged with the same req_id span as the non-streaming path
|
||||||
|
// so the journal can be reconstructed regardless of which API
|
||||||
|
// surface the client hit.
|
||||||
|
let req_id = new_req_id();
|
||||||
|
let span = tracing::info_span!(
|
||||||
|
"tp_chat_stream",
|
||||||
|
req_id = %req_id,
|
||||||
|
model = %model_id
|
||||||
|
);
|
||||||
|
let req_start = std::time::Instant::now();
|
||||||
|
tracing::info!(
|
||||||
|
parent: &span,
|
||||||
|
prompt_len,
|
||||||
|
max_new,
|
||||||
|
temperature,
|
||||||
|
?top_p,
|
||||||
|
?eos_id,
|
||||||
|
"TP chat_completion (stream): starting"
|
||||||
|
);
|
||||||
let tp_for_task = Arc::clone(&tp);
|
let tp_for_task = Arc::clone(&tp);
|
||||||
tokio::spawn(async move {
|
tokio::spawn(
|
||||||
let mut pool = tp_for_task.pool.lock().await;
|
async move {
|
||||||
let leader_arc = tp_for_task.leader_model.clone();
|
let mut failure: Option<String> = None;
|
||||||
|
let mut pool = acquire_pool_lock(&tp_for_task.pool, &model_id).await;
|
||||||
|
let leader_arc = tp_for_task.leader_model.clone();
|
||||||
|
|
||||||
if let Err(e) = pool.clear_kv_cache(&model_id, leader_arc.clone()).await {
|
let mut all_tokens: Vec<u32> = Vec::new();
|
||||||
tracing::warn!(model = %model_id, error = %e, "TP stream: clear_kv_cache failed");
|
let mut decoded_prefix = String::new();
|
||||||
return;
|
let mut finish_reason = "length".to_string();
|
||||||
}
|
|
||||||
|
|
||||||
let mut logits_processor = {
|
'work: {
|
||||||
let sampling = if temperature <= 0.0 {
|
if let Err(e) = pool.clear_kv_cache(&model_id, leader_arc.clone()).await {
|
||||||
Sampling::ArgMax
|
failure = Some(format!("clear_kv_cache: {e:#}"));
|
||||||
} else {
|
break 'work;
|
||||||
match top_p {
|
|
||||||
Some(p) => Sampling::TopP { p, temperature },
|
|
||||||
None => Sampling::All { temperature },
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
LogitsProcessor::from_sampling(seed, sampling)
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut all_tokens: Vec<u32> = Vec::new();
|
let mut logits_processor = {
|
||||||
let mut decoded_prefix = String::new();
|
let sampling = if temperature <= 0.0 {
|
||||||
let mut finish_reason = "length".to_string();
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match top_p {
|
||||||
|
Some(p) => Sampling::TopP { p, temperature },
|
||||||
|
None => Sampling::All { temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
// Prefill — every rank embeds the prompt, offset = 0.
|
// Prefill — every rank embeds the prompt, offset = 0.
|
||||||
let logits = match pool
|
|
||||||
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(l) => l,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::warn!(model = %model_id, error = %e, "TP stream: prefill failed");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let mut next_token = match sample_with_penalty(
|
|
||||||
&logits,
|
|
||||||
&all_tokens,
|
|
||||||
&mut logits_processor,
|
|
||||||
) {
|
|
||||||
Ok(t) => t,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::warn!(model = %model_id, error = %e, "TP stream: prefill sample failed");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if Some(next_token) == eos_id {
|
|
||||||
finish_reason = "stop".into();
|
|
||||||
} else {
|
|
||||||
all_tokens.push(next_token);
|
|
||||||
if !emit_chunk(
|
|
||||||
&all_tokens,
|
|
||||||
&mut decoded_prefix,
|
|
||||||
&tokenizer,
|
|
||||||
&tx,
|
|
||||||
&id,
|
|
||||||
created,
|
|
||||||
&model_id,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for index in 0..max_new.saturating_sub(1) {
|
|
||||||
let logits = match pool
|
let logits = match pool
|
||||||
.generate_step(
|
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
||||||
&model_id,
|
|
||||||
leader_arc.clone(),
|
|
||||||
vec![next_token],
|
|
||||||
prompt_len + index,
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(l) => l,
|
Ok(l) => l,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(
|
failure = Some(format!("prefill: {e:#}"));
|
||||||
model = %model_id,
|
break 'work;
|
||||||
error = %e,
|
|
||||||
"TP stream: decode step failed"
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
next_token =
|
let mut next_token =
|
||||||
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||||
Ok(t) => t,
|
Ok(t) => t,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(
|
failure = Some(format!("prefill sample: {e:#}"));
|
||||||
model = %model_id,
|
break 'work;
|
||||||
error = %e,
|
|
||||||
"TP stream: decode sample failed"
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if Some(next_token) == eos_id {
|
if Some(next_token) == eos_id {
|
||||||
finish_reason = "stop".into();
|
finish_reason = "stop".into();
|
||||||
break;
|
} else {
|
||||||
}
|
all_tokens.push(next_token);
|
||||||
all_tokens.push(next_token);
|
if !emit_chunk(
|
||||||
if !emit_chunk(
|
&all_tokens,
|
||||||
&all_tokens,
|
&mut decoded_prefix,
|
||||||
&mut decoded_prefix,
|
&tokenizer,
|
||||||
&tokenizer,
|
&tx,
|
||||||
&tx,
|
&id,
|
||||||
&id,
|
created,
|
||||||
created,
|
&model_id,
|
||||||
&model_id,
|
)
|
||||||
)
|
.await
|
||||||
.await
|
{
|
||||||
{
|
// Client gone — treat as normal stream end,
|
||||||
return;
|
// not a failure. No log spam.
|
||||||
|
break 'work;
|
||||||
|
}
|
||||||
|
|
||||||
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
|
let logits = match pool
|
||||||
|
.generate_step(
|
||||||
|
&model_id,
|
||||||
|
leader_arc.clone(),
|
||||||
|
vec![next_token],
|
||||||
|
prompt_len + index,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(l) => l,
|
||||||
|
Err(e) => {
|
||||||
|
failure = Some(format!("decode step {index}: {e:#}"));
|
||||||
|
break 'work;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
next_token = match sample_with_penalty(
|
||||||
|
&logits,
|
||||||
|
&all_tokens,
|
||||||
|
&mut logits_processor,
|
||||||
|
) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
failure = Some(format!("decode sample {index}: {e:#}"));
|
||||||
|
break 'work;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
finish_reason = "stop".into();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if !emit_chunk(
|
||||||
|
&all_tokens,
|
||||||
|
&mut decoded_prefix,
|
||||||
|
&tokenizer,
|
||||||
|
&tx,
|
||||||
|
&id,
|
||||||
|
created,
|
||||||
|
&model_id,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
break 'work;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Final chunk carrying finish_reason.
|
// One terminal line per request, success or failure. The
|
||||||
let final_chunk = ChatCompletionChunk {
|
// success branch was previously implicit (the SSE final
|
||||||
id: id.clone(),
|
// chunk went out and the spawned task just ended); now
|
||||||
object: "chat.completion.chunk".into(),
|
// there's always a log line for the operator.
|
||||||
created,
|
if let Some(err) = &failure {
|
||||||
model: model_id.clone(),
|
tracing::error!(
|
||||||
choices: vec![ChunkChoice {
|
error = %err,
|
||||||
index: 0,
|
completion_tokens = all_tokens.len(),
|
||||||
delta: serde_json::Value::Object(Default::default()),
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
finish_reason: Some(finish_reason),
|
"TP chat_completion (stream): failed"
|
||||||
extra: serde_json::Value::Object(Default::default()),
|
);
|
||||||
}],
|
} else {
|
||||||
usage: None,
|
tracing::info!(
|
||||||
extra: serde_json::Value::Object(Default::default()),
|
prompt_tokens = prompt_len,
|
||||||
};
|
completion_tokens = all_tokens.len(),
|
||||||
let _ = tx.send(final_chunk).await;
|
finish_reason = %finish_reason,
|
||||||
});
|
total_ms = req_start.elapsed().as_millis(),
|
||||||
|
"TP chat_completion (stream): done"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final chunk carrying finish_reason — only on the success
|
||||||
|
// path. On failure we drop the channel so the client sees
|
||||||
|
// the SSE stream end abruptly (matches pre-change behaviour
|
||||||
|
// when the failed-path early-returned without final chunk).
|
||||||
|
if failure.is_none() {
|
||||||
|
let final_chunk = ChatCompletionChunk {
|
||||||
|
id: id.clone(),
|
||||||
|
object: "chat.completion.chunk".into(),
|
||||||
|
created,
|
||||||
|
model: model_id.clone(),
|
||||||
|
choices: vec![ChunkChoice {
|
||||||
|
index: 0,
|
||||||
|
delta: serde_json::Value::Object(Default::default()),
|
||||||
|
finish_reason: Some(finish_reason),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
};
|
||||||
|
let _ = tx.send(final_chunk).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.instrument(span),
|
||||||
|
);
|
||||||
|
|
||||||
Ok(rx)
|
Ok(rx)
|
||||||
}
|
}
|
||||||
@@ -1455,13 +1655,10 @@ async fn chat_completion_tp_inner(
|
|||||||
// leader_model's own Mutex is acquired step-by-step inside
|
// leader_model's own Mutex is acquired step-by-step inside
|
||||||
// pool.generate_step (so spawn_blocking can grab it without
|
// pool.generate_step (so spawn_blocking can grab it without
|
||||||
// holding the pool lock across the blocking_lock call).
|
// holding the pool lock across the blocking_lock call).
|
||||||
let lock_start = std::time::Instant::now();
|
// `acquire_pool_lock` warns periodically while we wait so a
|
||||||
let mut pool = tp.pool.lock().await;
|
// stuck holder doesn't make the queueing requests look like
|
||||||
tracing::debug!(
|
// silence in the journal.
|
||||||
model = %model_id,
|
let mut pool = acquire_pool_lock(&tp.pool, &model_id).await;
|
||||||
elapsed_ms = lock_start.elapsed().as_millis(),
|
|
||||||
"TP chat_completion: pool lock acquired"
|
|
||||||
);
|
|
||||||
let leader_arc = tp.leader_model.clone();
|
let leader_arc = tp.leader_model.clone();
|
||||||
|
|
||||||
// Reset every rank's KV cache so this request doesn't attend
|
// Reset every rank's KV cache so this request doesn't attend
|
||||||
|
|||||||
Reference in New Issue
Block a user