feat(neuron): strip reasoning from chat completions by default
Some checks failed
CI / CUDA type-check (push) Failing after 18s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Format (push) Successful in 32s
CI / Clippy (push) Successful in 2m36s
build-prerelease / Build cortex binary (push) Successful in 4m29s
CI / Test (push) Successful in 5m19s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 5m56s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ampere (push) Successful in 7m45s
build-prerelease / Build neuron-ada (push) Successful in 5m24s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m0s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m2s

Closes #8.

Reasoning-capable models (Qwen3, DeepSeek-R1, gpt-oss, Mistral
Magistral, …) emit `<think>...</think>` blocks inline in their
content stream. The chat-completions wire format has no slot for
reasoning, so until this change every consumer either parsed the
markers themselves (helexa-acp) or wrote the raw scratchpad
content into their UI (Zed's commit-message generator — visible
as the leaked reasoning block on every generated commit message
against benjy's Qwen3-8B).

## Implementation, model-agnostic by design

The neuron side now does token-level routing without any
hardcoded model knowledge:

1. **At load time** (`detect_reasoning_token_pair` in
   `wire::event`), probe the tokenizer's vocabulary for a known
   reasoning-marker pair: `<think>` / `</think>` (Qwen3,
   DeepSeek-R1, gpt-oss), `[THINK]` / `[/THINK]` (Mistral
   Magistral), and a couple of derivatives. Each marker must
   resolve to a single token id; if both open and close resolve,
   stash on `LoadedModel.reasoning_tokens` (similarly
   `TpLoadedModel`). Non-reasoning models get `None` and pass
   through unchanged.

2. **At inference time**, the three streaming paths
   (`run_inference_streaming` CPU, `stream_inference_via_worker`
   CUDA single-GPU, `chat_completion_tp_stream` CUDA TP) now
   check each sampled token against the pair via the new
   `handle_reasoning_marker` helper before feeding it to the
   detokeniser. Open marker → set `in_reasoning = true`, drop
   the marker. Close marker → unset, drop. Other tokens go
   through `emit_delta(_blocking)` which now picks
   `ReasoningDelta` or `TextDelta` based on state. Markers
   never appear in the streamed output.

3. **In `wire::openai_chat`**, the projector splits into:
   - `project_chat_stream` (unchanged signature; default
     behaviour — drops `ReasoningDelta`)
   - `project_chat_stream_with(rx, …, ChatProjectionConfig)` —
     when `include_thinking: true` and `reasoning_markers:
     Some(_)`, re-wraps reasoning content with the literal
     open/close marker text and emits as content deltas.
     Preserves the on-the-wire shape that helexa-acp's
     `ThinkParser` expects.

4. **HTTP handler** reads `x-include-thinking: true` (case-
   insensitive `1`/`true`/`yes`) from the request headers and
   threads it into the projection config. cortex-gateway already
   forwards arbitrary headers verbatim, so the opt-in works
   end-to-end without gateway changes.

5. **helexa-acp's `openai_chat` provider** sets
   `x-include-thinking: true` on every request so its existing
   `ThinkParser` keeps receiving the marked content stream.
   `ThinkParser` itself is unchanged — needed for endpoints that
   aren't reasoning-aware (OpenRouter, OpenAI directly, etc.).

## Acceptance

- Zed's commit-message generator (vanilla chat-completions
  client, no `x-include-thinking`) gets clean commit messages
  with no `<think>` block.
- helexa-acp sessions continue to render thinking in Zed's
  thought UI via the opt-in path.
- Models without reasoning tokens declared in their tokenizer
  pass through unchanged.
- Implementation contains zero references to "qwen3" or any
  specific model — entirely driven by tokenizer metadata.

## Tests

9 new tests in `wire::event` (token-pair detection across 4
marker conventions, edge cases) and `wire::openai_chat` (default
drop, opt-in re-wrap with multi-chunk reasoning, close-marker on
Finish, fallback when markers absent, off-switch with markers
present). All 213 workspace tests pass; fmt + clippy clean.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-31 17:55:04 +03:00
parent fdc0adb738
commit 7733eecba5
6 changed files with 645 additions and 67 deletions

View File

@@ -26,7 +26,10 @@ use cortex_core::openai::{
ChatMessage, MessageContent, Usage,
};
use crate::wire::{FinishReason, InferenceEvent, openai_chat as wire_chat};
use crate::wire::{
FinishReason, InferenceEvent, ReasoningTokenPair, detect_reasoning_token_pair,
openai_chat as wire_chat,
};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
@@ -148,6 +151,15 @@ pub struct LoadedModel {
/// for the TP path (which already had this invariant by accident
/// because the pool lock covered the same window).
pub inference_lock: tokio::sync::Mutex<()>,
/// Open/close token IDs for the reasoning marker this model
/// emits, populated once at load time by probing the tokenizer's
/// added-tokens table. `None` for non-reasoning models or
/// reasoning models whose markers aren't single tokens. When
/// `Some`, the streaming inference loop splits output into
/// [`InferenceEvent::TextDelta`] and
/// [`InferenceEvent::ReasoningDelta`] at the token boundary;
/// when `None` everything is `TextDelta`.
pub reasoning_tokens: Option<ReasoningTokenPair>,
}
impl LoadedModel {
@@ -203,6 +215,11 @@ pub struct TpLoadedModel {
/// `CudaContext`, `NcclState`, and the boxed `TpLeaderModel`
/// referenced by `leader_handle`.
pub worker: Arc<super::device_worker::DeviceWorkerHandle>,
/// Same shape as [`LoadedModel::reasoning_tokens`] — open/close
/// reasoning marker token IDs probed from the tokenizer at
/// load time. `None` when the model declares no reasoning
/// markers.
pub reasoning_tokens: Option<ReasoningTokenPair>,
}
#[cfg(feature = "cuda")]
@@ -1594,13 +1611,34 @@ impl CandleHarness {
pub async fn chat_completion_stream(
&self,
request: ChatCompletionRequest,
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
self.chat_completion_stream_with(request, wire_chat::ChatProjectionConfig::default())
.await
}
/// Same as [`Self::chat_completion_stream`] but lets the caller
/// pick the projection config — currently used by the HTTP
/// handler to thread `x-include-thinking` from the request
/// headers into [`wire_chat::ChatProjectionConfig::include_thinking`].
pub async fn chat_completion_stream_with(
&self,
request: ChatCompletionRequest,
mut config: wire_chat::ChatProjectionConfig,
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
let stream = self.inference_stream(request).await?;
Ok(wire_chat::project_chat_stream(
// Fill in the model's reasoning markers if the caller
// didn't pre-populate them — they're a property of the
// loaded model (which the HTTP handler doesn't reach into
// directly), not of the request.
if config.reasoning_markers.is_none() {
config.reasoning_markers = stream.reasoning_markers.clone();
}
Ok(wire_chat::project_chat_stream_with(
stream.events,
stream.id,
stream.created,
stream.model_id,
config,
))
}
@@ -1747,6 +1785,7 @@ impl CandleHarness {
#[cfg(feature = "cuda")]
{
let prompt_tokens = prompt_tokens.clone();
let reasoning_tokens_inner = loaded.reasoning_tokens.clone();
tokio::spawn(
async move {
let _inference_guard = loaded_for_task.inference_lock.lock().await;
@@ -1760,6 +1799,7 @@ impl CandleHarness {
top_p,
seed,
eos_id,
reasoning_tokens_inner,
tx,
)
.await
@@ -1799,6 +1839,7 @@ impl CandleHarness {
unreachable!("worker handle present without cuda feature");
}
} else if let Some(arch_arc) = loaded.arch.clone() {
let reasoning_tokens_inner = loaded.reasoning_tokens.clone();
tokio::task::spawn_blocking(move || {
let _g = span_for_task.enter();
// `blocking_lock` is safe here: spawn_blocking runs on
@@ -1816,6 +1857,7 @@ impl CandleHarness {
top_p,
seed,
eos_id,
reasoning_tokens_inner.as_ref(),
&tx,
) {
Ok(()) => tracing::info!(
@@ -1853,11 +1895,13 @@ impl CandleHarness {
// Hand the raw event channel back to the public entry
// points (chat_completion_stream / responses_stream); they
// pick the wire projection.
let reasoning_markers = loaded.reasoning_tokens.clone();
Ok(InferenceStream {
events: event_rx,
id,
created,
model_id,
reasoning_markers,
})
}
}
@@ -1881,6 +1925,13 @@ pub struct InferenceStream {
/// Local model id (no endpoint prefix). Stamped into every
/// wire-format frame so consumers can correlate.
pub model_id: String,
/// Open/close reasoning marker text (and token ids) for the
/// loaded model, or `None` for non-reasoning models. Used by
/// the chat-completions projector when `include_thinking` is
/// set — the projector re-wraps reasoning content with the
/// literal markers so client-side parsers (helexa-acp's
/// `ThinkParser`) see the original on-the-wire shape.
pub reasoning_markers: Option<ReasoningTokenPair>,
}
#[async_trait]
@@ -1987,6 +2038,26 @@ impl Harness for CandleHarness {
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
// Probe for reasoning markers in the tokenizer's
// added-tokens table — `<think>` / `</think>` on Qwen3 +
// DeepSeek-R1 + gpt-oss, `[THINK]` / `[/THINK]` on
// Mistral Magistral, etc. `None` for non-reasoning models.
// The streaming loop uses this to route between TextDelta
// and ReasoningDelta without any hardcoded model
// knowledge; wire projectors decide what to do with the
// split.
let reasoning_tokens = detect_reasoning_token_pair(|s| tokenizer.token_to_id(s));
if let Some(ref pair) = reasoning_tokens {
tracing::info!(
model = %spec.model_id,
open = %pair.open_text,
close = %pair.close_text,
open_id = pair.open_id,
close_id = pair.close_id,
"reasoning markers detected — streaming will route ReasoningDelta separately"
);
}
let loaded = Arc::new(LoadedModel {
model_id: spec.model_id.clone(),
arch: arch_local,
@@ -1998,6 +2069,7 @@ impl Harness for CandleHarness {
worker,
arch_handle,
inference_lock: tokio::sync::Mutex::new(()),
reasoning_tokens,
});
let mut models = self.models.write().await;
@@ -2170,6 +2242,17 @@ impl CandleHarness {
// 6. Tokenizer (same as single-GPU path).
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
// Reasoning-marker probe — identical to the single-GPU
// path. See `LoadedModel.reasoning_tokens` for the why.
let reasoning_tokens = detect_reasoning_token_pair(|s| tokenizer.token_to_id(s));
if let Some(ref pair) = reasoning_tokens {
tracing::info!(
model = %spec.model_id,
open = %pair.open_text,
close = %pair.close_text,
"TP load: reasoning markers detected"
);
}
let tp_loaded = StdArc::new(TpLoadedModel {
model_id: spec.model_id.clone(),
@@ -2183,6 +2266,7 @@ impl CandleHarness {
// single `Arc` shared between WorkerPool and
// TpLoadedModel so they reference the same thread.
worker: leader_worker,
reasoning_tokens,
});
let mut models = self.models.write().await;
@@ -2331,6 +2415,7 @@ impl CandleHarness {
let id = format!("chatcmpl-{:x}", unix_subsec_nanos());
let created = unix_now_secs();
let tokenizer = tp.tokenizer.clone();
let reasoning_tokens = tp.reasoning_tokens.clone();
// The spawned orchestration task below consumes both `id`
// and `model_id` (tracing, pool lookups, NCCL ops use them
// heavily). The wire projector at the bottom of this fn
@@ -2396,6 +2481,10 @@ impl CandleHarness {
// split a multi-byte char across tokens.
let mut decode_stream = tokenizer.decode_stream(true);
let mut finish_reason = FinishReason::Length;
// Reasoning marker state machine — same as the
// single-GPU path. The TP path needs its own copy
// because the spawn closure owns it.
let mut in_reasoning = false;
'work: {
if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await {
@@ -2464,11 +2553,17 @@ impl CandleHarness {
if Some(next_token) == eos_id {
finish_reason = FinishReason::Stop;
} else if handle_reasoning_marker(
next_token,
reasoning_tokens.as_ref(),
&mut in_reasoning,
) {
all_tokens.push(next_token);
} else {
all_tokens.push(next_token);
match decode_stream.step(next_token) {
Ok(Some(delta)) => {
if !emit_delta(&delta, &tx).await {
if !emit_delta(&delta, &tx, in_reasoning).await {
// Client gone — treat as normal stream end,
// not a failure. No log spam.
break 'work;
@@ -2543,10 +2638,18 @@ impl CandleHarness {
finish_reason = FinishReason::Stop;
break;
}
if handle_reasoning_marker(
next_token,
reasoning_tokens.as_ref(),
&mut in_reasoning,
) {
all_tokens.push(next_token);
continue;
}
all_tokens.push(next_token);
match decode_stream.step(next_token) {
Ok(Some(delta)) => {
if !emit_delta(&delta, &tx).await {
if !emit_delta(&delta, &tx, in_reasoning).await {
break 'work;
}
}
@@ -2612,11 +2715,13 @@ impl CandleHarness {
// points; they pick the wire projection. Uses the clones
// we stashed before the spawn — the originals were moved
// into the orchestration task above.
let reasoning_markers = tp.reasoning_tokens.clone();
Ok(InferenceStream {
events: event_rx,
id: projector_id,
created,
model_id: projector_model_id,
reasoning_markers,
})
}
}
@@ -2855,25 +2960,57 @@ async fn chat_completion_tp_inner(
/// [`crate::wire::openai_chat`] stamps it onto every chunk
/// downstream.
#[cfg(feature = "cuda")]
async fn emit_delta(delta: &str, tx: &mpsc::Sender<InferenceEvent>) -> bool {
async fn emit_delta(delta: &str, tx: &mpsc::Sender<InferenceEvent>, in_reasoning: bool) -> bool {
if delta.is_empty() {
return true;
}
tx.send(InferenceEvent::TextDelta(delta.into()))
.await
.is_ok()
let event = if in_reasoning {
InferenceEvent::ReasoningDelta(delta.into())
} else {
InferenceEvent::TextDelta(delta.into())
};
tx.send(event).await.is_ok()
}
/// Sync counterpart of [`emit_delta`] for the CPU path's
/// `spawn_blocking` closure. Same shape, `blocking_send` instead of
/// `send`. Kept as a separate fn so the async / blocking-send choice
/// is local to one place per path.
fn emit_delta_blocking(delta: &str, tx: &mpsc::Sender<InferenceEvent>) -> bool {
fn emit_delta_blocking(delta: &str, tx: &mpsc::Sender<InferenceEvent>, in_reasoning: bool) -> bool {
if delta.is_empty() {
return true;
}
tx.blocking_send(InferenceEvent::TextDelta(delta.into()))
.is_ok()
let event = if in_reasoning {
InferenceEvent::ReasoningDelta(delta.into())
} else {
InferenceEvent::TextDelta(delta.into())
};
tx.blocking_send(event).is_ok()
}
/// If `next_token` is one of the loaded model's reasoning markers,
/// flip `in_reasoning` and return `true` to tell the caller to
/// skip detokenisation + emission for this token. The markers
/// themselves never appear in the streamed output — they exist
/// purely to transition state.
///
/// `pair = None` short-circuits to `false` (no reasoning markers
/// configured for this model → pass-through).
fn handle_reasoning_marker(
next_token: u32,
pair: Option<&ReasoningTokenPair>,
in_reasoning: &mut bool,
) -> bool {
let Some(pair) = pair else { return false };
if next_token == pair.open_id {
*in_reasoning = true;
return true;
}
if next_token == pair.close_id {
*in_reasoning = false;
return true;
}
false
}
/// Errors returned by `CandleHarness::chat_completion`. The
@@ -3038,6 +3175,7 @@ async fn stream_inference_via_worker(
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
reasoning_tokens: Option<ReasoningTokenPair>,
tx: mpsc::Sender<InferenceEvent>,
) -> Result<String> {
let mut logits_processor = {
@@ -3062,6 +3200,10 @@ async fn stream_inference_via_worker(
let mut decode_stream = tokenizer.decode_stream(true);
let prompt_len = prompt_tokens.len();
let mut finish_reason = FinishReason::Length;
// Reasoning marker state machine — see `run_inference_streaming`
// for the why. Markers never reach `decode_stream`; they only
// toggle the variant `emit_delta` produces.
let mut in_reasoning = false;
worker
.clear_kv_cache(handle)
@@ -3088,50 +3230,56 @@ async fn stream_inference_via_worker(
if Some(next_token) == eos_id {
finish_reason = FinishReason::Stop;
} else if handle_reasoning_marker(next_token, reasoning_tokens.as_ref(), &mut in_reasoning) {
all_tokens.push(next_token);
} else {
all_tokens.push(next_token);
match decode_stream.step(next_token) {
Ok(Some(delta)) => {
if !emit_delta(&delta, &tx).await {
if !emit_delta(&delta, &tx, in_reasoning).await {
return Ok(finish_reason.as_openai_str().to_string());
}
}
Ok(None) => {}
Err(e) => tracing::warn!(error = %e, "decode_stream step failed"),
}
}
for index in 0..max_new.saturating_sub(1) {
let logits_vec = worker
.forward_logits(handle, vec![next_token], prompt_len + index)
.await
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
step = index,
?health,
"chat_completion (stream/worker): decode sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
finish_reason = FinishReason::Stop;
break;
for index in 0..max_new.saturating_sub(1) {
let logits_vec = worker
.forward_logits(handle, vec![next_token], prompt_len + index)
.await
.map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?;
let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?;
next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) {
Ok(t) => t,
Err(e) => {
let health = logits_health_slice(&logits_vec);
tracing::warn!(
step = index,
?health,
"chat_completion (stream/worker): decode sample failed; logits unhealthy"
);
return Err(e);
}
};
if Some(next_token) == eos_id {
finish_reason = FinishReason::Stop;
break;
}
if handle_reasoning_marker(next_token, reasoning_tokens.as_ref(), &mut in_reasoning) {
all_tokens.push(next_token);
match decode_stream.step(next_token) {
Ok(Some(delta)) => {
if !emit_delta(&delta, &tx).await {
return Ok(finish_reason.as_openai_str().to_string());
}
continue;
}
all_tokens.push(next_token);
match decode_stream.step(next_token) {
Ok(Some(delta)) => {
if !emit_delta(&delta, &tx, in_reasoning).await {
return Ok(finish_reason.as_openai_str().to_string());
}
Ok(None) => {}
Err(e) => tracing::warn!(error = %e, "decode_stream step failed"),
}
Ok(None) => {}
Err(e) => tracing::warn!(error = %e, "decode_stream step failed"),
}
}
@@ -3211,6 +3359,7 @@ fn run_inference_streaming(
top_p: Option<f64>,
seed: u64,
eos_id: Option<u32>,
reasoning_tokens: Option<&ReasoningTokenPair>,
tx: &mpsc::Sender<InferenceEvent>,
) -> Result<()> {
let mut logits_processor = {
@@ -3232,6 +3381,12 @@ fn run_inference_streaming(
// boundaries and only emits when a clean codepoint completes.
let mut decode_stream = tokenizer.decode_stream(true);
let mut finish_reason = FinishReason::Length;
// Reasoning marker state machine. Flips on
// `next_token == reasoning_tokens.open_id`, off on
// `.close_id`. The marker tokens themselves never feed into
// `decode_stream` — they aren't part of any visible output,
// they exist purely as state transitions.
let mut in_reasoning = false;
arch.clear_kv_cache()?;
let logits = chunked_prefill_local(arch, device, prompt_tokens)?;
@@ -3239,36 +3394,42 @@ fn run_inference_streaming(
if Some(next_token) == eos_id {
finish_reason = FinishReason::Stop;
} else if handle_reasoning_marker(next_token, reasoning_tokens, &mut in_reasoning) {
all_tokens.push(next_token);
} else {
all_tokens.push(next_token);
match decode_stream.step(next_token) {
Ok(Some(delta)) => {
if !emit_delta_blocking(&delta, tx) {
if !emit_delta_blocking(&delta, tx, in_reasoning) {
return Ok(());
}
}
Ok(None) => {}
Err(e) => tracing::warn!(error = %e, "stream: decode_stream step failed"),
}
}
for index in 0..max_new.saturating_sub(1) {
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
if Some(next_token) == eos_id {
finish_reason = FinishReason::Stop;
break;
}
for index in 0..max_new.saturating_sub(1) {
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
let logits = arch.forward(&input, prompt_tokens.len() + index)?;
next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?;
if Some(next_token) == eos_id {
finish_reason = FinishReason::Stop;
break;
}
if handle_reasoning_marker(next_token, reasoning_tokens, &mut in_reasoning) {
all_tokens.push(next_token);
match decode_stream.step(next_token) {
Ok(Some(delta)) => {
if !emit_delta_blocking(&delta, tx) {
return Ok(());
}
continue;
}
all_tokens.push(next_token);
match decode_stream.step(next_token) {
Ok(Some(delta)) => {
if !emit_delta_blocking(&delta, tx, in_reasoning) {
return Ok(());
}
Ok(None) => {}
Err(e) => tracing::warn!(error = %e, "stream: decode_stream step failed"),
}
Ok(None) => {}
Err(e) => tracing::warn!(error = %e, "stream: decode_stream step failed"),
}
}