feat(neuron): TP-vision Stage 2 — per-rank image RPC + worker plumbing

Carry image content through the TP forward path so every rank encodes
and splices locally (replicated tower, no embedding broadcast).

- rpc.rs: new WorkerRequest::GenerateStepWithImages carrying the source
  image data URIs + image_token_id for the single-shot vision prefill;
  worker still replies GenerateStepOk. Round-trip test added.
- tp_qwen3_5.rs: TpQwen3_5ForCausalLM::forward_with_images — encode each
  preprocessed image through the rank's replicated tower, cat, splice,
  forward. Shared by leader and worker so every rank runs identical work.
- tp/mod.rs: TpLeaderModel::forward_with_images and
  WorkerPool::generate_step_with_images (mirrors generate_step: fan out
  GenerateStepWithImages to subprocess ranks, run the leader's image
  forward on its device worker thread, drain, combine).
- worker.rs: WorkerModel::forward_with_images + handle_generate_step_with_images
  — each subprocess rank preprocesses the same data URIs via the shared
  deterministic preprocess_data_uri, encodes, splices, forwards.
- device_worker: Job::TpForwardLogitsWithImages + tp_forward_logits_with_images
  dispatch handler + DeviceWorkerHandle::tp_forward_logits_with_images.

Determinism: every rank runs the same preprocess on the same source
URIs through the same replicated tower, so the spliced hidden state
matches across ranks — preserving the replicated-hidden-state invariant
the row-parallel AllReduce relies on, with no NCCL broadcast.

No caller yet — Stage 3 wires the TP chat/stream entry points to invoke
generate_step_with_images for image prefill. cuda-gated plumbing covered
by CI's CUDA type-check; rpc/route/forward_with_images compile on the
non-cuda build.

Refs TP-vision plan Stage 2.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-04 15:08:08 +03:00
parent 9a24b05866
commit 4994b94c84
7 changed files with 508 additions and 0 deletions

View File

@@ -262,6 +262,25 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
let result = tp_forward_logits(&mut state, handle, &tokens, offset);
let _ = reply.send(result);
}
#[cfg(feature = "cuda")]
Job::TpForwardLogitsWithImages {
handle,
tokens,
offset,
image_token_id,
image_data_uris,
reply,
} => {
let result = tp_forward_logits_with_images(
&mut state,
handle,
&tokens,
offset,
image_token_id,
&image_data_uris,
);
let _ = reply.send(result);
}
// Handled by the matches!() check above; reaching here
// means a Shutdown slipped past which is a bug.
Job::Shutdown => unreachable!("Shutdown should break above"),
@@ -734,6 +753,61 @@ fn tp_forward_logits(
Ok(values)
}
/// Image-bearing leader forward (rank 0). Preprocesses each source
/// `image_data_uris` entry through the same deterministic
/// `preprocess_data_uri` every rank runs, uploads to the leader's
/// device, encodes + splices + forwards via
/// `TpLeaderModel::forward_with_images`, and copies the `[vocab]`
/// logits to CPU. Mirrors the single-GPU `forward_logits_with_images`
/// but on the TP leader's replicated tower.
#[cfg(feature = "cuda")]
fn tp_forward_logits_with_images(
state: &mut DeviceWorkerState,
handle: TpHandle,
tokens: &[u32],
offset: usize,
image_token_id: u32,
image_data_uris: &[String],
) -> anyhow::Result<Vec<f32>> {
use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri};
use candle_core::{DType, Tensor};
if image_data_uris.is_empty() {
anyhow::bail!("TpForwardLogitsWithImages dispatched with zero images");
}
// Preprocess every image into a device-resident (C, H, W) tensor.
// Same fixed-resolution profile + decode path the subprocess workers
// run, so the encoded embeddings match across ranks bit-for-bit.
let profile = PreprocessProfile::qwen3_6();
let (h, w) = (
profile.target_height as usize,
profile.target_width as usize,
);
let mut pixels: Vec<Tensor> = Vec::with_capacity(image_data_uris.len());
for (idx, uri) in image_data_uris.iter().enumerate() {
let px = preprocess_data_uri(uri, &profile)
.with_context(|| format!("preprocess image[{idx}] (TP leader)"))?;
let t = Tensor::from_vec(px, (3, h, w), &state.device)?;
pixels.push(t);
}
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
let model = state.tp_models.get_mut(&handle).ok_or_else(|| {
anyhow::anyhow!(
"TpForwardLogitsWithImages: no model for handle {}",
handle.0
)
})?;
let logits = model.forward_with_images(&input, offset, &pixels, image_token_id)?;
let logits = logits.squeeze(0)?.squeeze(0)?;
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
let values = logits.to_vec1::<f32>()?;
Ok(values)
}
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
/// for sampling on the async caller. The model's `device()` (CUDA or
/// CPU) determines where the kernel runs; this fn doesn't care.

View File

@@ -231,6 +231,23 @@ pub enum Job {
offset: usize,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Image-bearing leader (rank 0) forward for the single-shot vision
/// prefill. The handler preprocesses each `image_data_uris` entry
/// (the same deterministic path every rank runs), encodes through
/// the leader's replicated tower, splices at `image_token_id`, and
/// returns CPU-side `[vocab]` logits. Image tensors never escape the
/// worker thread. Caller fans out `GenerateStepWithImages` to the
/// subprocess ranks and drains them; only the leader forward moves
/// here.
#[cfg(feature = "cuda")]
TpForwardLogitsWithImages {
handle: TpHandle,
tokens: Vec<u32>,
offset: usize,
image_token_id: u32,
image_data_uris: Vec<String>,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Tell the worker to break its dispatch loop and exit. Any jobs
/// queued after this in the channel reply `Err` to their oneshot
/// senders (the senders are dropped on the worker's exit, which

View File

@@ -572,6 +572,47 @@ impl DeviceWorkerHandle {
}
}
/// Image-bearing TP leader forward (single-shot vision prefill).
/// Routes `Job::TpForwardLogitsWithImages` onto the worker thread;
/// the handler preprocesses + encodes + splices + forwards and
/// returns CPU-side `[vocab]` logits. The `WorkerPool` fans the
/// matching `GenerateStepWithImages` out to subprocess ranks so the
/// row-parallel collectives complete.
#[cfg(feature = "cuda")]
pub async fn tp_forward_logits_with_images(
&self,
handle: TpHandle,
tokens: Vec<u32>,
offset: usize,
image_token_id: u32,
image_data_uris: Vec<String>,
) -> Result<Vec<f32>, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TpForwardLogitsWithImages {
handle,
tokens,
offset,
image_token_id,
image_data_uris,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
/// twice is a no-op the second time.
pub fn shutdown(&self) -> anyhow::Result<()> {

View File

@@ -62,6 +62,25 @@ impl TpLeaderModel {
}
}
/// Image-bearing forward on rank 0. Only the vision-capable
/// `qwen3_5` arch supports it; the dense `qwen3` arch has no tower.
pub fn forward_with_images(
&mut self,
input: &candle_core::Tensor,
offset: usize,
image_pixels: &[candle_core::Tensor],
image_token_id: u32,
) -> candle_core::Result<candle_core::Tensor> {
match self {
TpLeaderModel::Qwen3_5(m) => {
m.forward_with_images(input, offset, image_pixels, image_token_id)
}
TpLeaderModel::Qwen3(_) => {
candle_core::bail!("forward_with_images: qwen3 (dense) has no vision tower")
}
}
}
pub fn clear_kv_cache(&mut self) {
match self {
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
@@ -687,6 +706,129 @@ impl WorkerPool {
}
}
/// Image-bearing variant of [`Self::generate_step`] for the
/// single-shot vision prefill. Identical fan-out / leader-forward /
/// drain shape, but every rank runs the encode + splice path:
///
/// - subprocess workers get `GenerateStepWithImages` (carrying the
/// source `image_data_uris`); each preprocesses + encodes through
/// its replicated tower and splices locally;
/// - the leader runs the same encode + splice + forward on its
/// device worker thread via `tp_forward_logits_with_images`.
///
/// The row-parallel `AllReduce`s synchronise the ranks exactly as in
/// the text path. Because the tower is replicated and the preprocess
/// is deterministic, every rank's spliced hidden state matches — no
/// embedding broadcast. Only used for prefill; decode reuses
/// `generate_step`.
#[cfg(feature = "cuda")]
pub async fn generate_step_with_images(
&mut self,
model_id: &str,
leader_handle: super::device_worker::TpHandle,
tokens: Vec<u32>,
offset: usize,
image_token_id: u32,
image_data_uris: Vec<String>,
) -> Result<Vec<f32>> {
let step_start = std::time::Instant::now();
let tokens_len = tokens.len();
tracing::debug!(
model = %model_id,
tokens = tokens_len,
offset,
images = image_data_uris.len(),
"WorkerPool::generate_step_with_images: fan-out"
);
// 1. Fan-out the image-bearing prefill to subprocess workers.
for w in &mut self.workers {
w.send_only(&WorkerRequest::GenerateStepWithImages {
model_id: model_id.to_string(),
tokens: tokens.clone(),
offset,
image_token_id,
image_data_uris: image_data_uris.clone(),
})
.await?;
}
// 2. Leader's image forward on its device worker thread. The
// AllReduce CustomOps block until every worker issues the
// matching collective; CPU-side logits keep the device tensor
// from escaping the worker thread.
let leader_start = std::time::Instant::now();
let leader_result = self
.leader_worker
.tp_forward_logits_with_images(
leader_handle,
tokens,
offset,
image_token_id,
image_data_uris,
)
.await;
let leader_ok = leader_result.is_ok();
let leader_ms = leader_start.elapsed().as_millis();
if !leader_ok {
let detail = leader_result
.as_ref()
.err()
.map(|e| format!("{e:#}"))
.unwrap_or_default();
tracing::warn!(
model = %model_id,
tokens = tokens_len,
offset,
leader_ms,
error = %detail,
"WorkerPool::generate_step_with_images: leader forward failed"
);
}
// 3. ALWAYS drain worker responses, regardless of the leader's
// outcome, so stale GenerateStepOk replies don't poison the
// next request's recv (same invariant as generate_step).
let worker_errors = drain_workers(&mut self.workers, |r| match r {
WorkerResponse::GenerateStepOk => Ok(()),
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
other => Err(format!("expected GenerateStepOk, got {other:?}")),
})
.await;
tracing::debug!(
model = %model_id,
leader_ms,
leader_ok,
errors = worker_errors.len(),
total_ms = step_start.elapsed().as_millis(),
"WorkerPool::generate_step_with_images: workers drained"
);
match leader_result {
Ok(values) => {
if worker_errors.is_empty() {
Ok(values)
} else {
anyhow::bail!(
"GenerateStepWithImages: leader succeeded but workers failed: {}",
worker_errors.join("; ")
)
}
}
Err(e) => {
if worker_errors.is_empty() {
Err(anyhow::Error::new(e)
.context("GenerateStepWithImages: leader forward failed"))
} else {
Err(anyhow::Error::new(e).context(format!(
"GenerateStepWithImages: leader forward failed and workers also failed: {}",
worker_errors.join("; ")
)))
}
}
}
}
/// Reset the KV cache for `model_id` on every rank. Called at the
/// start of every inference so a fresh request doesn't attend over
/// the previous one's tokens.

View File

@@ -88,6 +88,29 @@ pub enum WorkerRequest {
offset: usize,
},
/// Like `GenerateStep` but the prefill carries image content. Every
/// rank preprocesses the same `image_data_uris` through its
/// *replicated* vision tower, splices the resulting patch embeddings
/// at `image_token_id` positions, and runs the forward — the
/// row-parallel `AllReduce`s still synchronise every rank. Because
/// the tower is replicated and `preprocess_data_uri` is
/// deterministic, the spliced hidden state is identical on every
/// rank, so no embedding broadcast is needed. Sent only for the
/// (single-shot) image-bearing prefill; decode steps use plain
/// `GenerateStep`. Worker replies with the same `GenerateStepOk`.
GenerateStepWithImages {
model_id: String,
tokens: Vec<u32>,
offset: usize,
/// `<|image_pad|>` sentinel id (248056 for Qwen3.6); splice
/// target in the expanded token stream.
image_token_id: u32,
/// Source image data URIs (`data:image/...;base64,...`), one per
/// image in prompt order. Each rank decodes + preprocesses these
/// identically; tens of KB each, so cheap over the stdin pipe.
image_data_uris: Vec<String>,
},
/// Reset the KV cache for this model on this rank. Sent at the
/// start of every inference so a fresh request doesn't accidentally
/// attend over the previous one's tokens.
@@ -191,6 +214,32 @@ mod tests {
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
}
#[test]
fn request_generate_step_with_images_round_trip() {
let req = WorkerRequest::GenerateStepWithImages {
model_id: "Qwen/Qwen3.6-27B".into(),
tokens: vec![1, 2, 248056, 3],
offset: 0,
image_token_id: 248056,
image_data_uris: vec!["data:image/png;base64,AAA=".into()],
};
let wire = serde_json::to_string(&req).unwrap();
assert!(wire.contains(r#""op":"generate_step_with_images""#));
match roundtrip(&req) {
WorkerRequest::GenerateStepWithImages {
tokens,
image_token_id,
image_data_uris,
..
} => {
assert_eq!(tokens, vec![1, 2, 248056, 3]);
assert_eq!(image_token_id, 248056);
assert_eq!(image_data_uris.len(), 1);
}
other => panic!("expected GenerateStepWithImages, got {other:?}"),
}
}
#[test]
fn request_shutdown_round_trip() {
assert_eq!(

View File

@@ -1192,6 +1192,38 @@ impl TpQwen3_5ForCausalLM {
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
/// End-to-end image prefill on one rank: encode each preprocessed
/// `(C, H, W)` pixel tensor through this rank's replicated tower,
/// concatenate the per-image embeddings along the patch axis, and
/// forward with the splice. Shared by the leader (`TpLeaderModel`)
/// and the subprocess worker (`WorkerModel`) so every rank runs the
/// identical encode → splice → forward and keeps the replicated
/// hidden state in lockstep. Returns last-position logits
/// `(B, 1, vocab)`, same contract as `forward`.
pub fn forward_with_images(
&mut self,
input: &Tensor,
offset: usize,
image_pixels: &[Tensor],
image_token_id: u32,
) -> candle_core::Result<Tensor> {
if image_pixels.is_empty() {
candle_core::bail!("forward_with_images: called with zero images");
}
// Encode each image (immutable borrows of the tower) before the
// mutable forward below; the borrows end as each owned embedding
// is pushed.
let mut per_image = Vec::with_capacity(image_pixels.len());
for (idx, img) in image_pixels.iter().enumerate() {
let embed = self
.encode_image(img)
.map_err(|e| candle_core::Error::Msg(format!("encode image[{idx}]: {e:#}")))?;
per_image.push(embed);
}
let image_embeds = Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)?;
self.forward_with_vision(input, offset, &image_embeds, image_token_id)
}
pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache();
}

View File

@@ -47,6 +47,28 @@ impl WorkerModel {
}
}
/// Image-bearing forward on this rank. Only the vision-capable
/// `qwen3_5` arch has a replicated tower; the dense `qwen3` arch
/// errors. The returned logits are discarded by the caller (the
/// leader samples from its own rank-0 copy) — the value is the NCCL
/// collectives the forward issues.
fn forward_with_images(
&mut self,
input: &candle_core::Tensor,
offset: usize,
image_pixels: &[candle_core::Tensor],
image_token_id: u32,
) -> candle_core::Result<candle_core::Tensor> {
match self {
WorkerModel::Qwen3_5(m) => {
m.forward_with_images(input, offset, image_pixels, image_token_id)
}
WorkerModel::Qwen3(_) => {
candle_core::bail!("forward_with_images: qwen3 (dense) has no vision tower")
}
}
}
fn clear_kv_cache(&mut self) {
match self {
WorkerModel::Qwen3(m) => m.clear_kv_cache(),
@@ -167,6 +189,19 @@ impl WorkerState {
tokens,
offset,
} => self.handle_generate_step(&model_id, tokens, offset),
WorkerRequest::GenerateStepWithImages {
model_id,
tokens,
offset,
image_token_id,
image_data_uris,
} => self.handle_generate_step_with_images(
&model_id,
tokens,
offset,
image_token_id,
image_data_uris,
),
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
WorkerRequest::Shutdown => WorkerResponse::Bye,
@@ -418,6 +453,124 @@ impl WorkerState {
}
}
/// Image-bearing prefill on this rank. Preprocesses each source data
/// URI through the same deterministic `preprocess_data_uri` the
/// leader runs, encodes through this rank's replicated tower, and
/// splices + forwards. The logits are discarded (the leader samples
/// from rank 0); the row-parallel `AllReduce`s are the point.
#[cfg(feature = "cuda")]
fn handle_generate_step_with_images(
&mut self,
model_id: &str,
tokens: Vec<u32>,
offset: usize,
image_token_id: u32,
image_data_uris: Vec<String>,
) -> WorkerResponse {
use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri};
use candle_core::Tensor;
if image_data_uris.is_empty() {
return WorkerResponse::Error {
kind: "bad_request".into(),
message: "GenerateStepWithImages with zero images".into(),
};
}
let Some(model) = self.models.get_mut(model_id) else {
return WorkerResponse::Error {
kind: "model_not_loaded".into(),
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
};
};
let device = model.device().clone();
// Preprocess each image identically to the leader so the encoded
// embeddings — and thus the spliced hidden state — match across
// ranks. Fixed 448×448 profile.
let profile = PreprocessProfile::qwen3_6();
let (h, w) = (
profile.target_height as usize,
profile.target_width as usize,
);
let mut pixels: Vec<Tensor> = Vec::with_capacity(image_data_uris.len());
for (idx, uri) in image_data_uris.iter().enumerate() {
let px = match preprocess_data_uri(uri, &profile) {
Ok(p) => p,
Err(e) => {
return WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("preprocess image[{idx}]: {e:#}"),
};
}
};
match Tensor::from_vec(px, (3, h, w), &device) {
Ok(t) => pixels.push(t),
Err(e) => {
return WorkerResponse::Error {
kind: "forward_failed".into(),
message: format!("build image[{idx}] tensor: {e}"),
};
}
}
}
let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => {
return WorkerResponse::Error {
kind: "forward_failed".into(),
message: format!("build input tensor: {e}"),
};
}
};
let start = std::time::Instant::now();
tracing::debug!(
rank = self.config.rank,
model = %model_id,
tokens = tokens.len(),
offset,
images = pixels.len(),
"worker GenerateStepWithImages: forward starting"
);
// Drop the logits — the leader samples from its own rank-0 copy.
if let Err(e) = model.forward_with_images(&input, offset, &pixels, image_token_id) {
tracing::warn!(
rank = self.config.rank,
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
error = %e,
"worker GenerateStepWithImages: forward failed"
);
return WorkerResponse::Error {
kind: "forward_failed".into(),
message: format!("TP image forward: {e}"),
};
}
tracing::debug!(
rank = self.config.rank,
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
"worker GenerateStepWithImages: forward done"
);
WorkerResponse::GenerateStepOk
}
#[cfg(not(feature = "cuda"))]
fn handle_generate_step_with_images(
&mut self,
_model_id: &str,
_tokens: Vec<u32>,
_offset: usize,
_image_token_id: u32,
_image_data_uris: Vec<String>,
) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "GenerateStepWithImages requires --features cuda".into(),
}
}
#[cfg(feature = "cuda")]
fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse {
let Some(model) = self.models.get_mut(model_id) else {