From 766c20ba47f1d5e17e4cf8769f5447b37eaf25b8 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Thu, 4 Jun 2026 13:57:02 +0300 Subject: [PATCH] =?UTF-8?q?feat(neuron):=20C1=20=E2=80=94=20streaming=20SS?= =?UTF-8?q?E=20chat=20completion=20with=20vision?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The streaming worker path now splices image embeddings on prefill, closing the silent text-only degrade for `stream=true` image requests. `inference_stream` gains the same vision-routing block as the non-streaming `chat_completion`: detect `image_url` content, reject it against text-only models with `VisionUnsupported` (before any SSE frame is sent), preprocess each image and expand its `<|image_pad|>` sentinel to the per-image patch count, then carry the payload through dispatch. Rather than duplicate the 75-line `route_token!` reasoning/tool-call state machine into a sibling streamer, `stream_inference_via_worker` takes an `Option<(Vec, u32)>`: when `Some`, prefill is a single-shot `forward_logits_with_images` splice; when `None`, the original chunked text-only prefill. Image embeddings are prefill-only, so every decode step stays on the plain `forward_logits` path and the shared decode loop is untouched. This keeps exactly one copy of the tool-call/reasoning logic to maintain. The Responses API streaming path (`responses_stream`) inherits vision for free since it drives the same `inference_stream`. Unit test covers `request_has_images` (the shared routing gate); the real-weights SSE smoke is the manual curl on beast (cuda-integration). Closes part of #16 (Stage C1). Co-Authored-By: Claude Opus 4.8 (1M context) --- crates/neuron/src/harness/candle.rs | 112 ++++++++++++++++++++++++++-- 1 file changed, 106 insertions(+), 6 deletions(-) diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 324e860..701f323 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -1981,7 +1981,50 @@ impl CandleHarness { .tokenizer .encode(prompt.as_str(), true) .map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?; - let prompt_tokens: Vec = encoding.get_ids().to_vec(); + let mut prompt_tokens: Vec = encoding.get_ids().to_vec(); + + // Stage C1: vision routing for the streaming path. Mirrors the + // non-streaming `chat_completion` block — detect image content, + // reject it against text-only models, preprocess each image and + // expand its `<|image_pad|>` sentinel to the per-image patch + // count, then carry the payload through to a single-shot + // image-spliced prefill. Non-image requests skip all of this. + // Returning early here (before the `Start` event below) keeps a + // rejected vision request from opening a half-formed SSE stream. + let vision_route: Option<(Vec, u32)> = + if request_has_images(&request) { + if !loaded.has_vision { + return Err(InferenceError::VisionUnsupported { + model_id: request.model.clone(), + }); + } + let image_token_id = + loaded + .image_token_id + .ok_or_else(|| InferenceError::VisionUnsupported { + model_id: request.model.clone(), + })?; + let patches_per_image = loaded.lm_tokens_per_image.ok_or_else(|| { + InferenceError::VisionUnsupported { + model_id: request.model.clone(), + } + })?; + let profile = super::preprocess::PreprocessProfile::qwen3_6(); + let images = extract_images_from_request(&request, &profile) + .map_err(|e| InferenceError::Other(anyhow::anyhow!("extract_images: {e}")))?; + if images.is_empty() { + return Err(InferenceError::Other(anyhow::anyhow!( + "request has image content but extractor produced zero images" + ))); + } + let per_image_counts: Vec = vec![patches_per_image; images.len()]; + prompt_tokens = + expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts) + .map_err(InferenceError::Other)?; + Some((images, image_token_id)) + } else { + None + }; let temperature = request.temperature.unwrap_or(0.7); let top_p = request.top_p; @@ -2048,6 +2091,7 @@ impl CandleHarness { ?eos_id, vram_free_mb, vram_total_mb, + vision = vision_route.is_some(), "chat_completion (stream): starting" ); } @@ -2078,6 +2122,7 @@ impl CandleHarness { handle, tokenizer, prompt_tokens, + vision_route, max_new, temperature, top_p, @@ -4046,6 +4091,17 @@ async fn run_inference_via_worker( /// forward step through `worker.forward_logits()`. Same per-step /// CPU-side sampling discipline — no device tensor escapes the /// worker thread. +/// +/// `images` carries the Stage C vision payload. When `Some`, prefill +/// is a single-shot `forward_logits_with_images` that splices image +/// embeddings at `image_token_id` positions (same contract as the +/// non-streaming [`run_inference_with_images_via_worker`]); image +/// embeddings are prefill-only, so every decode step below takes the +/// plain `forward_logits` path regardless. When `None`, prefill is +/// chunked (`chunked_prefill_via_worker`) to bound activation memory +/// — the original text-only behaviour, unchanged. The decode loop and +/// the `route_token!` reasoning/tool-call state machine are shared +/// across both prefill shapes, so there's exactly one copy to maintain. #[cfg(feature = "cuda")] #[allow(clippy::too_many_arguments)] async fn stream_inference_via_worker( @@ -4053,6 +4109,7 @@ async fn stream_inference_via_worker( handle: super::device_worker::ArchHandle, tokenizer: Tokenizer, prompt_tokens: Vec, + images: Option<(Vec, u32)>, max_new: usize, temperature: f64, top_p: Option, @@ -4098,11 +4155,19 @@ async fn stream_inference_via_worker( .await .map_err(|e| anyhow::anyhow!("clear_kv_cache: {e}"))?; - // Chunked prefill (see `chunked_prefill_via_worker`). The owning - // `prompt_tokens: Vec` is borrowed for the loop's duration; - // we still need `prompt_len` (already extracted above) for the - // decode-step offset arithmetic. - let logits_vec = chunked_prefill_via_worker(&*worker, handle, &prompt_tokens).await?; + // Prefill. Vision-bearing requests (`images = Some`) do a + // single-shot prefill that splices the image embeddings; text-only + // requests use chunked prefill (see `chunked_prefill_via_worker`) + // to bound activation memory. Either way the owning + // `prompt_tokens: Vec` outlives this step; we use `prompt_len` + // (already extracted above) for the decode-step offset arithmetic. + let logits_vec = match images { + Some((imgs, image_token_id)) => worker + .forward_logits_with_images(handle, prompt_tokens.clone(), 0, imgs, image_token_id) + .await + .map_err(|e| anyhow::anyhow!("forward_logits_with_images: {e}"))?, + None => chunked_prefill_via_worker(&*worker, handle, &prompt_tokens).await?, + }; let logits = Tensor::new(logits_vec.as_slice(), &Device::Cpu)?; let mut next_token = match sample_with_penalty(&logits, &all_tokens, &mut logits_processor) { Ok(t) => t, @@ -4699,4 +4764,39 @@ mod tests { let out = expand_image_pad_tokens(&input, pad, &[]).unwrap(); assert_eq!(out, input); } + + /// `request_has_images` is the gate that routes both the + /// non-streaming (`chat_completion`) and streaming + /// (`inference_stream`, Stage C1) paths to the vision-aware + /// prefill. Exercise the three shapes it must distinguish: plain + /// text, a text-only content-parts array, and a parts array + /// carrying an `image_url`. + #[test] + fn request_has_images_detects_image_url_parts() { + let text_only: ChatCompletionRequest = serde_json::from_value(serde_json::json!({ + "model": "m", + "messages": [{"role": "user", "content": "hello"}], + })) + .unwrap(); + assert!(!request_has_images(&text_only)); + + let parts_text_only: ChatCompletionRequest = serde_json::from_value(serde_json::json!({ + "model": "m", + "messages": [{"role": "user", "content": [ + {"type": "text", "text": "hello"} + ]}], + })) + .unwrap(); + assert!(!request_has_images(&parts_text_only)); + + let with_image: ChatCompletionRequest = serde_json::from_value(serde_json::json!({ + "model": "m", + "messages": [{"role": "user", "content": [ + {"type": "text", "text": "what is this?"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,AAA="}} + ]}], + })) + .unwrap(); + assert!(request_has_images(&with_image)); + } }