diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index c5f75bf..f6eaf55 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -3537,100 +3537,103 @@ async fn stream_inference_via_worker( } }; - // Inlined per-token routing — parallel to the TP path. Macro - // approach used in the CPU path doesn't translate cleanly - // here because the emit is async (.await) and macros don't - // tolerate `.await` inside reused expansions across two - // call sites well. - async fn route_token( - next_token: u32, - all_tokens: &mut Vec, - in_reasoning: &mut bool, - in_tool_call: &mut bool, - tool_call_buf: &mut String, - tool_call_idx: &mut usize, - reasoning_tokens: Option<&ReasoningTokenPair>, - tool_call_tokens: Option<&ToolCallTokenPair>, - decode_stream: &mut tokenizers::DecodeStream<'_>, - tx: &mpsc::Sender, - ) -> bool { - all_tokens.push(next_token); - match handle_tool_call_marker(next_token, tool_call_tokens, in_tool_call, tool_call_buf) { - ToolCallMarker::Enter => return true, - ToolCallMarker::Exit { buffer } => { - let idx = *tool_call_idx; - *tool_call_idx += 1; - match parse_tool_call_body(&buffer, idx) { - Some((id, name, arguments)) => { - if tx - .send(InferenceEvent::ToolCall { - index: idx, - id, - name, - arguments, - }) - .await - .is_err() - { - return false; + // Per-token routing. `tokenizers::DecodeStream` carries five + // generic parameters (`M, N, PT, PP, D`) which makes naming + // its type from a helper signature painful. Use a macro + // instead — the body expands inline with `decode_stream`'s + // concrete type inferred from the call site. The macro + // contains `.await` calls, so it can only expand inside an + // `async` context (which both call sites below are). + // + // The macro takes a single `$next_token` expression and + // returns control to the enclosing scope via `break 'work_step` + // (success path) — labels are needed because Rust macros can't + // emit naked `return` from the caller when the caller's return + // type isn't `()`. Instead the macro `break`s out of a + // labelled block, and the surrounding `if !routed { ... }` + // checks whether the consumer hung up via a captured `routed` + // flag. + macro_rules! route_token { + ($next_token:expr) => {{ + let nt = $next_token; + all_tokens.push(nt); + let mut consumer_alive = true; + 'route: { + match handle_tool_call_marker( + nt, + tool_call_tokens.as_ref(), + &mut in_tool_call, + &mut tool_call_buf, + ) { + ToolCallMarker::Enter => break 'route, + ToolCallMarker::Exit { buffer } => { + let idx = tool_call_idx; + tool_call_idx += 1; + match parse_tool_call_body(&buffer, idx) { + Some((id, name, arguments)) => { + if tx + .send(InferenceEvent::ToolCall { + index: idx, + id, + name, + arguments, + }) + .await + .is_err() + { + consumer_alive = false; + } + } + None => { + let open = tool_call_tokens + .as_ref() + .map(|p| p.open_text.as_str()) + .unwrap_or(""); + let close = tool_call_tokens + .as_ref() + .map(|p| p.close_text.as_str()) + .unwrap_or(""); + let raw = format!("{open}{buffer}{close}"); + if !emit_delta(&raw, &tx, in_reasoning).await { + consumer_alive = false; + } + } + } + break 'route; + } + ToolCallMarker::None => {} + } + if in_tool_call { + match decode_stream.step(nt) { + Ok(Some(s)) => tool_call_buf.push_str(&s), + Ok(None) => {} + Err(e) => tracing::warn!( + error = %e, + "decode_stream step failed (in tool_call)" + ), + } + break 'route; + } + if handle_reasoning_marker(nt, reasoning_tokens.as_ref(), &mut in_reasoning) { + break 'route; + } + match decode_stream.step(nt) { + Ok(Some(delta)) => { + if !emit_delta(&delta, &tx, in_reasoning).await { + consumer_alive = false; } } - None => { - let open = tool_call_tokens - .map(|p| p.open_text.as_str()) - .unwrap_or(""); - let close = tool_call_tokens - .map(|p| p.close_text.as_str()) - .unwrap_or(""); - let raw = format!("{open}{buffer}{close}"); - if !emit_delta(&raw, tx, *in_reasoning).await { - return false; - } - } - } - return true; - } - ToolCallMarker::None => {} - } - if *in_tool_call { - match decode_stream.step(next_token) { - Ok(Some(s)) => tool_call_buf.push_str(&s), - Ok(None) => {} - Err(e) => tracing::warn!(error = %e, "decode_stream step failed (in tool_call)"), - } - return true; - } - if handle_reasoning_marker(next_token, reasoning_tokens, in_reasoning) { - return true; - } - match decode_stream.step(next_token) { - Ok(Some(delta)) => { - if !emit_delta(&delta, tx, *in_reasoning).await { - return false; + Ok(None) => {} + Err(e) => tracing::warn!(error = %e, "decode_stream step failed"), } } - Ok(None) => {} - Err(e) => tracing::warn!(error = %e, "decode_stream step failed"), - } - true + consumer_alive + }}; } if Some(next_token) == eos_id { finish_reason = FinishReason::Stop; - } else if !route_token( - next_token, - &mut all_tokens, - &mut in_reasoning, - &mut in_tool_call, - &mut tool_call_buf, - &mut tool_call_idx, - reasoning_tokens.as_ref(), - tool_call_tokens.as_ref(), - &mut decode_stream, - &tx, - ) - .await - { + } else if !route_token!(next_token) { return Ok(finish_reason.as_openai_str().to_string()); } @@ -3656,20 +3659,7 @@ async fn stream_inference_via_worker( finish_reason = FinishReason::Stop; break; } - if !route_token( - next_token, - &mut all_tokens, - &mut in_reasoning, - &mut in_tool_call, - &mut tool_call_buf, - &mut tool_call_idx, - reasoning_tokens.as_ref(), - tool_call_tokens.as_ref(), - &mut decode_stream, - &tx, - ) - .await - { + if !route_token!(next_token) { return Ok(finish_reason.as_openai_str().to_string()); } }