diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 3d73c38..404a27c 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -27,8 +27,8 @@ use cortex_core::openai::{ }; use crate::wire::{ - FinishReason, InferenceEvent, ReasoningTokenPair, detect_reasoning_token_pair, - openai_chat as wire_chat, + FinishReason, InferenceEvent, ReasoningTokenPair, ToolCallTokenPair, + detect_reasoning_token_pair, detect_tool_call_token_pair, openai_chat as wire_chat, }; use std::collections::HashMap; use std::path::PathBuf; @@ -160,6 +160,13 @@ pub struct LoadedModel { /// [`InferenceEvent::ReasoningDelta`] at the token boundary; /// when `None` everything is `TextDelta`. pub reasoning_tokens: Option, + /// Open/close token IDs for the model's tool-call marker + /// pair (`` / `` on Qwen3-Coder / Hermes + /// / DeepSeek / gpt-oss). `None` for models that don't emit + /// structured tool calls in this convention; output passes + /// through as plain text in that case and the consumer parses + /// the markers itself if it knows how. + pub tool_call_tokens: Option, } impl LoadedModel { @@ -220,6 +227,8 @@ pub struct TpLoadedModel { /// load time. `None` when the model declares no reasoning /// markers. pub reasoning_tokens: Option, + /// Same shape as [`LoadedModel::tool_call_tokens`]. + pub tool_call_tokens: Option, } #[cfg(feature = "cuda")] @@ -1786,6 +1795,7 @@ impl CandleHarness { { let prompt_tokens = prompt_tokens.clone(); let reasoning_tokens_inner = loaded.reasoning_tokens.clone(); + let tool_call_tokens_inner = loaded.tool_call_tokens.clone(); tokio::spawn( async move { let _inference_guard = loaded_for_task.inference_lock.lock().await; @@ -1800,6 +1810,7 @@ impl CandleHarness { seed, eos_id, reasoning_tokens_inner, + tool_call_tokens_inner, tx, ) .await @@ -1840,6 +1851,7 @@ impl CandleHarness { } } else if let Some(arch_arc) = loaded.arch.clone() { let reasoning_tokens_inner = loaded.reasoning_tokens.clone(); + let tool_call_tokens_inner = loaded.tool_call_tokens.clone(); tokio::task::spawn_blocking(move || { let _g = span_for_task.enter(); // `blocking_lock` is safe here: spawn_blocking runs on @@ -1858,6 +1870,7 @@ impl CandleHarness { seed, eos_id, reasoning_tokens_inner.as_ref(), + tool_call_tokens_inner.as_ref(), &tx, ) { Ok(()) => tracing::info!( @@ -2057,6 +2070,17 @@ impl Harness for CandleHarness { "reasoning markers detected — streaming will route ReasoningDelta separately" ); } + let tool_call_tokens = detect_tool_call_token_pair(|s| tokenizer.token_to_id(s)); + if let Some(ref pair) = tool_call_tokens { + tracing::info!( + model = %spec.model_id, + open = %pair.open_text, + close = %pair.close_text, + open_id = pair.open_id, + close_id = pair.close_id, + "tool-call markers detected — streaming will emit structured ToolCall events" + ); + } let loaded = Arc::new(LoadedModel { model_id: spec.model_id.clone(), @@ -2070,6 +2094,7 @@ impl Harness for CandleHarness { arch_handle, inference_lock: tokio::sync::Mutex::new(()), reasoning_tokens, + tool_call_tokens, }); let mut models = self.models.write().await; @@ -2242,8 +2267,9 @@ impl CandleHarness { // 6. Tokenizer (same as single-GPU path). let tokenizer = Tokenizer::from_file(&tokenizer_path) .map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?; - // Reasoning-marker probe — identical to the single-GPU - // path. See `LoadedModel.reasoning_tokens` for the why. + // Reasoning + tool-call marker probes — identical to the + // single-GPU path. See LoadedModel's matching fields for + // the why. let reasoning_tokens = detect_reasoning_token_pair(|s| tokenizer.token_to_id(s)); if let Some(ref pair) = reasoning_tokens { tracing::info!( @@ -2253,6 +2279,15 @@ impl CandleHarness { "TP load: reasoning markers detected" ); } + let tool_call_tokens = detect_tool_call_token_pair(|s| tokenizer.token_to_id(s)); + if let Some(ref pair) = tool_call_tokens { + tracing::info!( + model = %spec.model_id, + open = %pair.open_text, + close = %pair.close_text, + "TP load: tool-call markers detected" + ); + } let tp_loaded = StdArc::new(TpLoadedModel { model_id: spec.model_id.clone(), @@ -2267,6 +2302,7 @@ impl CandleHarness { // TpLoadedModel so they reference the same thread. worker: leader_worker, reasoning_tokens, + tool_call_tokens, }); let mut models = self.models.write().await; @@ -2416,6 +2452,7 @@ impl CandleHarness { let created = unix_now_secs(); let tokenizer = tp.tokenizer.clone(); let reasoning_tokens = tp.reasoning_tokens.clone(); + let tool_call_tokens = tp.tool_call_tokens.clone(); // The spawned orchestration task below consumes both `id` // and `model_id` (tracing, pool lookups, NCCL ops use them // heavily). The wire projector at the bottom of this fn @@ -2481,10 +2518,13 @@ impl CandleHarness { // split a multi-byte char across tokens. let mut decode_stream = tokenizer.decode_stream(true); let mut finish_reason = FinishReason::Length; - // Reasoning marker state machine — same as the - // single-GPU path. The TP path needs its own copy - // because the spawn closure owns it. + // Reasoning + tool-call state machines — same as + // the single-GPU path. The TP path needs its own + // copies because the spawn closure owns them. let mut in_reasoning = false; + let mut in_tool_call = false; + let mut tool_call_buf = String::new(); + let mut tool_call_idx: usize = 0; 'work: { if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await { @@ -2553,28 +2593,82 @@ impl CandleHarness { if Some(next_token) == eos_id { finish_reason = FinishReason::Stop; - } else if handle_reasoning_marker( - next_token, - reasoning_tokens.as_ref(), - &mut in_reasoning, - ) { - all_tokens.push(next_token); } else { all_tokens.push(next_token); - match decode_stream.step(next_token) { - Ok(Some(delta)) => { - if !emit_delta(&delta, &tx, in_reasoning).await { - // Client gone — treat as normal stream end, - // not a failure. No log spam. - break 'work; + match handle_tool_call_marker( + next_token, + tool_call_tokens.as_ref(), + &mut in_tool_call, + &mut tool_call_buf, + ) { + ToolCallMarker::Enter => {} + ToolCallMarker::Exit { buffer } => { + let idx = tool_call_idx; + tool_call_idx += 1; + match parse_tool_call_body(&buffer, idx) { + Some((id, name, arguments)) => { + if tx + .send(InferenceEvent::ToolCall { + index: idx, + id, + name, + arguments, + }) + .await + .is_err() + { + break 'work; + } + } + None => { + let open = tool_call_tokens + .as_ref() + .map(|p| p.open_text.as_str()) + .unwrap_or(""); + let close = tool_call_tokens + .as_ref() + .map(|p| p.close_text.as_str()) + .unwrap_or(""); + let raw = format!("{open}{buffer}{close}"); + if !emit_delta(&raw, &tx, in_reasoning).await { + break 'work; + } + } + } + } + ToolCallMarker::None => { + if in_tool_call { + match decode_stream.step(next_token) { + Ok(Some(s)) => tool_call_buf.push_str(&s), + Ok(None) => {} + Err(e) => tracing::warn!( + model = %model_id, + error = %e, + "TP stream: decode_stream step failed (in tool_call)" + ), + } + } else if handle_reasoning_marker( + next_token, + reasoning_tokens.as_ref(), + &mut in_reasoning, + ) { + // marker — nothing to emit + } else { + match decode_stream.step(next_token) { + Ok(Some(delta)) => { + if !emit_delta(&delta, &tx, in_reasoning).await { + break 'work; + } + } + Ok(None) => {} + Err(e) => tracing::warn!( + model = %model_id, + error = %e, + "TP stream: decode_stream step failed" + ), + } } } - Ok(None) => {} - Err(e) => tracing::warn!( - model = %model_id, - error = %e, - "TP stream: decode_stream step failed" - ), } for index in 0..max_new.saturating_sub(1) { @@ -2638,15 +2732,70 @@ impl CandleHarness { finish_reason = FinishReason::Stop; break; } + all_tokens.push(next_token); + match handle_tool_call_marker( + next_token, + tool_call_tokens.as_ref(), + &mut in_tool_call, + &mut tool_call_buf, + ) { + ToolCallMarker::Enter => continue, + ToolCallMarker::Exit { buffer } => { + let idx = tool_call_idx; + tool_call_idx += 1; + match parse_tool_call_body(&buffer, idx) { + Some((id, name, arguments)) => { + if tx + .send(InferenceEvent::ToolCall { + index: idx, + id, + name, + arguments, + }) + .await + .is_err() + { + break 'work; + } + } + None => { + let open = tool_call_tokens + .as_ref() + .map(|p| p.open_text.as_str()) + .unwrap_or(""); + let close = tool_call_tokens + .as_ref() + .map(|p| p.close_text.as_str()) + .unwrap_or(""); + let raw = format!("{open}{buffer}{close}"); + if !emit_delta(&raw, &tx, in_reasoning).await { + break 'work; + } + } + } + continue; + } + ToolCallMarker::None => {} + } + if in_tool_call { + match decode_stream.step(next_token) { + Ok(Some(s)) => tool_call_buf.push_str(&s), + Ok(None) => {} + Err(e) => tracing::warn!( + model = %model_id, + error = %e, + "TP stream: decode_stream step failed (in tool_call)" + ), + } + continue; + } if handle_reasoning_marker( next_token, reasoning_tokens.as_ref(), &mut in_reasoning, ) { - all_tokens.push(next_token); continue; } - all_tokens.push(next_token); match decode_stream.step(next_token) { Ok(Some(delta)) => { if !emit_delta(&delta, &tx, in_reasoning).await { @@ -3013,6 +3162,68 @@ fn handle_reasoning_marker( false } +/// Outcome of checking a sampled token against the model's +/// tool-call markers. +enum ToolCallMarker { + /// Not a tool-call marker — caller proceeds with the normal + /// detokenize-and-emit path. + None, + /// `` token — caller starts buffering subsequent + /// detokenized text into the tool-call buffer instead of + /// emitting it. The token itself produces no output. + Enter, + /// `` token — caller takes ownership of the + /// buffered JSON, parses it, and emits either a structured + /// `InferenceEvent::ToolCall` or (on parse failure) the + /// original `{buf}` as text. The + /// returned buffer is `std::mem::take`-d out of the inner + /// state. + Exit { buffer: String }, +} + +fn handle_tool_call_marker( + next_token: u32, + pair: Option<&ToolCallTokenPair>, + in_tool_call: &mut bool, + buffer: &mut String, +) -> ToolCallMarker { + let Some(pair) = pair else { + return ToolCallMarker::None; + }; + if next_token == pair.open_id { + *in_tool_call = true; + buffer.clear(); + return ToolCallMarker::Enter; + } + if next_token == pair.close_id { + *in_tool_call = false; + return ToolCallMarker::Exit { + buffer: std::mem::take(buffer), + }; + } + ToolCallMarker::None +} + +/// Parse a `{json}` body into the fields the +/// `InferenceEvent::ToolCall` variant carries. Returns `None` when +/// the body isn't valid JSON or doesn't carry a `name`. The caller +/// falls back to passing the original text through on `None` so +/// downstream consumers (helexa-acp's existing `ToolCallParser`, +/// which has its own repair passes) can take another swing. +/// +/// Generates a fresh `call_` id per parsed call; the model +/// itself doesn't include ids in the wire convention we model. +fn parse_tool_call_body(body: &str, index: usize) -> Option<(String, String, String)> { + let value: serde_json::Value = serde_json::from_str(body.trim()).ok()?; + let name = value.get("name")?.as_str()?.to_string(); + let arguments = value + .get("arguments") + .map(|v| v.to_string()) + .unwrap_or_else(|| "{}".into()); + let id = format!("call_{:x}_{}", unix_subsec_nanos(), index); + Some((id, name, arguments)) +} + /// Errors returned by `CandleHarness::chat_completion`. The /// `ModelNotLoaded`, `PromptTooLong`, and `InsufficientVram` variants /// let the HTTP handler map cleanly to 404 / 400 / 503 without @@ -3176,6 +3387,7 @@ async fn stream_inference_via_worker( seed: u64, eos_id: Option, reasoning_tokens: Option, + tool_call_tokens: Option, tx: mpsc::Sender, ) -> Result { let mut logits_processor = { @@ -3200,10 +3412,14 @@ async fn stream_inference_via_worker( let mut decode_stream = tokenizer.decode_stream(true); let prompt_len = prompt_tokens.len(); let mut finish_reason = FinishReason::Length; - // Reasoning marker state machine — see `run_inference_streaming` - // for the why. Markers never reach `decode_stream`; they only - // toggle the variant `emit_delta` produces. + // Reasoning + tool-call state machines — see + // `run_inference_streaming` for the why. Markers never reach + // `decode_stream`; they toggle state. Tool-call content + // accumulates into `tool_call_buf` until the close marker. let mut in_reasoning = false; + let mut in_tool_call = false; + let mut tool_call_buf = String::new(); + let mut tool_call_idx: usize = 0; worker .clear_kv_cache(handle) @@ -3228,21 +3444,101 @@ async fn stream_inference_via_worker( } }; - if Some(next_token) == eos_id { - finish_reason = FinishReason::Stop; - } else if handle_reasoning_marker(next_token, reasoning_tokens.as_ref(), &mut in_reasoning) { - all_tokens.push(next_token); - } else { + // Inlined per-token routing — parallel to the TP path. Macro + // approach used in the CPU path doesn't translate cleanly + // here because the emit is async (.await) and macros don't + // tolerate `.await` inside reused expansions across two + // call sites well. + async fn route_token( + next_token: u32, + all_tokens: &mut Vec, + in_reasoning: &mut bool, + in_tool_call: &mut bool, + tool_call_buf: &mut String, + tool_call_idx: &mut usize, + reasoning_tokens: Option<&ReasoningTokenPair>, + tool_call_tokens: Option<&ToolCallTokenPair>, + decode_stream: &mut tokenizers::DecodeStream<'_>, + tx: &mpsc::Sender, + ) -> bool { all_tokens.push(next_token); + match handle_tool_call_marker(next_token, tool_call_tokens, in_tool_call, tool_call_buf) { + ToolCallMarker::Enter => return true, + ToolCallMarker::Exit { buffer } => { + let idx = *tool_call_idx; + *tool_call_idx += 1; + match parse_tool_call_body(&buffer, idx) { + Some((id, name, arguments)) => { + if tx + .send(InferenceEvent::ToolCall { + index: idx, + id, + name, + arguments, + }) + .await + .is_err() + { + return false; + } + } + None => { + let open = tool_call_tokens + .map(|p| p.open_text.as_str()) + .unwrap_or(""); + let close = tool_call_tokens + .map(|p| p.close_text.as_str()) + .unwrap_or(""); + let raw = format!("{open}{buffer}{close}"); + if !emit_delta(&raw, tx, *in_reasoning).await { + return false; + } + } + } + return true; + } + ToolCallMarker::None => {} + } + if *in_tool_call { + match decode_stream.step(next_token) { + Ok(Some(s)) => tool_call_buf.push_str(&s), + Ok(None) => {} + Err(e) => tracing::warn!(error = %e, "decode_stream step failed (in tool_call)"), + } + return true; + } + if handle_reasoning_marker(next_token, reasoning_tokens, in_reasoning) { + return true; + } match decode_stream.step(next_token) { Ok(Some(delta)) => { - if !emit_delta(&delta, &tx, in_reasoning).await { - return Ok(finish_reason.as_openai_str().to_string()); + if !emit_delta(&delta, tx, *in_reasoning).await { + return false; } } Ok(None) => {} Err(e) => tracing::warn!(error = %e, "decode_stream step failed"), } + true + } + + if Some(next_token) == eos_id { + finish_reason = FinishReason::Stop; + } else if !route_token( + next_token, + &mut all_tokens, + &mut in_reasoning, + &mut in_tool_call, + &mut tool_call_buf, + &mut tool_call_idx, + reasoning_tokens.as_ref(), + tool_call_tokens.as_ref(), + &mut decode_stream, + &tx, + ) + .await + { + return Ok(finish_reason.as_openai_str().to_string()); } for index in 0..max_new.saturating_sub(1) { @@ -3267,19 +3563,21 @@ async fn stream_inference_via_worker( finish_reason = FinishReason::Stop; break; } - if handle_reasoning_marker(next_token, reasoning_tokens.as_ref(), &mut in_reasoning) { - all_tokens.push(next_token); - continue; - } - all_tokens.push(next_token); - match decode_stream.step(next_token) { - Ok(Some(delta)) => { - if !emit_delta(&delta, &tx, in_reasoning).await { - return Ok(finish_reason.as_openai_str().to_string()); - } - } - Ok(None) => {} - Err(e) => tracing::warn!(error = %e, "decode_stream step failed"), + if !route_token( + next_token, + &mut all_tokens, + &mut in_reasoning, + &mut in_tool_call, + &mut tool_call_buf, + &mut tool_call_idx, + reasoning_tokens.as_ref(), + tool_call_tokens.as_ref(), + &mut decode_stream, + &tx, + ) + .await + { + return Ok(finish_reason.as_openai_str().to_string()); } } @@ -3360,6 +3658,7 @@ fn run_inference_streaming( seed: u64, eos_id: Option, reasoning_tokens: Option<&ReasoningTokenPair>, + tool_call_tokens: Option<&ToolCallTokenPair>, tx: &mpsc::Sender, ) -> Result<()> { let mut logits_processor = { @@ -3387,26 +3686,101 @@ fn run_inference_streaming( // `decode_stream` — they aren't part of any visible output, // they exist purely as state transitions. let mut in_reasoning = false; + // Tool-call state. While `in_tool_call`, content tokens get + // accumulated into `tool_call_buf` instead of emitted; on the + // close marker we parse the buffer and emit a structured + // ToolCall event (or fall back to passing the raw text + // through if the buffer doesn't parse). + let mut in_tool_call = false; + let mut tool_call_buf = String::new(); + let mut tool_call_idx: usize = 0; arch.clear_kv_cache()?; let logits = chunked_prefill_local(arch, device, prompt_tokens)?; let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?; - if Some(next_token) == eos_id { - finish_reason = FinishReason::Stop; - } else if handle_reasoning_marker(next_token, reasoning_tokens, &mut in_reasoning) { - all_tokens.push(next_token); - } else { - all_tokens.push(next_token); - match decode_stream.step(next_token) { - Ok(Some(delta)) => { - if !emit_delta_blocking(&delta, tx, in_reasoning) { - return Ok(()); + // Per-token routing block, used at both the prefill-sample + // tail and the decode loop. Macros are ugly but Rust's + // closure inference fights `&mut DecodeStream<'_>` capture + + // mutable borrows of the surrounding `tool_call_buf` / + // `in_reasoning` / etc. Inline the body via a macro and live + // with the duplication of the call sites instead. + macro_rules! route_token { + ($next_token:expr) => {{ + let nt = $next_token; + all_tokens.push(nt); + match handle_tool_call_marker(nt, tool_call_tokens, &mut in_tool_call, &mut tool_call_buf) { + ToolCallMarker::Enter => {} + ToolCallMarker::Exit { buffer } => { + let idx = tool_call_idx; + tool_call_idx += 1; + match parse_tool_call_body(&buffer, idx) { + Some((id, name, arguments)) => { + if tx + .blocking_send(InferenceEvent::ToolCall { + index: idx, + id, + name, + arguments, + }) + .is_err() + { + return Ok(()); + } + } + None => { + // Malformed JSON — pass the block + // through as text so consumer parsers + // can try their own repair. + let open = tool_call_tokens + .map(|p| p.open_text.as_str()) + .unwrap_or(""); + let close = tool_call_tokens + .map(|p| p.close_text.as_str()) + .unwrap_or(""); + let raw = format!("{open}{buffer}{close}"); + if !emit_delta_blocking(&raw, tx, in_reasoning) { + return Ok(()); + } + } + } + } + ToolCallMarker::None => { + if in_tool_call { + // Buffer JSON content without emitting. + match decode_stream.step(nt) { + Ok(Some(s)) => tool_call_buf.push_str(&s), + Ok(None) => {} + Err(e) => tracing::warn!( + error = %e, + "stream: decode_stream step failed (in tool_call)" + ), + } + } else if handle_reasoning_marker(nt, reasoning_tokens, &mut in_reasoning) { + // marker — nothing to emit + } else { + match decode_stream.step(nt) { + Ok(Some(delta)) => { + if !emit_delta_blocking(&delta, tx, in_reasoning) { + return Ok(()); + } + } + Ok(None) => {} + Err(e) => tracing::warn!( + error = %e, + "stream: decode_stream step failed" + ), + } + } } } - Ok(None) => {} - Err(e) => tracing::warn!(error = %e, "stream: decode_stream step failed"), - } + }}; + } + + if Some(next_token) == eos_id { + finish_reason = FinishReason::Stop; + } else { + route_token!(next_token); } for index in 0..max_new.saturating_sub(1) { @@ -3417,20 +3791,7 @@ fn run_inference_streaming( finish_reason = FinishReason::Stop; break; } - if handle_reasoning_marker(next_token, reasoning_tokens, &mut in_reasoning) { - all_tokens.push(next_token); - continue; - } - all_tokens.push(next_token); - match decode_stream.step(next_token) { - Ok(Some(delta)) => { - if !emit_delta_blocking(&delta, tx, in_reasoning) { - return Ok(()); - } - } - Ok(None) => {} - Err(e) => tracing::warn!(error = %e, "stream: decode_stream step failed"), - } + route_token!(next_token); } let _ = tx.blocking_send(InferenceEvent::Finish { diff --git a/crates/neuron/src/wire/event.rs b/crates/neuron/src/wire/event.rs index ff39a00..4344a30 100644 --- a/crates/neuron/src/wire/event.rs +++ b/crates/neuron/src/wire/event.rs @@ -44,16 +44,29 @@ pub enum InferenceEvent { /// concatenate into the complete reply. TextDelta(String), /// Reasoning / scratchpad text the model emitted inside a - /// `` block (or equivalent). Producers that don't - /// surface reasoning separately use [`TextDelta`] for - /// everything; future split lives here. - /// - /// Not yet emitted by the candle harness — present so future - /// stages (qwen3 `` routing, OpenAI o-series reasoning) - /// have a typed home without breaking the existing - /// projections. - #[allow(dead_code)] + /// `` block (or equivalent). The harness routes + /// content between marker tokens here so wire projectors can + /// decide what to do with it (chat completions drops by + /// default; Responses API has a dedicated event family). ReasoningDelta(String), + /// A tool call has been parsed out of a `{json}` + /// block. Carries the parsed name + arguments JSON string + /// (Anthropic / OpenAI projectors emit their own wire shape + /// from this). + /// + /// `index` is the call slot — incremented per tool call in a + /// turn so wire formats that order calls by index + /// (OpenAI chat completions) can correlate. + ToolCall { + index: usize, + id: String, + name: String, + /// Complete JSON arguments string. The model could in + /// principle stream these token-by-token, but our + /// extraction buffers the whole block until `` + /// arrives and emits exactly one event per call. + arguments: String, + }, /// The stream is complete. Carries the reason so wire formats /// that use it (OpenAI's `finish_reason`, Anthropic's /// `stop_reason`) can render it without re-parsing. @@ -137,6 +150,51 @@ const KNOWN_REASONING_MARKERS: &[(&str, &str)] = &[ ("", ""), ]; +/// Open/close token IDs for the model's tool-call marker +/// convention (or `None` for models that don't emit structured +/// tool calls). Same shape as [`ReasoningTokenPair`]: probed once +/// at load time, consumed by the inference loop to switch between +/// "emit visible deltas" and "buffer JSON for the next tool +/// call". +#[derive(Debug, Clone)] +pub struct ToolCallTokenPair { + pub open_id: u32, + pub close_id: u32, + pub open_text: String, + pub close_text: String, +} + +/// Tool-call marker conventions. Open-weight tool-use models +/// converged on `` / `` (Qwen3-Coder / +/// -Instruct, the Hermes function-call format, DeepSeek-Coder, +/// gpt-oss). The pair lives alongside the reasoning markers in +/// the same `added_tokens` table. +const KNOWN_TOOL_CALL_MARKERS: &[(&str, &str)] = &[("", "")]; + +/// Probe a tokenizer for known tool-call marker pairs. Mirrors +/// [`detect_reasoning_token_pair`] — both open AND close must +/// resolve for the pair to be returned. `None` means the model +/// doesn't emit structured tool calls (or its tokenizer split +/// the markers across tokens). +pub fn detect_tool_call_token_pair(token_to_id: F) -> Option +where + F: Fn(&str) -> Option, +{ + for (open_text, close_text) in KNOWN_TOOL_CALL_MARKERS { + let open_id = token_to_id(open_text); + let close_id = token_to_id(close_text); + if let (Some(open_id), Some(close_id)) = (open_id, close_id) { + return Some(ToolCallTokenPair { + open_id, + close_id, + open_text: (*open_text).into(), + close_text: (*close_text).into(), + }); + } + } + None +} + /// Inspect a tokenizer for known reasoning-marker pairs and return /// the first match. The tokenizer types this trait is defined over /// just need to expose `token_to_id(&str) -> Option` so this @@ -213,6 +271,24 @@ mod tests { assert!(detect_reasoning_token_pair(lookup(&m)).is_none()); } + #[test] + fn detects_tool_call_markers() { + let mut m = HashMap::new(); + m.insert("", 151657); + m.insert("", 151658); + let pair = detect_tool_call_token_pair(lookup(&m)).expect("pair detected"); + assert_eq!(pair.open_id, 151657); + assert_eq!(pair.close_id, 151658); + assert_eq!(pair.open_text, ""); + assert_eq!(pair.close_text, ""); + } + + #[test] + fn returns_none_for_non_tool_use_tokenizer() { + let m: HashMap<&'static str, u32> = HashMap::new(); + assert!(detect_tool_call_token_pair(lookup(&m)).is_none()); + } + #[test] fn first_match_wins_when_multiple_pairs_declared() { // Hypothetical tokenizer with both Qwen-style AND Mistral-style diff --git a/crates/neuron/src/wire/mod.rs b/crates/neuron/src/wire/mod.rs index b531a91..ee4ccc4 100644 --- a/crates/neuron/src/wire/mod.rs +++ b/crates/neuron/src/wire/mod.rs @@ -21,4 +21,7 @@ pub mod event; pub mod openai_chat; pub mod openai_responses; -pub use event::{FinishReason, InferenceEvent, ReasoningTokenPair, detect_reasoning_token_pair}; +pub use event::{ + FinishReason, InferenceEvent, ReasoningTokenPair, ToolCallTokenPair, + detect_reasoning_token_pair, detect_tool_call_token_pair, +}; diff --git a/crates/neuron/src/wire/openai_chat.rs b/crates/neuron/src/wire/openai_chat.rs index 868e25c..b43173d 100644 --- a/crates/neuron/src/wire/openai_chat.rs +++ b/crates/neuron/src/wire/openai_chat.rs @@ -172,6 +172,22 @@ pub fn project_chat_stream_with( was_in_reasoning = true; chunks } + InferenceEvent::ToolCall { + index, + id: call_id, + name, + arguments, + } => { + // OpenAI streaming shape for tool calls: + // `delta.tool_calls[]` with id + function.name + // on the first chunk per index, then + // function.arguments deltas. We have the + // complete arguments buffered already, so one + // delta carries everything. + vec![tool_call_chunk( + &id, created, &model_id, index, &call_id, &name, &arguments, + )] + } InferenceEvent::Finish { reason } => { vec![final_chunk(&id, created, &model_id, reason)] } @@ -222,6 +238,47 @@ fn content_chunk(id: &str, created: u64, model_id: &str, text: &str) -> ChatComp } } +/// OpenAI chat streaming shape for a tool call. One chunk per +/// call slot, carrying id + name + the complete arguments JSON. +/// Mirrors the format real OpenAI emits on the streaming path, +/// minus the per-token arguments-streaming complication (we have +/// the whole buffer already after the model finishes the +/// `...` block). +fn tool_call_chunk( + id: &str, + created: u64, + model_id: &str, + index: usize, + call_id: &str, + name: &str, + arguments: &str, +) -> ChatCompletionChunk { + ChatCompletionChunk { + id: id.into(), + object: "chat.completion.chunk".into(), + created, + model: model_id.into(), + choices: vec![ChunkChoice { + index: 0, + delta: json!({ + "tool_calls": [{ + "index": index, + "id": call_id, + "type": "function", + "function": { + "name": name, + "arguments": arguments, + } + }], + }), + finish_reason: None, + extra: serde_json::Value::Object(Default::default()), + }], + usage: None, + extra: serde_json::Value::Object(Default::default()), + } +} + fn final_chunk( id: &str, created: u64, diff --git a/crates/neuron/src/wire/openai_responses.rs b/crates/neuron/src/wire/openai_responses.rs index be6b54d..e7ade8f 100644 --- a/crates/neuron/src/wire/openai_responses.rs +++ b/crates/neuron/src/wire/openai_responses.rs @@ -296,6 +296,13 @@ async fn run_projection( // Stage where it'd land: a `response.reasoning_*` // event family alongside `response.output_text.*`. } + InferenceEvent::ToolCall { .. } => { + // Responses-side tool-call routing not wired yet + // (would emit response.function_call_arguments.* + // events). Drop for now; the chat-completions + // projector handles tool calls. Future work + // tracked in #7 alongside the in_progress event. + } InferenceEvent::Finish { reason } => { finish = Some(reason); }