diff --git a/crates/helexa-acp/src/provider/openai_chat.rs b/crates/helexa-acp/src/provider/openai_chat.rs index 6d5bd8c..d42e9d6 100644 --- a/crates/helexa-acp/src/provider/openai_chat.rs +++ b/crates/helexa-acp/src/provider/openai_chat.rs @@ -109,6 +109,16 @@ impl Provider for OpenAIChatProvider { let mut req = self .http .post(self.endpoint.chat_completions_url()) + // Tell reasoning-aware servers (neuron after issue #8) + // to include the model's `` markers in the + // stream rather than stripping them. helexa-acp's + // ThinkParser routes the marked content to Zed's + // thought UI; without this header neuron would + // default to clean content (the right choice for + // naïve clients like Zed's commit-message generator + // but wrong for us). Servers that don't recognise + // the header ignore it harmlessly. + .header("x-include-thinking", "true") .json(&body); if let Some(key) = &self.api_key { req = req.bearer_auth(key); diff --git a/crates/neuron/src/api.rs b/crates/neuron/src/api.rs index ec32e62..0a81789 100644 --- a/crates/neuron/src/api.rs +++ b/crates/neuron/src/api.rs @@ -4,7 +4,7 @@ use crate::activation::ActivationTracker; use crate::harness::HarnessRegistry; use crate::harness::candle::{CandleHarness, InferenceError}; use crate::health::HealthCache; -use crate::wire::openai_responses; +use crate::wire::{openai_chat, openai_responses}; use axum::Router; use axum::extract::{Path, State}; use axum::http::StatusCode; @@ -148,6 +148,7 @@ async fn model_endpoint( /// `ChatCompletionResponse`. async fn chat_completions( State(state): State>, + headers: axum::http::HeaderMap, Json(req): Json, ) -> impl IntoResponse { let Some(candle) = state.candle.as_ref().map(Arc::clone) else { @@ -158,8 +159,23 @@ async fn chat_completions( .into_response(); }; + // Reasoning-content opt-in. Off by default → naïve clients + // (Zed's commit-message generator, vanilla OpenAI clients) + // never see `` blocks. On when the caller sends + // `x-include-thinking: true` (helexa-acp does this so its + // own ThinkParser keeps working unchanged). + let include_thinking = headers + .get("x-include-thinking") + .and_then(|v| v.to_str().ok()) + .map(|s| matches!(s.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes")) + .unwrap_or(false); + let chat_config = openai_chat::ChatProjectionConfig { + include_thinking, + reasoning_markers: None, // filled in from the loaded model inside candle + }; + if req.stream.unwrap_or(false) { - match candle.chat_completion_stream(req).await { + match candle.chat_completion_stream_with(req, chat_config).await { Ok(rx) => { // Each chunk → one SSE `data: {json}` line. After the // channel closes, append the OpenAI [DONE] terminator. diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 30f4a71..3d73c38 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -26,7 +26,10 @@ use cortex_core::openai::{ ChatMessage, MessageContent, Usage, }; -use crate::wire::{FinishReason, InferenceEvent, openai_chat as wire_chat}; +use crate::wire::{ + FinishReason, InferenceEvent, ReasoningTokenPair, detect_reasoning_token_pair, + openai_chat as wire_chat, +}; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; @@ -148,6 +151,15 @@ pub struct LoadedModel { /// for the TP path (which already had this invariant by accident /// because the pool lock covered the same window). pub inference_lock: tokio::sync::Mutex<()>, + /// Open/close token IDs for the reasoning marker this model + /// emits, populated once at load time by probing the tokenizer's + /// added-tokens table. `None` for non-reasoning models or + /// reasoning models whose markers aren't single tokens. When + /// `Some`, the streaming inference loop splits output into + /// [`InferenceEvent::TextDelta`] and + /// [`InferenceEvent::ReasoningDelta`] at the token boundary; + /// when `None` everything is `TextDelta`. + pub reasoning_tokens: Option, } impl LoadedModel { @@ -203,6 +215,11 @@ pub struct TpLoadedModel { /// `CudaContext`, `NcclState`, and the boxed `TpLeaderModel` /// referenced by `leader_handle`. pub worker: Arc, + /// Same shape as [`LoadedModel::reasoning_tokens`] — open/close + /// reasoning marker token IDs probed from the tokenizer at + /// load time. `None` when the model declares no reasoning + /// markers. + pub reasoning_tokens: Option, } #[cfg(feature = "cuda")] @@ -1594,13 +1611,34 @@ impl CandleHarness { pub async fn chat_completion_stream( &self, request: ChatCompletionRequest, + ) -> Result, InferenceError> { + self.chat_completion_stream_with(request, wire_chat::ChatProjectionConfig::default()) + .await + } + + /// Same as [`Self::chat_completion_stream`] but lets the caller + /// pick the projection config — currently used by the HTTP + /// handler to thread `x-include-thinking` from the request + /// headers into [`wire_chat::ChatProjectionConfig::include_thinking`]. + pub async fn chat_completion_stream_with( + &self, + request: ChatCompletionRequest, + mut config: wire_chat::ChatProjectionConfig, ) -> Result, InferenceError> { let stream = self.inference_stream(request).await?; - Ok(wire_chat::project_chat_stream( + // Fill in the model's reasoning markers if the caller + // didn't pre-populate them — they're a property of the + // loaded model (which the HTTP handler doesn't reach into + // directly), not of the request. + if config.reasoning_markers.is_none() { + config.reasoning_markers = stream.reasoning_markers.clone(); + } + Ok(wire_chat::project_chat_stream_with( stream.events, stream.id, stream.created, stream.model_id, + config, )) } @@ -1747,6 +1785,7 @@ impl CandleHarness { #[cfg(feature = "cuda")] { let prompt_tokens = prompt_tokens.clone(); + let reasoning_tokens_inner = loaded.reasoning_tokens.clone(); tokio::spawn( async move { let _inference_guard = loaded_for_task.inference_lock.lock().await; @@ -1760,6 +1799,7 @@ impl CandleHarness { top_p, seed, eos_id, + reasoning_tokens_inner, tx, ) .await @@ -1799,6 +1839,7 @@ impl CandleHarness { unreachable!("worker handle present without cuda feature"); } } else if let Some(arch_arc) = loaded.arch.clone() { + let reasoning_tokens_inner = loaded.reasoning_tokens.clone(); tokio::task::spawn_blocking(move || { let _g = span_for_task.enter(); // `blocking_lock` is safe here: spawn_blocking runs on @@ -1816,6 +1857,7 @@ impl CandleHarness { top_p, seed, eos_id, + reasoning_tokens_inner.as_ref(), &tx, ) { Ok(()) => tracing::info!( @@ -1853,11 +1895,13 @@ impl CandleHarness { // Hand the raw event channel back to the public entry // points (chat_completion_stream / responses_stream); they // pick the wire projection. + let reasoning_markers = loaded.reasoning_tokens.clone(); Ok(InferenceStream { events: event_rx, id, created, model_id, + reasoning_markers, }) } } @@ -1881,6 +1925,13 @@ pub struct InferenceStream { /// Local model id (no endpoint prefix). Stamped into every /// wire-format frame so consumers can correlate. pub model_id: String, + /// Open/close reasoning marker text (and token ids) for the + /// loaded model, or `None` for non-reasoning models. Used by + /// the chat-completions projector when `include_thinking` is + /// set — the projector re-wraps reasoning content with the + /// literal markers so client-side parsers (helexa-acp's + /// `ThinkParser`) see the original on-the-wire shape. + pub reasoning_markers: Option, } #[async_trait] @@ -1987,6 +2038,26 @@ impl Harness for CandleHarness { let tokenizer = Tokenizer::from_file(&tokenizer_path) .map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?; + // Probe for reasoning markers in the tokenizer's + // added-tokens table — `` / `` on Qwen3 + + // DeepSeek-R1 + gpt-oss, `[THINK]` / `[/THINK]` on + // Mistral Magistral, etc. `None` for non-reasoning models. + // The streaming loop uses this to route between TextDelta + // and ReasoningDelta without any hardcoded model + // knowledge; wire projectors decide what to do with the + // split. + let reasoning_tokens = detect_reasoning_token_pair(|s| tokenizer.token_to_id(s)); + if let Some(ref pair) = reasoning_tokens { + tracing::info!( + model = %spec.model_id, + open = %pair.open_text, + close = %pair.close_text, + open_id = pair.open_id, + close_id = pair.close_id, + "reasoning markers detected — streaming will route ReasoningDelta separately" + ); + } + let loaded = Arc::new(LoadedModel { model_id: spec.model_id.clone(), arch: arch_local, @@ -1998,6 +2069,7 @@ impl Harness for CandleHarness { worker, arch_handle, inference_lock: tokio::sync::Mutex::new(()), + reasoning_tokens, }); let mut models = self.models.write().await; @@ -2170,6 +2242,17 @@ impl CandleHarness { // 6. Tokenizer (same as single-GPU path). let tokenizer = Tokenizer::from_file(&tokenizer_path) .map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?; + // Reasoning-marker probe — identical to the single-GPU + // path. See `LoadedModel.reasoning_tokens` for the why. + let reasoning_tokens = detect_reasoning_token_pair(|s| tokenizer.token_to_id(s)); + if let Some(ref pair) = reasoning_tokens { + tracing::info!( + model = %spec.model_id, + open = %pair.open_text, + close = %pair.close_text, + "TP load: reasoning markers detected" + ); + } let tp_loaded = StdArc::new(TpLoadedModel { model_id: spec.model_id.clone(), @@ -2183,6 +2266,7 @@ impl CandleHarness { // single `Arc` shared between WorkerPool and // TpLoadedModel so they reference the same thread. worker: leader_worker, + reasoning_tokens, }); let mut models = self.models.write().await; @@ -2331,6 +2415,7 @@ impl CandleHarness { let id = format!("chatcmpl-{:x}", unix_subsec_nanos()); let created = unix_now_secs(); let tokenizer = tp.tokenizer.clone(); + let reasoning_tokens = tp.reasoning_tokens.clone(); // The spawned orchestration task below consumes both `id` // and `model_id` (tracing, pool lookups, NCCL ops use them // heavily). The wire projector at the bottom of this fn @@ -2396,6 +2481,10 @@ impl CandleHarness { // split a multi-byte char across tokens. let mut decode_stream = tokenizer.decode_stream(true); let mut finish_reason = FinishReason::Length; + // Reasoning marker state machine — same as the + // single-GPU path. The TP path needs its own copy + // because the spawn closure owns it. + let mut in_reasoning = false; 'work: { if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await { @@ -2464,11 +2553,17 @@ impl CandleHarness { if Some(next_token) == eos_id { finish_reason = FinishReason::Stop; + } else if handle_reasoning_marker( + next_token, + reasoning_tokens.as_ref(), + &mut in_reasoning, + ) { + all_tokens.push(next_token); } else { all_tokens.push(next_token); match decode_stream.step(next_token) { Ok(Some(delta)) => { - if !emit_delta(&delta, &tx).await { + if !emit_delta(&delta, &tx, in_reasoning).await { // Client gone — treat as normal stream end, // not a failure. No log spam. break 'work; @@ -2543,10 +2638,18 @@ impl CandleHarness { finish_reason = FinishReason::Stop; break; } + if handle_reasoning_marker( + next_token, + reasoning_tokens.as_ref(), + &mut in_reasoning, + ) { + all_tokens.push(next_token); + continue; + } all_tokens.push(next_token); match decode_stream.step(next_token) { Ok(Some(delta)) => { - if !emit_delta(&delta, &tx).await { + if !emit_delta(&delta, &tx, in_reasoning).await { break 'work; } } @@ -2612,11 +2715,13 @@ impl CandleHarness { // points; they pick the wire projection. Uses the clones // we stashed before the spawn — the originals were moved // into the orchestration task above. + let reasoning_markers = tp.reasoning_tokens.clone(); Ok(InferenceStream { events: event_rx, id: projector_id, created, model_id: projector_model_id, + reasoning_markers, }) } } @@ -2855,25 +2960,57 @@ async fn chat_completion_tp_inner( /// [`crate::wire::openai_chat`] stamps it onto every chunk /// downstream. #[cfg(feature = "cuda")] -async fn emit_delta(delta: &str, tx: &mpsc::Sender) -> bool { +async fn emit_delta(delta: &str, tx: &mpsc::Sender, in_reasoning: bool) -> bool { if delta.is_empty() { return true; } - tx.send(InferenceEvent::TextDelta(delta.into())) - .await - .is_ok() + let event = if in_reasoning { + InferenceEvent::ReasoningDelta(delta.into()) + } else { + InferenceEvent::TextDelta(delta.into()) + }; + tx.send(event).await.is_ok() } /// Sync counterpart of [`emit_delta`] for the CPU path's /// `spawn_blocking` closure. Same shape, `blocking_send` instead of /// `send`. Kept as a separate fn so the async / blocking-send choice /// is local to one place per path. -fn emit_delta_blocking(delta: &str, tx: &mpsc::Sender) -> bool { +fn emit_delta_blocking(delta: &str, tx: &mpsc::Sender, in_reasoning: bool) -> bool { if delta.is_empty() { return true; } - tx.blocking_send(InferenceEvent::TextDelta(delta.into())) - .is_ok() + let event = if in_reasoning { + InferenceEvent::ReasoningDelta(delta.into()) + } else { + InferenceEvent::TextDelta(delta.into()) + }; + tx.blocking_send(event).is_ok() +} + +/// If `next_token` is one of the loaded model's reasoning markers, +/// flip `in_reasoning` and return `true` to tell the caller to +/// skip detokenisation + emission for this token. The markers +/// themselves never appear in the streamed output — they exist +/// purely to transition state. +/// +/// `pair = None` short-circuits to `false` (no reasoning markers +/// configured for this model → pass-through). +fn handle_reasoning_marker( + next_token: u32, + pair: Option<&ReasoningTokenPair>, + in_reasoning: &mut bool, +) -> bool { + let Some(pair) = pair else { return false }; + if next_token == pair.open_id { + *in_reasoning = true; + return true; + } + if next_token == pair.close_id { + *in_reasoning = false; + return true; + } + false } /// Errors returned by `CandleHarness::chat_completion`. The @@ -3038,6 +3175,7 @@ async fn stream_inference_via_worker( top_p: Option, seed: u64, eos_id: Option, + reasoning_tokens: Option, tx: mpsc::Sender, ) -> Result { let mut logits_processor = { @@ -3062,6 +3200,10 @@ async fn stream_inference_via_worker( let mut decode_stream = tokenizer.decode_stream(true); let prompt_len = prompt_tokens.len(); let mut finish_reason = FinishReason::Length; + // Reasoning marker state machine — see `run_inference_streaming` + // for the why. Markers never reach `decode_stream`; they only + // toggle the variant `emit_delta` produces. + let mut in_reasoning = false; worker .clear_kv_cache(handle) @@ -3088,50 +3230,56 @@ async fn stream_inference_via_worker( if Some(next_token) == eos_id { finish_reason = FinishReason::Stop; + } else if handle_reasoning_marker(next_token, reasoning_tokens.as_ref(), &mut in_reasoning) { + all_tokens.push(next_token); } else { all_tokens.push(next_token); match decode_stream.step(next_token) { Ok(Some(delta)) => { - if !emit_delta(&delta, &tx).await { + if !emit_delta(&delta, &tx, in_reasoning).await { return Ok(finish_reason.as_openai_str().to_string()); } } Ok(None) => {} Err(e) => tracing::warn!(error = %e, "decode_stream step failed"), } + } - for index in 0..max_new.saturating_sub(1) { - let logits_vec = worker - .forward_logits(handle, vec![next_token], prompt_len + index) - .await - .map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?; - let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?; - next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) { - Ok(t) => t, - Err(e) => { - let health = logits_health_slice(&logits_vec); - tracing::warn!( - step = index, - ?health, - "chat_completion (stream/worker): decode sample failed; logits unhealthy" - ); - return Err(e); - } - }; - if Some(next_token) == eos_id { - finish_reason = FinishReason::Stop; - break; + for index in 0..max_new.saturating_sub(1) { + let logits_vec = worker + .forward_logits(handle, vec![next_token], prompt_len + index) + .await + .map_err(|e| anyhow::anyhow!("decode step {index}: {e}"))?; + let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?; + next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) { + Ok(t) => t, + Err(e) => { + let health = logits_health_slice(&logits_vec); + tracing::warn!( + step = index, + ?health, + "chat_completion (stream/worker): decode sample failed; logits unhealthy" + ); + return Err(e); } + }; + if Some(next_token) == eos_id { + finish_reason = FinishReason::Stop; + break; + } + if handle_reasoning_marker(next_token, reasoning_tokens.as_ref(), &mut in_reasoning) { all_tokens.push(next_token); - match decode_stream.step(next_token) { - Ok(Some(delta)) => { - if !emit_delta(&delta, &tx).await { - return Ok(finish_reason.as_openai_str().to_string()); - } + continue; + } + all_tokens.push(next_token); + match decode_stream.step(next_token) { + Ok(Some(delta)) => { + if !emit_delta(&delta, &tx, in_reasoning).await { + return Ok(finish_reason.as_openai_str().to_string()); } - Ok(None) => {} - Err(e) => tracing::warn!(error = %e, "decode_stream step failed"), } + Ok(None) => {} + Err(e) => tracing::warn!(error = %e, "decode_stream step failed"), } } @@ -3211,6 +3359,7 @@ fn run_inference_streaming( top_p: Option, seed: u64, eos_id: Option, + reasoning_tokens: Option<&ReasoningTokenPair>, tx: &mpsc::Sender, ) -> Result<()> { let mut logits_processor = { @@ -3232,6 +3381,12 @@ fn run_inference_streaming( // boundaries and only emits when a clean codepoint completes. let mut decode_stream = tokenizer.decode_stream(true); let mut finish_reason = FinishReason::Length; + // Reasoning marker state machine. Flips on + // `next_token == reasoning_tokens.open_id`, off on + // `.close_id`. The marker tokens themselves never feed into + // `decode_stream` — they aren't part of any visible output, + // they exist purely as state transitions. + let mut in_reasoning = false; arch.clear_kv_cache()?; let logits = chunked_prefill_local(arch, device, prompt_tokens)?; @@ -3239,36 +3394,42 @@ fn run_inference_streaming( if Some(next_token) == eos_id { finish_reason = FinishReason::Stop; + } else if handle_reasoning_marker(next_token, reasoning_tokens, &mut in_reasoning) { + all_tokens.push(next_token); } else { all_tokens.push(next_token); match decode_stream.step(next_token) { Ok(Some(delta)) => { - if !emit_delta_blocking(&delta, tx) { + if !emit_delta_blocking(&delta, tx, in_reasoning) { return Ok(()); } } Ok(None) => {} Err(e) => tracing::warn!(error = %e, "stream: decode_stream step failed"), } + } - for index in 0..max_new.saturating_sub(1) { - let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?; - let logits = arch.forward(&input, prompt_tokens.len() + index)?; - next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?; - if Some(next_token) == eos_id { - finish_reason = FinishReason::Stop; - break; - } + for index in 0..max_new.saturating_sub(1) { + let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?; + let logits = arch.forward(&input, prompt_tokens.len() + index)?; + next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?; + if Some(next_token) == eos_id { + finish_reason = FinishReason::Stop; + break; + } + if handle_reasoning_marker(next_token, reasoning_tokens, &mut in_reasoning) { all_tokens.push(next_token); - match decode_stream.step(next_token) { - Ok(Some(delta)) => { - if !emit_delta_blocking(&delta, tx) { - return Ok(()); - } + continue; + } + all_tokens.push(next_token); + match decode_stream.step(next_token) { + Ok(Some(delta)) => { + if !emit_delta_blocking(&delta, tx, in_reasoning) { + return Ok(()); } - Ok(None) => {} - Err(e) => tracing::warn!(error = %e, "stream: decode_stream step failed"), } + Ok(None) => {} + Err(e) => tracing::warn!(error = %e, "stream: decode_stream step failed"), } } diff --git a/crates/neuron/src/wire/event.rs b/crates/neuron/src/wire/event.rs index f92f451..ff39a00 100644 --- a/crates/neuron/src/wire/event.rs +++ b/crates/neuron/src/wire/event.rs @@ -97,3 +97,134 @@ impl FinishReason { } } } + +/// Open/close token IDs for the reasoning marker a loaded model uses +/// (or `None` for non-reasoning models). The harness reads this once +/// at load time from the tokenizer's added-tokens table, then the +/// inference loop checks `next_token` against the pair to flip +/// between [`InferenceEvent::TextDelta`] and +/// [`InferenceEvent::ReasoningDelta`]. +/// +/// `open` and `close` text are kept alongside the IDs so wire +/// projectors that want to re-emit the literal markers (the +/// opt-in `include_thinking` path on chat completions) don't have +/// to reach back into the tokenizer for the strings. +#[derive(Debug, Clone)] +pub struct ReasoningTokenPair { + pub open_id: u32, + pub close_id: u32, + pub open_text: String, + pub close_text: String, +} + +/// Known reasoning-marker conventions. Each is a `(open, close)` +/// pair of literal token strings. Each modern reasoning model +/// declares its markers in the tokenizer's `added_tokens` table; +/// at load time we probe for whichever pair the loaded tokenizer +/// has and stash both IDs. +/// +/// Ordering matters only for tie-breaking when a model declares +/// multiple pairs (shouldn't happen in practice); the first hit +/// wins. +const KNOWN_REASONING_MARKERS: &[(&str, &str)] = &[ + // Qwen3, DeepSeek-R1, gpt-oss, and most other open-weight + // reasoning models. + ("", ""), + // Mistral Magistral. + ("[THINK]", "[/THINK]"), + // Some older derivatives; harmless to probe. + ("", ""), + ("", ""), +]; + +/// Inspect a tokenizer for known reasoning-marker pairs and return +/// the first match. The tokenizer types this trait is defined over +/// just need to expose `token_to_id(&str) -> Option` so this +/// stays decoupled from the candle crate — the production caller +/// passes a `tokenizers::Tokenizer`, but tests can fake one. +/// +/// Returns `None` when no known marker pair is fully declared +/// (both open AND close token ids must resolve). That's the +/// pass-through case — non-reasoning models, or reasoning models +/// whose tokenizer split the markers across multiple tokens (rare +/// in practice; modern reasoning tokenizers list them as +/// `added_tokens`). +pub fn detect_reasoning_token_pair(token_to_id: F) -> Option +where + F: Fn(&str) -> Option, +{ + for (open_text, close_text) in KNOWN_REASONING_MARKERS { + let open_id = token_to_id(open_text); + let close_id = token_to_id(close_text); + if let (Some(open_id), Some(close_id)) = (open_id, close_id) { + return Some(ReasoningTokenPair { + open_id, + close_id, + open_text: (*open_text).into(), + close_text: (*close_text).into(), + }); + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + fn lookup<'a>(map: &'a HashMap<&'static str, u32>) -> impl Fn(&str) -> Option + 'a { + |s| map.get(s).copied() + } + + #[test] + fn detects_qwen3_style_think_markers() { + let mut m = HashMap::new(); + m.insert("", 151648); + m.insert("", 151649); + let pair = detect_reasoning_token_pair(lookup(&m)).expect("pair detected"); + assert_eq!(pair.open_id, 151648); + assert_eq!(pair.close_id, 151649); + assert_eq!(pair.open_text, ""); + assert_eq!(pair.close_text, ""); + } + + #[test] + fn detects_mistral_magistral_markers() { + let mut m = HashMap::new(); + m.insert("[THINK]", 100); + m.insert("[/THINK]", 101); + let pair = detect_reasoning_token_pair(lookup(&m)).expect("pair detected"); + assert_eq!(pair.open_text, "[THINK]"); + } + + #[test] + fn returns_none_when_only_open_marker_present() { + // A pathological tokenizer that has `` but not + // `` shouldn't half-detect. Pass-through. + let mut m = HashMap::new(); + m.insert("", 1); + assert!(detect_reasoning_token_pair(lookup(&m)).is_none()); + } + + #[test] + fn returns_none_for_non_reasoning_tokenizer() { + let m: HashMap<&'static str, u32> = HashMap::new(); + assert!(detect_reasoning_token_pair(lookup(&m)).is_none()); + } + + #[test] + fn first_match_wins_when_multiple_pairs_declared() { + // Hypothetical tokenizer with both Qwen-style AND Mistral-style + // markers — the `` pair is earlier in the convention + // table so it wins. + let mut m = HashMap::new(); + m.insert("", 1); + m.insert("", 2); + m.insert("[THINK]", 3); + m.insert("[/THINK]", 4); + let pair = detect_reasoning_token_pair(lookup(&m)).unwrap(); + assert_eq!(pair.open_id, 1); + assert_eq!(pair.close_id, 2); + } +} diff --git a/crates/neuron/src/wire/mod.rs b/crates/neuron/src/wire/mod.rs index f704279..b531a91 100644 --- a/crates/neuron/src/wire/mod.rs +++ b/crates/neuron/src/wire/mod.rs @@ -21,4 +21,4 @@ pub mod event; pub mod openai_chat; pub mod openai_responses; -pub use event::{FinishReason, InferenceEvent}; +pub use event::{FinishReason, InferenceEvent, ReasoningTokenPair, detect_reasoning_token_pair}; diff --git a/crates/neuron/src/wire/openai_chat.rs b/crates/neuron/src/wire/openai_chat.rs index def289b..868e25c 100644 --- a/crates/neuron/src/wire/openai_chat.rs +++ b/crates/neuron/src/wire/openai_chat.rs @@ -30,13 +30,42 @@ use cortex_core::openai::{ChatCompletionChunk, ChunkChoice}; use serde_json::json; use tokio::sync::mpsc; -use super::event::{FinishReason, InferenceEvent}; +use super::event::{FinishReason, InferenceEvent, ReasoningTokenPair}; /// Output channel buffer size. Mirrors the input side's bound; one /// event maps to at most one chunk, so equal capacity keeps the /// two ends in sync without surprising memory growth. const CHUNK_CHANNEL_CAPACITY: usize = 32; +/// Per-stream config for the chat projector. Used by the +/// production handler to thread per-request choices (currently: +/// whether to surface reasoning content) into the projection +/// without bloating the function signature. +#[derive(Debug, Clone, Default)] +pub struct ChatProjectionConfig { + /// When `true`, reasoning content is re-wrapped with the + /// model's literal open/close markers and emitted as content + /// deltas — preserving the on-the-wire shape that + /// reasoning-aware clients like helexa-acp's `ThinkParser` + /// expect. + /// + /// When `false` (the default), [`InferenceEvent::ReasoningDelta`]s + /// are dropped entirely so consumers that don't know about + /// reasoning (Zed's commit-message generator, any vanilla + /// OpenAI client) don't have model-internal scratchpad + /// material leaking into their UI. The chat-completions wire + /// format has no slot for reasoning, so the default chooses + /// the safer-for-naïve-clients behaviour. + pub include_thinking: bool, + /// Open/close marker strings to re-emit when `include_thinking` + /// is set. Sourced from the loaded model's + /// [`ReasoningTokenPair`]; `None` for non-reasoning models or + /// when the caller doesn't have the pair handy (in which case + /// `include_thinking` becomes equivalent to dropping reasoning + /// because there's nothing to wrap). + pub reasoning_markers: Option, +} + /// Project an [`InferenceEvent`] receiver into a /// [`ChatCompletionChunk`] receiver. Spawns one tokio task that /// owns the input receiver for the stream's lifetime and exits @@ -46,15 +75,55 @@ const CHUNK_CHANNEL_CAPACITY: usize = 32; /// chunk so the receiver can stay generic (decoupled from /// per-request metadata). pub fn project_chat_stream( - mut rx: mpsc::Receiver, + rx: mpsc::Receiver, id: String, created: u64, model_id: String, +) -> mpsc::Receiver { + // Default config: include_thinking off, no marker rewrap. + project_chat_stream_with(rx, id, created, model_id, ChatProjectionConfig::default()) +} + +/// Same as [`project_chat_stream`] but with a per-stream config +/// (currently controlling reasoning surfacing). Production +/// callers that need the opt-in path call this directly; the +/// shorter wrapper above stays as the no-config convenience. +pub fn project_chat_stream_with( + mut rx: mpsc::Receiver, + id: String, + created: u64, + model_id: String, + config: ChatProjectionConfig, ) -> mpsc::Receiver { let (tx, out_rx) = mpsc::channel::(CHUNK_CHANNEL_CAPACITY); tokio::spawn(async move { + // Track whether the previous event was inside a reasoning + // block — used to decide when to emit the literal close + // marker on the include_thinking re-wrap path. When this + // flips from true → false (a TextDelta or Finish lands + // after one or more ReasoningDeltas), we emit the close + // marker exactly once. + let mut was_in_reasoning = false; + while let Some(event) = rx.recv().await { + // Close-marker insertion: if we're leaving a reasoning + // chain, emit the literal close marker before the + // current event. + if was_in_reasoning && !matches!(event, InferenceEvent::ReasoningDelta(_)) { + if let Some(marker) = config + .include_thinking + .then_some(()) + .and(config.reasoning_markers.as_ref()) + { + let chunk = content_chunk(&id, created, &model_id, &marker.close_text); + if tx.send(chunk).await.is_err() { + return; + } + } + was_in_reasoning = false; + } + let chunks = match event { InferenceEvent::Start => vec![role_chunk(&id, created, &model_id)], InferenceEvent::TextDelta(text) => { @@ -66,12 +135,42 @@ pub fn project_chat_stream( } vec![content_chunk(&id, created, &model_id, &text)] } - InferenceEvent::ReasoningDelta(_) => { - // Reasoning isn't representable in OpenAI chat - // streaming today. The o-series uses a separate - // `summary` event but it's gated by the - // Responses API; chat-completions just drops it. - continue; + InferenceEvent::ReasoningDelta(text) => { + if !config.include_thinking { + // Default path — reasoning has no slot in + // chat completions, so it's dropped. Naïve + // clients (Zed commit-message generator, + // any vanilla OpenAI client) get clean + // output. + continue; + } + let Some(markers) = config.reasoning_markers.as_ref() else { + // Caller asked to include thinking but + // didn't supply markers — best we can do + // is emit the content as visible text. + // Skip the wrap entirely. + if text.is_empty() { + continue; + } + let chunk = content_chunk(&id, created, &model_id, &text); + if tx.send(chunk).await.is_err() { + return; + } + continue; + }; + // First chunk of a reasoning block → open + // marker prelude. Subsequent reasoning deltas + // in the same block reuse `was_in_reasoning` + // to skip the prelude. + let mut chunks = Vec::new(); + if !was_in_reasoning { + chunks.push(content_chunk(&id, created, &model_id, &markers.open_text)); + } + if !text.is_empty() { + chunks.push(content_chunk(&id, created, &model_id, &text)); + } + was_in_reasoning = true; + chunks } InferenceEvent::Finish { reason } => { vec![final_chunk(&id, created, &model_id, reason)] @@ -238,4 +337,165 @@ mod tests { assert_eq!(out.len(), 1); assert_eq!(out[0].choices[0].delta["content"], "real"); } + + fn pair() -> ReasoningTokenPair { + ReasoningTokenPair { + open_id: 0, + close_id: 1, + open_text: "".into(), + close_text: "".into(), + } + } + + #[tokio::test] + async fn include_thinking_rewraps_reasoning_with_literal_markers() { + let (tx, rx) = mpsc::channel::(8); + let out_rx = project_chat_stream_with( + rx, + "id".into(), + 1, + "m".into(), + ChatProjectionConfig { + include_thinking: true, + reasoning_markers: Some(pair()), + }, + ); + tx.send(InferenceEvent::ReasoningDelta("first ".into())) + .await + .unwrap(); + tx.send(InferenceEvent::ReasoningDelta("second".into())) + .await + .unwrap(); + tx.send(InferenceEvent::TextDelta("answer".into())) + .await + .unwrap(); + tx.send(InferenceEvent::Finish { + reason: FinishReason::Stop, + }) + .await + .unwrap(); + drop(tx); + let out = collect(out_rx).await; + // Expected sequence: open marker → reasoning content (2 chunks) + // → close marker → visible answer → final chunk. + let contents: Vec<&str> = out + .iter() + .filter_map(|c| c.choices[0].delta["content"].as_str()) + .collect(); + assert_eq!( + contents, + vec!["", "first ", "second", "", "answer"] + ); + assert_eq!( + out.last().unwrap().choices[0].finish_reason.as_deref(), + Some("stop") + ); + } + + #[tokio::test] + async fn include_thinking_closes_marker_at_finish_when_no_trailing_text() { + // Edge case: stream ends inside a reasoning block (model + // hit max_tokens mid-thought, no visible answer ever). + // The Finish event still triggers the close marker so the + // stream is balanced. + let (tx, rx) = mpsc::channel::(4); + let out_rx = project_chat_stream_with( + rx, + "id".into(), + 1, + "m".into(), + ChatProjectionConfig { + include_thinking: true, + reasoning_markers: Some(pair()), + }, + ); + tx.send(InferenceEvent::ReasoningDelta("thinking...".into())) + .await + .unwrap(); + tx.send(InferenceEvent::Finish { + reason: FinishReason::Length, + }) + .await + .unwrap(); + drop(tx); + let out = collect(out_rx).await; + let contents: Vec<&str> = out + .iter() + .filter_map(|c| c.choices[0].delta["content"].as_str()) + .collect(); + assert_eq!(contents, vec!["", "thinking...", ""]); + assert_eq!( + out.last().unwrap().choices[0].finish_reason.as_deref(), + Some("length") + ); + } + + #[tokio::test] + async fn include_thinking_without_markers_emits_content_directly() { + // Defensive: if the caller asks for thinking but the + // model declared no markers, we still emit the content + // rather than dropping it. Better to leak than to lose. + let (tx, rx) = mpsc::channel::(4); + let out_rx = project_chat_stream_with( + rx, + "id".into(), + 1, + "m".into(), + ChatProjectionConfig { + include_thinking: true, + reasoning_markers: None, + }, + ); + tx.send(InferenceEvent::ReasoningDelta("raw".into())) + .await + .unwrap(); + tx.send(InferenceEvent::Finish { + reason: FinishReason::Stop, + }) + .await + .unwrap(); + drop(tx); + let out = collect(out_rx).await; + let contents: Vec<&str> = out + .iter() + .filter_map(|c| c.choices[0].delta["content"].as_str()) + .collect(); + assert_eq!(contents, vec!["raw"]); + } + + #[tokio::test] + async fn include_thinking_off_drops_reasoning_even_with_markers() { + // Default behaviour even when markers happen to be + // configured. The flag is the gate, not the marker + // presence. + let (tx, rx) = mpsc::channel::(4); + let out_rx = project_chat_stream_with( + rx, + "id".into(), + 1, + "m".into(), + ChatProjectionConfig { + include_thinking: false, + reasoning_markers: Some(pair()), + }, + ); + tx.send(InferenceEvent::ReasoningDelta("hidden".into())) + .await + .unwrap(); + tx.send(InferenceEvent::TextDelta("visible".into())) + .await + .unwrap(); + tx.send(InferenceEvent::Finish { + reason: FinishReason::Stop, + }) + .await + .unwrap(); + drop(tx); + let out = collect(out_rx).await; + let contents: Vec<&str> = out + .iter() + .filter_map(|c| c.choices[0].delta["content"].as_str()) + .collect(); + assert_eq!(contents, vec!["visible"]); + } }