From 4994b94c84473eea7c8180c9deb6cf9c9893de2c Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Thu, 4 Jun 2026 15:08:08 +0300 Subject: [PATCH] =?UTF-8?q?feat(neuron):=20TP-vision=20Stage=202=20?= =?UTF-8?q?=E2=80=94=20per-rank=20image=20RPC=20+=20worker=20plumbing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Carry image content through the TP forward path so every rank encodes and splices locally (replicated tower, no embedding broadcast). - rpc.rs: new WorkerRequest::GenerateStepWithImages carrying the source image data URIs + image_token_id for the single-shot vision prefill; worker still replies GenerateStepOk. Round-trip test added. - tp_qwen3_5.rs: TpQwen3_5ForCausalLM::forward_with_images — encode each preprocessed image through the rank's replicated tower, cat, splice, forward. Shared by leader and worker so every rank runs identical work. - tp/mod.rs: TpLeaderModel::forward_with_images and WorkerPool::generate_step_with_images (mirrors generate_step: fan out GenerateStepWithImages to subprocess ranks, run the leader's image forward on its device worker thread, drain, combine). - worker.rs: WorkerModel::forward_with_images + handle_generate_step_with_images — each subprocess rank preprocesses the same data URIs via the shared deterministic preprocess_data_uri, encodes, splices, forwards. - device_worker: Job::TpForwardLogitsWithImages + tp_forward_logits_with_images dispatch handler + DeviceWorkerHandle::tp_forward_logits_with_images. Determinism: every rank runs the same preprocess on the same source URIs through the same replicated tower, so the spliced hidden state matches across ranks — preserving the replicated-hidden-state invariant the row-parallel AllReduce relies on, with no NCCL broadcast. No caller yet — Stage 3 wires the TP chat/stream entry points to invoke generate_step_with_images for image prefill. cuda-gated plumbing covered by CI's CUDA type-check; rpc/route/forward_with_images compile on the non-cuda build. Refs TP-vision plan Stage 2. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../src/harness/device_worker/dispatch.rs | 74 +++++++++ .../neuron/src/harness/device_worker/jobs.rs | 17 ++ .../neuron/src/harness/device_worker/mod.rs | 41 +++++ crates/neuron/src/harness/tp/mod.rs | 142 ++++++++++++++++ crates/neuron/src/harness/tp/rpc.rs | 49 ++++++ crates/neuron/src/harness/tp/tp_qwen3_5.rs | 32 ++++ crates/neuron/src/harness/tp/worker.rs | 153 ++++++++++++++++++ 7 files changed, 508 insertions(+) diff --git a/crates/neuron/src/harness/device_worker/dispatch.rs b/crates/neuron/src/harness/device_worker/dispatch.rs index 8113dfd..d83379d 100644 --- a/crates/neuron/src/harness/device_worker/dispatch.rs +++ b/crates/neuron/src/harness/device_worker/dispatch.rs @@ -262,6 +262,25 @@ pub(crate) fn run(device_index: u32, rx: Receiver, poisoned: Arc { + let result = tp_forward_logits_with_images( + &mut state, + handle, + &tokens, + offset, + image_token_id, + &image_data_uris, + ); + let _ = reply.send(result); + } // Handled by the matches!() check above; reaching here // means a Shutdown slipped past which is a bug. Job::Shutdown => unreachable!("Shutdown should break above"), @@ -734,6 +753,61 @@ fn tp_forward_logits( Ok(values) } +/// Image-bearing leader forward (rank 0). Preprocesses each source +/// `image_data_uris` entry through the same deterministic +/// `preprocess_data_uri` every rank runs, uploads to the leader's +/// device, encodes + splices + forwards via +/// `TpLeaderModel::forward_with_images`, and copies the `[vocab]` +/// logits to CPU. Mirrors the single-GPU `forward_logits_with_images` +/// but on the TP leader's replicated tower. +#[cfg(feature = "cuda")] +fn tp_forward_logits_with_images( + state: &mut DeviceWorkerState, + handle: TpHandle, + tokens: &[u32], + offset: usize, + image_token_id: u32, + image_data_uris: &[String], +) -> anyhow::Result> { + use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri}; + use candle_core::{DType, Tensor}; + + if image_data_uris.is_empty() { + anyhow::bail!("TpForwardLogitsWithImages dispatched with zero images"); + } + + // Preprocess every image into a device-resident (C, H, W) tensor. + // Same fixed-resolution profile + decode path the subprocess workers + // run, so the encoded embeddings match across ranks bit-for-bit. + let profile = PreprocessProfile::qwen3_6(); + let (h, w) = ( + profile.target_height as usize, + profile.target_width as usize, + ); + let mut pixels: Vec = Vec::with_capacity(image_data_uris.len()); + for (idx, uri) in image_data_uris.iter().enumerate() { + let px = preprocess_data_uri(uri, &profile) + .with_context(|| format!("preprocess image[{idx}] (TP leader)"))?; + let t = Tensor::from_vec(px, (3, h, w), &state.device)?; + pixels.push(t); + } + + let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?; + + let model = state.tp_models.get_mut(&handle).ok_or_else(|| { + anyhow::anyhow!( + "TpForwardLogitsWithImages: no model for handle {}", + handle.0 + ) + })?; + + let logits = model.forward_with_images(&input, offset, &pixels, image_token_id)?; + let logits = logits.squeeze(0)?.squeeze(0)?; + let logits = logits.to_dtype(DType::F32)?.flatten_all()?; + let values = logits.to_vec1::()?; + Ok(values) +} + /// Forward step + copy the `[vocab]` logits to a CPU `Vec` ready /// for sampling on the async caller. The model's `device()` (CUDA or /// CPU) determines where the kernel runs; this fn doesn't care. diff --git a/crates/neuron/src/harness/device_worker/jobs.rs b/crates/neuron/src/harness/device_worker/jobs.rs index d0b023d..534dabb 100644 --- a/crates/neuron/src/harness/device_worker/jobs.rs +++ b/crates/neuron/src/harness/device_worker/jobs.rs @@ -231,6 +231,23 @@ pub enum Job { offset: usize, reply: oneshot::Sender>>, }, + /// Image-bearing leader (rank 0) forward for the single-shot vision + /// prefill. The handler preprocesses each `image_data_uris` entry + /// (the same deterministic path every rank runs), encodes through + /// the leader's replicated tower, splices at `image_token_id`, and + /// returns CPU-side `[vocab]` logits. Image tensors never escape the + /// worker thread. Caller fans out `GenerateStepWithImages` to the + /// subprocess ranks and drains them; only the leader forward moves + /// here. + #[cfg(feature = "cuda")] + TpForwardLogitsWithImages { + handle: TpHandle, + tokens: Vec, + offset: usize, + image_token_id: u32, + image_data_uris: Vec, + reply: oneshot::Sender>>, + }, /// Tell the worker to break its dispatch loop and exit. Any jobs /// queued after this in the channel reply `Err` to their oneshot /// senders (the senders are dropped on the worker's exit, which diff --git a/crates/neuron/src/harness/device_worker/mod.rs b/crates/neuron/src/harness/device_worker/mod.rs index 48e41df..cab5bb6 100644 --- a/crates/neuron/src/harness/device_worker/mod.rs +++ b/crates/neuron/src/harness/device_worker/mod.rs @@ -572,6 +572,47 @@ impl DeviceWorkerHandle { } } + /// Image-bearing TP leader forward (single-shot vision prefill). + /// Routes `Job::TpForwardLogitsWithImages` onto the worker thread; + /// the handler preprocesses + encodes + splices + forwards and + /// returns CPU-side `[vocab]` logits. The `WorkerPool` fans the + /// matching `GenerateStepWithImages` out to subprocess ranks so the + /// row-parallel collectives complete. + #[cfg(feature = "cuda")] + pub async fn tp_forward_logits_with_images( + &self, + handle: TpHandle, + tokens: Vec, + offset: usize, + image_token_id: u32, + image_data_uris: Vec, + ) -> Result, WorkerError> { + if self.poisoned.load(Ordering::Acquire) { + return Err(WorkerError::Poisoned { + device_index: self.device_index, + }); + } + let (reply_tx, reply_rx) = oneshot::channel(); + self.tx + .send(Job::TpForwardLogitsWithImages { + handle, + tokens, + offset, + image_token_id, + image_data_uris, + reply: reply_tx, + }) + .map_err(|_| WorkerError::Gone { + device_index: self.device_index, + })?; + match reply_rx.await { + Ok(result) => result.map_err(WorkerError::from), + Err(_) => Err(WorkerError::Gone { + device_index: self.device_index, + }), + } + } + /// Send `Job::Shutdown` and join the thread. Idempotent — calling /// twice is a no-op the second time. pub fn shutdown(&self) -> anyhow::Result<()> { diff --git a/crates/neuron/src/harness/tp/mod.rs b/crates/neuron/src/harness/tp/mod.rs index a9849c5..8d6802d 100644 --- a/crates/neuron/src/harness/tp/mod.rs +++ b/crates/neuron/src/harness/tp/mod.rs @@ -62,6 +62,25 @@ impl TpLeaderModel { } } + /// Image-bearing forward on rank 0. Only the vision-capable + /// `qwen3_5` arch supports it; the dense `qwen3` arch has no tower. + pub fn forward_with_images( + &mut self, + input: &candle_core::Tensor, + offset: usize, + image_pixels: &[candle_core::Tensor], + image_token_id: u32, + ) -> candle_core::Result { + match self { + TpLeaderModel::Qwen3_5(m) => { + m.forward_with_images(input, offset, image_pixels, image_token_id) + } + TpLeaderModel::Qwen3(_) => { + candle_core::bail!("forward_with_images: qwen3 (dense) has no vision tower") + } + } + } + pub fn clear_kv_cache(&mut self) { match self { TpLeaderModel::Qwen3(m) => m.clear_kv_cache(), @@ -687,6 +706,129 @@ impl WorkerPool { } } + /// Image-bearing variant of [`Self::generate_step`] for the + /// single-shot vision prefill. Identical fan-out / leader-forward / + /// drain shape, but every rank runs the encode + splice path: + /// + /// - subprocess workers get `GenerateStepWithImages` (carrying the + /// source `image_data_uris`); each preprocesses + encodes through + /// its replicated tower and splices locally; + /// - the leader runs the same encode + splice + forward on its + /// device worker thread via `tp_forward_logits_with_images`. + /// + /// The row-parallel `AllReduce`s synchronise the ranks exactly as in + /// the text path. Because the tower is replicated and the preprocess + /// is deterministic, every rank's spliced hidden state matches — no + /// embedding broadcast. Only used for prefill; decode reuses + /// `generate_step`. + #[cfg(feature = "cuda")] + pub async fn generate_step_with_images( + &mut self, + model_id: &str, + leader_handle: super::device_worker::TpHandle, + tokens: Vec, + offset: usize, + image_token_id: u32, + image_data_uris: Vec, + ) -> Result> { + let step_start = std::time::Instant::now(); + let tokens_len = tokens.len(); + tracing::debug!( + model = %model_id, + tokens = tokens_len, + offset, + images = image_data_uris.len(), + "WorkerPool::generate_step_with_images: fan-out" + ); + + // 1. Fan-out the image-bearing prefill to subprocess workers. + for w in &mut self.workers { + w.send_only(&WorkerRequest::GenerateStepWithImages { + model_id: model_id.to_string(), + tokens: tokens.clone(), + offset, + image_token_id, + image_data_uris: image_data_uris.clone(), + }) + .await?; + } + + // 2. Leader's image forward on its device worker thread. The + // AllReduce CustomOps block until every worker issues the + // matching collective; CPU-side logits keep the device tensor + // from escaping the worker thread. + let leader_start = std::time::Instant::now(); + let leader_result = self + .leader_worker + .tp_forward_logits_with_images( + leader_handle, + tokens, + offset, + image_token_id, + image_data_uris, + ) + .await; + let leader_ok = leader_result.is_ok(); + let leader_ms = leader_start.elapsed().as_millis(); + if !leader_ok { + let detail = leader_result + .as_ref() + .err() + .map(|e| format!("{e:#}")) + .unwrap_or_default(); + tracing::warn!( + model = %model_id, + tokens = tokens_len, + offset, + leader_ms, + error = %detail, + "WorkerPool::generate_step_with_images: leader forward failed" + ); + } + + // 3. ALWAYS drain worker responses, regardless of the leader's + // outcome, so stale GenerateStepOk replies don't poison the + // next request's recv (same invariant as generate_step). + let worker_errors = drain_workers(&mut self.workers, |r| match r { + WorkerResponse::GenerateStepOk => Ok(()), + WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")), + other => Err(format!("expected GenerateStepOk, got {other:?}")), + }) + .await; + tracing::debug!( + model = %model_id, + leader_ms, + leader_ok, + errors = worker_errors.len(), + total_ms = step_start.elapsed().as_millis(), + "WorkerPool::generate_step_with_images: workers drained" + ); + + match leader_result { + Ok(values) => { + if worker_errors.is_empty() { + Ok(values) + } else { + anyhow::bail!( + "GenerateStepWithImages: leader succeeded but workers failed: {}", + worker_errors.join("; ") + ) + } + } + Err(e) => { + if worker_errors.is_empty() { + Err(anyhow::Error::new(e) + .context("GenerateStepWithImages: leader forward failed")) + } else { + Err(anyhow::Error::new(e).context(format!( + "GenerateStepWithImages: leader forward failed and workers also failed: {}", + worker_errors.join("; ") + ))) + } + } + } + } + /// Reset the KV cache for `model_id` on every rank. Called at the /// start of every inference so a fresh request doesn't attend over /// the previous one's tokens. diff --git a/crates/neuron/src/harness/tp/rpc.rs b/crates/neuron/src/harness/tp/rpc.rs index 5c0c540..dd2f243 100644 --- a/crates/neuron/src/harness/tp/rpc.rs +++ b/crates/neuron/src/harness/tp/rpc.rs @@ -88,6 +88,29 @@ pub enum WorkerRequest { offset: usize, }, + /// Like `GenerateStep` but the prefill carries image content. Every + /// rank preprocesses the same `image_data_uris` through its + /// *replicated* vision tower, splices the resulting patch embeddings + /// at `image_token_id` positions, and runs the forward — the + /// row-parallel `AllReduce`s still synchronise every rank. Because + /// the tower is replicated and `preprocess_data_uri` is + /// deterministic, the spliced hidden state is identical on every + /// rank, so no embedding broadcast is needed. Sent only for the + /// (single-shot) image-bearing prefill; decode steps use plain + /// `GenerateStep`. Worker replies with the same `GenerateStepOk`. + GenerateStepWithImages { + model_id: String, + tokens: Vec, + offset: usize, + /// `<|image_pad|>` sentinel id (248056 for Qwen3.6); splice + /// target in the expanded token stream. + image_token_id: u32, + /// Source image data URIs (`data:image/...;base64,...`), one per + /// image in prompt order. Each rank decodes + preprocesses these + /// identically; tens of KB each, so cheap over the stdin pipe. + image_data_uris: Vec, + }, + /// Reset the KV cache for this model on this rank. Sent at the /// start of every inference so a fresh request doesn't accidentally /// attend over the previous one's tokens. @@ -191,6 +214,32 @@ mod tests { assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#); } + #[test] + fn request_generate_step_with_images_round_trip() { + let req = WorkerRequest::GenerateStepWithImages { + model_id: "Qwen/Qwen3.6-27B".into(), + tokens: vec![1, 2, 248056, 3], + offset: 0, + image_token_id: 248056, + image_data_uris: vec!["data:image/png;base64,AAA=".into()], + }; + let wire = serde_json::to_string(&req).unwrap(); + assert!(wire.contains(r#""op":"generate_step_with_images""#)); + match roundtrip(&req) { + WorkerRequest::GenerateStepWithImages { + tokens, + image_token_id, + image_data_uris, + .. + } => { + assert_eq!(tokens, vec![1, 2, 248056, 3]); + assert_eq!(image_token_id, 248056); + assert_eq!(image_data_uris.len(), 1); + } + other => panic!("expected GenerateStepWithImages, got {other:?}"), + } + } + #[test] fn request_shutdown_round_trip() { assert_eq!( diff --git a/crates/neuron/src/harness/tp/tp_qwen3_5.rs b/crates/neuron/src/harness/tp/tp_qwen3_5.rs index 8d812ab..19cc1e5 100644 --- a/crates/neuron/src/harness/tp/tp_qwen3_5.rs +++ b/crates/neuron/src/harness/tp/tp_qwen3_5.rs @@ -1192,6 +1192,38 @@ impl TpQwen3_5ForCausalLM { hidden.i((.., l - 1.., ..))?.apply(&self.lm_head) } + /// End-to-end image prefill on one rank: encode each preprocessed + /// `(C, H, W)` pixel tensor through this rank's replicated tower, + /// concatenate the per-image embeddings along the patch axis, and + /// forward with the splice. Shared by the leader (`TpLeaderModel`) + /// and the subprocess worker (`WorkerModel`) so every rank runs the + /// identical encode → splice → forward and keeps the replicated + /// hidden state in lockstep. Returns last-position logits + /// `(B, 1, vocab)`, same contract as `forward`. + pub fn forward_with_images( + &mut self, + input: &Tensor, + offset: usize, + image_pixels: &[Tensor], + image_token_id: u32, + ) -> candle_core::Result { + if image_pixels.is_empty() { + candle_core::bail!("forward_with_images: called with zero images"); + } + // Encode each image (immutable borrows of the tower) before the + // mutable forward below; the borrows end as each owned embedding + // is pushed. + let mut per_image = Vec::with_capacity(image_pixels.len()); + for (idx, img) in image_pixels.iter().enumerate() { + let embed = self + .encode_image(img) + .map_err(|e| candle_core::Error::Msg(format!("encode image[{idx}]: {e:#}")))?; + per_image.push(embed); + } + let image_embeds = Tensor::cat(&per_image.iter().collect::>(), 0)?; + self.forward_with_vision(input, offset, &image_embeds, image_token_id) + } + pub fn clear_kv_cache(&mut self) { self.base.clear_kv_cache(); } diff --git a/crates/neuron/src/harness/tp/worker.rs b/crates/neuron/src/harness/tp/worker.rs index 6ec9fd8..f2a8077 100644 --- a/crates/neuron/src/harness/tp/worker.rs +++ b/crates/neuron/src/harness/tp/worker.rs @@ -47,6 +47,28 @@ impl WorkerModel { } } + /// Image-bearing forward on this rank. Only the vision-capable + /// `qwen3_5` arch has a replicated tower; the dense `qwen3` arch + /// errors. The returned logits are discarded by the caller (the + /// leader samples from its own rank-0 copy) — the value is the NCCL + /// collectives the forward issues. + fn forward_with_images( + &mut self, + input: &candle_core::Tensor, + offset: usize, + image_pixels: &[candle_core::Tensor], + image_token_id: u32, + ) -> candle_core::Result { + match self { + WorkerModel::Qwen3_5(m) => { + m.forward_with_images(input, offset, image_pixels, image_token_id) + } + WorkerModel::Qwen3(_) => { + candle_core::bail!("forward_with_images: qwen3 (dense) has no vision tower") + } + } + } + fn clear_kv_cache(&mut self) { match self { WorkerModel::Qwen3(m) => m.clear_kv_cache(), @@ -167,6 +189,19 @@ impl WorkerState { tokens, offset, } => self.handle_generate_step(&model_id, tokens, offset), + WorkerRequest::GenerateStepWithImages { + model_id, + tokens, + offset, + image_token_id, + image_data_uris, + } => self.handle_generate_step_with_images( + &model_id, + tokens, + offset, + image_token_id, + image_data_uris, + ), WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id), WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id), WorkerRequest::Shutdown => WorkerResponse::Bye, @@ -418,6 +453,124 @@ impl WorkerState { } } + /// Image-bearing prefill on this rank. Preprocesses each source data + /// URI through the same deterministic `preprocess_data_uri` the + /// leader runs, encodes through this rank's replicated tower, and + /// splices + forwards. The logits are discarded (the leader samples + /// from rank 0); the row-parallel `AllReduce`s are the point. + #[cfg(feature = "cuda")] + fn handle_generate_step_with_images( + &mut self, + model_id: &str, + tokens: Vec, + offset: usize, + image_token_id: u32, + image_data_uris: Vec, + ) -> WorkerResponse { + use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri}; + use candle_core::Tensor; + + if image_data_uris.is_empty() { + return WorkerResponse::Error { + kind: "bad_request".into(), + message: "GenerateStepWithImages with zero images".into(), + }; + } + let Some(model) = self.models.get_mut(model_id) else { + return WorkerResponse::Error { + kind: "model_not_loaded".into(), + message: format!("model '{model_id}' not loaded on rank {}", self.config.rank), + }; + }; + let device = model.device().clone(); + + // Preprocess each image identically to the leader so the encoded + // embeddings — and thus the spliced hidden state — match across + // ranks. Fixed 448×448 profile. + let profile = PreprocessProfile::qwen3_6(); + let (h, w) = ( + profile.target_height as usize, + profile.target_width as usize, + ); + let mut pixels: Vec = Vec::with_capacity(image_data_uris.len()); + for (idx, uri) in image_data_uris.iter().enumerate() { + let px = match preprocess_data_uri(uri, &profile) { + Ok(p) => p, + Err(e) => { + return WorkerResponse::Error { + kind: "bad_request".into(), + message: format!("preprocess image[{idx}]: {e:#}"), + }; + } + }; + match Tensor::from_vec(px, (3, h, w), &device) { + Ok(t) => pixels.push(t), + Err(e) => { + return WorkerResponse::Error { + kind: "forward_failed".into(), + message: format!("build image[{idx}] tensor: {e}"), + }; + } + } + } + + let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) { + Ok(t) => t, + Err(e) => { + return WorkerResponse::Error { + kind: "forward_failed".into(), + message: format!("build input tensor: {e}"), + }; + } + }; + + let start = std::time::Instant::now(); + tracing::debug!( + rank = self.config.rank, + model = %model_id, + tokens = tokens.len(), + offset, + images = pixels.len(), + "worker GenerateStepWithImages: forward starting" + ); + // Drop the logits — the leader samples from its own rank-0 copy. + if let Err(e) = model.forward_with_images(&input, offset, &pixels, image_token_id) { + tracing::warn!( + rank = self.config.rank, + model = %model_id, + elapsed_ms = start.elapsed().as_millis(), + error = %e, + "worker GenerateStepWithImages: forward failed" + ); + return WorkerResponse::Error { + kind: "forward_failed".into(), + message: format!("TP image forward: {e}"), + }; + } + tracing::debug!( + rank = self.config.rank, + model = %model_id, + elapsed_ms = start.elapsed().as_millis(), + "worker GenerateStepWithImages: forward done" + ); + WorkerResponse::GenerateStepOk + } + + #[cfg(not(feature = "cuda"))] + fn handle_generate_step_with_images( + &mut self, + _model_id: &str, + _tokens: Vec, + _offset: usize, + _image_token_id: u32, + _image_data_uris: Vec, + ) -> WorkerResponse { + WorkerResponse::Error { + kind: "cuda_feature_not_enabled".into(), + message: "GenerateStepWithImages requires --features cuda".into(), + } + } + #[cfg(feature = "cuda")] fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse { let Some(model) = self.models.get_mut(model_id) else {