feat(neuron): TP-vision Stage 2 — per-rank image RPC + worker plumbing
Carry image content through the TP forward path so every rank encodes and splices locally (replicated tower, no embedding broadcast). - rpc.rs: new WorkerRequest::GenerateStepWithImages carrying the source image data URIs + image_token_id for the single-shot vision prefill; worker still replies GenerateStepOk. Round-trip test added. - tp_qwen3_5.rs: TpQwen3_5ForCausalLM::forward_with_images — encode each preprocessed image through the rank's replicated tower, cat, splice, forward. Shared by leader and worker so every rank runs identical work. - tp/mod.rs: TpLeaderModel::forward_with_images and WorkerPool::generate_step_with_images (mirrors generate_step: fan out GenerateStepWithImages to subprocess ranks, run the leader's image forward on its device worker thread, drain, combine). - worker.rs: WorkerModel::forward_with_images + handle_generate_step_with_images — each subprocess rank preprocesses the same data URIs via the shared deterministic preprocess_data_uri, encodes, splices, forwards. - device_worker: Job::TpForwardLogitsWithImages + tp_forward_logits_with_images dispatch handler + DeviceWorkerHandle::tp_forward_logits_with_images. Determinism: every rank runs the same preprocess on the same source URIs through the same replicated tower, so the spliced hidden state matches across ranks — preserving the replicated-hidden-state invariant the row-parallel AllReduce relies on, with no NCCL broadcast. No caller yet — Stage 3 wires the TP chat/stream entry points to invoke generate_step_with_images for image prefill. cuda-gated plumbing covered by CI's CUDA type-check; rpc/route/forward_with_images compile on the non-cuda build. Refs TP-vision plan Stage 2. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -88,6 +88,29 @@ pub enum WorkerRequest {
|
||||
offset: usize,
|
||||
},
|
||||
|
||||
/// Like `GenerateStep` but the prefill carries image content. Every
|
||||
/// rank preprocesses the same `image_data_uris` through its
|
||||
/// *replicated* vision tower, splices the resulting patch embeddings
|
||||
/// at `image_token_id` positions, and runs the forward — the
|
||||
/// row-parallel `AllReduce`s still synchronise every rank. Because
|
||||
/// the tower is replicated and `preprocess_data_uri` is
|
||||
/// deterministic, the spliced hidden state is identical on every
|
||||
/// rank, so no embedding broadcast is needed. Sent only for the
|
||||
/// (single-shot) image-bearing prefill; decode steps use plain
|
||||
/// `GenerateStep`. Worker replies with the same `GenerateStepOk`.
|
||||
GenerateStepWithImages {
|
||||
model_id: String,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
/// `<|image_pad|>` sentinel id (248056 for Qwen3.6); splice
|
||||
/// target in the expanded token stream.
|
||||
image_token_id: u32,
|
||||
/// Source image data URIs (`data:image/...;base64,...`), one per
|
||||
/// image in prompt order. Each rank decodes + preprocesses these
|
||||
/// identically; tens of KB each, so cheap over the stdin pipe.
|
||||
image_data_uris: Vec<String>,
|
||||
},
|
||||
|
||||
/// Reset the KV cache for this model on this rank. Sent at the
|
||||
/// start of every inference so a fresh request doesn't accidentally
|
||||
/// attend over the previous one's tokens.
|
||||
@@ -191,6 +214,32 @@ mod tests {
|
||||
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_generate_step_with_images_round_trip() {
|
||||
let req = WorkerRequest::GenerateStepWithImages {
|
||||
model_id: "Qwen/Qwen3.6-27B".into(),
|
||||
tokens: vec![1, 2, 248056, 3],
|
||||
offset: 0,
|
||||
image_token_id: 248056,
|
||||
image_data_uris: vec!["data:image/png;base64,AAA=".into()],
|
||||
};
|
||||
let wire = serde_json::to_string(&req).unwrap();
|
||||
assert!(wire.contains(r#""op":"generate_step_with_images""#));
|
||||
match roundtrip(&req) {
|
||||
WorkerRequest::GenerateStepWithImages {
|
||||
tokens,
|
||||
image_token_id,
|
||||
image_data_uris,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(tokens, vec![1, 2, 248056, 3]);
|
||||
assert_eq!(image_token_id, 248056);
|
||||
assert_eq!(image_data_uris.len(), 1);
|
||||
}
|
||||
other => panic!("expected GenerateStepWithImages, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_shutdown_round_trip() {
|
||||
assert_eq!(
|
||||
|
||||
Reference in New Issue
Block a user