diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 75323a4..24dbec6 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -1167,127 +1167,32 @@ impl CandleHarness { Ok(()) } - /// Non-streaming chat completion against a TP model. Pattern mirrors - /// the single-GPU `run_inference`: tokenize, prefill, sample, decode - /// loop, detokenize. Each forward step fans out to every rank via - /// the WorkerPool and uses the leader's last-position logits to - /// sample. + /// Non-streaming chat completion against a TP model. + /// + /// The actual work runs inside a `tokio::spawn`'d task so the HTTP + /// client disconnecting (curl timeout, browser nav-away, etc.) + /// can't cancel the future mid-`pool.generate_step` and leave the + /// worker subprocesses mid-RPC. If the spawned task is dropped, + /// it still runs to completion and finishes draining the pool — + /// the next inference request finds a clean pool. The HTTP layer + /// just gives up on the response. + /// + /// Every step also emits `info`/`debug` tracing so journalctl + /// shows where time went without needing to surface internals in + /// the HTTP error response. #[cfg(feature = "cuda")] async fn chat_completion_tp( &self, tp: Arc, request: ChatCompletionRequest, ) -> Result { - let prompt = format_qwen3_prompt(&request.messages); - let encoding = tp - .tokenizer - .encode(prompt.as_str(), true) - .map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?; - let prompt_tokens: Vec = encoding.get_ids().to_vec(); - let prompt_len = prompt_tokens.len(); - - let temperature = request.temperature.unwrap_or(0.7); - let top_p = request.top_p; - let max_new = request.max_tokens.unwrap_or(512) as usize; - let seed = unix_subsec_nanos(); - - let eos_id = tp - .tokenizer - .token_to_id("<|im_end|>") - .or_else(|| tp.tokenizer.token_to_id("<|endoftext|>")); - - let model_id = request.model.clone(); - - // Acquire the pool lock for the duration of the request. The - // leader_model's own Mutex is acquired step-by-step inside - // pool.generate_step (so spawn_blocking can grab it without - // holding the pool lock across the blocking_lock call). - let mut pool = tp.pool.lock().await; - let leader_arc = tp.leader_model.clone(); - - // Reset every rank's KV cache so this request doesn't attend - // over the previous request's tokens. - pool.clear_kv_cache(&model_id, leader_arc.clone()) - .await - .map_err(InferenceError::Other)?; - - let mut logits_processor = { - let sampling = if temperature <= 0.0 { - Sampling::ArgMax - } else { - match top_p { - Some(p) => Sampling::TopP { p, temperature }, - None => Sampling::All { temperature }, - } - }; - LogitsProcessor::from_sampling(seed, sampling) - }; - - let mut generated: Vec = Vec::new(); - let mut finish_reason = "length".to_string(); - - // Prefill: every rank embeds the whole prompt, offset = 0. - let logits = pool - .generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0) - .await - .map_err(InferenceError::Other)?; - let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor) - .map_err(InferenceError::Other)?; - - if Some(next_token) == eos_id { - finish_reason = "stop".into(); - } else { - generated.push(next_token); - for index in 0..max_new.saturating_sub(1) { - let logits = pool - .generate_step( - &model_id, - leader_arc.clone(), - vec![next_token], - prompt_len + index, - ) - .await - .map_err(InferenceError::Other)?; - next_token = sample_with_penalty(&logits, &generated, &mut logits_processor) - .map_err(InferenceError::Other)?; - if Some(next_token) == eos_id { - finish_reason = "stop".into(); - break; - } - generated.push(next_token); - } + let handle = tokio::spawn(chat_completion_tp_inner(tp, request)); + match handle.await { + Ok(result) => result, + Err(join_err) => Err(InferenceError::Other(anyhow::anyhow!( + "TP inference task panicked or was cancelled: {join_err}" + ))), } - drop(pool); - - let completion_text = tp - .tokenizer - .decode(&generated, true) - .map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?; - - let usage = Usage { - prompt_tokens: prompt_len as u64, - completion_tokens: generated.len() as u64, - total_tokens: (prompt_len + generated.len()) as u64, - }; - - Ok(ChatCompletionResponse { - id: format!("chatcmpl-{:x}", unix_subsec_nanos()), - object: "chat.completion".into(), - created: unix_now_secs(), - model: model_id, - choices: vec![ChatCompletionChoice { - index: 0, - message: ChatMessage { - role: "assistant".into(), - content: MessageContent::Text(completion_text), - extra: serde_json::Value::Object(Default::default()), - }, - finish_reason: Some(finish_reason), - extra: serde_json::Value::Object(Default::default()), - }], - usage: Some(usage), - extra: serde_json::Value::Object(Default::default()), - }) } /// Streaming counterpart to `chat_completion_tp`. Same per-step @@ -1499,6 +1404,187 @@ impl CandleHarness { } } +/// Body of the TP non-streaming chat completion, hoisted out of +/// `CandleHarness::chat_completion_tp` so it can run inside +/// `tokio::spawn` (which requires a `'static` future) and survive +/// HTTP-layer cancellation. +/// +/// Tracing strategy: `info` for request entry/exit so journalctl +/// always shows when an inference started and finished; `debug` for +/// per-step timing so an operator running with `RUST_LOG=debug` sees +/// where the request actually spends its time without needing to +/// instrument the model code. +#[cfg(feature = "cuda")] +async fn chat_completion_tp_inner( + tp: Arc, + request: ChatCompletionRequest, +) -> Result { + let req_start = std::time::Instant::now(); + let model_id = request.model.clone(); + + let prompt = format_qwen3_prompt(&request.messages); + let encoding = tp + .tokenizer + .encode(prompt.as_str(), true) + .map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?; + let prompt_tokens: Vec = encoding.get_ids().to_vec(); + let prompt_len = prompt_tokens.len(); + + let temperature = request.temperature.unwrap_or(0.7); + let top_p = request.top_p; + let max_new = request.max_tokens.unwrap_or(512) as usize; + let seed = unix_subsec_nanos(); + + let eos_id = tp + .tokenizer + .token_to_id("<|im_end|>") + .or_else(|| tp.tokenizer.token_to_id("<|endoftext|>")); + + tracing::info!( + model = %model_id, + prompt_len, + max_new, + temperature, + ?top_p, + ?eos_id, + "TP chat_completion: starting" + ); + + // Acquire the pool lock for the duration of the request. The + // leader_model's own Mutex is acquired step-by-step inside + // pool.generate_step (so spawn_blocking can grab it without + // holding the pool lock across the blocking_lock call). + let lock_start = std::time::Instant::now(); + let mut pool = tp.pool.lock().await; + tracing::debug!( + model = %model_id, + elapsed_ms = lock_start.elapsed().as_millis(), + "TP chat_completion: pool lock acquired" + ); + let leader_arc = tp.leader_model.clone(); + + // Reset every rank's KV cache so this request doesn't attend + // over the previous request's tokens. + let clear_start = std::time::Instant::now(); + pool.clear_kv_cache(&model_id, leader_arc.clone()) + .await + .map_err(InferenceError::Other)?; + tracing::debug!( + model = %model_id, + elapsed_ms = clear_start.elapsed().as_millis(), + "TP chat_completion: kv cache cleared" + ); + + let mut logits_processor = { + let sampling = if temperature <= 0.0 { + Sampling::ArgMax + } else { + match top_p { + Some(p) => Sampling::TopP { p, temperature }, + None => Sampling::All { temperature }, + } + }; + LogitsProcessor::from_sampling(seed, sampling) + }; + + let mut generated: Vec = Vec::new(); + let mut finish_reason = "length".to_string(); + + // Prefill: every rank embeds the whole prompt, offset = 0. + let prefill_start = std::time::Instant::now(); + let logits = pool + .generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0) + .await + .map_err(InferenceError::Other)?; + tracing::info!( + model = %model_id, + prompt_len, + elapsed_ms = prefill_start.elapsed().as_millis(), + "TP chat_completion: prefill complete" + ); + let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor) + .map_err(InferenceError::Other)?; + + if Some(next_token) == eos_id { + finish_reason = "stop".into(); + } else { + generated.push(next_token); + let decode_start = std::time::Instant::now(); + for index in 0..max_new.saturating_sub(1) { + let step_start = std::time::Instant::now(); + let logits = pool + .generate_step( + &model_id, + leader_arc.clone(), + vec![next_token], + prompt_len + index, + ) + .await + .map_err(InferenceError::Other)?; + next_token = sample_with_penalty(&logits, &generated, &mut logits_processor) + .map_err(InferenceError::Other)?; + tracing::trace!( + model = %model_id, + step = index, + next_token, + step_ms = step_start.elapsed().as_millis(), + "TP chat_completion: decode step" + ); + if Some(next_token) == eos_id { + finish_reason = "stop".into(); + break; + } + generated.push(next_token); + } + tracing::info!( + model = %model_id, + generated = generated.len(), + elapsed_ms = decode_start.elapsed().as_millis(), + "TP chat_completion: decode complete" + ); + } + drop(pool); + + let completion_text = tp + .tokenizer + .decode(&generated, true) + .map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?; + + let usage = Usage { + prompt_tokens: prompt_len as u64, + completion_tokens: generated.len() as u64, + total_tokens: (prompt_len + generated.len()) as u64, + }; + + tracing::info!( + model = %model_id, + prompt_tokens = prompt_len, + completion_tokens = generated.len(), + finish_reason = %finish_reason, + total_ms = req_start.elapsed().as_millis(), + "TP chat_completion: done" + ); + + Ok(ChatCompletionResponse { + id: format!("chatcmpl-{:x}", unix_subsec_nanos()), + object: "chat.completion".into(), + created: unix_now_secs(), + model: model_id, + choices: vec![ChatCompletionChoice { + index: 0, + message: ChatMessage { + role: "assistant".into(), + content: MessageContent::Text(completion_text), + extra: serde_json::Value::Object(Default::default()), + }, + finish_reason: Some(finish_reason), + extra: serde_json::Value::Object(Default::default()), + }], + usage: Some(usage), + extra: serde_json::Value::Object(Default::default()), + }) +} + /// Decode the cumulative token list, emit the delta (substring appended /// since the last chunk) as a `chat.completion.chunk`. Returns `false` /// if the receiver has hung up — the caller should bail. diff --git a/crates/neuron/src/harness/tp/mod.rs b/crates/neuron/src/harness/tp/mod.rs index dae6fb9..a53212f 100644 --- a/crates/neuron/src/harness/tp/mod.rs +++ b/crates/neuron/src/harness/tp/mod.rs @@ -604,6 +604,14 @@ impl WorkerPool { tokens: Vec, offset: usize, ) -> Result { + let step_start = std::time::Instant::now(); + let tokens_len = tokens.len(); + tracing::debug!( + model = %model_id, + tokens = tokens_len, + offset, + "WorkerPool::generate_step: fan-out" + ); // 1. Fan-out to workers. for w in &mut self.workers { w.send_only(&WorkerRequest::GenerateStep { @@ -617,6 +625,7 @@ impl WorkerPool { // 2. Leader's forward in spawn_blocking. The AllReduce CustomOps // inside the row-parallel layers block until every worker's // forward issues the matching collective. + let leader_start = std::time::Instant::now(); let leader_result = tokio::task::spawn_blocking(move || -> Result { let mut model = leader_model.blocking_lock(); let device = model.device().clone(); @@ -628,6 +637,14 @@ impl WorkerPool { }) .await .context("leader forward task panicked"); + let leader_ok = matches!(leader_result, Ok(Ok(_))); + tracing::debug!( + model = %model_id, + tokens = tokens_len, + leader_ms = leader_start.elapsed().as_millis(), + leader_ok, + "WorkerPool::generate_step: leader forward returned" + ); // 3. ALWAYS drain worker responses, regardless of whether the // leader succeeded. Skipping this on the leader's error @@ -635,12 +652,20 @@ impl WorkerPool { // pipes that poison the NEXT request's recv (was seeing // "ClearKvCache: expected KvCacheCleared, got // GenerateStepOk" the call after any forward-time failure). + let drain_start = std::time::Instant::now(); let worker_errors = drain_workers(&mut self.workers, |r| match r { WorkerResponse::GenerateStepOk => Ok(()), WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")), other => Err(format!("expected GenerateStepOk, got {other:?}")), }) .await; + tracing::debug!( + model = %model_id, + drain_ms = drain_start.elapsed().as_millis(), + errors = worker_errors.len(), + total_ms = step_start.elapsed().as_millis(), + "WorkerPool::generate_step: workers drained" + ); combine_leader_workers(leader_result, worker_errors, "GenerateStep") } @@ -653,6 +678,8 @@ impl WorkerPool { model_id: &str, #[cfg(feature = "cuda")] leader_model: std::sync::Arc>, ) -> Result<()> { + let start = std::time::Instant::now(); + tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out"); for w in &mut self.workers { w.send_only(&WorkerRequest::ClearKvCache { model_id: model_id.to_string(), @@ -674,6 +701,12 @@ impl WorkerPool { other => Err(format!("expected KvCacheCleared, got {other:?}")), }) .await; + tracing::debug!( + model = %model_id, + elapsed_ms = start.elapsed().as_millis(), + errors = worker_errors.len(), + "WorkerPool::clear_kv_cache: workers drained" + ); if !worker_errors.is_empty() { anyhow::bail!("ClearKvCache: {}", worker_errors.join("; ")); } diff --git a/crates/neuron/src/harness/tp/worker.rs b/crates/neuron/src/harness/tp/worker.rs index 82267c5..2d3e444 100644 --- a/crates/neuron/src/harness/tp/worker.rs +++ b/crates/neuron/src/harness/tp/worker.rs @@ -344,15 +344,36 @@ impl WorkerState { }; } }; + let start = std::time::Instant::now(); + tracing::debug!( + rank = self.config.rank, + model = %model_id, + tokens = tokens.len(), + offset, + "worker GenerateStep: forward starting" + ); // Drop the resulting logits — the leader uses its own copy from // rank 0. The forward's value here is the NCCL collectives it // issues, which let the leader's rank-0 forward make progress. if let Err(e) = model.forward(&input, offset) { + tracing::warn!( + rank = self.config.rank, + model = %model_id, + elapsed_ms = start.elapsed().as_millis(), + error = %e, + "worker GenerateStep: forward failed" + ); return WorkerResponse::Error { kind: "forward_failed".into(), message: format!("TP forward: {e}"), }; } + tracing::debug!( + rank = self.config.rank, + model = %model_id, + elapsed_ms = start.elapsed().as_millis(), + "worker GenerateStep: forward done" + ); WorkerResponse::GenerateStepOk }