feat(neuron): OpenAI-compatible SSE streaming chat completions
Stage 4 of the candle-native pivot. /v1/chat/completions now switches
to text/event-stream when the request sets stream: true, emitting one
chat.completion.chunk per generated token followed by the OpenAI
[DONE] terminator.
Pipeline:
- chat_completion_stream creates a bounded mpsc::channel<ChatCompletionChunk>(32),
sends the leading role chunk, then spawns a blocking task that
acquires the per-model arch lock and runs the streaming generation
loop.
- run_inference_streaming tracks a cumulative decoded prefix so each
chunk's delta.content is the substring added since the last chunk —
safe across BPE byte-fallback boundaries that would otherwise split
multi-byte UTF-8 chars.
- The blocking task aborts cleanly if blocking_send fails (client
disconnected), so generation stops when the SSE consumer hangs up.
- Final chunk carries finish_reason ("stop" on EOS, "length" on
max_tokens). The handler appends data: [DONE] after the channel
closes.
The Stage 3 streaming 501 placeholder test is repurposed: with the
streaming path live, an unloaded model now hits the same 404 surface
as the non-streaming path (the model lookup happens first).
cortex-gateway's existing proxy is unchanged — it already forwards
SSE bytes verbatim from Phase 2 work, so the candle SSE format passes
through unmodified.
Neuron Cargo.toml gains futures + tokio-stream (both already in
workspace deps) for ReceiverStream and stream combinators.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -6,14 +6,18 @@ use crate::health::HealthCache;
|
||||
use axum::Router;
|
||||
use axum::extract::{Path, State};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use axum::response::{IntoResponse, Json};
|
||||
use axum::routing::{get, post};
|
||||
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
||||
use cortex_core::harness::ModelSpec;
|
||||
use cortex_core::openai::ChatCompletionRequest;
|
||||
use futures::stream::{self, StreamExt};
|
||||
use serde_json::{Value, json};
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
/// Shared state for the neuron HTTP server.
|
||||
pub struct NeuronState {
|
||||
@@ -110,8 +114,9 @@ async fn model_endpoint(
|
||||
}
|
||||
}
|
||||
|
||||
/// OpenAI-compatible chat completions. Non-streaming for Stage 3; the
|
||||
/// streaming path is added in Stage 4.
|
||||
/// OpenAI-compatible chat completions. Dispatches to streaming SSE when
|
||||
/// `stream: true` is set on the request; otherwise returns a single
|
||||
/// `ChatCompletionResponse`.
|
||||
async fn chat_completions(
|
||||
State(state): State<Arc<NeuronState>>,
|
||||
Json(req): Json<ChatCompletionRequest>,
|
||||
@@ -125,24 +130,44 @@ async fn chat_completions(
|
||||
};
|
||||
|
||||
if req.stream.unwrap_or(false) {
|
||||
return (
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
Json(json!({"error": "streaming responses arrive in Stage 4"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
match candle.chat_completion(req).await {
|
||||
Ok(resp) => Json(resp).into_response(),
|
||||
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::Other(e)) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": e.to_string()})),
|
||||
)
|
||||
.into_response(),
|
||||
match candle.chat_completion_stream(req).await {
|
||||
Ok(rx) => {
|
||||
// Each chunk → one SSE `data: {json}` line. After the
|
||||
// channel closes, append the OpenAI [DONE] terminator.
|
||||
let body_stream = ReceiverStream::new(rx).map(|chunk| {
|
||||
let body = serde_json::to_string(&chunk).unwrap_or_default();
|
||||
Ok::<_, Infallible>(Event::default().data(body))
|
||||
});
|
||||
let done_stream =
|
||||
stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) });
|
||||
Sse::new(body_stream.chain(done_stream))
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
}
|
||||
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::Other(e)) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": e.to_string()})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
} else {
|
||||
match candle.chat_completion(req).await {
|
||||
Ok(resp) => Json(resp).into_response(),
|
||||
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::Other(e)) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": e.to_string()})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user