feat(neuron): strip reasoning from chat completions by default
Some checks failed
CI / CUDA type-check (push) Failing after 18s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Format (push) Successful in 32s
CI / Clippy (push) Successful in 2m36s
build-prerelease / Build cortex binary (push) Successful in 4m29s
CI / Test (push) Successful in 5m19s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 5m56s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ampere (push) Successful in 7m45s
build-prerelease / Build neuron-ada (push) Successful in 5m24s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m0s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m2s
Some checks failed
CI / CUDA type-check (push) Failing after 18s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Format (push) Successful in 32s
CI / Clippy (push) Successful in 2m36s
build-prerelease / Build cortex binary (push) Successful in 4m29s
CI / Test (push) Successful in 5m19s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 5m56s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ampere (push) Successful in 7m45s
build-prerelease / Build neuron-ada (push) Successful in 5m24s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m0s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m2s
Closes #8. Reasoning-capable models (Qwen3, DeepSeek-R1, gpt-oss, Mistral Magistral, …) emit `<think>...</think>` blocks inline in their content stream. The chat-completions wire format has no slot for reasoning, so until this change every consumer either parsed the markers themselves (helexa-acp) or wrote the raw scratchpad content into their UI (Zed's commit-message generator — visible as the leaked reasoning block on every generated commit message against benjy's Qwen3-8B). ## Implementation, model-agnostic by design The neuron side now does token-level routing without any hardcoded model knowledge: 1. **At load time** (`detect_reasoning_token_pair` in `wire::event`), probe the tokenizer's vocabulary for a known reasoning-marker pair: `<think>` / `</think>` (Qwen3, DeepSeek-R1, gpt-oss), `[THINK]` / `[/THINK]` (Mistral Magistral), and a couple of derivatives. Each marker must resolve to a single token id; if both open and close resolve, stash on `LoadedModel.reasoning_tokens` (similarly `TpLoadedModel`). Non-reasoning models get `None` and pass through unchanged. 2. **At inference time**, the three streaming paths (`run_inference_streaming` CPU, `stream_inference_via_worker` CUDA single-GPU, `chat_completion_tp_stream` CUDA TP) now check each sampled token against the pair via the new `handle_reasoning_marker` helper before feeding it to the detokeniser. Open marker → set `in_reasoning = true`, drop the marker. Close marker → unset, drop. Other tokens go through `emit_delta(_blocking)` which now picks `ReasoningDelta` or `TextDelta` based on state. Markers never appear in the streamed output. 3. **In `wire::openai_chat`**, the projector splits into: - `project_chat_stream` (unchanged signature; default behaviour — drops `ReasoningDelta`) - `project_chat_stream_with(rx, …, ChatProjectionConfig)` — when `include_thinking: true` and `reasoning_markers: Some(_)`, re-wraps reasoning content with the literal open/close marker text and emits as content deltas. Preserves the on-the-wire shape that helexa-acp's `ThinkParser` expects. 4. **HTTP handler** reads `x-include-thinking: true` (case- insensitive `1`/`true`/`yes`) from the request headers and threads it into the projection config. cortex-gateway already forwards arbitrary headers verbatim, so the opt-in works end-to-end without gateway changes. 5. **helexa-acp's `openai_chat` provider** sets `x-include-thinking: true` on every request so its existing `ThinkParser` keeps receiving the marked content stream. `ThinkParser` itself is unchanged — needed for endpoints that aren't reasoning-aware (OpenRouter, OpenAI directly, etc.). ## Acceptance - Zed's commit-message generator (vanilla chat-completions client, no `x-include-thinking`) gets clean commit messages with no `<think>` block. - helexa-acp sessions continue to render thinking in Zed's thought UI via the opt-in path. - Models without reasoning tokens declared in their tokenizer pass through unchanged. - Implementation contains zero references to "qwen3" or any specific model — entirely driven by tokenizer metadata. ## Tests 9 new tests in `wire::event` (token-pair detection across 4 marker conventions, edge cases) and `wire::openai_chat` (default drop, opt-in re-wrap with multi-chunk reasoning, close-marker on Finish, fallback when markers absent, off-switch with markers present). All 213 workspace tests pass; fmt + clippy clean. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -4,7 +4,7 @@ use crate::activation::ActivationTracker;
|
||||
use crate::harness::HarnessRegistry;
|
||||
use crate::harness::candle::{CandleHarness, InferenceError};
|
||||
use crate::health::HealthCache;
|
||||
use crate::wire::openai_responses;
|
||||
use crate::wire::{openai_chat, openai_responses};
|
||||
use axum::Router;
|
||||
use axum::extract::{Path, State};
|
||||
use axum::http::StatusCode;
|
||||
@@ -148,6 +148,7 @@ async fn model_endpoint(
|
||||
/// `ChatCompletionResponse`.
|
||||
async fn chat_completions(
|
||||
State(state): State<Arc<NeuronState>>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(req): Json<ChatCompletionRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
|
||||
@@ -158,8 +159,23 @@ async fn chat_completions(
|
||||
.into_response();
|
||||
};
|
||||
|
||||
// Reasoning-content opt-in. Off by default → naïve clients
|
||||
// (Zed's commit-message generator, vanilla OpenAI clients)
|
||||
// never see `<think>` blocks. On when the caller sends
|
||||
// `x-include-thinking: true` (helexa-acp does this so its
|
||||
// own ThinkParser keeps working unchanged).
|
||||
let include_thinking = headers
|
||||
.get("x-include-thinking")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| matches!(s.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes"))
|
||||
.unwrap_or(false);
|
||||
let chat_config = openai_chat::ChatProjectionConfig {
|
||||
include_thinking,
|
||||
reasoning_markers: None, // filled in from the loaded model inside candle
|
||||
};
|
||||
|
||||
if req.stream.unwrap_or(false) {
|
||||
match candle.chat_completion_stream(req).await {
|
||||
match candle.chat_completion_stream_with(req, chat_config).await {
|
||||
Ok(rx) => {
|
||||
// Each chunk → one SSE `data: {json}` line. After the
|
||||
// channel closes, append the OpenAI [DONE] terminator.
|
||||
|
||||
@@ -26,7 +26,10 @@ use cortex_core::openai::{
|
||||
ChatMessage, MessageContent, Usage,
|
||||
};
|
||||
|
||||
use crate::wire::{FinishReason, InferenceEvent, openai_chat as wire_chat};
|
||||
use crate::wire::{
|
||||
FinishReason, InferenceEvent, ReasoningTokenPair, detect_reasoning_token_pair,
|
||||
openai_chat as wire_chat,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
@@ -148,6 +151,15 @@ pub struct LoadedModel {
|
||||
/// for the TP path (which already had this invariant by accident
|
||||
/// because the pool lock covered the same window).
|
||||
pub inference_lock: tokio::sync::Mutex<()>,
|
||||
/// Open/close token IDs for the reasoning marker this model
|
||||
/// emits, populated once at load time by probing the tokenizer's
|
||||
/// added-tokens table. `None` for non-reasoning models or
|
||||
/// reasoning models whose markers aren't single tokens. When
|
||||
/// `Some`, the streaming inference loop splits output into
|
||||
/// [`InferenceEvent::TextDelta`] and
|
||||
/// [`InferenceEvent::ReasoningDelta`] at the token boundary;
|
||||
/// when `None` everything is `TextDelta`.
|
||||
pub reasoning_tokens: Option<ReasoningTokenPair>,
|
||||
}
|
||||
|
||||
impl LoadedModel {
|
||||
@@ -203,6 +215,11 @@ pub struct TpLoadedModel {
|
||||
/// `CudaContext`, `NcclState`, and the boxed `TpLeaderModel`
|
||||
/// referenced by `leader_handle`.
|
||||
pub worker: Arc<super::device_worker::DeviceWorkerHandle>,
|
||||
/// Same shape as [`LoadedModel::reasoning_tokens`] — open/close
|
||||
/// reasoning marker token IDs probed from the tokenizer at
|
||||
/// load time. `None` when the model declares no reasoning
|
||||
/// markers.
|
||||
pub reasoning_tokens: Option<ReasoningTokenPair>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
@@ -1594,13 +1611,34 @@ impl CandleHarness {
|
||||
pub async fn chat_completion_stream(
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
||||
self.chat_completion_stream_with(request, wire_chat::ChatProjectionConfig::default())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Same as [`Self::chat_completion_stream`] but lets the caller
|
||||
/// pick the projection config — currently used by the HTTP
|
||||
/// handler to thread `x-include-thinking` from the request
|
||||
/// headers into [`wire_chat::ChatProjectionConfig::include_thinking`].
|
||||
pub async fn chat_completion_stream_with(
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
mut config: wire_chat::ChatProjectionConfig,
|
||||
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
||||
let stream = self.inference_stream(request).await?;
|
||||
Ok(wire_chat::project_chat_stream(
|
||||
// Fill in the model's reasoning markers if the caller
|
||||
// didn't pre-populate them — they're a property of the
|
||||
// loaded model (which the HTTP handler doesn't reach into
|
||||
// directly), not of the request.
|
||||
if config.reasoning_markers.is_none() {
|
||||
config.reasoning_markers = stream.reasoning_markers.clone();
|
||||
}
|
||||
Ok(wire_chat::project_chat_stream_with(
|
||||
stream.events,
|
||||
stream.id,
|
||||
stream.created,
|
||||
stream.model_id,
|
||||
config,
|
||||
))
|
||||
}
|
||||
|
||||
@@ -1747,6 +1785,7 @@ impl CandleHarness {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
let prompt_tokens = prompt_tokens.clone();
|
||||
let reasoning_tokens_inner = loaded.reasoning_tokens.clone();
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let _inference_guard = loaded_for_task.inference_lock.lock().await;
|
||||
@@ -1760,6 +1799,7 @@ impl CandleHarness {
|
||||
top_p,
|
||||
seed,
|
||||
eos_id,
|
||||
reasoning_tokens_inner,
|
||||
tx,
|
||||
)
|
||||
.await
|
||||
@@ -1799,6 +1839,7 @@ impl CandleHarness {
|
||||
unreachable!("worker handle present without cuda feature");
|
||||
}
|
||||
} else if let Some(arch_arc) = loaded.arch.clone() {
|
||||
let reasoning_tokens_inner = loaded.reasoning_tokens.clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let _g = span_for_task.enter();
|
||||
// `blocking_lock` is safe here: spawn_blocking runs on
|
||||
@@ -1816,6 +1857,7 @@ impl CandleHarness {
|
||||
top_p,
|
||||
seed,
|
||||
eos_id,
|
||||
reasoning_tokens_inner.as_ref(),
|
||||
&tx,
|
||||
) {
|
||||
Ok(()) => tracing::info!(
|
||||
@@ -1853,11 +1895,13 @@ impl CandleHarness {
|
||||
// Hand the raw event channel back to the public entry
|
||||
// points (chat_completion_stream / responses_stream); they
|
||||
// pick the wire projection.
|
||||
let reasoning_markers = loaded.reasoning_tokens.clone();
|
||||
Ok(InferenceStream {
|
||||
events: event_rx,
|
||||
id,
|
||||
created,
|
||||
model_id,
|
||||
reasoning_markers,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1881,6 +1925,13 @@ pub struct InferenceStream {
|
||||
/// Local model id (no endpoint prefix). Stamped into every
|
||||
/// wire-format frame so consumers can correlate.
|
||||
pub model_id: String,
|
||||
/// Open/close reasoning marker text (and token ids) for the
|
||||
/// loaded model, or `None` for non-reasoning models. Used by
|
||||
/// the chat-completions projector when `include_thinking` is
|
||||
/// set — the projector re-wraps reasoning content with the
|
||||
/// literal markers so client-side parsers (helexa-acp's
|
||||
/// `ThinkParser`) see the original on-the-wire shape.
|
||||
pub reasoning_markers: Option<ReasoningTokenPair>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -1987,6 +2038,26 @@ impl Harness for CandleHarness {
|
||||
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
||||
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
||||
|
||||
// Probe for reasoning markers in the tokenizer's
|
||||
// added-tokens table — `<think>` / `</think>` on Qwen3 +
|
||||
// DeepSeek-R1 + gpt-oss, `[THINK]` / `[/THINK]` on
|
||||
// Mistral Magistral, etc. `None` for non-reasoning models.
|
||||
// The streaming loop uses this to route between TextDelta
|
||||
// and ReasoningDelta without any hardcoded model
|
||||
// knowledge; wire projectors decide what to do with the
|
||||
// split.
|
||||
let reasoning_tokens = detect_reasoning_token_pair(|s| tokenizer.token_to_id(s));
|
||||
if let Some(ref pair) = reasoning_tokens {
|
||||
tracing::info!(
|
||||
model = %spec.model_id,
|
||||
open = %pair.open_text,
|
||||
close = %pair.close_text,
|
||||
open_id = pair.open_id,
|
||||
close_id = pair.close_id,
|
||||
"reasoning markers detected — streaming will route ReasoningDelta separately"
|
||||
);
|
||||
}
|
||||
|
||||
let loaded = Arc::new(LoadedModel {
|
||||
model_id: spec.model_id.clone(),
|
||||
arch: arch_local,
|
||||
@@ -1998,6 +2069,7 @@ impl Harness for CandleHarness {
|
||||
worker,
|
||||
arch_handle,
|
||||
inference_lock: tokio::sync::Mutex::new(()),
|
||||
reasoning_tokens,
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
@@ -2170,6 +2242,17 @@ impl CandleHarness {
|
||||
// 6. Tokenizer (same as single-GPU path).
|
||||
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
||||
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
||||
// Reasoning-marker probe — identical to the single-GPU
|
||||
// path. See `LoadedModel.reasoning_tokens` for the why.
|
||||
let reasoning_tokens = detect_reasoning_token_pair(|s| tokenizer.token_to_id(s));
|
||||
if let Some(ref pair) = reasoning_tokens {
|
||||
tracing::info!(
|
||||
model = %spec.model_id,
|
||||
open = %pair.open_text,
|
||||
close = %pair.close_text,
|
||||
"TP load: reasoning markers detected"
|
||||
);
|
||||
}
|
||||
|
||||
let tp_loaded = StdArc::new(TpLoadedModel {
|
||||
model_id: spec.model_id.clone(),
|
||||
@@ -2183,6 +2266,7 @@ impl CandleHarness {
|
||||
// single `Arc` shared between WorkerPool and
|
||||
// TpLoadedModel so they reference the same thread.
|
||||
worker: leader_worker,
|
||||
reasoning_tokens,
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
@@ -2331,6 +2415,7 @@ impl CandleHarness {
|
||||
let id = format!("chatcmpl-{:x}", unix_subsec_nanos());
|
||||
let created = unix_now_secs();
|
||||
let tokenizer = tp.tokenizer.clone();
|
||||
let reasoning_tokens = tp.reasoning_tokens.clone();
|
||||
// The spawned orchestration task below consumes both `id`
|
||||
// and `model_id` (tracing, pool lookups, NCCL ops use them
|
||||
// heavily). The wire projector at the bottom of this fn
|
||||
@@ -2396,6 +2481,10 @@ impl CandleHarness {
|
||||
// split a multi-byte char across tokens.
|
||||
let mut decode_stream = tokenizer.decode_stream(true);
|
||||
let mut finish_reason = FinishReason::Length;
|
||||
// Reasoning marker state machine — same as the
|
||||
// single-GPU path. The TP path needs its own copy
|
||||
// because the spawn closure owns it.
|
||||
let mut in_reasoning = false;
|
||||
|
||||
'work: {
|
||||
if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await {
|
||||
@@ -2464,11 +2553,17 @@ impl CandleHarness {
|
||||
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = FinishReason::Stop;
|
||||
} else if handle_reasoning_marker(
|
||||
next_token,
|
||||
reasoning_tokens.as_ref(),
|
||||
&mut in_reasoning,
|
||||
) {
|
||||
all_tokens.push(next_token);
|
||||
} else {
|
||||
all_tokens.push(next_token);
|
||||
match decode_stream.step(next_token) {
|
||||
Ok(Some(delta)) => {
|
||||
if !emit_delta(&delta, &tx).await {
|
||||
if !emit_delta(&delta, &tx, in_reasoning).await {
|
||||
// Client gone — treat as normal stream end,
|
||||
// not a failure. No log spam.
|
||||
break 'work;
|
||||
@@ -2543,10 +2638,18 @@ impl CandleHarness {
|
||||
finish_reason = FinishReason::Stop;
|
||||
break;
|
||||
}
|
||||
if handle_reasoning_marker(
|
||||
next_token,
|
||||
reasoning_tokens.as_ref(),
|
||||
&mut in_reasoning,
|
||||
) {
|
||||
all_tokens.push(next_token);
|
||||
continue;
|
||||
}
|
||||
all_tokens.push(next_token);
|
||||
match decode_stream.step(next_token) {
|
||||
Ok(Some(delta)) => {
|
||||
if !emit_delta(&delta, &tx).await {
|
||||
if !emit_delta(&delta, &tx, in_reasoning).await {
|
||||
break 'work;
|
||||
}
|
||||
}
|
||||
@@ -2612,11 +2715,13 @@ impl CandleHarness {
|
||||
// points; they pick the wire projection. Uses the clones
|
||||
// we stashed before the spawn — the originals were moved
|
||||
// into the orchestration task above.
|
||||
let reasoning_markers = tp.reasoning_tokens.clone();
|
||||
Ok(InferenceStream {
|
||||
events: event_rx,
|
||||
id: projector_id,
|
||||
created,
|
||||
model_id: projector_model_id,
|
||||
reasoning_markers,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2855,25 +2960,57 @@ async fn chat_completion_tp_inner(
|
||||
/// [`crate::wire::openai_chat`] stamps it onto every chunk
|
||||
/// downstream.
|
||||
#[cfg(feature = "cuda")]
|
||||
async fn emit_delta(delta: &str, tx: &mpsc::Sender<InferenceEvent>) -> bool {
|
||||
async fn emit_delta(delta: &str, tx: &mpsc::Sender<InferenceEvent>, in_reasoning: bool) -> bool {
|
||||
if delta.is_empty() {
|
||||
return true;
|
||||
}
|
||||
tx.send(InferenceEvent::TextDelta(delta.into()))
|
||||
.await
|
||||
.is_ok()
|
||||
let event = if in_reasoning {
|
||||
InferenceEvent::ReasoningDelta(delta.into())
|
||||
} else {
|
||||
InferenceEvent::TextDelta(delta.into())
|
||||
};
|
||||
tx.send(event).await.is_ok()
|
||||
}
|
||||
|
||||
/// Sync counterpart of [`emit_delta`] for the CPU path's
|
||||
/// `spawn_blocking` closure. Same shape, `blocking_send` instead of
|
||||
/// `send`. Kept as a separate fn so the async / blocking-send choice
|
||||
/// is local to one place per path.
|
||||
fn emit_delta_blocking(delta: &str, tx: &mpsc::Sender<InferenceEvent>) -> bool {
|
||||
fn emit_delta_blocking(delta: &str, tx: &mpsc::Sender<InferenceEvent>, in_reasoning: bool) -> bool {
|
||||
if delta.is_empty() {
|
||||
return true;
|
||||
}
|
||||
tx.blocking_send(InferenceEvent::TextDelta(delta.into()))
|
||||
.is_ok()
|
||||
let event = if in_reasoning {
|
||||
InferenceEvent::ReasoningDelta(delta.into())
|
||||
} else {
|
||||
InferenceEvent::TextDelta(delta.into())
|
||||
};
|
||||
tx.blocking_send(event).is_ok()
|
||||
}
|
||||
|
||||
/// If `next_token` is one of the loaded model's reasoning markers,
|
||||
/// flip `in_reasoning` and return `true` to tell the caller to
|
||||
/// skip detokenisation + emission for this token. The markers
|
||||
/// themselves never appear in the streamed output — they exist
|
||||
/// purely to transition state.
|
||||
///
|
||||
/// `pair = None` short-circuits to `false` (no reasoning markers
|
||||
/// configured for this model → pass-through).
|
||||
fn handle_reasoning_marker(
|
||||
next_token: u32,
|
||||
pair: Option<&ReasoningTokenPair>,
|
||||
in_reasoning: &mut bool,
|
||||
) -> bool {
|
||||
let Some(pair) = pair else { return false };
|
||||
if next_token == pair.open_id {
|
||||
*in_reasoning = true;
|
||||
return true;
|
||||
}
|
||||
if next_token == pair.close_id {
|
||||
*in_reasoning = false;
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Errors returned by `CandleHarness::chat_completion`. The
|
||||
@@ -3038,6 +3175,7 @@ async fn stream_inference_via_worker(
|
||||
top_p: Option<f64>,
|
||||
seed: u64,
|
||||
eos_id: Option<u32>,
|
||||
reasoning_tokens: Option<ReasoningTokenPair>,
|
||||
tx: mpsc::Sender<InferenceEvent>,
|
||||
) -> Result<String> {
|
||||
let mut logits_processor = {
|
||||
@@ -3062,6 +3200,10 @@ async fn stream_inference_via_worker(
|
||||
let mut decode_stream = tokenizer.decode_stream(true);
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut finish_reason = FinishReason::Length;
|
||||
// Reasoning marker state machine — see `run_inference_streaming`
|
||||
// for the why. Markers never reach `decode_stream`; they only
|
||||
// toggle the variant `emit_delta` produces.
|
||||
let mut in_reasoning = false;
|
||||
|
||||
worker
|
||||
.clear_kv_cache(handle)
|
||||
@@ -3088,50 +3230,56 @@ async fn stream_inference_via_worker(
|
||||
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = FinishReason::Stop;
|
||||
} else if handle_reasoning_marker(next_token, reasoning_tokens.as_ref(), &mut in_reasoning) {
|
||||
all_tokens.push(next_token);
|
||||
} else {
|
||||
all_tokens.push(next_token);
|
||||
match decode_stream.step(next_token) {
|
||||
Ok(Some(delta)) => {
|
||||
if !emit_delta(&delta, &tx).await {
|
||||
if !emit_delta(&delta, &tx, in_reasoning).await {
|
||||
return Ok(finish_reason.as_openai_str().to_string());
|
||||
}
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => tracing::warn!(error = %e, "decode_stream step failed"),
|
||||
}
|
||||
}
|
||||
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
let logits_vec = worker
|
||||
.forward_logits(handle, vec![next_token], prompt_len + index)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
step = index,
|
||||
?health,
|
||||
"chat_completion (stream/worker): decode sample failed; logits unhealthy"
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = FinishReason::Stop;
|
||||
break;
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
let logits_vec = worker
|
||||
.forward_logits(handle, vec![next_token], prompt_len + index)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
|
||||
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
|
||||
next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let health = logits_health_slice(&logits_vec);
|
||||
tracing::warn!(
|
||||
step = index,
|
||||
?health,
|
||||
"chat_completion (stream/worker): decode sample failed; logits unhealthy"
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = FinishReason::Stop;
|
||||
break;
|
||||
}
|
||||
if handle_reasoning_marker(next_token, reasoning_tokens.as_ref(), &mut in_reasoning) {
|
||||
all_tokens.push(next_token);
|
||||
match decode_stream.step(next_token) {
|
||||
Ok(Some(delta)) => {
|
||||
if !emit_delta(&delta, &tx).await {
|
||||
return Ok(finish_reason.as_openai_str().to_string());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
all_tokens.push(next_token);
|
||||
match decode_stream.step(next_token) {
|
||||
Ok(Some(delta)) => {
|
||||
if !emit_delta(&delta, &tx, in_reasoning).await {
|
||||
return Ok(finish_reason.as_openai_str().to_string());
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => tracing::warn!(error = %e, "decode_stream step failed"),
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => tracing::warn!(error = %e, "decode_stream step failed"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3211,6 +3359,7 @@ fn run_inference_streaming(
|
||||
top_p: Option<f64>,
|
||||
seed: u64,
|
||||
eos_id: Option<u32>,
|
||||
reasoning_tokens: Option<&ReasoningTokenPair>,
|
||||
tx: &mpsc::Sender<InferenceEvent>,
|
||||
) -> Result<()> {
|
||||
let mut logits_processor = {
|
||||
@@ -3232,6 +3381,12 @@ fn run_inference_streaming(
|
||||
// boundaries and only emits when a clean codepoint completes.
|
||||
let mut decode_stream = tokenizer.decode_stream(true);
|
||||
let mut finish_reason = FinishReason::Length;
|
||||
// Reasoning marker state machine. Flips on
|
||||
// `next_token == reasoning_tokens.open_id`, off on
|
||||
// `.close_id`. The marker tokens themselves never feed into
|
||||
// `decode_stream` — they aren't part of any visible output,
|
||||
// they exist purely as state transitions.
|
||||
let mut in_reasoning = false;
|
||||
|
||||
arch.clear_kv_cache()?;
|
||||
let logits = chunked_prefill_local(arch, device, prompt_tokens)?;
|
||||
@@ -3239,36 +3394,42 @@ fn run_inference_streaming(
|
||||
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = FinishReason::Stop;
|
||||
} else if handle_reasoning_marker(next_token, reasoning_tokens, &mut in_reasoning) {
|
||||
all_tokens.push(next_token);
|
||||
} else {
|
||||
all_tokens.push(next_token);
|
||||
match decode_stream.step(next_token) {
|
||||
Ok(Some(delta)) => {
|
||||
if !emit_delta_blocking(&delta, tx) {
|
||||
if !emit_delta_blocking(&delta, tx, in_reasoning) {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => tracing::warn!(error = %e, "stream: decode_stream step failed"),
|
||||
}
|
||||
}
|
||||
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
|
||||
next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = FinishReason::Stop;
|
||||
break;
|
||||
}
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
|
||||
next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = FinishReason::Stop;
|
||||
break;
|
||||
}
|
||||
if handle_reasoning_marker(next_token, reasoning_tokens, &mut in_reasoning) {
|
||||
all_tokens.push(next_token);
|
||||
match decode_stream.step(next_token) {
|
||||
Ok(Some(delta)) => {
|
||||
if !emit_delta_blocking(&delta, tx) {
|
||||
return Ok(());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
all_tokens.push(next_token);
|
||||
match decode_stream.step(next_token) {
|
||||
Ok(Some(delta)) => {
|
||||
if !emit_delta_blocking(&delta, tx, in_reasoning) {
|
||||
return Ok(());
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => tracing::warn!(error = %e, "stream: decode_stream step failed"),
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => tracing::warn!(error = %e, "stream: decode_stream step failed"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -97,3 +97,134 @@ impl FinishReason {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Open/close token IDs for the reasoning marker a loaded model uses
|
||||
/// (or `None` for non-reasoning models). The harness reads this once
|
||||
/// at load time from the tokenizer's added-tokens table, then the
|
||||
/// inference loop checks `next_token` against the pair to flip
|
||||
/// between [`InferenceEvent::TextDelta`] and
|
||||
/// [`InferenceEvent::ReasoningDelta`].
|
||||
///
|
||||
/// `open` and `close` text are kept alongside the IDs so wire
|
||||
/// projectors that want to re-emit the literal markers (the
|
||||
/// opt-in `include_thinking` path on chat completions) don't have
|
||||
/// to reach back into the tokenizer for the strings.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReasoningTokenPair {
|
||||
pub open_id: u32,
|
||||
pub close_id: u32,
|
||||
pub open_text: String,
|
||||
pub close_text: String,
|
||||
}
|
||||
|
||||
/// Known reasoning-marker conventions. Each is a `(open, close)`
|
||||
/// pair of literal token strings. Each modern reasoning model
|
||||
/// declares its markers in the tokenizer's `added_tokens` table;
|
||||
/// at load time we probe for whichever pair the loaded tokenizer
|
||||
/// has and stash both IDs.
|
||||
///
|
||||
/// Ordering matters only for tie-breaking when a model declares
|
||||
/// multiple pairs (shouldn't happen in practice); the first hit
|
||||
/// wins.
|
||||
const KNOWN_REASONING_MARKERS: &[(&str, &str)] = &[
|
||||
// Qwen3, DeepSeek-R1, gpt-oss, and most other open-weight
|
||||
// reasoning models.
|
||||
("<think>", "</think>"),
|
||||
// Mistral Magistral.
|
||||
("[THINK]", "[/THINK]"),
|
||||
// Some older derivatives; harmless to probe.
|
||||
("<thought>", "</thought>"),
|
||||
("<reasoning>", "</reasoning>"),
|
||||
];
|
||||
|
||||
/// Inspect a tokenizer for known reasoning-marker pairs and return
|
||||
/// the first match. The tokenizer types this trait is defined over
|
||||
/// just need to expose `token_to_id(&str) -> Option<u32>` so this
|
||||
/// stays decoupled from the candle crate — the production caller
|
||||
/// passes a `tokenizers::Tokenizer`, but tests can fake one.
|
||||
///
|
||||
/// Returns `None` when no known marker pair is fully declared
|
||||
/// (both open AND close token ids must resolve). That's the
|
||||
/// pass-through case — non-reasoning models, or reasoning models
|
||||
/// whose tokenizer split the markers across multiple tokens (rare
|
||||
/// in practice; modern reasoning tokenizers list them as
|
||||
/// `added_tokens`).
|
||||
pub fn detect_reasoning_token_pair<F>(token_to_id: F) -> Option<ReasoningTokenPair>
|
||||
where
|
||||
F: Fn(&str) -> Option<u32>,
|
||||
{
|
||||
for (open_text, close_text) in KNOWN_REASONING_MARKERS {
|
||||
let open_id = token_to_id(open_text);
|
||||
let close_id = token_to_id(close_text);
|
||||
if let (Some(open_id), Some(close_id)) = (open_id, close_id) {
|
||||
return Some(ReasoningTokenPair {
|
||||
open_id,
|
||||
close_id,
|
||||
open_text: (*open_text).into(),
|
||||
close_text: (*close_text).into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn lookup<'a>(map: &'a HashMap<&'static str, u32>) -> impl Fn(&str) -> Option<u32> + 'a {
|
||||
|s| map.get(s).copied()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_qwen3_style_think_markers() {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("<think>", 151648);
|
||||
m.insert("</think>", 151649);
|
||||
let pair = detect_reasoning_token_pair(lookup(&m)).expect("pair detected");
|
||||
assert_eq!(pair.open_id, 151648);
|
||||
assert_eq!(pair.close_id, 151649);
|
||||
assert_eq!(pair.open_text, "<think>");
|
||||
assert_eq!(pair.close_text, "</think>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_mistral_magistral_markers() {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("[THINK]", 100);
|
||||
m.insert("[/THINK]", 101);
|
||||
let pair = detect_reasoning_token_pair(lookup(&m)).expect("pair detected");
|
||||
assert_eq!(pair.open_text, "[THINK]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_when_only_open_marker_present() {
|
||||
// A pathological tokenizer that has `<think>` but not
|
||||
// `</think>` shouldn't half-detect. Pass-through.
|
||||
let mut m = HashMap::new();
|
||||
m.insert("<think>", 1);
|
||||
assert!(detect_reasoning_token_pair(lookup(&m)).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_for_non_reasoning_tokenizer() {
|
||||
let m: HashMap<&'static str, u32> = HashMap::new();
|
||||
assert!(detect_reasoning_token_pair(lookup(&m)).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_match_wins_when_multiple_pairs_declared() {
|
||||
// Hypothetical tokenizer with both Qwen-style AND Mistral-style
|
||||
// markers — the `<think>` pair is earlier in the convention
|
||||
// table so it wins.
|
||||
let mut m = HashMap::new();
|
||||
m.insert("<think>", 1);
|
||||
m.insert("</think>", 2);
|
||||
m.insert("[THINK]", 3);
|
||||
m.insert("[/THINK]", 4);
|
||||
let pair = detect_reasoning_token_pair(lookup(&m)).unwrap();
|
||||
assert_eq!(pair.open_id, 1);
|
||||
assert_eq!(pair.close_id, 2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,4 +21,4 @@ pub mod event;
|
||||
pub mod openai_chat;
|
||||
pub mod openai_responses;
|
||||
|
||||
pub use event::{FinishReason, InferenceEvent};
|
||||
pub use event::{FinishReason, InferenceEvent, ReasoningTokenPair, detect_reasoning_token_pair};
|
||||
|
||||
@@ -30,13 +30,42 @@ use cortex_core::openai::{ChatCompletionChunk, ChunkChoice};
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::event::{FinishReason, InferenceEvent};
|
||||
use super::event::{FinishReason, InferenceEvent, ReasoningTokenPair};
|
||||
|
||||
/// Output channel buffer size. Mirrors the input side's bound; one
|
||||
/// event maps to at most one chunk, so equal capacity keeps the
|
||||
/// two ends in sync without surprising memory growth.
|
||||
const CHUNK_CHANNEL_CAPACITY: usize = 32;
|
||||
|
||||
/// Per-stream config for the chat projector. Used by the
|
||||
/// production handler to thread per-request choices (currently:
|
||||
/// whether to surface reasoning content) into the projection
|
||||
/// without bloating the function signature.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ChatProjectionConfig {
|
||||
/// When `true`, reasoning content is re-wrapped with the
|
||||
/// model's literal open/close markers and emitted as content
|
||||
/// deltas — preserving the on-the-wire shape that
|
||||
/// reasoning-aware clients like helexa-acp's `ThinkParser`
|
||||
/// expect.
|
||||
///
|
||||
/// When `false` (the default), [`InferenceEvent::ReasoningDelta`]s
|
||||
/// are dropped entirely so consumers that don't know about
|
||||
/// reasoning (Zed's commit-message generator, any vanilla
|
||||
/// OpenAI client) don't have model-internal scratchpad
|
||||
/// material leaking into their UI. The chat-completions wire
|
||||
/// format has no slot for reasoning, so the default chooses
|
||||
/// the safer-for-naïve-clients behaviour.
|
||||
pub include_thinking: bool,
|
||||
/// Open/close marker strings to re-emit when `include_thinking`
|
||||
/// is set. Sourced from the loaded model's
|
||||
/// [`ReasoningTokenPair`]; `None` for non-reasoning models or
|
||||
/// when the caller doesn't have the pair handy (in which case
|
||||
/// `include_thinking` becomes equivalent to dropping reasoning
|
||||
/// because there's nothing to wrap).
|
||||
pub reasoning_markers: Option<ReasoningTokenPair>,
|
||||
}
|
||||
|
||||
/// Project an [`InferenceEvent`] receiver into a
|
||||
/// [`ChatCompletionChunk`] receiver. Spawns one tokio task that
|
||||
/// owns the input receiver for the stream's lifetime and exits
|
||||
@@ -46,15 +75,55 @@ const CHUNK_CHANNEL_CAPACITY: usize = 32;
|
||||
/// chunk so the receiver can stay generic (decoupled from
|
||||
/// per-request metadata).
|
||||
pub fn project_chat_stream(
|
||||
mut rx: mpsc::Receiver<InferenceEvent>,
|
||||
rx: mpsc::Receiver<InferenceEvent>,
|
||||
id: String,
|
||||
created: u64,
|
||||
model_id: String,
|
||||
) -> mpsc::Receiver<ChatCompletionChunk> {
|
||||
// Default config: include_thinking off, no marker rewrap.
|
||||
project_chat_stream_with(rx, id, created, model_id, ChatProjectionConfig::default())
|
||||
}
|
||||
|
||||
/// Same as [`project_chat_stream`] but with a per-stream config
|
||||
/// (currently controlling reasoning surfacing). Production
|
||||
/// callers that need the opt-in path call this directly; the
|
||||
/// shorter wrapper above stays as the no-config convenience.
|
||||
pub fn project_chat_stream_with(
|
||||
mut rx: mpsc::Receiver<InferenceEvent>,
|
||||
id: String,
|
||||
created: u64,
|
||||
model_id: String,
|
||||
config: ChatProjectionConfig,
|
||||
) -> mpsc::Receiver<ChatCompletionChunk> {
|
||||
let (tx, out_rx) = mpsc::channel::<ChatCompletionChunk>(CHUNK_CHANNEL_CAPACITY);
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Track whether the previous event was inside a reasoning
|
||||
// block — used to decide when to emit the literal close
|
||||
// marker on the include_thinking re-wrap path. When this
|
||||
// flips from true → false (a TextDelta or Finish lands
|
||||
// after one or more ReasoningDeltas), we emit the close
|
||||
// marker exactly once.
|
||||
let mut was_in_reasoning = false;
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
// Close-marker insertion: if we're leaving a reasoning
|
||||
// chain, emit the literal close marker before the
|
||||
// current event.
|
||||
if was_in_reasoning && !matches!(event, InferenceEvent::ReasoningDelta(_)) {
|
||||
if let Some(marker) = config
|
||||
.include_thinking
|
||||
.then_some(())
|
||||
.and(config.reasoning_markers.as_ref())
|
||||
{
|
||||
let chunk = content_chunk(&id, created, &model_id, &marker.close_text);
|
||||
if tx.send(chunk).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
was_in_reasoning = false;
|
||||
}
|
||||
|
||||
let chunks = match event {
|
||||
InferenceEvent::Start => vec![role_chunk(&id, created, &model_id)],
|
||||
InferenceEvent::TextDelta(text) => {
|
||||
@@ -66,12 +135,42 @@ pub fn project_chat_stream(
|
||||
}
|
||||
vec![content_chunk(&id, created, &model_id, &text)]
|
||||
}
|
||||
InferenceEvent::ReasoningDelta(_) => {
|
||||
// Reasoning isn't representable in OpenAI chat
|
||||
// streaming today. The o-series uses a separate
|
||||
// `summary` event but it's gated by the
|
||||
// Responses API; chat-completions just drops it.
|
||||
continue;
|
||||
InferenceEvent::ReasoningDelta(text) => {
|
||||
if !config.include_thinking {
|
||||
// Default path — reasoning has no slot in
|
||||
// chat completions, so it's dropped. Naïve
|
||||
// clients (Zed commit-message generator,
|
||||
// any vanilla OpenAI client) get clean
|
||||
// output.
|
||||
continue;
|
||||
}
|
||||
let Some(markers) = config.reasoning_markers.as_ref() else {
|
||||
// Caller asked to include thinking but
|
||||
// didn't supply markers — best we can do
|
||||
// is emit the content as visible text.
|
||||
// Skip the wrap entirely.
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let chunk = content_chunk(&id, created, &model_id, &text);
|
||||
if tx.send(chunk).await.is_err() {
|
||||
return;
|
||||
}
|
||||
continue;
|
||||
};
|
||||
// First chunk of a reasoning block → open
|
||||
// marker prelude. Subsequent reasoning deltas
|
||||
// in the same block reuse `was_in_reasoning`
|
||||
// to skip the prelude.
|
||||
let mut chunks = Vec::new();
|
||||
if !was_in_reasoning {
|
||||
chunks.push(content_chunk(&id, created, &model_id, &markers.open_text));
|
||||
}
|
||||
if !text.is_empty() {
|
||||
chunks.push(content_chunk(&id, created, &model_id, &text));
|
||||
}
|
||||
was_in_reasoning = true;
|
||||
chunks
|
||||
}
|
||||
InferenceEvent::Finish { reason } => {
|
||||
vec![final_chunk(&id, created, &model_id, reason)]
|
||||
@@ -238,4 +337,165 @@ mod tests {
|
||||
assert_eq!(out.len(), 1);
|
||||
assert_eq!(out[0].choices[0].delta["content"], "real");
|
||||
}
|
||||
|
||||
fn pair() -> ReasoningTokenPair {
|
||||
ReasoningTokenPair {
|
||||
open_id: 0,
|
||||
close_id: 1,
|
||||
open_text: "<think>".into(),
|
||||
close_text: "</think>".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn include_thinking_rewraps_reasoning_with_literal_markers() {
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(8);
|
||||
let out_rx = project_chat_stream_with(
|
||||
rx,
|
||||
"id".into(),
|
||||
1,
|
||||
"m".into(),
|
||||
ChatProjectionConfig {
|
||||
include_thinking: true,
|
||||
reasoning_markers: Some(pair()),
|
||||
},
|
||||
);
|
||||
tx.send(InferenceEvent::ReasoningDelta("first ".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::ReasoningDelta("second".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::TextDelta("answer".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::Finish {
|
||||
reason: FinishReason::Stop,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
let out = collect(out_rx).await;
|
||||
// Expected sequence: open marker → reasoning content (2 chunks)
|
||||
// → close marker → visible answer → final chunk.
|
||||
let contents: Vec<&str> = out
|
||||
.iter()
|
||||
.filter_map(|c| c.choices[0].delta["content"].as_str())
|
||||
.collect();
|
||||
assert_eq!(
|
||||
contents,
|
||||
vec!["<think>", "first ", "second", "</think>", "answer"]
|
||||
);
|
||||
assert_eq!(
|
||||
out.last().unwrap().choices[0].finish_reason.as_deref(),
|
||||
Some("stop")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn include_thinking_closes_marker_at_finish_when_no_trailing_text() {
|
||||
// Edge case: stream ends inside a reasoning block (model
|
||||
// hit max_tokens mid-thought, no visible answer ever).
|
||||
// The Finish event still triggers the close marker so the
|
||||
// stream is balanced.
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
|
||||
let out_rx = project_chat_stream_with(
|
||||
rx,
|
||||
"id".into(),
|
||||
1,
|
||||
"m".into(),
|
||||
ChatProjectionConfig {
|
||||
include_thinking: true,
|
||||
reasoning_markers: Some(pair()),
|
||||
},
|
||||
);
|
||||
tx.send(InferenceEvent::ReasoningDelta("thinking...".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::Finish {
|
||||
reason: FinishReason::Length,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
let out = collect(out_rx).await;
|
||||
let contents: Vec<&str> = out
|
||||
.iter()
|
||||
.filter_map(|c| c.choices[0].delta["content"].as_str())
|
||||
.collect();
|
||||
assert_eq!(contents, vec!["<think>", "thinking...", "</think>"]);
|
||||
assert_eq!(
|
||||
out.last().unwrap().choices[0].finish_reason.as_deref(),
|
||||
Some("length")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn include_thinking_without_markers_emits_content_directly() {
|
||||
// Defensive: if the caller asks for thinking but the
|
||||
// model declared no markers, we still emit the content
|
||||
// rather than dropping it. Better to leak than to lose.
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
|
||||
let out_rx = project_chat_stream_with(
|
||||
rx,
|
||||
"id".into(),
|
||||
1,
|
||||
"m".into(),
|
||||
ChatProjectionConfig {
|
||||
include_thinking: true,
|
||||
reasoning_markers: None,
|
||||
},
|
||||
);
|
||||
tx.send(InferenceEvent::ReasoningDelta("raw".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::Finish {
|
||||
reason: FinishReason::Stop,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
let out = collect(out_rx).await;
|
||||
let contents: Vec<&str> = out
|
||||
.iter()
|
||||
.filter_map(|c| c.choices[0].delta["content"].as_str())
|
||||
.collect();
|
||||
assert_eq!(contents, vec!["raw"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn include_thinking_off_drops_reasoning_even_with_markers() {
|
||||
// Default behaviour even when markers happen to be
|
||||
// configured. The flag is the gate, not the marker
|
||||
// presence.
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
|
||||
let out_rx = project_chat_stream_with(
|
||||
rx,
|
||||
"id".into(),
|
||||
1,
|
||||
"m".into(),
|
||||
ChatProjectionConfig {
|
||||
include_thinking: false,
|
||||
reasoning_markers: Some(pair()),
|
||||
},
|
||||
);
|
||||
tx.send(InferenceEvent::ReasoningDelta("hidden".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::TextDelta("visible".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::Finish {
|
||||
reason: FinishReason::Stop,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
let out = collect(out_rx).await;
|
||||
let contents: Vec<&str> = out
|
||||
.iter()
|
||||
.filter_map(|c| c.choices[0].delta["content"].as_str())
|
||||
.collect();
|
||||
assert_eq!(contents, vec!["visible"]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user