diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml index 7487a9c..c2cff2a 100644 --- a/.gitea/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -97,12 +97,14 @@ jobs: # the category of bug where a refactor compiles fine under the # default feature set (which is what the `clippy` and `test` jobs # exercise) but fails inside a `#[cfg(feature = "cuda")]` block. - # Lives on `runs-on: rpm` because that's where nvcc / cudarc's - # build prerequisites are installed; the generic `rust` runner - # doesn't have them. + # `runs-on: cuda-13.0` selects the runner that ships nvcc / + # cudarc's build prerequisites. The generic `rust` and `rpm` + # runners don't have them (the previous label `rpm` was tried + # first and tripped cudarc's `nvcc --version` build script — + # see commit history). cuda-check: name: CUDA type-check - runs-on: rpm + runs-on: cuda-13.0 steps: - uses: actions/checkout@v4 - name: cargo check --features cuda (with retry) diff --git a/crates/cortex-core/src/lib.rs b/crates/cortex-core/src/lib.rs index 5ab8fef..d721fd1 100644 --- a/crates/cortex-core/src/lib.rs +++ b/crates/cortex-core/src/lib.rs @@ -6,4 +6,5 @@ pub mod harness; pub mod metrics; pub mod node; pub mod openai; +pub mod responses; pub mod translate; diff --git a/crates/cortex-core/src/responses.rs b/crates/cortex-core/src/responses.rs new file mode 100644 index 0000000..588c9e8 --- /dev/null +++ b/crates/cortex-core/src/responses.rs @@ -0,0 +1,341 @@ +//! OpenAI Responses API (`POST /v1/responses`) envelope types. +//! +//! This is OpenAI's newer chat surface, distinct from +//! `/v1/chat/completions` in three ways that matter for us: +//! +//! 1. **Input shape**. Instead of a `messages` array, the request +//! carries `input` — either a plain string (single user turn) +//! or an array of typed items (messages, function calls, +//! function-call outputs, reasoning blocks, …). +//! 2. **Output shape**. The response carries a single `output` +//! array of items, each typed. We always emit one +//! `OutputItem::Message` containing the assistant's reply (plus, +//! when we get there, separate `function_call` items). +//! 3. **Streaming events**. Where chat completions stream +//! structurally-identical `chat.completion.chunk` frames over +//! `data:` lines, Responses streams *named* events +//! (`response.created`, `response.output_text.delta`, +//! `response.completed`, …) over `event:` + `data:` SSE pairs. +//! The wire projector in `neuron::wire::openai_responses` builds +//! these from the same [`crate::openai`]-shaped +//! `InferenceEvent` stream the chat projector consumes. +//! +//! Scope cuts for this first cut: +//! +//! - **`previous_response_id` is rejected at parse time**. Stateful +//! chained conversations need a persistence layer we don't have. +//! - **Reasoning items are accepted-and-ignored** (no Qwen3 +//! `` routing yet). Audio and embedded resources are +//! rejected as unsupported. +//! - **Tool calls** (function_call / function_call_output) are +//! carried as round-trip types but the candle harness doesn't +//! emit them yet — wired so the surface is in place for the +//! day we add proper tool-call extraction. + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +// ── Request ────────────────────────────────────────────────────────── + +/// Body of a `POST /v1/responses` request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponsesRequest { + pub model: String, + pub input: ResponsesInput, + /// System-prompt-style instructions. The Responses API + /// separates these from input so a caller doesn't have to + /// build a `system` message item by hand. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub instructions: Option, + #[serde(default)] + pub stream: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub top_p: Option, + /// Chained-conversation identifier. We don't store responses + /// server-side yet; if this is `Some`, the handler returns 400. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub previous_response_id: Option, + /// Catch-all for anything we don't model yet (tools, tool_choice, + /// reasoning, response_format, …). Lets a client send a + /// forward-compatible request without our parser rejecting it. + #[serde(flatten)] + pub extra: Value, +} + +/// `input` is either a single string or an array of typed items. +/// `#[serde(untagged)]` so the wire shape `"input": "hi"` and +/// `"input": [{...}]` both deserialize. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ResponsesInput { + Text(String), + Items(Vec), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ResponsesInputItem { + /// A user / assistant / system turn. + Message { + role: String, + content: ResponsesMessageContent, + }, + /// Assistant emitted a tool call. Round-trip only — neuron + /// doesn't synthesise these yet. + FunctionCall { + call_id: String, + name: String, + arguments: String, + }, + /// User is feeding a tool result back into the model. + FunctionCallOutput { call_id: String, output: String }, + /// Reasoning items emitted by o-series models. Accepted but + /// not forwarded to the model — neuron's candle path doesn't + /// surface reasoning separately yet. + Reasoning { + #[serde(default)] + content: Vec, + }, +} + +/// Inside a `Message` item, content is either a plain string or an +/// array of typed parts. Mirrors the chat-completions Parts shape. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ResponsesMessageContent { + Text(String), + Parts(Vec), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ResponsesContentPart { + /// Plain text inside a user / system turn. + InputText { text: String }, + /// An image. `image_url` is either a remote URL or a + /// `data:image/png;base64,…` URI; the request translator just + /// forwards the string. + InputImage { + image_url: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + detail: Option, + }, + /// Returned text inside an assistant turn — only relevant when + /// the caller is feeding an assistant turn back in to continue + /// a conversation manually (no `previous_response_id`). + OutputText { + text: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + annotations: Vec, + }, +} + +// ── Response (non-streaming) ───────────────────────────────────────── + +/// Body of a `POST /v1/responses` response. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponsesResponse { + pub id: String, + /// Always `"response"`. + pub object: String, + pub created_at: u64, + /// `"completed"`, `"incomplete"`, or — for the initial event of + /// a streaming response — `"in_progress"`. + pub status: String, + pub model: String, + pub output: Vec, + /// Populated on completion; `None` while streaming. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ResponsesOutputItem { + Message { + id: String, + /// Always `"assistant"` for model output. + role: String, + /// Output content parts. We always emit a single + /// `OutputText` today; multi-part output would land here + /// once we have e.g. image generation. + content: Vec, + /// Item-level status. `"in_progress"` while streaming the + /// content parts, `"completed"` when done. + #[serde(default = "default_item_status")] + status: String, + }, + /// Reserved for the day tool-call extraction lands. The wire + /// shape mirrors `ResponsesInputItem::FunctionCall`. + FunctionCall { + id: String, + call_id: String, + name: String, + arguments: String, + #[serde(default = "default_item_status")] + status: String, + }, +} + +fn default_item_status() -> String { + "completed".into() +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ResponsesOutputContent { + OutputText { + text: String, + /// Citations / inline annotations. Empty today; reserved + /// for the day we wire in web search / file search. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + annotations: Vec, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponsesUsage { + pub input_tokens: u64, + pub output_tokens: u64, + pub total_tokens: u64, +} + +// ── Streaming event names ──────────────────────────────────────────── + +/// Event names the SSE projector emits, hoisted as constants so +/// the projector and the wire shape stay in sync without +/// string-typos. The strings are dictated by OpenAI's published +/// Responses API. +pub mod events { + pub const CREATED: &str = "response.created"; + pub const OUTPUT_ITEM_ADDED: &str = "response.output_item.added"; + pub const CONTENT_PART_ADDED: &str = "response.content_part.added"; + pub const OUTPUT_TEXT_DELTA: &str = "response.output_text.delta"; + pub const OUTPUT_TEXT_DONE: &str = "response.output_text.done"; + pub const CONTENT_PART_DONE: &str = "response.content_part.done"; + pub const OUTPUT_ITEM_DONE: &str = "response.output_item.done"; + pub const COMPLETED: &str = "response.completed"; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn deserialises_input_string_form() { + let raw = r#"{"model": "m", "input": "hello"}"#; + let req: ResponsesRequest = serde_json::from_str(raw).unwrap(); + match req.input { + ResponsesInput::Text(s) => assert_eq!(s, "hello"), + other => panic!("expected Text, got {other:?}"), + } + } + + #[test] + fn deserialises_input_items_form() { + let raw = r#"{ + "model": "m", + "input": [ + {"type": "message", "role": "user", "content": "hi"} + ] + }"#; + let req: ResponsesRequest = serde_json::from_str(raw).unwrap(); + match req.input { + ResponsesInput::Items(items) => { + assert_eq!(items.len(), 1); + match &items[0] { + ResponsesInputItem::Message { role, content } => { + assert_eq!(role, "user"); + match content { + ResponsesMessageContent::Text(t) => assert_eq!(t, "hi"), + other => panic!("expected Text content, got {other:?}"), + } + } + other => panic!("expected Message item, got {other:?}"), + } + } + other => panic!("expected Items, got {other:?}"), + } + } + + #[test] + fn deserialises_input_with_image() { + let raw = r#"{ + "model": "m", + "input": [ + {"type": "message", "role": "user", "content": [ + {"type": "input_text", "text": "what is this"}, + {"type": "input_image", "image_url": "data:image/png;base64,AAA="} + ]} + ] + }"#; + let req: ResponsesRequest = serde_json::from_str(raw).unwrap(); + let items = match req.input { + ResponsesInput::Items(i) => i, + other => panic!("expected Items, got {other:?}"), + }; + let parts = match &items[0] { + ResponsesInputItem::Message { + content: ResponsesMessageContent::Parts(p), + .. + } => p, + other => panic!("expected Parts, got {other:?}"), + }; + assert_eq!(parts.len(), 2); + assert!(matches!( + &parts[0], + ResponsesContentPart::InputText { text } if text == "what is this" + )); + assert!(matches!( + &parts[1], + ResponsesContentPart::InputImage { image_url, .. } + if image_url == "data:image/png;base64,AAA=" + )); + } + + #[test] + fn unknown_fields_round_trip_via_extra() { + let raw = r#"{ + "model": "m", + "input": "hi", + "tools": [{"type": "web_search"}], + "reasoning": {"effort": "medium"} + }"#; + let req: ResponsesRequest = serde_json::from_str(raw).unwrap(); + assert!(req.extra.get("tools").is_some()); + assert!(req.extra.get("reasoning").is_some()); + } + + #[test] + fn response_round_trips_through_serde() { + let r = ResponsesResponse { + id: "resp_1".into(), + object: "response".into(), + created_at: 1700, + status: "completed".into(), + model: "m".into(), + output: vec![ResponsesOutputItem::Message { + id: "msg_1".into(), + role: "assistant".into(), + content: vec![ResponsesOutputContent::OutputText { + text: "hi there".into(), + annotations: vec![], + }], + status: "completed".into(), + }], + usage: Some(ResponsesUsage { + input_tokens: 5, + output_tokens: 3, + total_tokens: 8, + }), + }; + let json = serde_json::to_string(&r).unwrap(); + let parsed: ResponsesResponse = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.id, "resp_1"); + assert_eq!(parsed.output.len(), 1); + } +} diff --git a/crates/neuron/src/api.rs b/crates/neuron/src/api.rs index f400563..ec32e62 100644 --- a/crates/neuron/src/api.rs +++ b/crates/neuron/src/api.rs @@ -4,6 +4,7 @@ use crate::activation::ActivationTracker; use crate::harness::HarnessRegistry; use crate::harness::candle::{CandleHarness, InferenceError}; use crate::health::HealthCache; +use crate::wire::openai_responses; use axum::Router; use axum::extract::{Path, State}; use axum::http::StatusCode; @@ -12,11 +13,13 @@ use axum::response::{IntoResponse, Json}; use axum::routing::{get, post}; use cortex_core::discovery::{DiscoveryResponse, HealthResponse}; use cortex_core::harness::ModelSpec; -use cortex_core::openai::ChatCompletionRequest; +use cortex_core::openai::{ChatCompletionRequest, MessageContent}; +use cortex_core::responses::{ResponsesRequest, ResponsesUsage}; use futures::stream::{self, StreamExt}; use serde_json::{Value, json}; use std::convert::Infallible; use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; use tokio_stream::wrappers::ReceiverStream; @@ -44,6 +47,7 @@ pub fn neuron_routes() -> Router> { .route("/models/unload", post(unload_model)) .route("/models/{model_id}/endpoint", get(model_endpoint)) .route("/v1/chat/completions", post(chat_completions)) + .route("/v1/responses", post(responses)) } async fn discovery_handler(State(state): State>) -> Json { @@ -246,3 +250,187 @@ async fn chat_completions( } } } + +/// OpenAI Responses API (`POST /v1/responses`). Translates the +/// Responses-shaped request into a chat-completions one the candle +/// harness already understands, then re-projects the harness's +/// event stream into the Responses event family. +async fn responses( + State(state): State>, + Json(req): Json, +) -> impl IntoResponse { + let Some(candle) = state.candle.as_ref().map(Arc::clone) else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({"error": "candle harness not enabled on this neuron"})), + ) + .into_response(); + }; + + let stream_requested = req.stream; + let model_id = req.model.clone(); + let response_id = mint_response_id(); + let message_item_id = mint_message_item_id(); + + // Translate Responses → chat completions. The only failure + // mode today is `previous_response_id` set, which we reject + // with 400 — stateful conversations need a persistence layer + // we haven't built. + let mut chat_req = match openai_responses::request_to_chat(req) { + Ok(r) => r, + Err(openai_responses::TranslateError::ChainedConversationNotSupported) => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "previous_response_id is not supported on this neuron", + "code": "chained_conversation_not_supported" + })), + ) + .into_response(); + } + }; + chat_req.stream = Some(stream_requested); + + if stream_requested { + match candle + .responses_stream(chat_req, response_id, message_item_id) + .await + { + Ok(rx) => { + // Each ResponseStreamFrame → one SSE event carrying + // both an event name and JSON data. The Responses + // API doesn't use a `[DONE]` terminator — clients + // see the `response.completed` event as the end of + // the stream. + let body_stream = ReceiverStream::new(rx).map(|frame| { + let body = serde_json::to_string(&frame.data).unwrap_or_else(|_| "{}".into()); + Ok::<_, Infallible>(Event::default().event(frame.event_name).data(body)) + }); + Sse::new(body_stream) + .keep_alive(KeepAlive::default()) + .into_response() + } + Err(e) => inference_error_response(e), + } + } else { + // Non-streaming: drive the existing chat completion path + // and translate the result. We don't currently re-tokenise + // to compute usage; the harness returns it via the chat + // response and we pass it through. + match candle.chat_completion(chat_req).await { + Ok(chat_resp) => { + // Extract the assistant text (chat completions + // always emits one choice on the candle path). + let text = chat_resp + .choices + .first() + .map(|c| match &c.message.content { + MessageContent::Text(t) => t.clone(), + MessageContent::Parts(_) => { + // Candle output is always text today; + // a Parts response would be surprising. + // Empty-string fallback is safer than + // a panic. + String::new() + } + }) + .unwrap_or_default(); + let finish = chat_resp + .choices + .first() + .and_then(|c| c.finish_reason.as_deref()) + .map(finish_reason_from_str) + .unwrap_or(crate::wire::FinishReason::Stop); + let usage = chat_resp.usage.as_ref().map(|u| ResponsesUsage { + input_tokens: u.prompt_tokens, + output_tokens: u.completion_tokens, + total_tokens: u.prompt_tokens + u.completion_tokens, + }); + let meta = openai_responses::ResponseMeta { + response_id: mint_response_id(), + created_at: unix_now_secs(), + model_id, + message_item_id: mint_message_item_id(), + }; + let _ = chat_resp; // make the borrow-checker happy if `text` consumed it + let resp = openai_responses::build_response(&meta, text, finish, usage); + Json(resp).into_response() + } + Err(e) => inference_error_response(e), + } + } +} + +fn finish_reason_from_str(s: &str) -> crate::wire::FinishReason { + use crate::wire::FinishReason; + match s { + "length" => FinishReason::Length, + "tool_calls" => FinishReason::ToolCalls, + _ => FinishReason::Stop, + } +} + +/// Centralised mapping from [`InferenceError`] to an HTTP response. +/// Lifted out so the chat-completions and responses handlers stay +/// readable and changes to error-code semantics happen in one spot. +fn inference_error_response(err: InferenceError) -> axum::response::Response { + match err { + InferenceError::ModelNotLoaded(id) => ( + StatusCode::NOT_FOUND, + Json(json!({"error": format!("model '{id}' not loaded on this neuron")})), + ) + .into_response(), + InferenceError::PromptTooLong { prompt_len, max } => ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": format!("prompt has {prompt_len} tokens but max is {max}"), + "code": "prompt_too_long", + "prompt_len": prompt_len, + "max": max, + })), + ) + .into_response(), + InferenceError::InsufficientVram { + free_mb, + required_mb, + } => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ + "error": format!( + "insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB" + ), + "code": "insufficient_vram", + "free_mb": free_mb, + "required_mb": required_mb, + })), + ) + .into_response(), + InferenceError::Other(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": format!("{e:#}")})), + ) + .into_response(), + } +} + +fn mint_response_id() -> String { + format!("resp_{:x}", unix_subsec_nanos()) +} + +fn mint_message_item_id() -> String { + format!("msg_{:x}", unix_subsec_nanos()) +} + +fn unix_now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0) +} + +fn unix_subsec_nanos() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_nanos() as u64) + .unwrap_or(0) +} diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 8cd095d..30f4a71 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -1595,6 +1595,49 @@ impl CandleHarness { &self, request: ChatCompletionRequest, ) -> Result, InferenceError> { + let stream = self.inference_stream(request).await?; + Ok(wire_chat::project_chat_stream( + stream.events, + stream.id, + stream.created, + stream.model_id, + )) + } + + /// Streaming OpenAI Responses API entry point. Same harness + /// output as [`Self::chat_completion_stream`], projected into + /// the named-event SSE frames the Responses API client wants. + /// `response_id` and `message_item_id` are stamped into every + /// frame so the consumer can correlate. + pub async fn responses_stream( + &self, + request: ChatCompletionRequest, + response_id: String, + message_item_id: String, + ) -> Result, InferenceError> + { + let stream = self.inference_stream(request).await?; + let meta = crate::wire::openai_responses::ResponseMeta { + response_id, + created_at: stream.created, + model_id: stream.model_id, + message_item_id, + }; + Ok(crate::wire::openai_responses::project_responses_stream( + stream.events, + meta, + )) + } + + /// Format-agnostic streaming inference. Returns the raw + /// [`InferenceEvent`] receiver plus the per-request metadata + /// wire projectors stamp onto their frames. Lets every wire + /// format land on the same harness output without duplicating + /// setup / dispatch / spawn logic. + async fn inference_stream( + &self, + request: ChatCompletionRequest, + ) -> Result { let handle = { let models = self.models.read().await; models.get(&request.model).cloned() @@ -1608,7 +1651,7 @@ impl CandleHarness { LoadedHandle::Single(m) => m, #[cfg(feature = "cuda")] LoadedHandle::Tp(m) => { - return self.chat_completion_tp_stream(m, request).await; + return self.inference_tp_stream(m, request).await; } }; @@ -1807,16 +1850,39 @@ impl CandleHarness { ))); } - // Wrap the InferenceEvent receiver in the OpenAI chat - // projection so the HTTP handler keeps receiving - // ChatCompletionChunks bit-for-bit identical to before. - // The id/created/model_id snapshot taken at request setup - // gets stamped into every emitted chunk. - let rx = wire_chat::project_chat_stream(event_rx, id, created, model_id); - Ok(rx) + // Hand the raw event channel back to the public entry + // points (chat_completion_stream / responses_stream); they + // pick the wire projection. + Ok(InferenceStream { + events: event_rx, + id, + created, + model_id, + }) } } +/// The seam between inference (one shape, always) and wire formats +/// (many shapes, projector-per-format). Public so the format +/// projectors live outside the harness and the harness's +/// streaming-inference internals stay encapsulated. +pub struct InferenceStream { + /// Stream of model-output events. Producers (the various + /// inference loops) emit on this; consumers (wire projectors) + /// read from it. + pub events: mpsc::Receiver, + /// Request id stamped into every wire-format frame + /// (`chatcmpl-…` for chat completions; the Responses path + /// makes its own `resp_…` id separately and ignores this one). + pub id: String, + /// Unix seconds when inference began. Same field threads into + /// every wire format's `created` / `created_at` slot. + pub created: u64, + /// Local model id (no endpoint prefix). Stamped into every + /// wire-format frame so consumers can correlate. + pub model_id: String, +} + #[async_trait] impl Harness for CandleHarness { fn name(&self) -> &str { @@ -2234,11 +2300,11 @@ impl CandleHarness { /// So we `tokio::spawn` the orchestration task and use plain /// `Sender::send`. #[cfg(feature = "cuda")] - async fn chat_completion_tp_stream( + async fn inference_tp_stream( &self, tp: Arc, request: ChatCompletionRequest, - ) -> Result, InferenceError> { + ) -> Result { if tp.poisoned.load(Ordering::Acquire) { return Err(poisoned_error(&request.model)); } @@ -2542,14 +2608,16 @@ impl CandleHarness { .instrument(span), ); - // Wrap the InferenceEvent receiver in the OpenAI chat - // projection so the HTTP handler keeps consuming - // ChatCompletionChunks unchanged. Uses the clones we - // stashed before the spawn — the originals were moved + // Hand the raw event channel back to the public entry + // points; they pick the wire projection. Uses the clones + // we stashed before the spawn — the originals were moved // into the orchestration task above. - let rx = - wire_chat::project_chat_stream(event_rx, projector_id, created, projector_model_id); - Ok(rx) + Ok(InferenceStream { + events: event_rx, + id: projector_id, + created, + model_id: projector_model_id, + }) } } diff --git a/crates/neuron/src/wire/mod.rs b/crates/neuron/src/wire/mod.rs index f91f164..f704279 100644 --- a/crates/neuron/src/wire/mod.rs +++ b/crates/neuron/src/wire/mod.rs @@ -19,5 +19,6 @@ pub mod event; pub mod openai_chat; +pub mod openai_responses; pub use event::{FinishReason, InferenceEvent}; diff --git a/crates/neuron/src/wire/openai_responses.rs b/crates/neuron/src/wire/openai_responses.rs new file mode 100644 index 0000000..be6b54d --- /dev/null +++ b/crates/neuron/src/wire/openai_responses.rs @@ -0,0 +1,847 @@ +//! OpenAI Responses API projection. +//! +//! Two responsibilities: +//! +//! 1. **Translate request shape**: [`request_to_chat`] flattens +//! [`ResponsesRequest`]'s typed `input` items + `instructions` +//! into the [`ChatCompletionRequest`] the candle harness already +//! knows how to run. The Responses-specific shape stops at this +//! function — everything downstream is the same chat path the +//! `/v1/chat/completions` route exercises. +//! +//! 2. **Project event stream**: [`project_responses_stream`] reads +//! [`InferenceEvent`]s from the harness and emits the named SSE +//! events the Responses API client expects +//! (`response.created`, `response.output_text.delta`, +//! `response.completed`, …) along with their JSON payloads. +//! The HTTP handler in [`crate::api`] reads +//! `(event_name, data)` tuples off the receiver and stamps them +//! onto axum SSE frames. +//! +//! Scope cuts (carried over from [`cortex_core::responses`]): +//! +//! - `previous_response_id` is rejected by [`request_to_chat`] +//! with [`TranslateError::ChainedConversationNotSupported`]. +//! - `Reasoning` input items are dropped (no equivalent in chat). +//! - `FunctionCall` / `FunctionCallOutput` items round-trip but the +//! harness never emits tool calls today; the synthesis paths are +//! in place so the surface is ready when it does. + +use cortex_core::openai::{ChatCompletionRequest, ChatMessage, MessageContent}; +use cortex_core::responses::{ + ResponsesContentPart, ResponsesInput, ResponsesInputItem, ResponsesMessageContent, + ResponsesOutputContent, ResponsesOutputItem, ResponsesRequest, ResponsesResponse, + ResponsesUsage, events, +}; +use serde_json::{Value, json}; +use tokio::sync::mpsc; + +use super::event::{FinishReason, InferenceEvent}; + +/// Per-request metadata that has to be stamped into every emitted +/// event. The projector spawns a task that owns one of these. +#[derive(Debug, Clone)] +pub struct ResponseMeta { + pub response_id: String, + pub created_at: u64, + pub model_id: String, + /// Item id used inside `output[0]` (the message). All + /// `content_part.*` and `output_text.*` events reference this + /// so the consumer knows which item the delta belongs to. + pub message_item_id: String, +} + +/// Reasons [`request_to_chat`] refuses a request. +#[derive(Debug, thiserror::Error)] +pub enum TranslateError { + #[error( + "previous_response_id is not supported on this neuron; chained \ + conversations require server-side state we don't store yet" + )] + ChainedConversationNotSupported, +} + +/// Flatten a [`ResponsesRequest`] into the chat-completions shape +/// the candle harness already knows how to drive. Keeps the +/// Responses-specific machinery contained to a single function so +/// the harness stays format-agnostic. +/// +/// Semantics: +/// +/// - `instructions` (if set) becomes a leading `system` message. +/// - `input: ""` becomes a single `user` message. +/// - `input: [items]` flattens each item: +/// - `Message { role, content }` → one `ChatMessage`. +/// - `FunctionCall` → an `assistant` turn whose `extra.tool_calls` +/// carries the call (chat-completions-shaped). The harness +/// doesn't act on tool_calls today, but the shape stays +/// consistent with what chat would expect. +/// - `FunctionCallOutput` → a `tool` role message with the +/// output text. Matches OpenAI's chat convention. +/// - `Reasoning` items are dropped (no equivalent in chat). +/// - Text parts within an array `content` collapse to a single +/// string; image parts get rendered as a chat-style content +/// array `[{type:"text"}, {type:"image_url"}]` so the chat +/// handler's existing vision path applies. +pub fn request_to_chat(req: ResponsesRequest) -> Result { + if req.previous_response_id.is_some() { + return Err(TranslateError::ChainedConversationNotSupported); + } + + let mut messages: Vec = Vec::new(); + + if let Some(instructions) = req.instructions + && !instructions.is_empty() + { + messages.push(ChatMessage { + role: "system".into(), + content: MessageContent::Text(instructions), + extra: Value::Object(Default::default()), + }); + } + + match req.input { + ResponsesInput::Text(text) => { + messages.push(ChatMessage { + role: "user".into(), + content: MessageContent::Text(text), + extra: Value::Object(Default::default()), + }); + } + ResponsesInput::Items(items) => { + for item in items { + if let Some(msg) = input_item_to_chat(item) { + messages.push(msg); + } + } + } + } + + Ok(ChatCompletionRequest { + model: req.model, + messages, + temperature: req.temperature, + top_p: req.top_p, + max_tokens: req.max_output_tokens, + stream: Some(req.stream), + extra: Value::Object(Default::default()), + }) +} + +fn input_item_to_chat(item: ResponsesInputItem) -> Option { + match item { + ResponsesInputItem::Message { role, content } => Some(ChatMessage { + role, + content: message_content_to_chat(content), + extra: Value::Object(Default::default()), + }), + ResponsesInputItem::FunctionCall { + call_id, + name, + arguments, + } => { + // Express the call in chat-completions shape via + // `extra.tool_calls`. The harness ignores it today but + // the shape is consistent for the day it doesn't. + let mut extra = serde_json::Map::new(); + extra.insert( + "tool_calls".into(), + json!([{ + "id": call_id, + "type": "function", + "function": { "name": name, "arguments": arguments }, + }]), + ); + Some(ChatMessage { + role: "assistant".into(), + content: MessageContent::Text(String::new()), + extra: Value::Object(extra), + }) + } + ResponsesInputItem::FunctionCallOutput { call_id, output } => { + let mut extra = serde_json::Map::new(); + extra.insert("tool_call_id".into(), Value::String(call_id)); + Some(ChatMessage { + role: "tool".into(), + content: MessageContent::Text(output), + extra: Value::Object(extra), + }) + } + // Reasoning items don't have a chat-completions equivalent + // we can faithfully forward. Silently drop — the alternative + // is rejecting a well-formed request, which is worse UX. + ResponsesInputItem::Reasoning { .. } => None, + } +} + +fn message_content_to_chat(content: ResponsesMessageContent) -> MessageContent { + match content { + ResponsesMessageContent::Text(s) => MessageContent::Text(s), + ResponsesMessageContent::Parts(parts) => { + // Collapse to a string when every part is text; emit + // the chat content-array shape only when an image is + // present (some upstreams treat the array form as a + // vision-only signal and reject it for text-only + // models). + let has_image = parts + .iter() + .any(|p| matches!(p, ResponsesContentPart::InputImage { .. })); + if !has_image { + let joined = parts + .into_iter() + .filter_map(|p| match p { + ResponsesContentPart::InputText { text } + | ResponsesContentPart::OutputText { text, .. } => Some(text), + ResponsesContentPart::InputImage { .. } => None, + }) + .collect::>() + .join("\n\n"); + return MessageContent::Text(joined); + } + let mut out: Vec = Vec::with_capacity(parts.len()); + for p in parts { + match p { + ResponsesContentPart::InputText { text } + | ResponsesContentPart::OutputText { text, .. } => { + out.push(json!({ "type": "text", "text": text })); + } + ResponsesContentPart::InputImage { image_url, .. } => { + out.push(json!({ + "type": "image_url", + "image_url": { "url": image_url }, + })); + } + } + } + MessageContent::Parts(out) + } + } +} + +// ── Streaming projection ───────────────────────────────────────────── + +/// One frame the projector emits. The HTTP handler maps each into +/// an axum `Sse::Event` with both an `event:` name and a `data:` +/// JSON payload — Responses, unlike chat completions, uses named +/// SSE events. +#[derive(Debug, Clone)] +pub struct ResponseStreamFrame { + pub event_name: &'static str, + pub data: Value, +} + +/// Project an [`InferenceEvent`] receiver into a stream of +/// [`ResponseStreamFrame`]s. The emitted sequence per stream is: +/// +/// 1. `response.created` — shell with `status: "in_progress"`. +/// 2. `response.output_item.added` — empty message item. +/// 3. `response.content_part.added` — empty `output_text` part. +/// 4. `response.output_text.delta` × N — token-by-token text. +/// 5. `response.output_text.done` — full accumulated text. +/// 6. `response.content_part.done` — full part payload. +/// 7. `response.output_item.done` — full message item. +/// 8. `response.completed` — final response with `status:"completed"`. +/// +/// Empty TextDeltas (the harness's incomplete-UTF-8 buffering) are +/// dropped. `ReasoningDelta`s have no representation in the +/// Responses API spec we model yet, so they're dropped too. +pub fn project_responses_stream( + rx: mpsc::Receiver, + meta: ResponseMeta, +) -> mpsc::Receiver { + let (tx, out_rx) = mpsc::channel::(64); + tokio::spawn(async move { + run_projection(rx, meta, tx).await; + }); + out_rx +} + +async fn run_projection( + mut rx: mpsc::Receiver, + meta: ResponseMeta, + tx: mpsc::Sender, +) { + let mut accumulated = String::new(); + let mut finish: Option = None; + let mut emitted_start = false; + + while let Some(event) = rx.recv().await { + match event { + InferenceEvent::Start => { + emitted_start = true; + if !emit_start_frames(&tx, &meta).await { + return; + } + } + InferenceEvent::TextDelta(text) => { + if text.is_empty() { + continue; + } + accumulated.push_str(&text); + let frame = ResponseStreamFrame { + event_name: events::OUTPUT_TEXT_DELTA, + data: json!({ + "item_id": meta.message_item_id, + "output_index": 0, + "content_index": 0, + "delta": text, + }), + }; + if tx.send(frame).await.is_err() { + return; + } + } + InferenceEvent::ReasoningDelta(_) => { + // No representation in our Responses model yet. + // Stage where it'd land: a `response.reasoning_*` + // event family alongside `response.output_text.*`. + } + InferenceEvent::Finish { reason } => { + finish = Some(reason); + } + } + } + + // Producers can drop without ever sending Start (e.g. early + // poisoned-model error). Synthesize the open frames so the + // consumer at least sees a coherent shell before completed. + if !emitted_start && !emit_start_frames(&tx, &meta).await { + return; + } + + let reason = finish.unwrap_or(FinishReason::Stop); + let _ = emit_finish_frames(&tx, &meta, &accumulated, reason).await; +} + +async fn emit_start_frames(tx: &mpsc::Sender, meta: &ResponseMeta) -> bool { + let shell = response_shell(meta, "in_progress", &[], None); + let frames = [ + ResponseStreamFrame { + event_name: events::CREATED, + data: json!({ "response": shell }), + }, + ResponseStreamFrame { + event_name: events::OUTPUT_ITEM_ADDED, + data: json!({ + "output_index": 0, + "item": empty_message_item(&meta.message_item_id), + }), + }, + ResponseStreamFrame { + event_name: events::CONTENT_PART_ADDED, + data: json!({ + "item_id": meta.message_item_id, + "output_index": 0, + "content_index": 0, + "part": { "type": "output_text", "text": "", "annotations": [] }, + }), + }, + ]; + for frame in frames { + if tx.send(frame).await.is_err() { + return false; + } + } + true +} + +async fn emit_finish_frames( + tx: &mpsc::Sender, + meta: &ResponseMeta, + full_text: &str, + reason: FinishReason, +) -> bool { + let status = finish_to_status(reason); + let full_part = json!({ + "type": "output_text", + "text": full_text, + "annotations": [], + }); + let full_item = json!({ + "type": "message", + "id": meta.message_item_id, + "role": "assistant", + "content": [full_part.clone()], + "status": status, + }); + let frames = [ + ResponseStreamFrame { + event_name: events::OUTPUT_TEXT_DONE, + data: json!({ + "item_id": meta.message_item_id, + "output_index": 0, + "content_index": 0, + "text": full_text, + }), + }, + ResponseStreamFrame { + event_name: events::CONTENT_PART_DONE, + data: json!({ + "item_id": meta.message_item_id, + "output_index": 0, + "content_index": 0, + "part": full_part, + }), + }, + ResponseStreamFrame { + event_name: events::OUTPUT_ITEM_DONE, + data: json!({ + "output_index": 0, + "item": full_item.clone(), + }), + }, + ResponseStreamFrame { + event_name: events::COMPLETED, + data: json!({ + "response": response_shell(meta, status, &[full_item], None) + }), + }, + ]; + for frame in frames { + if tx.send(frame).await.is_err() { + return false; + } + } + true +} + +fn response_shell( + meta: &ResponseMeta, + status: &str, + output: &[Value], + usage: Option<&ResponsesUsage>, +) -> Value { + let mut obj = serde_json::Map::new(); + obj.insert("id".into(), Value::String(meta.response_id.clone())); + obj.insert("object".into(), Value::String("response".into())); + obj.insert("created_at".into(), json!(meta.created_at)); + obj.insert("status".into(), Value::String(status.into())); + obj.insert("model".into(), Value::String(meta.model_id.clone())); + obj.insert("output".into(), Value::Array(output.to_vec())); + if let Some(u) = usage { + obj.insert( + "usage".into(), + json!({ + "input_tokens": u.input_tokens, + "output_tokens": u.output_tokens, + "total_tokens": u.total_tokens, + }), + ); + } + Value::Object(obj) +} + +fn empty_message_item(item_id: &str) -> Value { + json!({ + "type": "message", + "id": item_id, + "role": "assistant", + "content": [], + "status": "in_progress", + }) +} + +fn finish_to_status(reason: FinishReason) -> &'static str { + match reason { + FinishReason::Stop | FinishReason::ToolCalls => "completed", + FinishReason::Length => "incomplete", + } +} + +// ── Non-streaming helpers ──────────────────────────────────────────── + +/// Collect a chat-completions response into a non-streaming +/// [`ResponsesResponse`]. Used by the `/v1/responses` handler when +/// the request doesn't set `stream: true`. +pub fn build_response( + meta: &ResponseMeta, + full_text: String, + reason: FinishReason, + usage: Option, +) -> ResponsesResponse { + let status = finish_to_status(reason).to_string(); + ResponsesResponse { + id: meta.response_id.clone(), + object: "response".into(), + created_at: meta.created_at, + status: status.clone(), + model: meta.model_id.clone(), + output: vec![ResponsesOutputItem::Message { + id: meta.message_item_id.clone(), + role: "assistant".into(), + content: vec![ResponsesOutputContent::OutputText { + text: full_text, + annotations: vec![], + }], + status, + }], + usage, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use cortex_core::openai::MessageContent; + + fn meta() -> ResponseMeta { + ResponseMeta { + response_id: "resp_1".into(), + created_at: 1700, + model_id: "m".into(), + message_item_id: "msg_1".into(), + } + } + + // ── request translator ────────────────────────────────────────── + + #[test] + fn translates_text_input_to_single_user_message() { + let req = ResponsesRequest { + model: "m".into(), + input: ResponsesInput::Text("hi".into()), + instructions: None, + stream: false, + max_output_tokens: None, + temperature: None, + top_p: None, + previous_response_id: None, + extra: Value::Object(Default::default()), + }; + let chat = request_to_chat(req).unwrap(); + assert_eq!(chat.messages.len(), 1); + assert_eq!(chat.messages[0].role, "user"); + assert!(matches!( + &chat.messages[0].content, + MessageContent::Text(t) if t == "hi" + )); + } + + #[test] + fn instructions_become_leading_system_message() { + let req = ResponsesRequest { + model: "m".into(), + input: ResponsesInput::Text("hi".into()), + instructions: Some("you are helpful".into()), + stream: false, + max_output_tokens: None, + temperature: None, + top_p: None, + previous_response_id: None, + extra: Value::Object(Default::default()), + }; + let chat = request_to_chat(req).unwrap(); + assert_eq!(chat.messages.len(), 2); + assert_eq!(chat.messages[0].role, "system"); + assert!(matches!( + &chat.messages[0].content, + MessageContent::Text(t) if t == "you are helpful" + )); + assert_eq!(chat.messages[1].role, "user"); + } + + #[test] + fn rejects_previous_response_id() { + let req = ResponsesRequest { + model: "m".into(), + input: ResponsesInput::Text("hi".into()), + instructions: None, + stream: false, + max_output_tokens: None, + temperature: None, + top_p: None, + previous_response_id: Some("resp_prev".into()), + extra: Value::Object(Default::default()), + }; + assert!(matches!( + request_to_chat(req), + Err(TranslateError::ChainedConversationNotSupported) + )); + } + + #[test] + fn translates_input_items_to_chat_messages() { + let req = ResponsesRequest { + model: "m".into(), + input: ResponsesInput::Items(vec![ + ResponsesInputItem::Message { + role: "user".into(), + content: ResponsesMessageContent::Text("first".into()), + }, + ResponsesInputItem::Message { + role: "assistant".into(), + content: ResponsesMessageContent::Text("reply".into()), + }, + ResponsesInputItem::Message { + role: "user".into(), + content: ResponsesMessageContent::Text("second".into()), + }, + ]), + instructions: None, + stream: false, + max_output_tokens: None, + temperature: None, + top_p: None, + previous_response_id: None, + extra: Value::Object(Default::default()), + }; + let chat = request_to_chat(req).unwrap(); + assert_eq!(chat.messages.len(), 3); + let roles: Vec<&str> = chat.messages.iter().map(|m| m.role.as_str()).collect(); + assert_eq!(roles, vec!["user", "assistant", "user"]); + } + + #[test] + fn image_input_translates_to_chat_parts_array() { + let req = ResponsesRequest { + model: "m".into(), + input: ResponsesInput::Items(vec![ResponsesInputItem::Message { + role: "user".into(), + content: ResponsesMessageContent::Parts(vec![ + ResponsesContentPart::InputText { + text: "what is this?".into(), + }, + ResponsesContentPart::InputImage { + image_url: "data:image/png;base64,AAA=".into(), + detail: None, + }, + ]), + }]), + instructions: None, + stream: false, + max_output_tokens: None, + temperature: None, + top_p: None, + previous_response_id: None, + extra: Value::Object(Default::default()), + }; + let chat = request_to_chat(req).unwrap(); + let parts = match &chat.messages[0].content { + MessageContent::Parts(p) => p.clone(), + other => panic!("expected Parts, got {other:?}"), + }; + assert_eq!(parts.len(), 2); + assert_eq!(parts[0]["type"], "text"); + assert_eq!(parts[1]["type"], "image_url"); + assert_eq!(parts[1]["image_url"]["url"], "data:image/png;base64,AAA="); + } + + #[test] + fn text_only_parts_collapse_to_string() { + let req = ResponsesRequest { + model: "m".into(), + input: ResponsesInput::Items(vec![ResponsesInputItem::Message { + role: "user".into(), + content: ResponsesMessageContent::Parts(vec![ + ResponsesContentPart::InputText { + text: "first".into(), + }, + ResponsesContentPart::InputText { + text: "second".into(), + }, + ]), + }]), + instructions: None, + stream: false, + max_output_tokens: None, + temperature: None, + top_p: None, + previous_response_id: None, + extra: Value::Object(Default::default()), + }; + let chat = request_to_chat(req).unwrap(); + assert!(matches!( + &chat.messages[0].content, + MessageContent::Text(t) if t == "first\n\nsecond" + )); + } + + #[test] + fn reasoning_items_are_silently_dropped() { + let req = ResponsesRequest { + model: "m".into(), + input: ResponsesInput::Items(vec![ + ResponsesInputItem::Reasoning { content: vec![] }, + ResponsesInputItem::Message { + role: "user".into(), + content: ResponsesMessageContent::Text("hi".into()), + }, + ]), + instructions: None, + stream: false, + max_output_tokens: None, + temperature: None, + top_p: None, + previous_response_id: None, + extra: Value::Object(Default::default()), + }; + let chat = request_to_chat(req).unwrap(); + assert_eq!(chat.messages.len(), 1); + assert_eq!(chat.messages[0].role, "user"); + } + + // ── streaming projector ───────────────────────────────────────── + + async fn collect(mut rx: mpsc::Receiver) -> Vec { + let mut out = Vec::new(); + while let Some(f) = rx.recv().await { + out.push(f); + } + out + } + + #[tokio::test] + async fn full_stream_emits_expected_event_sequence() { + let (tx, rx) = mpsc::channel::(8); + let out = project_responses_stream(rx, meta()); + + tx.send(InferenceEvent::Start).await.unwrap(); + tx.send(InferenceEvent::TextDelta("hel".into())) + .await + .unwrap(); + tx.send(InferenceEvent::TextDelta("lo".into())) + .await + .unwrap(); + tx.send(InferenceEvent::Finish { + reason: FinishReason::Stop, + }) + .await + .unwrap(); + drop(tx); + + let frames = collect(out).await; + let names: Vec<&str> = frames.iter().map(|f| f.event_name).collect(); + assert_eq!( + names, + vec![ + events::CREATED, + events::OUTPUT_ITEM_ADDED, + events::CONTENT_PART_ADDED, + events::OUTPUT_TEXT_DELTA, + events::OUTPUT_TEXT_DELTA, + events::OUTPUT_TEXT_DONE, + events::CONTENT_PART_DONE, + events::OUTPUT_ITEM_DONE, + events::COMPLETED, + ] + ); + + // The two deltas should carry the right text. + assert_eq!(frames[3].data["delta"], "hel"); + assert_eq!(frames[4].data["delta"], "lo"); + + // The done event has the full accumulated text. + assert_eq!(frames[5].data["text"], "hello"); + + // Completed event carries the full message item. + let completed = &frames[8].data["response"]; + assert_eq!(completed["status"], "completed"); + let output = completed["output"].as_array().unwrap(); + assert_eq!(output.len(), 1); + assert_eq!(output[0]["content"][0]["text"], "hello"); + } + + #[tokio::test] + async fn length_finish_maps_to_incomplete_status() { + let (tx, rx) = mpsc::channel::(8); + let out = project_responses_stream(rx, meta()); + tx.send(InferenceEvent::Start).await.unwrap(); + tx.send(InferenceEvent::Finish { + reason: FinishReason::Length, + }) + .await + .unwrap(); + drop(tx); + let frames = collect(out).await; + let completed = frames + .iter() + .find(|f| f.event_name == events::COMPLETED) + .unwrap(); + assert_eq!(completed.data["response"]["status"], "incomplete"); + } + + #[tokio::test] + async fn synthesises_start_frames_when_producer_skips_start() { + // A producer that drops without sending Start (poisoned + // model, immediate disconnect, …) should still produce a + // coherent stream — the projector synthesises the + // mandatory header frames before COMPLETED so the + // consumer never sees an output_text.done without a + // matching content_part.added. + let (tx, rx) = mpsc::channel::(8); + let out = project_responses_stream(rx, meta()); + drop(tx); + let frames = collect(out).await; + let names: Vec<&str> = frames.iter().map(|f| f.event_name).collect(); + assert!(names.contains(&events::CREATED)); + assert!(names.contains(&events::COMPLETED)); + assert!(names.contains(&events::OUTPUT_TEXT_DONE)); + } + + #[tokio::test] + async fn empty_text_deltas_are_dropped() { + let (tx, rx) = mpsc::channel::(8); + let out = project_responses_stream(rx, meta()); + tx.send(InferenceEvent::Start).await.unwrap(); + tx.send(InferenceEvent::TextDelta(String::new())) + .await + .unwrap(); + tx.send(InferenceEvent::TextDelta("real".into())) + .await + .unwrap(); + tx.send(InferenceEvent::Finish { + reason: FinishReason::Stop, + }) + .await + .unwrap(); + drop(tx); + let frames = collect(out).await; + let delta_count = frames + .iter() + .filter(|f| f.event_name == events::OUTPUT_TEXT_DELTA) + .count(); + assert_eq!(delta_count, 1, "empty delta must not produce a frame"); + } + + // ── non-streaming builder ─────────────────────────────────────── + + #[test] + fn build_response_produces_completed_message_with_usage() { + let r = build_response( + &meta(), + "hello".into(), + FinishReason::Stop, + Some(ResponsesUsage { + input_tokens: 5, + output_tokens: 1, + total_tokens: 6, + }), + ); + assert_eq!(r.status, "completed"); + match &r.output[0] { + ResponsesOutputItem::Message { + role, + content, + status, + .. + } => { + assert_eq!(role, "assistant"); + assert_eq!(status, "completed"); + match &content[0] { + ResponsesOutputContent::OutputText { text, .. } => { + assert_eq!(text, "hello"); + } + } + } + other => panic!("expected Message, got {other:?}"), + } + let u = r.usage.unwrap(); + assert_eq!(u.total_tokens, 6); + } + + #[test] + fn build_response_length_yields_incomplete_status() { + let r = build_response(&meta(), "trunc".into(), FinishReason::Length, None); + assert_eq!(r.status, "incomplete"); + } +} diff --git a/crates/neuron/tests/api.rs b/crates/neuron/tests/api.rs index 86b20af..eca3b22 100644 --- a/crates/neuron/tests/api.rs +++ b/crates/neuron/tests/api.rs @@ -322,3 +322,168 @@ async fn test_chat_completions_streaming_model_not_loaded() { .unwrap(); assert_eq!(resp.status(), 404); } + +// ── /v1/responses ──────────────────────────────────────────────────── + +/// `/v1/responses` returns 503 when no candle harness is registered — +/// matches the chat-completions error shape so a client can swap +/// endpoints without re-handling 503s. +#[tokio::test] +async fn test_responses_no_candle_harness() { + let registry = HarnessRegistry::new(); + let health_cache = Arc::new(HealthCache::new()); + let state = Arc::new(NeuronState { + discovery: fake_discovery(), + health_cache, + registry: RwLock::new(registry), + candle: None, + activation: Arc::new(ActivationTracker::new(&[])), + }); + let app = api::neuron_routes().with_state(state); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + let url = format!("http://{addr}"); + + let resp = reqwest::Client::new() + .post(format!("{url}/v1/responses")) + .json(&json!({"model": "anything", "input": "hi"})) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 503); +} + +/// `previous_response_id` is rejected at translate time with 400 — +/// we don't store responses server-side yet, so chained +/// conversations can't be honoured. +#[tokio::test] +async fn test_responses_rejects_previous_response_id() { + use cortex_core::harness::HarnessConfig; + use neuron::config::HarnessSettings; + + let registry = HarnessRegistry::from_configs( + &[HarnessConfig { + name: "candle".into(), + }], + "http://localhost:0", + &HarnessSettings::default(), + ); + let candle = registry.candle(); + let health_cache = Arc::new(HealthCache::new()); + let state = Arc::new(NeuronState { + discovery: fake_discovery(), + health_cache, + registry: RwLock::new(registry), + candle, + activation: Arc::new(ActivationTracker::new(&[])), + }); + let app = api::neuron_routes().with_state(state); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + let url = format!("http://{addr}"); + + let resp = reqwest::Client::new() + .post(format!("{url}/v1/responses")) + .json(&json!({ + "model": "anything", + "input": "hi", + "previous_response_id": "resp_prev_42" + })) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 400); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["code"], "chained_conversation_not_supported"); +} + +/// `/v1/responses` returns 404 when the model isn't loaded — same +/// surface as chat completions. +#[tokio::test] +async fn test_responses_model_not_loaded() { + use cortex_core::harness::HarnessConfig; + use neuron::config::HarnessSettings; + + let registry = HarnessRegistry::from_configs( + &[HarnessConfig { + name: "candle".into(), + }], + "http://localhost:0", + &HarnessSettings::default(), + ); + let candle = registry.candle(); + let health_cache = Arc::new(HealthCache::new()); + let state = Arc::new(NeuronState { + discovery: fake_discovery(), + health_cache, + registry: RwLock::new(registry), + candle, + activation: Arc::new(ActivationTracker::new(&[])), + }); + let app = api::neuron_routes().with_state(state); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + let url = format!("http://{addr}"); + + let resp = reqwest::Client::new() + .post(format!("{url}/v1/responses")) + .json(&json!({"model": "not-loaded", "input": "hi"})) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 404); +} + +/// Same model-not-loaded surface on the streaming path. The +/// stream is opened only after model lookup succeeds, so a +/// missing model fails fast with a non-SSE 404 response. +#[tokio::test] +async fn test_responses_streaming_model_not_loaded() { + use cortex_core::harness::HarnessConfig; + use neuron::config::HarnessSettings; + + let registry = HarnessRegistry::from_configs( + &[HarnessConfig { + name: "candle".into(), + }], + "http://localhost:0", + &HarnessSettings::default(), + ); + let candle = registry.candle(); + let health_cache = Arc::new(HealthCache::new()); + let state = Arc::new(NeuronState { + discovery: fake_discovery(), + health_cache, + registry: RwLock::new(registry), + candle, + activation: Arc::new(ActivationTracker::new(&[])), + }); + let app = api::neuron_routes().with_state(state); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + let url = format!("http://{addr}"); + + let resp = reqwest::Client::new() + .post(format!("{url}/v1/responses")) + .json(&json!({ + "model": "not-loaded", + "input": "hi", + "stream": true + })) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 404); +}