feat(neuron): TP-vision Stage 2 — per-rank image RPC + worker plumbing

Carry image content through the TP forward path so every rank encodes
and splices locally (replicated tower, no embedding broadcast).

- rpc.rs: new WorkerRequest::GenerateStepWithImages carrying the source
  image data URIs + image_token_id for the single-shot vision prefill;
  worker still replies GenerateStepOk. Round-trip test added.
- tp_qwen3_5.rs: TpQwen3_5ForCausalLM::forward_with_images — encode each
  preprocessed image through the rank's replicated tower, cat, splice,
  forward. Shared by leader and worker so every rank runs identical work.
- tp/mod.rs: TpLeaderModel::forward_with_images and
  WorkerPool::generate_step_with_images (mirrors generate_step: fan out
  GenerateStepWithImages to subprocess ranks, run the leader's image
  forward on its device worker thread, drain, combine).
- worker.rs: WorkerModel::forward_with_images + handle_generate_step_with_images
  — each subprocess rank preprocesses the same data URIs via the shared
  deterministic preprocess_data_uri, encodes, splices, forwards.
- device_worker: Job::TpForwardLogitsWithImages + tp_forward_logits_with_images
  dispatch handler + DeviceWorkerHandle::tp_forward_logits_with_images.

Determinism: every rank runs the same preprocess on the same source
URIs through the same replicated tower, so the spliced hidden state
matches across ranks — preserving the replicated-hidden-state invariant
the row-parallel AllReduce relies on, with no NCCL broadcast.

No caller yet — Stage 3 wires the TP chat/stream entry points to invoke
generate_step_with_images for image prefill. cuda-gated plumbing covered
by CI's CUDA type-check; rpc/route/forward_with_images compile on the
non-cuda build.

Refs TP-vision plan Stage 2.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-04 15:08:08 +03:00
parent 9a24b05866
commit 4994b94c84
7 changed files with 508 additions and 0 deletions

View File

@@ -88,6 +88,29 @@ pub enum WorkerRequest {
offset: usize,
},
/// Like `GenerateStep` but the prefill carries image content. Every
/// rank preprocesses the same `image_data_uris` through its
/// *replicated* vision tower, splices the resulting patch embeddings
/// at `image_token_id` positions, and runs the forward — the
/// row-parallel `AllReduce`s still synchronise every rank. Because
/// the tower is replicated and `preprocess_data_uri` is
/// deterministic, the spliced hidden state is identical on every
/// rank, so no embedding broadcast is needed. Sent only for the
/// (single-shot) image-bearing prefill; decode steps use plain
/// `GenerateStep`. Worker replies with the same `GenerateStepOk`.
GenerateStepWithImages {
model_id: String,
tokens: Vec<u32>,
offset: usize,
/// `<|image_pad|>` sentinel id (248056 for Qwen3.6); splice
/// target in the expanded token stream.
image_token_id: u32,
/// Source image data URIs (`data:image/...;base64,...`), one per
/// image in prompt order. Each rank decodes + preprocesses these
/// identically; tens of KB each, so cheap over the stdin pipe.
image_data_uris: Vec<String>,
},
/// Reset the KV cache for this model on this rank. Sent at the
/// start of every inference so a fresh request doesn't accidentally
/// attend over the previous one's tokens.
@@ -191,6 +214,32 @@ mod tests {
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
}
#[test]
fn request_generate_step_with_images_round_trip() {
let req = WorkerRequest::GenerateStepWithImages {
model_id: "Qwen/Qwen3.6-27B".into(),
tokens: vec![1, 2, 248056, 3],
offset: 0,
image_token_id: 248056,
image_data_uris: vec!["data:image/png;base64,AAA=".into()],
};
let wire = serde_json::to_string(&req).unwrap();
assert!(wire.contains(r#""op":"generate_step_with_images""#));
match roundtrip(&req) {
WorkerRequest::GenerateStepWithImages {
tokens,
image_token_id,
image_data_uris,
..
} => {
assert_eq!(tokens, vec![1, 2, 248056, 3]);
assert_eq!(image_token_id, 248056);
assert_eq!(image_data_uris.len(), 1);
}
other => panic!("expected GenerateStepWithImages, got {other:?}"),
}
}
#[test]
fn request_shutdown_round_trip() {
assert_eq!(