diff --git a/crates/helexa-acp/src/path_util.rs b/crates/helexa-acp/src/path_util.rs index 6cdb12c..8995c40 100644 --- a/crates/helexa-acp/src/path_util.rs +++ b/crates/helexa-acp/src/path_util.rs @@ -26,6 +26,16 @@ use std::path::{Path, PathBuf}; +/// Process-global lock for tests that mutate `HOME`. Anyone in the +/// crate touching `HOME` must hold this for the duration of the +/// read-modify-restore window — otherwise concurrent `cargo test` +/// workers race and flake. +/// +/// Only built into the test binaries. Production code never mutates +/// env vars. +#[cfg(test)] +pub(crate) static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); + /// Expand `~`, `~/`, `$HOME`, and `$HOME/` prefixes against the /// current user's home directory. All other inputs pass through /// unchanged. @@ -56,13 +66,11 @@ mod tests { use super::*; /// Set HOME for the duration of the test. Tests using this run - /// serially under one mutex because env mutation isn't - /// thread-safe — `cargo test` parallel workers would race - /// without it. + /// serially under the crate-wide [`ENV_LOCK`] because env + /// mutation isn't thread-safe — `cargo test` parallel workers + /// would race without it. fn with_home(home: &str, body: F) { - use std::sync::Mutex; - static LOCK: Mutex<()> = Mutex::new(()); - let _g = LOCK.lock().unwrap(); + let _g = ENV_LOCK.lock().unwrap(); let prior = std::env::var("HOME").ok(); // SAFETY: tests touch process-global env. The mutex // serialises access; sub-threads in other test modules @@ -148,10 +156,10 @@ mod tests { #[test] fn no_home_env_passes_through() { - // Lock + clear HOME for this one. - use std::sync::Mutex; - static LOCK: Mutex<()> = Mutex::new(()); - let _g = LOCK.lock().unwrap(); + // Share the same crate-wide lock as `with_home` — otherwise + // a parallel test setting HOME races this clear-and-assert + // window. + let _g = ENV_LOCK.lock().unwrap(); let prior = std::env::var("HOME").ok(); // SAFETY: serialised by LOCK above. unsafe { diff --git a/crates/helexa-acp/src/tool_runner.rs b/crates/helexa-acp/src/tool_runner.rs index eca03e1..e41368b 100644 --- a/crates/helexa-acp/src/tool_runner.rs +++ b/crates/helexa-acp/src/tool_runner.rs @@ -1251,7 +1251,20 @@ mod tests { } #[tokio::test] + // Holds the env lock across an await — the await is the + // tool dispatch, which itself re-reads HOME via plan_dir_for. + // Releasing the lock would let another test mutate HOME + // between this test's setup and the gate's lookup. + #[allow(clippy::await_holding_lock)] async fn plan_mode_allows_write_inside_plan_dir_without_permission() { + // Plan-mode gate calls store::plan_dir_for at runtime + // (which reads HOME). If a parallel test mutates HOME + // mid-flight, the gate's plan_dir would differ from the + // one we computed up here and the path check would fail. + // Share the crate-wide env lock so we and any HOME-mutator + // serialise. + let _g = crate::path_util::ENV_LOCK.lock().unwrap(); + // Skip if we can't resolve a plan dir in this environment // (would happen with no HOME / XDG_DATA_HOME — neither // realistic in CI nor for an interactive run). @@ -1321,11 +1334,9 @@ mod tests { // correct *default*; this is the documented exception. #[allow(clippy::await_holding_lock)] async fn read_file_expands_tilde_before_dispatch() { - // HOME mutation is process-global; serialise tests that - // touch it under a single std::sync::Mutex. - use std::sync::Mutex; - static LOCK: Mutex<()> = Mutex::new(()); - let _g = LOCK.lock().unwrap(); + // HOME mutation is process-global; share the crate-wide + // ENV_LOCK with path_util's tests so workers don't race. + let _g = crate::path_util::ENV_LOCK.lock().unwrap(); let prior = std::env::var("HOME").ok(); unsafe { std::env::set_var("HOME", "/home/me"); diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 35e66e8..357e15a 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -23,9 +23,10 @@ use candle_transformers::models::qwen3_moe as qwen3_moe_dense; use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec}; use cortex_core::openai::{ ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, - ChatMessage, ChunkChoice, MessageContent, Usage, + ChatMessage, MessageContent, Usage, }; -use serde_json::json; + +use crate::wire::{FinishReason, InferenceEvent, openai_chat as wire_chat}; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; @@ -1635,36 +1636,24 @@ impl CandleHarness { let created = unix_now_secs(); // Bounded channel so the producer (blocking inference) is back- - // pressured by the consumer (SSE writer). 32 is generous — - // tokens arrive one at a time and the SSE writer is async. - let (tx, rx) = mpsc::channel::(32); + // pressured by the consumer (SSE writer, via the wire + // projector). 32 is generous — tokens arrive one at a time + // and downstream consumption is async. + let (tx, event_rx) = mpsc::channel::(32); - // Lead chunk: announce the assistant role per OpenAI streaming - // conventions. Tools that auto-detect a streaming reply expect - // this before any content delta. - let role_chunk = ChatCompletionChunk { - id: id.clone(), - object: "chat.completion.chunk".into(), - created, - model: model_id.clone(), - choices: vec![ChunkChoice { - index: 0, - delta: json!({"role": "assistant"}), - finish_reason: None, - extra: serde_json::Value::Object(Default::default()), - }], - usage: None, - extra: serde_json::Value::Object(Default::default()), - }; // Refuse if the model is already poisoned. No point opening - // an SSE stream just to send the role chunk and then bail. + // an SSE stream just to send the Start event and then bail. if loaded.poisoned.load(Ordering::Acquire) { return Err(poisoned_error(&model_id)); } - // If sending the role chunk fails the receiver is already gone; - // bail before kicking off the heavy blocking work. - tx.send(role_chunk) + // Start event: tells the wire projector to emit its + // format-specific "the assistant is about to speak" frame + // (an OpenAI `delta: {role: "assistant"}` chunk here; a + // `response.created` + `response.output_item.added` pair on + // the Responses path). If sending fails the receiver is + // already gone; bail before kicking off the heavy work. + tx.send(InferenceEvent::Start) .await .map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?; @@ -1728,9 +1717,6 @@ impl CandleHarness { top_p, seed, eos_id, - id, - created, - model_id, tx, ) .await @@ -1787,9 +1773,6 @@ impl CandleHarness { top_p, seed, eos_id, - &id, - created, - &model_id, &tx, ) { Ok(()) => tracing::info!( @@ -1824,6 +1807,12 @@ impl CandleHarness { ))); } + // Wrap the InferenceEvent receiver in the OpenAI chat + // projection so the HTTP handler keeps receiving + // ChatCompletionChunks bit-for-bit identical to before. + // The id/created/model_id snapshot taken at request setup + // gets stamped into every emitted chunk. + let rx = wire_chat::project_chat_stream(event_rx, id, created, model_id); Ok(rx) } } @@ -2277,27 +2266,16 @@ impl CandleHarness { let created = unix_now_secs(); let tokenizer = tp.tokenizer.clone(); - // Bounded channel — back-pressures the producer when the SSE - // writer is slow. - let (tx, rx) = mpsc::channel::(32); + // Bounded channel — back-pressures the producer when + // downstream consumption (wire projector → SSE writer) is + // slow. + let (tx, event_rx) = mpsc::channel::(32); - // Role chunk first, before kicking off the heavy work — if the - // receiver is gone by now there's no point starting inference. - let role_chunk = ChatCompletionChunk { - id: id.clone(), - object: "chat.completion.chunk".into(), - created, - model: model_id.clone(), - choices: vec![ChunkChoice { - index: 0, - delta: json!({"role": "assistant"}), - finish_reason: None, - extra: serde_json::Value::Object(Default::default()), - }], - usage: None, - extra: serde_json::Value::Object(Default::default()), - }; - tx.send(role_chunk) + // Start event first, before kicking off the heavy work — if + // the receiver is gone by now there's no point starting + // inference. The wire projector materialises this as the + // OpenAI `delta: {role: "assistant"}` chunk. + tx.send(InferenceEvent::Start) .await .map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?; @@ -2344,7 +2322,7 @@ impl CandleHarness { // UTF-8 mid-codepoint boundaries when BPE byte-fallback // split a multi-byte char across tokens. let mut decode_stream = tokenizer.decode_stream(true); - let mut finish_reason = "length".to_string(); + let mut finish_reason = FinishReason::Length; 'work: { if let Err(e) = pool.clear_kv_cache(&model_id, leader_handle).await { @@ -2412,12 +2390,12 @@ impl CandleHarness { }; if Some(next_token) == eos_id { - finish_reason = "stop".into(); + finish_reason = FinishReason::Stop; } else { all_tokens.push(next_token); match decode_stream.step(next_token) { Ok(Some(delta)) => { - if !emit_delta(&delta, &tx, &id, created, &model_id).await { + if !emit_delta(&delta, &tx).await { // Client gone — treat as normal stream end, // not a failure. No log spam. break 'work; @@ -2489,13 +2467,13 @@ impl CandleHarness { "TP chat_completion (stream): decode step" ); if Some(next_token) == eos_id { - finish_reason = "stop".into(); + finish_reason = FinishReason::Stop; break; } all_tokens.push(next_token); match decode_stream.step(next_token) { Ok(Some(delta)) => { - if !emit_delta(&delta, &tx, &id, created, &model_id).await { + if !emit_delta(&delta, &tx).await { break 'work; } } @@ -2535,37 +2513,32 @@ impl CandleHarness { tracing::info!( prompt_tokens = prompt_len, completion_tokens = all_tokens.len(), - finish_reason = %finish_reason, + finish_reason = finish_reason.as_openai_str(), total_ms = req_start.elapsed().as_millis(), "TP chat_completion (stream): done" ); } - // Final chunk carrying finish_reason — only on the success - // path. On failure we drop the channel so the client sees - // the SSE stream end abruptly (matches pre-change behaviour - // when the failed-path early-returned without final chunk). + // Finish event — only on the success path. On + // failure we drop the channel so the client sees the + // SSE stream end abruptly (matches the pre-refactor + // behaviour when the failed-path early-returned + // without a final chunk). if failure.is_none() { - let final_chunk = ChatCompletionChunk { - id: id.clone(), - object: "chat.completion.chunk".into(), - created, - model: model_id.clone(), - choices: vec![ChunkChoice { - index: 0, - delta: serde_json::Value::Object(Default::default()), - finish_reason: Some(finish_reason), - extra: serde_json::Value::Object(Default::default()), - }], - usage: None, - extra: serde_json::Value::Object(Default::default()), - }; - let _ = tx.send(final_chunk).await; + let _ = tx + .send(InferenceEvent::Finish { + reason: finish_reason, + }) + .await; } } .instrument(span), ); + // Wrap the InferenceEvent receiver in the OpenAI chat + // projection so the HTTP handler keeps consuming + // ChatCompletionChunks unchanged. + let rx = wire_chat::project_chat_stream(event_rx, id, created, model_id); Ok(rx) } } @@ -2793,68 +2766,36 @@ async fn chat_completion_tp_inner( }) } -/// Send `delta` as a `chat.completion.chunk`. Returns `false` if the -/// receiver has hung up — the caller should bail. Empty deltas (the -/// DecodeStream is buffering an incomplete UTF-8 sequence) are a -/// no-op return-true so the caller can treat "no delta yet" and "tx -/// still live" uniformly. +/// Send `delta` as an [`InferenceEvent::TextDelta`]. Returns `false` +/// if the receiver has hung up — the caller should bail. Empty +/// deltas (the DecodeStream is buffering an incomplete UTF-8 +/// sequence) are a no-op return-true so the caller can treat "no +/// delta yet" and "tx still live" uniformly. +/// +/// Wire-format-specific metadata (chunk id, created, model_id) +/// stays out of this function — the wire projector in +/// [`crate::wire::openai_chat`] stamps it onto every chunk +/// downstream. #[cfg(feature = "cuda")] -async fn emit_delta( - delta: &str, - tx: &mpsc::Sender, - id: &str, - created: u64, - model_id: &str, -) -> bool { +async fn emit_delta(delta: &str, tx: &mpsc::Sender) -> bool { if delta.is_empty() { return true; } - let chunk = ChatCompletionChunk { - id: id.into(), - object: "chat.completion.chunk".into(), - created, - model: model_id.into(), - choices: vec![ChunkChoice { - index: 0, - delta: json!({ "content": delta }), - finish_reason: None, - extra: serde_json::Value::Object(Default::default()), - }], - usage: None, - extra: serde_json::Value::Object(Default::default()), - }; - tx.send(chunk).await.is_ok() + tx.send(InferenceEvent::TextDelta(delta.into())) + .await + .is_ok() } /// Sync counterpart of [`emit_delta`] for the CPU path's /// `spawn_blocking` closure. Same shape, `blocking_send` instead of /// `send`. Kept as a separate fn so the async / blocking-send choice /// is local to one place per path. -fn emit_delta_blocking( - delta: &str, - tx: &mpsc::Sender, - id: &str, - created: u64, - model_id: &str, -) -> bool { +fn emit_delta_blocking(delta: &str, tx: &mpsc::Sender) -> bool { if delta.is_empty() { return true; } - let chunk = ChatCompletionChunk { - id: id.into(), - object: "chat.completion.chunk".into(), - created, - model: model_id.into(), - choices: vec![ChunkChoice { - index: 0, - delta: json!({ "content": delta }), - finish_reason: None, - extra: serde_json::Value::Object(Default::default()), - }], - usage: None, - extra: serde_json::Value::Object(Default::default()), - }; - tx.blocking_send(chunk).is_ok() + tx.blocking_send(InferenceEvent::TextDelta(delta.into())) + .is_ok() } /// Errors returned by `CandleHarness::chat_completion`. The @@ -3019,10 +2960,7 @@ async fn stream_inference_via_worker( top_p: Option, seed: u64, eos_id: Option, - id: String, - created: u64, - model_id: String, - tx: mpsc::Sender, + tx: mpsc::Sender, ) -> Result { let mut logits_processor = { let sampling = if temperature <= 0.0 { @@ -3045,7 +2983,7 @@ async fn stream_inference_via_worker( // codepoint; `Ok(None)` while it's buffering an incomplete one. let mut decode_stream = tokenizer.decode_stream(true); let prompt_len = prompt_tokens.len(); - let mut finish_reason = "length".to_string(); + let mut finish_reason = FinishReason::Length; worker .clear_kv_cache(handle) @@ -3071,13 +3009,13 @@ async fn stream_inference_via_worker( }; if Some(next_token) == eos_id { - finish_reason = "stop".into(); + finish_reason = FinishReason::Stop; } else { all_tokens.push(next_token); match decode_stream.step(next_token) { Ok(Some(delta)) => { - if !emit_delta(&delta, &tx, &id, created, &model_id).await { - return Ok(finish_reason); + if !emit_delta(&delta, &tx).await { + return Ok(finish_reason.as_openai_str().to_string()); } } Ok(None) => {} @@ -3103,14 +3041,14 @@ async fn stream_inference_via_worker( } }; if Some(next_token) == eos_id { - finish_reason = "stop".into(); + finish_reason = FinishReason::Stop; break; } all_tokens.push(next_token); match decode_stream.step(next_token) { Ok(Some(delta)) => { - if !emit_delta(&delta, &tx, &id, created, &model_id).await { - return Ok(finish_reason); + if !emit_delta(&delta, &tx).await { + return Ok(finish_reason.as_openai_str().to_string()); } } Ok(None) => {} @@ -3119,25 +3057,16 @@ async fn stream_inference_via_worker( } } - // Final chunk carrying finish_reason. Matches the run_inference_streaming - // shape so the SSE consumer sees an identical termination sequence. - let final_chunk = ChatCompletionChunk { - id: id.clone(), - object: "chat.completion.chunk".into(), - created, - model: model_id.clone(), - choices: vec![ChunkChoice { - index: 0, - delta: serde_json::Value::Object(Default::default()), - finish_reason: Some(finish_reason.clone()), - extra: serde_json::Value::Object(Default::default()), - }], - usage: None, - extra: serde_json::Value::Object(Default::default()), - }; - let _ = tx.send(final_chunk).await; + // Terminal Finish event. The wire projector turns this into a + // format-specific final chunk (`finish_reason: "stop"` on + // OpenAI chat, `response.completed` on Responses). + let _ = tx + .send(InferenceEvent::Finish { + reason: finish_reason, + }) + .await; - Ok(finish_reason) + Ok(finish_reason.as_openai_str().to_string()) } #[allow(clippy::too_many_arguments)] @@ -3204,10 +3133,7 @@ fn run_inference_streaming( top_p: Option, seed: u64, eos_id: Option, - id: &str, - created: u64, - model_id: &str, - tx: &mpsc::Sender, + tx: &mpsc::Sender, ) -> Result<()> { let mut logits_processor = { let sampling = if temperature <= 0.0 { @@ -3227,19 +3153,19 @@ fn run_inference_streaming( // buffers incomplete multi-byte UTF-8 sequences across token // boundaries and only emits when a clean codepoint completes. let mut decode_stream = tokenizer.decode_stream(true); - let mut finish_reason = "length".to_string(); + let mut finish_reason = FinishReason::Length; arch.clear_kv_cache()?; let logits = chunked_prefill_local(arch, device, prompt_tokens)?; let mut next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?; if Some(next_token) == eos_id { - finish_reason = "stop".into(); + finish_reason = FinishReason::Stop; } else { all_tokens.push(next_token); match decode_stream.step(next_token) { Ok(Some(delta)) => { - if !emit_delta_blocking(&delta, tx, id, created, model_id) { + if !emit_delta_blocking(&delta, tx) { return Ok(()); } } @@ -3252,13 +3178,13 @@ fn run_inference_streaming( let logits = arch.forward(&input, prompt_tokens.len() + index)?; next_token = sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?; if Some(next_token) == eos_id { - finish_reason = "stop".into(); + finish_reason = FinishReason::Stop; break; } all_tokens.push(next_token); match decode_stream.step(next_token) { Ok(Some(delta)) => { - if !emit_delta_blocking(&delta, tx, id, created, model_id) { + if !emit_delta_blocking(&delta, tx) { return Ok(()); } } @@ -3268,21 +3194,9 @@ fn run_inference_streaming( } } - let final_chunk = ChatCompletionChunk { - id: id.into(), - object: "chat.completion.chunk".into(), - created, - model: model_id.into(), - choices: vec![ChunkChoice { - index: 0, - delta: serde_json::Value::Object(Default::default()), - finish_reason: Some(finish_reason), - extra: serde_json::Value::Object(Default::default()), - }], - usage: None, - extra: serde_json::Value::Object(Default::default()), - }; - let _ = tx.blocking_send(final_chunk); + let _ = tx.blocking_send(InferenceEvent::Finish { + reason: finish_reason, + }); Ok(()) } diff --git a/crates/neuron/src/lib.rs b/crates/neuron/src/lib.rs index de600ca..5739c94 100644 --- a/crates/neuron/src/lib.rs +++ b/crates/neuron/src/lib.rs @@ -6,3 +6,4 @@ pub mod discovery; pub mod harness; pub mod health; pub mod startup; +pub mod wire; diff --git a/crates/neuron/src/wire/event.rs b/crates/neuron/src/wire/event.rs new file mode 100644 index 0000000..f92f451 --- /dev/null +++ b/crates/neuron/src/wire/event.rs @@ -0,0 +1,99 @@ +//! Format-agnostic inference event stream. +//! +//! The candle harness emits a sequence of these for every streaming +//! request. Wire-format projections in sibling modules +//! ([`super::openai_chat`], the eventual `openai_responses` / +//! `anthropic_messages` projections) read this stream and produce +//! the chunks / events their HTTP clients expect. +//! +//! Design notes: +//! +//! - [`Start`] carries no token of its own. It only signals "the +//! model has accepted the prompt and is about to begin emitting +//! text". OpenAI chat materialises this as a `role: assistant` +//! chunk; OpenAI Responses as the `response.created` + +//! `response.output_item.added` pair; Anthropic as +//! `message_start`. All three of those would otherwise have to +//! peek at the *first* token to know when to emit, which couples +//! the wire layer to the producer's pacing. +//! - [`TextDelta`] is *visible* output. Reasoning / `` +//! blocks go through a future [`ReasoningDelta`] variant once +//! the harness learns to split them (today they pass through as +//! plain text inside `TextDelta`; helexa-acp picks them apart on +//! the consumer side). +//! - [`Finish`] is the only place a stream is allowed to end +//! cleanly. Projections rely on this to emit final usage +//! bookkeeping; absence means the producer crashed and the +//! consumer should treat the stream as truncated. +//! +//! [`Start`]: InferenceEvent::Start +//! [`TextDelta`]: InferenceEvent::TextDelta +//! [`Finish`]: InferenceEvent::Finish + +/// One unit of output from the inference loop. +/// +/// Producers send these on an `mpsc::Sender`; +/// projection layers in sibling modules consume them and emit +/// wire-format-specific frames downstream. +#[derive(Debug, Clone)] +pub enum InferenceEvent { + /// The producer has accepted the prompt and is about to emit + /// the first token. Sent at most once per stream. + Start, + /// A piece of visible assistant text. Multiple deltas + /// concatenate into the complete reply. + TextDelta(String), + /// Reasoning / scratchpad text the model emitted inside a + /// `` block (or equivalent). Producers that don't + /// surface reasoning separately use [`TextDelta`] for + /// everything; future split lives here. + /// + /// Not yet emitted by the candle harness — present so future + /// stages (qwen3 `` routing, OpenAI o-series reasoning) + /// have a typed home without breaking the existing + /// projections. + #[allow(dead_code)] + ReasoningDelta(String), + /// The stream is complete. Carries the reason so wire formats + /// that use it (OpenAI's `finish_reason`, Anthropic's + /// `stop_reason`) can render it without re-parsing. + Finish { reason: FinishReason }, +} + +/// Why a stream stopped. Stays small on purpose — anything that +/// doesn't map cleanly to one of these collapses to [`Stop`]. +/// +/// Mappings to wire formats: +/// +/// | variant | OpenAI `finish_reason` | OpenAI Responses `status` | Anthropic `stop_reason` | +/// |---------|------------------------|---------------------------|-------------------------| +/// | `Stop` | `"stop"` | `"completed"` | `"end_turn"` | +/// | `Length`| `"length"` | `"incomplete"` | `"max_tokens"` | +/// | `ToolCalls` | `"tool_calls"` | `"completed"` | `"tool_use"` | +/// +/// [`Stop`]: FinishReason::Stop +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FinishReason { + /// Model emitted EOS naturally. + Stop, + /// Hit `max_tokens` before EOS. + Length, + /// Stopped because the model called a tool and is waiting for + /// the result. Not yet emitted by the candle harness — + /// reserved for the day tool-call extraction lands. + #[allow(dead_code)] + ToolCalls, +} + +impl FinishReason { + /// String form used by OpenAI chat completions and OpenAI + /// completions. Wire modules can call this directly or do their + /// own mapping for non-string formats. + pub fn as_openai_str(self) -> &'static str { + match self { + FinishReason::Stop => "stop", + FinishReason::Length => "length", + FinishReason::ToolCalls => "tool_calls", + } + } +} diff --git a/crates/neuron/src/wire/mod.rs b/crates/neuron/src/wire/mod.rs new file mode 100644 index 0000000..f91f164 --- /dev/null +++ b/crates/neuron/src/wire/mod.rs @@ -0,0 +1,23 @@ +//! Wire-format projection layer. +//! +//! The candle harness produces a single, format-agnostic stream of +//! [`InferenceEvent`]s. Each wire format (OpenAI chat completions, +//! OpenAI Responses, Anthropic messages, …) lives in its own module +//! under `wire::` and projects that event stream into the chunks / +//! events its HTTP clients expect. +//! +//! The benefit over translating *between* wire shapes (OpenAI chat +//! → Anthropic, etc.) is that we never have to reason about a +//! wire-N → wire-M conversion: every translation is wire-N ↔ the +//! internal event currency, and the projections are independent. A +//! new wire format adds a new file under `wire::`; nothing else +//! needs to know about it. +//! +//! Today: [`openai_chat`]. Stage 2 adds `openai_responses`. Stage 3 +//! could add a native Anthropic projection that replaces the +//! gateway-side translation. + +pub mod event; +pub mod openai_chat; + +pub use event::{FinishReason, InferenceEvent}; diff --git a/crates/neuron/src/wire/openai_chat.rs b/crates/neuron/src/wire/openai_chat.rs new file mode 100644 index 0000000..def289b --- /dev/null +++ b/crates/neuron/src/wire/openai_chat.rs @@ -0,0 +1,241 @@ +//! OpenAI chat completions projection. +//! +//! Reads [`InferenceEvent`]s from a receiver and produces +//! [`ChatCompletionChunk`]s in the shape `POST /v1/chat/completions` +//! clients expect on its streaming SSE response. The HTTP handler in +//! [`crate::api`] wraps the resulting receiver in axum's +//! `Sse::new(...)` adapter; nothing in this module touches HTTP +//! framing or `data:` lines. +//! +//! Per the OpenAI streaming spec, three chunk shapes appear: +//! +//! 1. **Role chunk** — `delta: { "role": "assistant" }`, no content, +//! sent once at stream start. We emit this on [`InferenceEvent::Start`]. +//! 2. **Content chunks** — `delta: { "content": "" }`, one per +//! [`InferenceEvent::TextDelta`]. +//! 3. **Final chunk** — empty `delta`, `finish_reason` populated. +//! Emitted on [`InferenceEvent::Finish`]. +//! +//! `usage` stays `None` on every chunk; the legacy candle paths +//! never surfaced usage on the streaming endpoint and we keep that +//! behaviour bit-for-bit so existing clients see no diff. +//! +//! Back-pressure: the projection task awaits both `rx.recv()` and +//! `tx.send()`. A slow consumer fills the output channel → the +//! task blocks on send → it stops reading from the input → the +//! producer blocks on its own send. The bounded channels +//! propagate without us writing any logic. + +use cortex_core::openai::{ChatCompletionChunk, ChunkChoice}; +use serde_json::json; +use tokio::sync::mpsc; + +use super::event::{FinishReason, InferenceEvent}; + +/// Output channel buffer size. Mirrors the input side's bound; one +/// event maps to at most one chunk, so equal capacity keeps the +/// two ends in sync without surprising memory growth. +const CHUNK_CHANNEL_CAPACITY: usize = 32; + +/// Project an [`InferenceEvent`] receiver into a +/// [`ChatCompletionChunk`] receiver. Spawns one tokio task that +/// owns the input receiver for the stream's lifetime and exits +/// when either side closes. +/// +/// `id`, `created`, and `model_id` are stamped into every emitted +/// chunk so the receiver can stay generic (decoupled from +/// per-request metadata). +pub fn project_chat_stream( + mut rx: mpsc::Receiver, + id: String, + created: u64, + model_id: String, +) -> mpsc::Receiver { + let (tx, out_rx) = mpsc::channel::(CHUNK_CHANNEL_CAPACITY); + + tokio::spawn(async move { + while let Some(event) = rx.recv().await { + let chunks = match event { + InferenceEvent::Start => vec![role_chunk(&id, created, &model_id)], + InferenceEvent::TextDelta(text) => { + if text.is_empty() { + // DecodeStream is buffering a multi-byte + // codepoint; don't bother sending an empty + // chunk downstream. + continue; + } + vec![content_chunk(&id, created, &model_id, &text)] + } + InferenceEvent::ReasoningDelta(_) => { + // Reasoning isn't representable in OpenAI chat + // streaming today. The o-series uses a separate + // `summary` event but it's gated by the + // Responses API; chat-completions just drops it. + continue; + } + InferenceEvent::Finish { reason } => { + vec![final_chunk(&id, created, &model_id, reason)] + } + }; + for chunk in chunks { + if tx.send(chunk).await.is_err() { + // Consumer hung up; nothing more to do. + return; + } + } + } + }); + + out_rx +} + +fn role_chunk(id: &str, created: u64, model_id: &str) -> ChatCompletionChunk { + ChatCompletionChunk { + id: id.into(), + object: "chat.completion.chunk".into(), + created, + model: model_id.into(), + choices: vec![ChunkChoice { + index: 0, + delta: json!({ "role": "assistant" }), + finish_reason: None, + extra: serde_json::Value::Object(Default::default()), + }], + usage: None, + extra: serde_json::Value::Object(Default::default()), + } +} + +fn content_chunk(id: &str, created: u64, model_id: &str, text: &str) -> ChatCompletionChunk { + ChatCompletionChunk { + id: id.into(), + object: "chat.completion.chunk".into(), + created, + model: model_id.into(), + choices: vec![ChunkChoice { + index: 0, + delta: json!({ "content": text }), + finish_reason: None, + extra: serde_json::Value::Object(Default::default()), + }], + usage: None, + extra: serde_json::Value::Object(Default::default()), + } +} + +fn final_chunk( + id: &str, + created: u64, + model_id: &str, + reason: FinishReason, +) -> ChatCompletionChunk { + ChatCompletionChunk { + id: id.into(), + object: "chat.completion.chunk".into(), + created, + model: model_id.into(), + choices: vec![ChunkChoice { + index: 0, + delta: serde_json::Value::Object(Default::default()), + finish_reason: Some(reason.as_openai_str().to_string()), + extra: serde_json::Value::Object(Default::default()), + }], + usage: None, + extra: serde_json::Value::Object(Default::default()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Drain the projection's output into a Vec for assertion. + async fn collect(mut rx: mpsc::Receiver) -> Vec { + let mut out = Vec::new(); + while let Some(chunk) = rx.recv().await { + out.push(chunk); + } + out + } + + #[tokio::test] + async fn empty_event_stream_yields_no_chunks() { + let (tx, rx) = mpsc::channel::(4); + drop(tx); + let out = collect(project_chat_stream(rx, "id-1".into(), 1700, "m".into())).await; + assert!(out.is_empty()); + } + + #[tokio::test] + async fn start_text_finish_produces_three_chunks() { + let (tx, rx) = mpsc::channel::(4); + let out_rx = project_chat_stream(rx, "id-1".into(), 1700, "m".into()); + + tx.send(InferenceEvent::Start).await.unwrap(); + tx.send(InferenceEvent::TextDelta("hello".into())) + .await + .unwrap(); + tx.send(InferenceEvent::Finish { + reason: FinishReason::Stop, + }) + .await + .unwrap(); + drop(tx); + + let out = collect(out_rx).await; + assert_eq!(out.len(), 3); + assert_eq!(out[0].choices[0].delta["role"], "assistant"); + assert_eq!(out[1].choices[0].delta["content"], "hello"); + assert_eq!(out[2].choices[0].finish_reason.as_deref(), Some("stop")); + // Every chunk carries the stamped metadata. + for chunk in &out { + assert_eq!(chunk.id, "id-1"); + assert_eq!(chunk.created, 1700); + assert_eq!(chunk.model, "m"); + assert_eq!(chunk.object, "chat.completion.chunk"); + } + } + + #[tokio::test] + async fn empty_text_delta_is_dropped() { + let (tx, rx) = mpsc::channel::(4); + let out_rx = project_chat_stream(rx, "id".into(), 1, "m".into()); + tx.send(InferenceEvent::TextDelta(String::new())) + .await + .unwrap(); + drop(tx); + let out = collect(out_rx).await; + assert!(out.is_empty(), "empty deltas must not produce chunks"); + } + + #[tokio::test] + async fn finish_length_maps_to_openai_string() { + let (tx, rx) = mpsc::channel::(4); + let out_rx = project_chat_stream(rx, "id".into(), 1, "m".into()); + tx.send(InferenceEvent::Finish { + reason: FinishReason::Length, + }) + .await + .unwrap(); + drop(tx); + let out = collect(out_rx).await; + assert_eq!(out.len(), 1); + assert_eq!(out[0].choices[0].finish_reason.as_deref(), Some("length")); + } + + #[tokio::test] + async fn reasoning_delta_is_dropped_in_chat_projection() { + let (tx, rx) = mpsc::channel::(4); + let out_rx = project_chat_stream(rx, "id".into(), 1, "m".into()); + tx.send(InferenceEvent::ReasoningDelta("".into())) + .await + .unwrap(); + tx.send(InferenceEvent::TextDelta("real".into())) + .await + .unwrap(); + drop(tx); + let out = collect(out_rx).await; + assert_eq!(out.len(), 1); + assert_eq!(out[0].choices[0].delta["content"], "real"); + } +}