From 435fd109026224ed633f1f2af70a792004b6f201 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Mon, 1 Jun 2026 08:59:56 +0300 Subject: [PATCH] fix(neuron): macro-ify CUDA single-GPU route_token so DecodeStream type stays inferred MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prerelease build (run 270) failed on commit cb30383 with: error[E0107]: struct takes 5 generic arguments but 0 generic arguments were supplied --> crates/neuron/src/harness/candle.rs:3554:41 | 3554 | decode_stream: &mut tokenizers::DecodeStream<'_>, | ^^^^^^^^^^^^ The Step-2-era refactor for #6's tool-call extraction added a nested `async fn route_token` inside `stream_inference_via_worker` that named `tokenizers::DecodeStream<'_>` as a parameter type. `DecodeStream` actually has five generic parameters (`'tok, M, N, PT, PP, D`) which makes naming it explicitly painful — the working approach the CPU path uses is a macro, where the body expands inline at the call site and the decoder type stays inferred. This commit replicates the CPU-side macro for the CUDA worker path. Same shape, just with `.await` calls inside (macros tolerate that since they expand inline into the enclosing async context). Control flow uses a labelled-block + `consumer_alive` flag rather than `return` so the macro stays generic over the surrounding return type. The CPU build (default-feature workspace, what `clippy` and `test` jobs exercise) doesn't compile this `#[cfg(feature = "cuda")]` branch, which is why local CI green-lit it. The cuda-check job should catch this category of breakage now that #cb30383+CI-fix landed; this commit just resolves the actual breakage on the prerelease workflow. Co-Authored-By: Claude Opus 4.7 --- crates/neuron/src/harness/candle.rs | 190 +++++++++++++--------------- 1 file changed, 90 insertions(+), 100 deletions(-) diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index c5f75bf..f6eaf55 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -3537,100 +3537,103 @@ async fn stream_inference_via_worker( } }; - // Inlined per-token routing — parallel to the TP path. Macro - // approach used in the CPU path doesn't translate cleanly - // here because the emit is async (.await) and macros don't - // tolerate `.await` inside reused expansions across two - // call sites well. - async fn route_token( - next_token: u32, - all_tokens: &mut Vec, - in_reasoning: &mut bool, - in_tool_call: &mut bool, - tool_call_buf: &mut String, - tool_call_idx: &mut usize, - reasoning_tokens: Option<&ReasoningTokenPair>, - tool_call_tokens: Option<&ToolCallTokenPair>, - decode_stream: &mut tokenizers::DecodeStream<'_>, - tx: &mpsc::Sender, - ) -> bool { - all_tokens.push(next_token); - match handle_tool_call_marker(next_token, tool_call_tokens, in_tool_call, tool_call_buf) { - ToolCallMarker::Enter => return true, - ToolCallMarker::Exit { buffer } => { - let idx = *tool_call_idx; - *tool_call_idx += 1; - match parse_tool_call_body(&buffer, idx) { - Some((id, name, arguments)) => { - if tx - .send(InferenceEvent::ToolCall { - index: idx, - id, - name, - arguments, - }) - .await - .is_err() - { - return false; + // Per-token routing. `tokenizers::DecodeStream` carries five + // generic parameters (`M, N, PT, PP, D`) which makes naming + // its type from a helper signature painful. Use a macro + // instead — the body expands inline with `decode_stream`'s + // concrete type inferred from the call site. The macro + // contains `.await` calls, so it can only expand inside an + // `async` context (which both call sites below are). + // + // The macro takes a single `$next_token` expression and + // returns control to the enclosing scope via `break 'work_step` + // (success path) — labels are needed because Rust macros can't + // emit naked `return` from the caller when the caller's return + // type isn't `()`. Instead the macro `break`s out of a + // labelled block, and the surrounding `if !routed { ... }` + // checks whether the consumer hung up via a captured `routed` + // flag. + macro_rules! route_token { + ($next_token:expr) => {{ + let nt = $next_token; + all_tokens.push(nt); + let mut consumer_alive = true; + 'route: { + match handle_tool_call_marker( + nt, + tool_call_tokens.as_ref(), + &mut in_tool_call, + &mut tool_call_buf, + ) { + ToolCallMarker::Enter => break 'route, + ToolCallMarker::Exit { buffer } => { + let idx = tool_call_idx; + tool_call_idx += 1; + match parse_tool_call_body(&buffer, idx) { + Some((id, name, arguments)) => { + if tx + .send(InferenceEvent::ToolCall { + index: idx, + id, + name, + arguments, + }) + .await + .is_err() + { + consumer_alive = false; + } + } + None => { + let open = tool_call_tokens + .as_ref() + .map(|p| p.open_text.as_str()) + .unwrap_or(""); + let close = tool_call_tokens + .as_ref() + .map(|p| p.close_text.as_str()) + .unwrap_or(""); + let raw = format!("{open}{buffer}{close}"); + if !emit_delta(&raw, &tx, in_reasoning).await { + consumer_alive = false; + } + } + } + break 'route; + } + ToolCallMarker::None => {} + } + if in_tool_call { + match decode_stream.step(nt) { + Ok(Some(s)) => tool_call_buf.push_str(&s), + Ok(None) => {} + Err(e) => tracing::warn!( + error = %e, + "decode_stream step failed (in tool_call)" + ), + } + break 'route; + } + if handle_reasoning_marker(nt, reasoning_tokens.as_ref(), &mut in_reasoning) { + break 'route; + } + match decode_stream.step(nt) { + Ok(Some(delta)) => { + if !emit_delta(&delta, &tx, in_reasoning).await { + consumer_alive = false; } } - None => { - let open = tool_call_tokens - .map(|p| p.open_text.as_str()) - .unwrap_or(""); - let close = tool_call_tokens - .map(|p| p.close_text.as_str()) - .unwrap_or(""); - let raw = format!("{open}{buffer}{close}"); - if !emit_delta(&raw, tx, *in_reasoning).await { - return false; - } - } - } - return true; - } - ToolCallMarker::None => {} - } - if *in_tool_call { - match decode_stream.step(next_token) { - Ok(Some(s)) => tool_call_buf.push_str(&s), - Ok(None) => {} - Err(e) => tracing::warn!(error = %e, "decode_stream step failed (in tool_call)"), - } - return true; - } - if handle_reasoning_marker(next_token, reasoning_tokens, in_reasoning) { - return true; - } - match decode_stream.step(next_token) { - Ok(Some(delta)) => { - if !emit_delta(&delta, tx, *in_reasoning).await { - return false; + Ok(None) => {} + Err(e) => tracing::warn!(error = %e, "decode_stream step failed"), } } - Ok(None) => {} - Err(e) => tracing::warn!(error = %e, "decode_stream step failed"), - } - true + consumer_alive + }}; } if Some(next_token) == eos_id { finish_reason = FinishReason::Stop; - } else if !route_token( - next_token, - &mut all_tokens, - &mut in_reasoning, - &mut in_tool_call, - &mut tool_call_buf, - &mut tool_call_idx, - reasoning_tokens.as_ref(), - tool_call_tokens.as_ref(), - &mut decode_stream, - &tx, - ) - .await - { + } else if !route_token!(next_token) { return Ok(finish_reason.as_openai_str().to_string()); } @@ -3656,20 +3659,7 @@ async fn stream_inference_via_worker( finish_reason = FinishReason::Stop; break; } - if !route_token( - next_token, - &mut all_tokens, - &mut in_reasoning, - &mut in_tool_call, - &mut tool_call_buf, - &mut tool_call_idx, - reasoning_tokens.as_ref(), - tool_call_tokens.as_ref(), - &mut decode_stream, - &tx, - ) - .await - { + if !route_token!(next_token) { return Ok(finish_reason.as_openai_str().to_string()); } }