feat(neuron): extract <tool_call> blocks to structured tool_calls deltas
Some checks failed
build-prerelease / Build cortex binary (push) Blocked by required conditions
CI / Clippy (push) Waiting to run
CI / Test (push) Waiting to run
CI / CUDA type-check (push) Failing after 17s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Format (push) Successful in 32s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
Some checks failed
build-prerelease / Build cortex binary (push) Blocked by required conditions
CI / Clippy (push) Waiting to run
CI / Test (push) Waiting to run
CI / CUDA type-check (push) Failing after 17s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Format (push) Successful in 32s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
Closes #6. Same model-agnostic seam as #8 but for tool-call markers (`<tool_call>` / `</tool_call>` on Qwen3-Coder, Hermes-format, DeepSeek-Coder, gpt-oss, …). Lets Zed's tool-use feature and any other vanilla OpenAI chat client get structured `tool_calls` deltas out of cortex without having to parse markers themselves. ## Implementation 1. **Tokenizer probe at load time** (`detect_tool_call_token_pair` in `wire::event`) — same shape as the reasoning-marker probe from #8. Both open AND close must resolve to single token ids; non-tool-use models get `None` and pass through unchanged. Stored on `LoadedModel.tool_call_tokens` and the TP analogue. 2. **New `InferenceEvent::ToolCall` variant** — carries `index` (call slot, per-turn counter), generated `id` (`call_<hex>_<idx>`), `name`, and the complete `arguments` JSON string. One event per parsed call. 3. **Token-level state machine** in all three streaming paths (CPU `run_inference_streaming`, CUDA single-GPU `stream_inference_via_worker`, CUDA TP `chat_completion_tp_stream`) layered on top of #8's reasoning routing: - `<tool_call>` token → enter buffering state, clear buffer. - Tokens while buffering → accumulate into `tool_call_buf` via the decoder (so multi-byte UTF-8 still buffers correctly) without emitting anything visible. - `</tool_call>` token → take the buffer, parse with `parse_tool_call_body` (extract `name` + `arguments`), emit a structured `ToolCall` event with a fresh `call_<hex>` id and the parsed fields. - On parse failure → fall back to re-emitting the original `<tool_call>{buf}</tool_call>` block as plain text content so helexa-acp's existing `ToolCallParser` repair passes still have a chance to recover the call. 4. **OpenAI chat projector** emits the OpenAI streaming `tool_calls` delta shape on `InferenceEvent::ToolCall` — `{tool_calls: [{index, id, type:"function", function:{name, arguments}}]}`. One chunk per call slot. 5. **OpenAI Responses projector** drops `ToolCall` events for now (Responses-side function_call event family routing tracked under #7); the chat path is what unblocks Zed's tool use today. ## Acceptance - Vanilla OpenAI chat clients (Zed's tool-use feature, any other OpenAI-compatible tool-call consumer) get structured tool_calls deltas against cortex+neuron without having to parse `<tool_call>` markers in content. - helexa-acp continues to work — when neuron parses cleanly, it consumes the structured deltas through its existing decoder. When the model emits malformed JSON, neuron falls back to text pass-through and helexa-acp's `ToolCallParser` recovers via the same path it always did. - Models without tool-call markers in their tokenizer pass through unchanged. - No hardcoded model knowledge — entirely driven by tokenizer metadata. ## Tests 2 new detection tests in `wire::event` (Qwen3-style marker detection, no-marker case). The streaming paths themselves stay covered by the existing chat-completions integration tests; full end-to-end exercise of the new path requires GPU-loaded models and lives outside the CI test surface. 215 workspace tests pass; clippy + fmt clean across the workspace. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -27,8 +27,8 @@ use cortex_core::openai::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
use crate::wire::{
|
use crate::wire::{
|
||||||
FinishReason, InferenceEvent, ReasoningTokenPair, detect_reasoning_token_pair,
|
FinishReason, InferenceEvent, ReasoningTokenPair, ToolCallTokenPair,
|
||||||
openai_chat as wire_chat,
|
detect_reasoning_token_pair, detect_tool_call_token_pair, openai_chat as wire_chat,
|
||||||
};
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
@@ -160,6 +160,13 @@ pub struct LoadedModel {
|
|||||||
/// [`InferenceEvent::ReasoningDelta`] at the token boundary;
|
/// [`InferenceEvent::ReasoningDelta`] at the token boundary;
|
||||||
/// when `None` everything is `TextDelta`.
|
/// when `None` everything is `TextDelta`.
|
||||||
pub reasoning_tokens: Option<ReasoningTokenPair>,
|
pub reasoning_tokens: Option<ReasoningTokenPair>,
|
||||||
|
/// Open/close token IDs for the model's tool-call marker
|
||||||
|
/// pair (`<tool_call>` / `</tool_call>` on Qwen3-Coder / Hermes
|
||||||
|
/// / DeepSeek / gpt-oss). `None` for models that don't emit
|
||||||
|
/// structured tool calls in this convention; output passes
|
||||||
|
/// through as plain text in that case and the consumer parses
|
||||||
|
/// the markers itself if it knows how.
|
||||||
|
pub tool_call_tokens: Option<ToolCallTokenPair>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LoadedModel {
|
impl LoadedModel {
|
||||||
@@ -220,6 +227,8 @@ pub struct TpLoadedModel {
|
|||||||
/// load time. `None` when the model declares no reasoning
|
/// load time. `None` when the model declares no reasoning
|
||||||
/// markers.
|
/// markers.
|
||||||
pub reasoning_tokens: Option<ReasoningTokenPair>,
|
pub reasoning_tokens: Option<ReasoningTokenPair>,
|
||||||
|
/// Same shape as [`LoadedModel::tool_call_tokens`].
|
||||||
|
pub tool_call_tokens: Option<ToolCallTokenPair>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
@@ -1786,6 +1795,7 @@ impl CandleHarness {
|
|||||||
{
|
{
|
||||||
let prompt_tokens = prompt_tokens.clone();
|
let prompt_tokens = prompt_tokens.clone();
|
||||||
let reasoning_tokens_inner = loaded.reasoning_tokens.clone();
|
let reasoning_tokens_inner = loaded.reasoning_tokens.clone();
|
||||||
|
let tool_call_tokens_inner = loaded.tool_call_tokens.clone();
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
async move {
|
async move {
|
||||||
let _inference_guard = loaded_for_task.inference_lock.lock().await;
|
let _inference_guard = loaded_for_task.inference_lock.lock().await;
|
||||||
@@ -1800,6 +1810,7 @@ impl CandleHarness {
|
|||||||
seed,
|
seed,
|
||||||
eos_id,
|
eos_id,
|
||||||
reasoning_tokens_inner,
|
reasoning_tokens_inner,
|
||||||
|
tool_call_tokens_inner,
|
||||||
tx,
|
tx,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@@ -1840,6 +1851,7 @@ impl CandleHarness {
|
|||||||
}
|
}
|
||||||
} else if let Some(arch_arc) = loaded.arch.clone() {
|
} else if let Some(arch_arc) = loaded.arch.clone() {
|
||||||
let reasoning_tokens_inner = loaded.reasoning_tokens.clone();
|
let reasoning_tokens_inner = loaded.reasoning_tokens.clone();
|
||||||
|
let tool_call_tokens_inner = loaded.tool_call_tokens.clone();
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
let _g = span_for_task.enter();
|
let _g = span_for_task.enter();
|
||||||
// `blocking_lock` is safe here: spawn_blocking runs on
|
// `blocking_lock` is safe here: spawn_blocking runs on
|
||||||
@@ -1858,6 +1870,7 @@ impl CandleHarness {
|
|||||||
seed,
|
seed,
|
||||||
eos_id,
|
eos_id,
|
||||||
reasoning_tokens_inner.as_ref(),
|
reasoning_tokens_inner.as_ref(),
|
||||||
|
tool_call_tokens_inner.as_ref(),
|
||||||
&tx,
|
&tx,
|
||||||
) {
|
) {
|
||||||
Ok(()) => tracing::info!(
|
Ok(()) => tracing::info!(
|
||||||
@@ -2057,6 +2070,17 @@ impl Harness for CandleHarness {
|
|||||||
"reasoning markers detected — streaming will route ReasoningDelta separately"
|
"reasoning markers detected — streaming will route ReasoningDelta separately"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
let tool_call_tokens = detect_tool_call_token_pair(|s| tokenizer.token_to_id(s));
|
||||||
|
if let Some(ref pair) = tool_call_tokens {
|
||||||
|
tracing::info!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
open = %pair.open_text,
|
||||||
|
close = %pair.close_text,
|
||||||
|
open_id = pair.open_id,
|
||||||
|
close_id = pair.close_id,
|
||||||
|
"tool-call markers detected — streaming will emit structured ToolCall events"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let loaded = Arc::new(LoadedModel {
|
let loaded = Arc::new(LoadedModel {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: spec.model_id.clone(),
|
||||||
@@ -2070,6 +2094,7 @@ impl Harness for CandleHarness {
|
|||||||
arch_handle,
|
arch_handle,
|
||||||
inference_lock: tokio::sync::Mutex::new(()),
|
inference_lock: tokio::sync::Mutex::new(()),
|
||||||
reasoning_tokens,
|
reasoning_tokens,
|
||||||
|
tool_call_tokens,
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut models = self.models.write().await;
|
let mut models = self.models.write().await;
|
||||||
@@ -2242,8 +2267,9 @@ impl CandleHarness {
|
|||||||
// 6. Tokenizer (same as single-GPU path).
|
// 6. Tokenizer (same as single-GPU path).
|
||||||
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
||||||
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
||||||
// Reasoning-marker probe — identical to the single-GPU
|
// Reasoning + tool-call marker probes — identical to the
|
||||||
// path. See `LoadedModel.reasoning_tokens` for the why.
|
// single-GPU path. See LoadedModel's matching fields for
|
||||||
|
// the why.
|
||||||
let reasoning_tokens = detect_reasoning_token_pair(|s| tokenizer.token_to_id(s));
|
let reasoning_tokens = detect_reasoning_token_pair(|s| tokenizer.token_to_id(s));
|
||||||
if let Some(ref pair) = reasoning_tokens {
|
if let Some(ref pair) = reasoning_tokens {
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
@@ -2253,6 +2279,15 @@ impl CandleHarness {
|
|||||||
"TP load: reasoning markers detected"
|
"TP load: reasoning markers detected"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
let tool_call_tokens = detect_tool_call_token_pair(|s| tokenizer.token_to_id(s));
|
||||||
|
if let Some(ref pair) = tool_call_tokens {
|
||||||
|
tracing::info!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
open = %pair.open_text,
|
||||||
|
close = %pair.close_text,
|
||||||
|
"TP load: tool-call markers detected"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let tp_loaded = StdArc::new(TpLoadedModel {
|
let tp_loaded = StdArc::new(TpLoadedModel {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: spec.model_id.clone(),
|
||||||
@@ -2267,6 +2302,7 @@ impl CandleHarness {
|
|||||||
// TpLoadedModel so they reference the same thread.
|
// TpLoadedModel so they reference the same thread.
|
||||||
worker: leader_worker,
|
worker: leader_worker,
|
||||||
reasoning_tokens,
|
reasoning_tokens,
|
||||||
|
tool_call_tokens,
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut models = self.models.write().await;
|
let mut models = self.models.write().await;
|
||||||
@@ -2416,6 +2452,7 @@ impl CandleHarness {
|
|||||||
let created = unix_now_secs();
|
let created = unix_now_secs();
|
||||||
let tokenizer = tp.tokenizer.clone();
|
let tokenizer = tp.tokenizer.clone();
|
||||||
let reasoning_tokens = tp.reasoning_tokens.clone();
|
let reasoning_tokens = tp.reasoning_tokens.clone();
|
||||||
|
let tool_call_tokens = tp.tool_call_tokens.clone();
|
||||||
// The spawned orchestration task below consumes both `id`
|
// The spawned orchestration task below consumes both `id`
|
||||||
// and `model_id` (tracing, pool lookups, NCCL ops use them
|
// and `model_id` (tracing, pool lookups, NCCL ops use them
|
||||||
// heavily). The wire projector at the bottom of this fn
|
// heavily). The wire projector at the bottom of this fn
|
||||||
@@ -2481,10 +2518,13 @@ impl CandleHarness {
|
|||||||
// split a multi-byte char across tokens.
|
// split a multi-byte char across tokens.
|
||||||
let mut decode_stream = tokenizer.decode_stream(true);
|
let mut decode_stream = tokenizer.decode_stream(true);
|
||||||
let mut finish_reason = FinishReason::Length;
|
let mut finish_reason = FinishReason::Length;
|
||||||
// Reasoning marker state machine — same as the
|
// Reasoning + tool-call state machines — same as
|
||||||
// single-GPU path. The TP path needs its own copy
|
// the single-GPU path. The TP path needs its own
|
||||||
// because the spawn closure owns it.
|
// copies because the spawn closure owns them.
|
||||||
let mut in_reasoning = false;
|
let mut in_reasoning = false;
|
||||||
|
let mut in_tool_call = false;
|
||||||
|
let mut tool_call_buf = String::new();
|
||||||
|
let mut tool_call_idx: usize = 0;
|
||||||
|
|
||||||
'work: {
|
'work: {
|
||||||
if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await {
|
if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await {
|
||||||
@@ -2553,19 +2593,70 @@ impl CandleHarness {
|
|||||||
|
|
||||||
if Some(next_token) == eos_id {
|
if Some(next_token) == eos_id {
|
||||||
finish_reason = FinishReason::Stop;
|
finish_reason = FinishReason::Stop;
|
||||||
|
} else {
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
match handle_tool_call_marker(
|
||||||
|
next_token,
|
||||||
|
tool_call_tokens.as_ref(),
|
||||||
|
&mut in_tool_call,
|
||||||
|
&mut tool_call_buf,
|
||||||
|
) {
|
||||||
|
ToolCallMarker::Enter => {}
|
||||||
|
ToolCallMarker::Exit { buffer } => {
|
||||||
|
let idx = tool_call_idx;
|
||||||
|
tool_call_idx += 1;
|
||||||
|
match parse_tool_call_body(&buffer, idx) {
|
||||||
|
Some((id, name, arguments)) => {
|
||||||
|
if tx
|
||||||
|
.send(InferenceEvent::ToolCall {
|
||||||
|
index: idx,
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
break 'work;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let open = tool_call_tokens
|
||||||
|
.as_ref()
|
||||||
|
.map(|p| p.open_text.as_str())
|
||||||
|
.unwrap_or("<tool_call>");
|
||||||
|
let close = tool_call_tokens
|
||||||
|
.as_ref()
|
||||||
|
.map(|p| p.close_text.as_str())
|
||||||
|
.unwrap_or("</tool_call>");
|
||||||
|
let raw = format!("{open}{buffer}{close}");
|
||||||
|
if !emit_delta(&raw, &tx, in_reasoning).await {
|
||||||
|
break 'work;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ToolCallMarker::None => {
|
||||||
|
if in_tool_call {
|
||||||
|
match decode_stream.step(next_token) {
|
||||||
|
Ok(Some(s)) => tool_call_buf.push_str(&s),
|
||||||
|
Ok(None) => {}
|
||||||
|
Err(e) => tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"TP stream: decode_stream step failed (in tool_call)"
|
||||||
|
),
|
||||||
|
}
|
||||||
} else if handle_reasoning_marker(
|
} else if handle_reasoning_marker(
|
||||||
next_token,
|
next_token,
|
||||||
reasoning_tokens.as_ref(),
|
reasoning_tokens.as_ref(),
|
||||||
&mut in_reasoning,
|
&mut in_reasoning,
|
||||||
) {
|
) {
|
||||||
all_tokens.push(next_token);
|
// marker — nothing to emit
|
||||||
} else {
|
} else {
|
||||||
all_tokens.push(next_token);
|
|
||||||
match decode_stream.step(next_token) {
|
match decode_stream.step(next_token) {
|
||||||
Ok(Some(delta)) => {
|
Ok(Some(delta)) => {
|
||||||
if !emit_delta(&delta, &tx, in_reasoning).await {
|
if !emit_delta(&delta, &tx, in_reasoning).await {
|
||||||
// Client gone — treat as normal stream end,
|
|
||||||
// not a failure. No log spam.
|
|
||||||
break 'work;
|
break 'work;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2576,6 +2667,9 @@ impl CandleHarness {
|
|||||||
"TP stream: decode_stream step failed"
|
"TP stream: decode_stream step failed"
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for index in 0..max_new.saturating_sub(1) {
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
let logits_vec = match pool
|
let logits_vec = match pool
|
||||||
@@ -2638,15 +2732,70 @@ impl CandleHarness {
|
|||||||
finish_reason = FinishReason::Stop;
|
finish_reason = FinishReason::Stop;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
match handle_tool_call_marker(
|
||||||
|
next_token,
|
||||||
|
tool_call_tokens.as_ref(),
|
||||||
|
&mut in_tool_call,
|
||||||
|
&mut tool_call_buf,
|
||||||
|
) {
|
||||||
|
ToolCallMarker::Enter => continue,
|
||||||
|
ToolCallMarker::Exit { buffer } => {
|
||||||
|
let idx = tool_call_idx;
|
||||||
|
tool_call_idx += 1;
|
||||||
|
match parse_tool_call_body(&buffer, idx) {
|
||||||
|
Some((id, name, arguments)) => {
|
||||||
|
if tx
|
||||||
|
.send(InferenceEvent::ToolCall {
|
||||||
|
index: idx,
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
break 'work;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let open = tool_call_tokens
|
||||||
|
.as_ref()
|
||||||
|
.map(|p| p.open_text.as_str())
|
||||||
|
.unwrap_or("<tool_call>");
|
||||||
|
let close = tool_call_tokens
|
||||||
|
.as_ref()
|
||||||
|
.map(|p| p.close_text.as_str())
|
||||||
|
.unwrap_or("</tool_call>");
|
||||||
|
let raw = format!("{open}{buffer}{close}");
|
||||||
|
if !emit_delta(&raw, &tx, in_reasoning).await {
|
||||||
|
break 'work;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ToolCallMarker::None => {}
|
||||||
|
}
|
||||||
|
if in_tool_call {
|
||||||
|
match decode_stream.step(next_token) {
|
||||||
|
Ok(Some(s)) => tool_call_buf.push_str(&s),
|
||||||
|
Ok(None) => {}
|
||||||
|
Err(e) => tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"TP stream: decode_stream step failed (in tool_call)"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if handle_reasoning_marker(
|
if handle_reasoning_marker(
|
||||||
next_token,
|
next_token,
|
||||||
reasoning_tokens.as_ref(),
|
reasoning_tokens.as_ref(),
|
||||||
&mut in_reasoning,
|
&mut in_reasoning,
|
||||||
) {
|
) {
|
||||||
all_tokens.push(next_token);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
all_tokens.push(next_token);
|
|
||||||
match decode_stream.step(next_token) {
|
match decode_stream.step(next_token) {
|
||||||
Ok(Some(delta)) => {
|
Ok(Some(delta)) => {
|
||||||
if !emit_delta(&delta, &tx, in_reasoning).await {
|
if !emit_delta(&delta, &tx, in_reasoning).await {
|
||||||
@@ -3013,6 +3162,68 @@ fn handle_reasoning_marker(
|
|||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Outcome of checking a sampled token against the model's
|
||||||
|
/// tool-call markers.
|
||||||
|
enum ToolCallMarker {
|
||||||
|
/// Not a tool-call marker — caller proceeds with the normal
|
||||||
|
/// detokenize-and-emit path.
|
||||||
|
None,
|
||||||
|
/// `<tool_call>` token — caller starts buffering subsequent
|
||||||
|
/// detokenized text into the tool-call buffer instead of
|
||||||
|
/// emitting it. The token itself produces no output.
|
||||||
|
Enter,
|
||||||
|
/// `</tool_call>` token — caller takes ownership of the
|
||||||
|
/// buffered JSON, parses it, and emits either a structured
|
||||||
|
/// `InferenceEvent::ToolCall` or (on parse failure) the
|
||||||
|
/// original `<tool_call>{buf}</tool_call>` as text. The
|
||||||
|
/// returned buffer is `std::mem::take`-d out of the inner
|
||||||
|
/// state.
|
||||||
|
Exit { buffer: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_tool_call_marker(
|
||||||
|
next_token: u32,
|
||||||
|
pair: Option<&ToolCallTokenPair>,
|
||||||
|
in_tool_call: &mut bool,
|
||||||
|
buffer: &mut String,
|
||||||
|
) -> ToolCallMarker {
|
||||||
|
let Some(pair) = pair else {
|
||||||
|
return ToolCallMarker::None;
|
||||||
|
};
|
||||||
|
if next_token == pair.open_id {
|
||||||
|
*in_tool_call = true;
|
||||||
|
buffer.clear();
|
||||||
|
return ToolCallMarker::Enter;
|
||||||
|
}
|
||||||
|
if next_token == pair.close_id {
|
||||||
|
*in_tool_call = false;
|
||||||
|
return ToolCallMarker::Exit {
|
||||||
|
buffer: std::mem::take(buffer),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
ToolCallMarker::None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a `<tool_call>{json}</tool_call>` body into the fields the
|
||||||
|
/// `InferenceEvent::ToolCall` variant carries. Returns `None` when
|
||||||
|
/// the body isn't valid JSON or doesn't carry a `name`. The caller
|
||||||
|
/// falls back to passing the original text through on `None` so
|
||||||
|
/// downstream consumers (helexa-acp's existing `ToolCallParser`,
|
||||||
|
/// which has its own repair passes) can take another swing.
|
||||||
|
///
|
||||||
|
/// Generates a fresh `call_<hex>` id per parsed call; the model
|
||||||
|
/// itself doesn't include ids in the wire convention we model.
|
||||||
|
fn parse_tool_call_body(body: &str, index: usize) -> Option<(String, String, String)> {
|
||||||
|
let value: serde_json::Value = serde_json::from_str(body.trim()).ok()?;
|
||||||
|
let name = value.get("name")?.as_str()?.to_string();
|
||||||
|
let arguments = value
|
||||||
|
.get("arguments")
|
||||||
|
.map(|v| v.to_string())
|
||||||
|
.unwrap_or_else(|| "{}".into());
|
||||||
|
let id = format!("call_{:x}_{}", unix_subsec_nanos(), index);
|
||||||
|
Some((id, name, arguments))
|
||||||
|
}
|
||||||
|
|
||||||
/// Errors returned by `CandleHarness::chat_completion`. The
|
/// Errors returned by `CandleHarness::chat_completion`. The
|
||||||
/// `ModelNotLoaded`, `PromptTooLong`, and `InsufficientVram` variants
|
/// `ModelNotLoaded`, `PromptTooLong`, and `InsufficientVram` variants
|
||||||
/// let the HTTP handler map cleanly to 404 / 400 / 503 without
|
/// let the HTTP handler map cleanly to 404 / 400 / 503 without
|
||||||
@@ -3176,6 +3387,7 @@ async fn stream_inference_via_worker(
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
eos_id: Option<u32>,
|
eos_id: Option<u32>,
|
||||||
reasoning_tokens: Option<ReasoningTokenPair>,
|
reasoning_tokens: Option<ReasoningTokenPair>,
|
||||||
|
tool_call_tokens: Option<ToolCallTokenPair>,
|
||||||
tx: mpsc::Sender<InferenceEvent>,
|
tx: mpsc::Sender<InferenceEvent>,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
let mut logits_processor = {
|
let mut logits_processor = {
|
||||||
@@ -3200,10 +3412,14 @@ async fn stream_inference_via_worker(
|
|||||||
let mut decode_stream = tokenizer.decode_stream(true);
|
let mut decode_stream = tokenizer.decode_stream(true);
|
||||||
let prompt_len = prompt_tokens.len();
|
let prompt_len = prompt_tokens.len();
|
||||||
let mut finish_reason = FinishReason::Length;
|
let mut finish_reason = FinishReason::Length;
|
||||||
// Reasoning marker state machine — see `run_inference_streaming`
|
// Reasoning + tool-call state machines — see
|
||||||
// for the why. Markers never reach `decode_stream`; they only
|
// `run_inference_streaming` for the why. Markers never reach
|
||||||
// toggle the variant `emit_delta` produces.
|
// `decode_stream`; they toggle state. Tool-call content
|
||||||
|
// accumulates into `tool_call_buf` until the close marker.
|
||||||
let mut in_reasoning = false;
|
let mut in_reasoning = false;
|
||||||
|
let mut in_tool_call = false;
|
||||||
|
let mut tool_call_buf = String::new();
|
||||||
|
let mut tool_call_idx: usize = 0;
|
||||||
|
|
||||||
worker
|
worker
|
||||||
.clear_kv_cache(handle)
|
.clear_kv_cache(handle)
|
||||||
@@ -3228,21 +3444,101 @@ async fn stream_inference_via_worker(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if Some(next_token) == eos_id {
|
// Inlined per-token routing — parallel to the TP path. Macro
|
||||||
finish_reason = FinishReason::Stop;
|
// approach used in the CPU path doesn't translate cleanly
|
||||||
} else if handle_reasoning_marker(next_token, reasoning_tokens.as_ref(), &mut in_reasoning) {
|
// here because the emit is async (.await) and macros don't
|
||||||
all_tokens.push(next_token);
|
// tolerate `.await` inside reused expansions across two
|
||||||
} else {
|
// call sites well.
|
||||||
|
async fn route_token(
|
||||||
|
next_token: u32,
|
||||||
|
all_tokens: &mut Vec<u32>,
|
||||||
|
in_reasoning: &mut bool,
|
||||||
|
in_tool_call: &mut bool,
|
||||||
|
tool_call_buf: &mut String,
|
||||||
|
tool_call_idx: &mut usize,
|
||||||
|
reasoning_tokens: Option<&ReasoningTokenPair>,
|
||||||
|
tool_call_tokens: Option<&ToolCallTokenPair>,
|
||||||
|
decode_stream: &mut tokenizers::DecodeStream<'_>,
|
||||||
|
tx: &mpsc::Sender<InferenceEvent>,
|
||||||
|
) -> bool {
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
|
match handle_tool_call_marker(next_token, tool_call_tokens, in_tool_call, tool_call_buf) {
|
||||||
|
ToolCallMarker::Enter => return true,
|
||||||
|
ToolCallMarker::Exit { buffer } => {
|
||||||
|
let idx = *tool_call_idx;
|
||||||
|
*tool_call_idx += 1;
|
||||||
|
match parse_tool_call_body(&buffer, idx) {
|
||||||
|
Some((id, name, arguments)) => {
|
||||||
|
if tx
|
||||||
|
.send(InferenceEvent::ToolCall {
|
||||||
|
index: idx,
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let open = tool_call_tokens
|
||||||
|
.map(|p| p.open_text.as_str())
|
||||||
|
.unwrap_or("<tool_call>");
|
||||||
|
let close = tool_call_tokens
|
||||||
|
.map(|p| p.close_text.as_str())
|
||||||
|
.unwrap_or("</tool_call>");
|
||||||
|
let raw = format!("{open}{buffer}{close}");
|
||||||
|
if !emit_delta(&raw, tx, *in_reasoning).await {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
ToolCallMarker::None => {}
|
||||||
|
}
|
||||||
|
if *in_tool_call {
|
||||||
|
match decode_stream.step(next_token) {
|
||||||
|
Ok(Some(s)) => tool_call_buf.push_str(&s),
|
||||||
|
Ok(None) => {}
|
||||||
|
Err(e) => tracing::warn!(error = %e, "decode_stream step failed (in tool_call)"),
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if handle_reasoning_marker(next_token, reasoning_tokens, in_reasoning) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
match decode_stream.step(next_token) {
|
match decode_stream.step(next_token) {
|
||||||
Ok(Some(delta)) => {
|
Ok(Some(delta)) => {
|
||||||
if !emit_delta(&delta, &tx, in_reasoning).await {
|
if !emit_delta(&delta, tx, *in_reasoning).await {
|
||||||
return Ok(finish_reason.as_openai_str().to_string());
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(None) => {}
|
Ok(None) => {}
|
||||||
Err(e) => tracing::warn!(error = %e, "decode_stream step failed"),
|
Err(e) => tracing::warn!(error = %e, "decode_stream step failed"),
|
||||||
}
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
finish_reason = FinishReason::Stop;
|
||||||
|
} else if !route_token(
|
||||||
|
next_token,
|
||||||
|
&mut all_tokens,
|
||||||
|
&mut in_reasoning,
|
||||||
|
&mut in_tool_call,
|
||||||
|
&mut tool_call_buf,
|
||||||
|
&mut tool_call_idx,
|
||||||
|
reasoning_tokens.as_ref(),
|
||||||
|
tool_call_tokens.as_ref(),
|
||||||
|
&mut decode_stream,
|
||||||
|
&tx,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
return Ok(finish_reason.as_openai_str().to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
for index in 0..max_new.saturating_sub(1) {
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
@@ -3267,21 +3563,23 @@ async fn stream_inference_via_worker(
|
|||||||
finish_reason = FinishReason::Stop;
|
finish_reason = FinishReason::Stop;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if handle_reasoning_marker(next_token, reasoning_tokens.as_ref(), &mut in_reasoning) {
|
if !route_token(
|
||||||
all_tokens.push(next_token);
|
next_token,
|
||||||
continue;
|
&mut all_tokens,
|
||||||
}
|
&mut in_reasoning,
|
||||||
all_tokens.push(next_token);
|
&mut in_tool_call,
|
||||||
match decode_stream.step(next_token) {
|
&mut tool_call_buf,
|
||||||
Ok(Some(delta)) => {
|
&mut tool_call_idx,
|
||||||
if !emit_delta(&delta, &tx, in_reasoning).await {
|
reasoning_tokens.as_ref(),
|
||||||
|
tool_call_tokens.as_ref(),
|
||||||
|
&mut decode_stream,
|
||||||
|
&tx,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
return Ok(finish_reason.as_openai_str().to_string());
|
return Ok(finish_reason.as_openai_str().to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(None) => {}
|
|
||||||
Err(e) => tracing::warn!(error = %e, "decode_stream step failed"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Terminal Finish event. The wire projector turns this into a
|
// Terminal Finish event. The wire projector turns this into a
|
||||||
// format-specific final chunk (`finish_reason: "stop"` on
|
// format-specific final chunk (`finish_reason: "stop"` on
|
||||||
@@ -3360,6 +3658,7 @@ fn run_inference_streaming(
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
eos_id: Option<u32>,
|
eos_id: Option<u32>,
|
||||||
reasoning_tokens: Option<&ReasoningTokenPair>,
|
reasoning_tokens: Option<&ReasoningTokenPair>,
|
||||||
|
tool_call_tokens: Option<&ToolCallTokenPair>,
|
||||||
tx: &mpsc::Sender<InferenceEvent>,
|
tx: &mpsc::Sender<InferenceEvent>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut logits_processor = {
|
let mut logits_processor = {
|
||||||
@@ -3387,27 +3686,102 @@ fn run_inference_streaming(
|
|||||||
// `decode_stream` — they aren't part of any visible output,
|
// `decode_stream` — they aren't part of any visible output,
|
||||||
// they exist purely as state transitions.
|
// they exist purely as state transitions.
|
||||||
let mut in_reasoning = false;
|
let mut in_reasoning = false;
|
||||||
|
// Tool-call state. While `in_tool_call`, content tokens get
|
||||||
|
// accumulated into `tool_call_buf` instead of emitted; on the
|
||||||
|
// close marker we parse the buffer and emit a structured
|
||||||
|
// ToolCall event (or fall back to passing the raw text
|
||||||
|
// through if the buffer doesn't parse).
|
||||||
|
let mut in_tool_call = false;
|
||||||
|
let mut tool_call_buf = String::new();
|
||||||
|
let mut tool_call_idx: usize = 0;
|
||||||
|
|
||||||
arch.clear_kv_cache()?;
|
arch.clear_kv_cache()?;
|
||||||
let logits = chunked_prefill_local(arch, device, prompt_tokens)?;
|
let logits = chunked_prefill_local(arch, device, prompt_tokens)?;
|
||||||
let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
|
let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
|
||||||
|
|
||||||
if Some(next_token) == eos_id {
|
// Per-token routing block, used at both the prefill-sample
|
||||||
finish_reason = FinishReason::Stop;
|
// tail and the decode loop. Macros are ugly but Rust's
|
||||||
} else if handle_reasoning_marker(next_token, reasoning_tokens, &mut in_reasoning) {
|
// closure inference fights `&mut DecodeStream<'_>` capture +
|
||||||
all_tokens.push(next_token);
|
// mutable borrows of the surrounding `tool_call_buf` /
|
||||||
|
// `in_reasoning` / etc. Inline the body via a macro and live
|
||||||
|
// with the duplication of the call sites instead.
|
||||||
|
macro_rules! route_token {
|
||||||
|
($next_token:expr) => {{
|
||||||
|
let nt = $next_token;
|
||||||
|
all_tokens.push(nt);
|
||||||
|
match handle_tool_call_marker(nt, tool_call_tokens, &mut in_tool_call, &mut tool_call_buf) {
|
||||||
|
ToolCallMarker::Enter => {}
|
||||||
|
ToolCallMarker::Exit { buffer } => {
|
||||||
|
let idx = tool_call_idx;
|
||||||
|
tool_call_idx += 1;
|
||||||
|
match parse_tool_call_body(&buffer, idx) {
|
||||||
|
Some((id, name, arguments)) => {
|
||||||
|
if tx
|
||||||
|
.blocking_send(InferenceEvent::ToolCall {
|
||||||
|
index: idx,
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
// Malformed JSON — pass the block
|
||||||
|
// through as text so consumer parsers
|
||||||
|
// can try their own repair.
|
||||||
|
let open = tool_call_tokens
|
||||||
|
.map(|p| p.open_text.as_str())
|
||||||
|
.unwrap_or("<tool_call>");
|
||||||
|
let close = tool_call_tokens
|
||||||
|
.map(|p| p.close_text.as_str())
|
||||||
|
.unwrap_or("</tool_call>");
|
||||||
|
let raw = format!("{open}{buffer}{close}");
|
||||||
|
if !emit_delta_blocking(&raw, tx, in_reasoning) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ToolCallMarker::None => {
|
||||||
|
if in_tool_call {
|
||||||
|
// Buffer JSON content without emitting.
|
||||||
|
match decode_stream.step(nt) {
|
||||||
|
Ok(Some(s)) => tool_call_buf.push_str(&s),
|
||||||
|
Ok(None) => {}
|
||||||
|
Err(e) => tracing::warn!(
|
||||||
|
error = %e,
|
||||||
|
"stream: decode_stream step failed (in tool_call)"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
} else if handle_reasoning_marker(nt, reasoning_tokens, &mut in_reasoning) {
|
||||||
|
// marker — nothing to emit
|
||||||
} else {
|
} else {
|
||||||
all_tokens.push(next_token);
|
match decode_stream.step(nt) {
|
||||||
match decode_stream.step(next_token) {
|
|
||||||
Ok(Some(delta)) => {
|
Ok(Some(delta)) => {
|
||||||
if !emit_delta_blocking(&delta, tx, in_reasoning) {
|
if !emit_delta_blocking(&delta, tx, in_reasoning) {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(None) => {}
|
Ok(None) => {}
|
||||||
Err(e) => tracing::warn!(error = %e, "stream: decode_stream step failed"),
|
Err(e) => tracing::warn!(
|
||||||
|
error = %e,
|
||||||
|
"stream: decode_stream step failed"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
finish_reason = FinishReason::Stop;
|
||||||
|
} else {
|
||||||
|
route_token!(next_token);
|
||||||
|
}
|
||||||
|
|
||||||
for index in 0..max_new.saturating_sub(1) {
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||||
@@ -3417,20 +3791,7 @@ fn run_inference_streaming(
|
|||||||
finish_reason = FinishReason::Stop;
|
finish_reason = FinishReason::Stop;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if handle_reasoning_marker(next_token, reasoning_tokens, &mut in_reasoning) {
|
route_token!(next_token);
|
||||||
all_tokens.push(next_token);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
all_tokens.push(next_token);
|
|
||||||
match decode_stream.step(next_token) {
|
|
||||||
Ok(Some(delta)) => {
|
|
||||||
if !emit_delta_blocking(&delta, tx, in_reasoning) {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(None) => {}
|
|
||||||
Err(e) => tracing::warn!(error = %e, "stream: decode_stream step failed"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let _ = tx.blocking_send(InferenceEvent::Finish {
|
let _ = tx.blocking_send(InferenceEvent::Finish {
|
||||||
|
|||||||
@@ -44,16 +44,29 @@ pub enum InferenceEvent {
|
|||||||
/// concatenate into the complete reply.
|
/// concatenate into the complete reply.
|
||||||
TextDelta(String),
|
TextDelta(String),
|
||||||
/// Reasoning / scratchpad text the model emitted inside a
|
/// Reasoning / scratchpad text the model emitted inside a
|
||||||
/// `<think>` block (or equivalent). Producers that don't
|
/// `<think>` block (or equivalent). The harness routes
|
||||||
/// surface reasoning separately use [`TextDelta`] for
|
/// content between marker tokens here so wire projectors can
|
||||||
/// everything; future split lives here.
|
/// decide what to do with it (chat completions drops by
|
||||||
///
|
/// default; Responses API has a dedicated event family).
|
||||||
/// Not yet emitted by the candle harness — present so future
|
|
||||||
/// stages (qwen3 `<think>` routing, OpenAI o-series reasoning)
|
|
||||||
/// have a typed home without breaking the existing
|
|
||||||
/// projections.
|
|
||||||
#[allow(dead_code)]
|
|
||||||
ReasoningDelta(String),
|
ReasoningDelta(String),
|
||||||
|
/// A tool call has been parsed out of a `<tool_call>{json}</tool_call>`
|
||||||
|
/// block. Carries the parsed name + arguments JSON string
|
||||||
|
/// (Anthropic / OpenAI projectors emit their own wire shape
|
||||||
|
/// from this).
|
||||||
|
///
|
||||||
|
/// `index` is the call slot — incremented per tool call in a
|
||||||
|
/// turn so wire formats that order calls by index
|
||||||
|
/// (OpenAI chat completions) can correlate.
|
||||||
|
ToolCall {
|
||||||
|
index: usize,
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
/// Complete JSON arguments string. The model could in
|
||||||
|
/// principle stream these token-by-token, but our
|
||||||
|
/// extraction buffers the whole block until `</tool_call>`
|
||||||
|
/// arrives and emits exactly one event per call.
|
||||||
|
arguments: String,
|
||||||
|
},
|
||||||
/// The stream is complete. Carries the reason so wire formats
|
/// The stream is complete. Carries the reason so wire formats
|
||||||
/// that use it (OpenAI's `finish_reason`, Anthropic's
|
/// that use it (OpenAI's `finish_reason`, Anthropic's
|
||||||
/// `stop_reason`) can render it without re-parsing.
|
/// `stop_reason`) can render it without re-parsing.
|
||||||
@@ -137,6 +150,51 @@ const KNOWN_REASONING_MARKERS: &[(&str, &str)] = &[
|
|||||||
("<reasoning>", "</reasoning>"),
|
("<reasoning>", "</reasoning>"),
|
||||||
];
|
];
|
||||||
|
|
||||||
|
/// Open/close token IDs for the model's tool-call marker
|
||||||
|
/// convention (or `None` for models that don't emit structured
|
||||||
|
/// tool calls). Same shape as [`ReasoningTokenPair`]: probed once
|
||||||
|
/// at load time, consumed by the inference loop to switch between
|
||||||
|
/// "emit visible deltas" and "buffer JSON for the next tool
|
||||||
|
/// call".
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ToolCallTokenPair {
|
||||||
|
pub open_id: u32,
|
||||||
|
pub close_id: u32,
|
||||||
|
pub open_text: String,
|
||||||
|
pub close_text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tool-call marker conventions. Open-weight tool-use models
|
||||||
|
/// converged on `<tool_call>` / `</tool_call>` (Qwen3-Coder /
|
||||||
|
/// -Instruct, the Hermes function-call format, DeepSeek-Coder,
|
||||||
|
/// gpt-oss). The pair lives alongside the reasoning markers in
|
||||||
|
/// the same `added_tokens` table.
|
||||||
|
const KNOWN_TOOL_CALL_MARKERS: &[(&str, &str)] = &[("<tool_call>", "</tool_call>")];
|
||||||
|
|
||||||
|
/// Probe a tokenizer for known tool-call marker pairs. Mirrors
|
||||||
|
/// [`detect_reasoning_token_pair`] — both open AND close must
|
||||||
|
/// resolve for the pair to be returned. `None` means the model
|
||||||
|
/// doesn't emit structured tool calls (or its tokenizer split
|
||||||
|
/// the markers across tokens).
|
||||||
|
pub fn detect_tool_call_token_pair<F>(token_to_id: F) -> Option<ToolCallTokenPair>
|
||||||
|
where
|
||||||
|
F: Fn(&str) -> Option<u32>,
|
||||||
|
{
|
||||||
|
for (open_text, close_text) in KNOWN_TOOL_CALL_MARKERS {
|
||||||
|
let open_id = token_to_id(open_text);
|
||||||
|
let close_id = token_to_id(close_text);
|
||||||
|
if let (Some(open_id), Some(close_id)) = (open_id, close_id) {
|
||||||
|
return Some(ToolCallTokenPair {
|
||||||
|
open_id,
|
||||||
|
close_id,
|
||||||
|
open_text: (*open_text).into(),
|
||||||
|
close_text: (*close_text).into(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
/// Inspect a tokenizer for known reasoning-marker pairs and return
|
/// Inspect a tokenizer for known reasoning-marker pairs and return
|
||||||
/// the first match. The tokenizer types this trait is defined over
|
/// the first match. The tokenizer types this trait is defined over
|
||||||
/// just need to expose `token_to_id(&str) -> Option<u32>` so this
|
/// just need to expose `token_to_id(&str) -> Option<u32>` so this
|
||||||
@@ -213,6 +271,24 @@ mod tests {
|
|||||||
assert!(detect_reasoning_token_pair(lookup(&m)).is_none());
|
assert!(detect_reasoning_token_pair(lookup(&m)).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detects_tool_call_markers() {
|
||||||
|
let mut m = HashMap::new();
|
||||||
|
m.insert("<tool_call>", 151657);
|
||||||
|
m.insert("</tool_call>", 151658);
|
||||||
|
let pair = detect_tool_call_token_pair(lookup(&m)).expect("pair detected");
|
||||||
|
assert_eq!(pair.open_id, 151657);
|
||||||
|
assert_eq!(pair.close_id, 151658);
|
||||||
|
assert_eq!(pair.open_text, "<tool_call>");
|
||||||
|
assert_eq!(pair.close_text, "</tool_call>");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn returns_none_for_non_tool_use_tokenizer() {
|
||||||
|
let m: HashMap<&'static str, u32> = HashMap::new();
|
||||||
|
assert!(detect_tool_call_token_pair(lookup(&m)).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn first_match_wins_when_multiple_pairs_declared() {
|
fn first_match_wins_when_multiple_pairs_declared() {
|
||||||
// Hypothetical tokenizer with both Qwen-style AND Mistral-style
|
// Hypothetical tokenizer with both Qwen-style AND Mistral-style
|
||||||
|
|||||||
@@ -21,4 +21,7 @@ pub mod event;
|
|||||||
pub mod openai_chat;
|
pub mod openai_chat;
|
||||||
pub mod openai_responses;
|
pub mod openai_responses;
|
||||||
|
|
||||||
pub use event::{FinishReason, InferenceEvent, ReasoningTokenPair, detect_reasoning_token_pair};
|
pub use event::{
|
||||||
|
FinishReason, InferenceEvent, ReasoningTokenPair, ToolCallTokenPair,
|
||||||
|
detect_reasoning_token_pair, detect_tool_call_token_pair,
|
||||||
|
};
|
||||||
|
|||||||
@@ -172,6 +172,22 @@ pub fn project_chat_stream_with(
|
|||||||
was_in_reasoning = true;
|
was_in_reasoning = true;
|
||||||
chunks
|
chunks
|
||||||
}
|
}
|
||||||
|
InferenceEvent::ToolCall {
|
||||||
|
index,
|
||||||
|
id: call_id,
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
} => {
|
||||||
|
// OpenAI streaming shape for tool calls:
|
||||||
|
// `delta.tool_calls[]` with id + function.name
|
||||||
|
// on the first chunk per index, then
|
||||||
|
// function.arguments deltas. We have the
|
||||||
|
// complete arguments buffered already, so one
|
||||||
|
// delta carries everything.
|
||||||
|
vec![tool_call_chunk(
|
||||||
|
&id, created, &model_id, index, &call_id, &name, &arguments,
|
||||||
|
)]
|
||||||
|
}
|
||||||
InferenceEvent::Finish { reason } => {
|
InferenceEvent::Finish { reason } => {
|
||||||
vec![final_chunk(&id, created, &model_id, reason)]
|
vec![final_chunk(&id, created, &model_id, reason)]
|
||||||
}
|
}
|
||||||
@@ -222,6 +238,47 @@ fn content_chunk(id: &str, created: u64, model_id: &str, text: &str) -> ChatComp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// OpenAI chat streaming shape for a tool call. One chunk per
|
||||||
|
/// call slot, carrying id + name + the complete arguments JSON.
|
||||||
|
/// Mirrors the format real OpenAI emits on the streaming path,
|
||||||
|
/// minus the per-token arguments-streaming complication (we have
|
||||||
|
/// the whole buffer already after the model finishes the
|
||||||
|
/// `<tool_call>...</tool_call>` block).
|
||||||
|
fn tool_call_chunk(
|
||||||
|
id: &str,
|
||||||
|
created: u64,
|
||||||
|
model_id: &str,
|
||||||
|
index: usize,
|
||||||
|
call_id: &str,
|
||||||
|
name: &str,
|
||||||
|
arguments: &str,
|
||||||
|
) -> ChatCompletionChunk {
|
||||||
|
ChatCompletionChunk {
|
||||||
|
id: id.into(),
|
||||||
|
object: "chat.completion.chunk".into(),
|
||||||
|
created,
|
||||||
|
model: model_id.into(),
|
||||||
|
choices: vec![ChunkChoice {
|
||||||
|
index: 0,
|
||||||
|
delta: json!({
|
||||||
|
"tool_calls": [{
|
||||||
|
"index": index,
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": name,
|
||||||
|
"arguments": arguments,
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
}),
|
||||||
|
finish_reason: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn final_chunk(
|
fn final_chunk(
|
||||||
id: &str,
|
id: &str,
|
||||||
created: u64,
|
created: u64,
|
||||||
|
|||||||
@@ -296,6 +296,13 @@ async fn run_projection(
|
|||||||
// Stage where it'd land: a `response.reasoning_*`
|
// Stage where it'd land: a `response.reasoning_*`
|
||||||
// event family alongside `response.output_text.*`.
|
// event family alongside `response.output_text.*`.
|
||||||
}
|
}
|
||||||
|
InferenceEvent::ToolCall { .. } => {
|
||||||
|
// Responses-side tool-call routing not wired yet
|
||||||
|
// (would emit response.function_call_arguments.*
|
||||||
|
// events). Drop for now; the chat-completions
|
||||||
|
// projector handles tool calls. Future work
|
||||||
|
// tracked in #7 alongside the in_progress event.
|
||||||
|
}
|
||||||
InferenceEvent::Finish { reason } => {
|
InferenceEvent::Finish { reason } => {
|
||||||
finish = Some(reason);
|
finish = Some(reason);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user