feat(neuron): TP-vision Stage 2 — per-rank image RPC + worker plumbing
Carry image content through the TP forward path so every rank encodes and splices locally (replicated tower, no embedding broadcast). - rpc.rs: new WorkerRequest::GenerateStepWithImages carrying the source image data URIs + image_token_id for the single-shot vision prefill; worker still replies GenerateStepOk. Round-trip test added. - tp_qwen3_5.rs: TpQwen3_5ForCausalLM::forward_with_images — encode each preprocessed image through the rank's replicated tower, cat, splice, forward. Shared by leader and worker so every rank runs identical work. - tp/mod.rs: TpLeaderModel::forward_with_images and WorkerPool::generate_step_with_images (mirrors generate_step: fan out GenerateStepWithImages to subprocess ranks, run the leader's image forward on its device worker thread, drain, combine). - worker.rs: WorkerModel::forward_with_images + handle_generate_step_with_images — each subprocess rank preprocesses the same data URIs via the shared deterministic preprocess_data_uri, encodes, splices, forwards. - device_worker: Job::TpForwardLogitsWithImages + tp_forward_logits_with_images dispatch handler + DeviceWorkerHandle::tp_forward_logits_with_images. Determinism: every rank runs the same preprocess on the same source URIs through the same replicated tower, so the spliced hidden state matches across ranks — preserving the replicated-hidden-state invariant the row-parallel AllReduce relies on, with no NCCL broadcast. No caller yet — Stage 3 wires the TP chat/stream entry points to invoke generate_step_with_images for image prefill. cuda-gated plumbing covered by CI's CUDA type-check; rpc/route/forward_with_images compile on the non-cuda build. Refs TP-vision plan Stage 2. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -262,6 +262,25 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
|||||||
let result = tp_forward_logits(&mut state, handle, &tokens, offset);
|
let result = tp_forward_logits(&mut state, handle, &tokens, offset);
|
||||||
let _ = reply.send(result);
|
let _ = reply.send(result);
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TpForwardLogitsWithImages {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let result = tp_forward_logits_with_images(
|
||||||
|
&mut state,
|
||||||
|
handle,
|
||||||
|
&tokens,
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
&image_data_uris,
|
||||||
|
);
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
// Handled by the matches!() check above; reaching here
|
// Handled by the matches!() check above; reaching here
|
||||||
// means a Shutdown slipped past which is a bug.
|
// means a Shutdown slipped past which is a bug.
|
||||||
Job::Shutdown => unreachable!("Shutdown should break above"),
|
Job::Shutdown => unreachable!("Shutdown should break above"),
|
||||||
@@ -734,6 +753,61 @@ fn tp_forward_logits(
|
|||||||
Ok(values)
|
Ok(values)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Image-bearing leader forward (rank 0). Preprocesses each source
|
||||||
|
/// `image_data_uris` entry through the same deterministic
|
||||||
|
/// `preprocess_data_uri` every rank runs, uploads to the leader's
|
||||||
|
/// device, encodes + splices + forwards via
|
||||||
|
/// `TpLeaderModel::forward_with_images`, and copies the `[vocab]`
|
||||||
|
/// logits to CPU. Mirrors the single-GPU `forward_logits_with_images`
|
||||||
|
/// but on the TP leader's replicated tower.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn tp_forward_logits_with_images(
|
||||||
|
state: &mut DeviceWorkerState,
|
||||||
|
handle: TpHandle,
|
||||||
|
tokens: &[u32],
|
||||||
|
offset: usize,
|
||||||
|
image_token_id: u32,
|
||||||
|
image_data_uris: &[String],
|
||||||
|
) -> anyhow::Result<Vec<f32>> {
|
||||||
|
use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri};
|
||||||
|
use candle_core::{DType, Tensor};
|
||||||
|
|
||||||
|
if image_data_uris.is_empty() {
|
||||||
|
anyhow::bail!("TpForwardLogitsWithImages dispatched with zero images");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preprocess every image into a device-resident (C, H, W) tensor.
|
||||||
|
// Same fixed-resolution profile + decode path the subprocess workers
|
||||||
|
// run, so the encoded embeddings match across ranks bit-for-bit.
|
||||||
|
let profile = PreprocessProfile::qwen3_6();
|
||||||
|
let (h, w) = (
|
||||||
|
profile.target_height as usize,
|
||||||
|
profile.target_width as usize,
|
||||||
|
);
|
||||||
|
let mut pixels: Vec<Tensor> = Vec::with_capacity(image_data_uris.len());
|
||||||
|
for (idx, uri) in image_data_uris.iter().enumerate() {
|
||||||
|
let px = preprocess_data_uri(uri, &profile)
|
||||||
|
.with_context(|| format!("preprocess image[{idx}] (TP leader)"))?;
|
||||||
|
let t = Tensor::from_vec(px, (3, h, w), &state.device)?;
|
||||||
|
pixels.push(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
|
let model = state.tp_models.get_mut(&handle).ok_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"TpForwardLogitsWithImages: no model for handle {}",
|
||||||
|
handle.0
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let logits = model.forward_with_images(&input, offset, &pixels, image_token_id)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?;
|
||||||
|
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
|
||||||
|
let values = logits.to_vec1::<f32>()?;
|
||||||
|
Ok(values)
|
||||||
|
}
|
||||||
|
|
||||||
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
|
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
|
||||||
/// for sampling on the async caller. The model's `device()` (CUDA or
|
/// for sampling on the async caller. The model's `device()` (CUDA or
|
||||||
/// CPU) determines where the kernel runs; this fn doesn't care.
|
/// CPU) determines where the kernel runs; this fn doesn't care.
|
||||||
|
|||||||
@@ -231,6 +231,23 @@ pub enum Job {
|
|||||||
offset: usize,
|
offset: usize,
|
||||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
},
|
},
|
||||||
|
/// Image-bearing leader (rank 0) forward for the single-shot vision
|
||||||
|
/// prefill. The handler preprocesses each `image_data_uris` entry
|
||||||
|
/// (the same deterministic path every rank runs), encodes through
|
||||||
|
/// the leader's replicated tower, splices at `image_token_id`, and
|
||||||
|
/// returns CPU-side `[vocab]` logits. Image tensors never escape the
|
||||||
|
/// worker thread. Caller fans out `GenerateStepWithImages` to the
|
||||||
|
/// subprocess ranks and drains them; only the leader forward moves
|
||||||
|
/// here.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
TpForwardLogitsWithImages {
|
||||||
|
handle: TpHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
image_token_id: u32,
|
||||||
|
image_data_uris: Vec<String>,
|
||||||
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
|
},
|
||||||
/// Tell the worker to break its dispatch loop and exit. Any jobs
|
/// Tell the worker to break its dispatch loop and exit. Any jobs
|
||||||
/// queued after this in the channel reply `Err` to their oneshot
|
/// queued after this in the channel reply `Err` to their oneshot
|
||||||
/// senders (the senders are dropped on the worker's exit, which
|
/// senders (the senders are dropped on the worker's exit, which
|
||||||
|
|||||||
@@ -572,6 +572,47 @@ impl DeviceWorkerHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Image-bearing TP leader forward (single-shot vision prefill).
|
||||||
|
/// Routes `Job::TpForwardLogitsWithImages` onto the worker thread;
|
||||||
|
/// the handler preprocesses + encodes + splices + forwards and
|
||||||
|
/// returns CPU-side `[vocab]` logits. The `WorkerPool` fans the
|
||||||
|
/// matching `GenerateStepWithImages` out to subprocess ranks so the
|
||||||
|
/// row-parallel collectives complete.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn tp_forward_logits_with_images(
|
||||||
|
&self,
|
||||||
|
handle: TpHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
image_token_id: u32,
|
||||||
|
image_data_uris: Vec<String>,
|
||||||
|
) -> Result<Vec<f32>, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::TpForwardLogitsWithImages {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
|
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
|
||||||
/// twice is a no-op the second time.
|
/// twice is a no-op the second time.
|
||||||
pub fn shutdown(&self) -> anyhow::Result<()> {
|
pub fn shutdown(&self) -> anyhow::Result<()> {
|
||||||
|
|||||||
@@ -62,6 +62,25 @@ impl TpLeaderModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Image-bearing forward on rank 0. Only the vision-capable
|
||||||
|
/// `qwen3_5` arch supports it; the dense `qwen3` arch has no tower.
|
||||||
|
pub fn forward_with_images(
|
||||||
|
&mut self,
|
||||||
|
input: &candle_core::Tensor,
|
||||||
|
offset: usize,
|
||||||
|
image_pixels: &[candle_core::Tensor],
|
||||||
|
image_token_id: u32,
|
||||||
|
) -> candle_core::Result<candle_core::Tensor> {
|
||||||
|
match self {
|
||||||
|
TpLeaderModel::Qwen3_5(m) => {
|
||||||
|
m.forward_with_images(input, offset, image_pixels, image_token_id)
|
||||||
|
}
|
||||||
|
TpLeaderModel::Qwen3(_) => {
|
||||||
|
candle_core::bail!("forward_with_images: qwen3 (dense) has no vision tower")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn clear_kv_cache(&mut self) {
|
pub fn clear_kv_cache(&mut self) {
|
||||||
match self {
|
match self {
|
||||||
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
|
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
|
||||||
@@ -687,6 +706,129 @@ impl WorkerPool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Image-bearing variant of [`Self::generate_step`] for the
|
||||||
|
/// single-shot vision prefill. Identical fan-out / leader-forward /
|
||||||
|
/// drain shape, but every rank runs the encode + splice path:
|
||||||
|
///
|
||||||
|
/// - subprocess workers get `GenerateStepWithImages` (carrying the
|
||||||
|
/// source `image_data_uris`); each preprocesses + encodes through
|
||||||
|
/// its replicated tower and splices locally;
|
||||||
|
/// - the leader runs the same encode + splice + forward on its
|
||||||
|
/// device worker thread via `tp_forward_logits_with_images`.
|
||||||
|
///
|
||||||
|
/// The row-parallel `AllReduce`s synchronise the ranks exactly as in
|
||||||
|
/// the text path. Because the tower is replicated and the preprocess
|
||||||
|
/// is deterministic, every rank's spliced hidden state matches — no
|
||||||
|
/// embedding broadcast. Only used for prefill; decode reuses
|
||||||
|
/// `generate_step`.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn generate_step_with_images(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
leader_handle: super::device_worker::TpHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
image_token_id: u32,
|
||||||
|
image_data_uris: Vec<String>,
|
||||||
|
) -> Result<Vec<f32>> {
|
||||||
|
let step_start = std::time::Instant::now();
|
||||||
|
let tokens_len = tokens.len();
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens_len,
|
||||||
|
offset,
|
||||||
|
images = image_data_uris.len(),
|
||||||
|
"WorkerPool::generate_step_with_images: fan-out"
|
||||||
|
);
|
||||||
|
|
||||||
|
// 1. Fan-out the image-bearing prefill to subprocess workers.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::GenerateStepWithImages {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
tokens: tokens.clone(),
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris: image_data_uris.clone(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Leader's image forward on its device worker thread. The
|
||||||
|
// AllReduce CustomOps block until every worker issues the
|
||||||
|
// matching collective; CPU-side logits keep the device tensor
|
||||||
|
// from escaping the worker thread.
|
||||||
|
let leader_start = std::time::Instant::now();
|
||||||
|
let leader_result = self
|
||||||
|
.leader_worker
|
||||||
|
.tp_forward_logits_with_images(
|
||||||
|
leader_handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let leader_ok = leader_result.is_ok();
|
||||||
|
let leader_ms = leader_start.elapsed().as_millis();
|
||||||
|
if !leader_ok {
|
||||||
|
let detail = leader_result
|
||||||
|
.as_ref()
|
||||||
|
.err()
|
||||||
|
.map(|e| format!("{e:#}"))
|
||||||
|
.unwrap_or_default();
|
||||||
|
tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens_len,
|
||||||
|
offset,
|
||||||
|
leader_ms,
|
||||||
|
error = %detail,
|
||||||
|
"WorkerPool::generate_step_with_images: leader forward failed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. ALWAYS drain worker responses, regardless of the leader's
|
||||||
|
// outcome, so stale GenerateStepOk replies don't poison the
|
||||||
|
// next request's recv (same invariant as generate_step).
|
||||||
|
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
||||||
|
WorkerResponse::GenerateStepOk => Ok(()),
|
||||||
|
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
||||||
|
other => Err(format!("expected GenerateStepOk, got {other:?}")),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
leader_ms,
|
||||||
|
leader_ok,
|
||||||
|
errors = worker_errors.len(),
|
||||||
|
total_ms = step_start.elapsed().as_millis(),
|
||||||
|
"WorkerPool::generate_step_with_images: workers drained"
|
||||||
|
);
|
||||||
|
|
||||||
|
match leader_result {
|
||||||
|
Ok(values) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Ok(values)
|
||||||
|
} else {
|
||||||
|
anyhow::bail!(
|
||||||
|
"GenerateStepWithImages: leader succeeded but workers failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Err(anyhow::Error::new(e)
|
||||||
|
.context("GenerateStepWithImages: leader forward failed"))
|
||||||
|
} else {
|
||||||
|
Err(anyhow::Error::new(e).context(format!(
|
||||||
|
"GenerateStepWithImages: leader forward failed and workers also failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Reset the KV cache for `model_id` on every rank. Called at the
|
/// Reset the KV cache for `model_id` on every rank. Called at the
|
||||||
/// start of every inference so a fresh request doesn't attend over
|
/// start of every inference so a fresh request doesn't attend over
|
||||||
/// the previous one's tokens.
|
/// the previous one's tokens.
|
||||||
|
|||||||
@@ -88,6 +88,29 @@ pub enum WorkerRequest {
|
|||||||
offset: usize,
|
offset: usize,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/// Like `GenerateStep` but the prefill carries image content. Every
|
||||||
|
/// rank preprocesses the same `image_data_uris` through its
|
||||||
|
/// *replicated* vision tower, splices the resulting patch embeddings
|
||||||
|
/// at `image_token_id` positions, and runs the forward — the
|
||||||
|
/// row-parallel `AllReduce`s still synchronise every rank. Because
|
||||||
|
/// the tower is replicated and `preprocess_data_uri` is
|
||||||
|
/// deterministic, the spliced hidden state is identical on every
|
||||||
|
/// rank, so no embedding broadcast is needed. Sent only for the
|
||||||
|
/// (single-shot) image-bearing prefill; decode steps use plain
|
||||||
|
/// `GenerateStep`. Worker replies with the same `GenerateStepOk`.
|
||||||
|
GenerateStepWithImages {
|
||||||
|
model_id: String,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
/// `<|image_pad|>` sentinel id (248056 for Qwen3.6); splice
|
||||||
|
/// target in the expanded token stream.
|
||||||
|
image_token_id: u32,
|
||||||
|
/// Source image data URIs (`data:image/...;base64,...`), one per
|
||||||
|
/// image in prompt order. Each rank decodes + preprocesses these
|
||||||
|
/// identically; tens of KB each, so cheap over the stdin pipe.
|
||||||
|
image_data_uris: Vec<String>,
|
||||||
|
},
|
||||||
|
|
||||||
/// Reset the KV cache for this model on this rank. Sent at the
|
/// Reset the KV cache for this model on this rank. Sent at the
|
||||||
/// start of every inference so a fresh request doesn't accidentally
|
/// start of every inference so a fresh request doesn't accidentally
|
||||||
/// attend over the previous one's tokens.
|
/// attend over the previous one's tokens.
|
||||||
@@ -191,6 +214,32 @@ mod tests {
|
|||||||
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
|
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_generate_step_with_images_round_trip() {
|
||||||
|
let req = WorkerRequest::GenerateStepWithImages {
|
||||||
|
model_id: "Qwen/Qwen3.6-27B".into(),
|
||||||
|
tokens: vec![1, 2, 248056, 3],
|
||||||
|
offset: 0,
|
||||||
|
image_token_id: 248056,
|
||||||
|
image_data_uris: vec!["data:image/png;base64,AAA=".into()],
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&req).unwrap();
|
||||||
|
assert!(wire.contains(r#""op":"generate_step_with_images""#));
|
||||||
|
match roundtrip(&req) {
|
||||||
|
WorkerRequest::GenerateStepWithImages {
|
||||||
|
tokens,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
assert_eq!(tokens, vec![1, 2, 248056, 3]);
|
||||||
|
assert_eq!(image_token_id, 248056);
|
||||||
|
assert_eq!(image_data_uris.len(), 1);
|
||||||
|
}
|
||||||
|
other => panic!("expected GenerateStepWithImages, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn request_shutdown_round_trip() {
|
fn request_shutdown_round_trip() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
|||||||
@@ -1192,6 +1192,38 @@ impl TpQwen3_5ForCausalLM {
|
|||||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// End-to-end image prefill on one rank: encode each preprocessed
|
||||||
|
/// `(C, H, W)` pixel tensor through this rank's replicated tower,
|
||||||
|
/// concatenate the per-image embeddings along the patch axis, and
|
||||||
|
/// forward with the splice. Shared by the leader (`TpLeaderModel`)
|
||||||
|
/// and the subprocess worker (`WorkerModel`) so every rank runs the
|
||||||
|
/// identical encode → splice → forward and keeps the replicated
|
||||||
|
/// hidden state in lockstep. Returns last-position logits
|
||||||
|
/// `(B, 1, vocab)`, same contract as `forward`.
|
||||||
|
pub fn forward_with_images(
|
||||||
|
&mut self,
|
||||||
|
input: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
image_pixels: &[Tensor],
|
||||||
|
image_token_id: u32,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
if image_pixels.is_empty() {
|
||||||
|
candle_core::bail!("forward_with_images: called with zero images");
|
||||||
|
}
|
||||||
|
// Encode each image (immutable borrows of the tower) before the
|
||||||
|
// mutable forward below; the borrows end as each owned embedding
|
||||||
|
// is pushed.
|
||||||
|
let mut per_image = Vec::with_capacity(image_pixels.len());
|
||||||
|
for (idx, img) in image_pixels.iter().enumerate() {
|
||||||
|
let embed = self
|
||||||
|
.encode_image(img)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("encode image[{idx}]: {e:#}")))?;
|
||||||
|
per_image.push(embed);
|
||||||
|
}
|
||||||
|
let image_embeds = Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)?;
|
||||||
|
self.forward_with_vision(input, offset, &image_embeds, image_token_id)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn clear_kv_cache(&mut self) {
|
pub fn clear_kv_cache(&mut self) {
|
||||||
self.base.clear_kv_cache();
|
self.base.clear_kv_cache();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,6 +47,28 @@ impl WorkerModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Image-bearing forward on this rank. Only the vision-capable
|
||||||
|
/// `qwen3_5` arch has a replicated tower; the dense `qwen3` arch
|
||||||
|
/// errors. The returned logits are discarded by the caller (the
|
||||||
|
/// leader samples from its own rank-0 copy) — the value is the NCCL
|
||||||
|
/// collectives the forward issues.
|
||||||
|
fn forward_with_images(
|
||||||
|
&mut self,
|
||||||
|
input: &candle_core::Tensor,
|
||||||
|
offset: usize,
|
||||||
|
image_pixels: &[candle_core::Tensor],
|
||||||
|
image_token_id: u32,
|
||||||
|
) -> candle_core::Result<candle_core::Tensor> {
|
||||||
|
match self {
|
||||||
|
WorkerModel::Qwen3_5(m) => {
|
||||||
|
m.forward_with_images(input, offset, image_pixels, image_token_id)
|
||||||
|
}
|
||||||
|
WorkerModel::Qwen3(_) => {
|
||||||
|
candle_core::bail!("forward_with_images: qwen3 (dense) has no vision tower")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn clear_kv_cache(&mut self) {
|
fn clear_kv_cache(&mut self) {
|
||||||
match self {
|
match self {
|
||||||
WorkerModel::Qwen3(m) => m.clear_kv_cache(),
|
WorkerModel::Qwen3(m) => m.clear_kv_cache(),
|
||||||
@@ -167,6 +189,19 @@ impl WorkerState {
|
|||||||
tokens,
|
tokens,
|
||||||
offset,
|
offset,
|
||||||
} => self.handle_generate_step(&model_id, tokens, offset),
|
} => self.handle_generate_step(&model_id, tokens, offset),
|
||||||
|
WorkerRequest::GenerateStepWithImages {
|
||||||
|
model_id,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris,
|
||||||
|
} => self.handle_generate_step_with_images(
|
||||||
|
&model_id,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris,
|
||||||
|
),
|
||||||
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
|
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
|
||||||
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
|
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
|
||||||
WorkerRequest::Shutdown => WorkerResponse::Bye,
|
WorkerRequest::Shutdown => WorkerResponse::Bye,
|
||||||
@@ -418,6 +453,124 @@ impl WorkerState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Image-bearing prefill on this rank. Preprocesses each source data
|
||||||
|
/// URI through the same deterministic `preprocess_data_uri` the
|
||||||
|
/// leader runs, encodes through this rank's replicated tower, and
|
||||||
|
/// splices + forwards. The logits are discarded (the leader samples
|
||||||
|
/// from rank 0); the row-parallel `AllReduce`s are the point.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_generate_step_with_images(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
image_token_id: u32,
|
||||||
|
image_data_uris: Vec<String>,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri};
|
||||||
|
use candle_core::Tensor;
|
||||||
|
|
||||||
|
if image_data_uris.is_empty() {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: "GenerateStepWithImages with zero images".into(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
let Some(model) = self.models.get_mut(model_id) else {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "model_not_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
let device = model.device().clone();
|
||||||
|
|
||||||
|
// Preprocess each image identically to the leader so the encoded
|
||||||
|
// embeddings — and thus the spliced hidden state — match across
|
||||||
|
// ranks. Fixed 448×448 profile.
|
||||||
|
let profile = PreprocessProfile::qwen3_6();
|
||||||
|
let (h, w) = (
|
||||||
|
profile.target_height as usize,
|
||||||
|
profile.target_width as usize,
|
||||||
|
);
|
||||||
|
let mut pixels: Vec<Tensor> = Vec::with_capacity(image_data_uris.len());
|
||||||
|
for (idx, uri) in image_data_uris.iter().enumerate() {
|
||||||
|
let px = match preprocess_data_uri(uri, &profile) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("preprocess image[{idx}]: {e:#}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match Tensor::from_vec(px, (3, h, w), &device) {
|
||||||
|
Ok(t) => pixels.push(t),
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "forward_failed".into(),
|
||||||
|
message: format!("build image[{idx}] tensor: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "forward_failed".into(),
|
||||||
|
message: format!("build input tensor: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
tracing::debug!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens.len(),
|
||||||
|
offset,
|
||||||
|
images = pixels.len(),
|
||||||
|
"worker GenerateStepWithImages: forward starting"
|
||||||
|
);
|
||||||
|
// Drop the logits — the leader samples from its own rank-0 copy.
|
||||||
|
if let Err(e) = model.forward_with_images(&input, offset, &pixels, image_token_id) {
|
||||||
|
tracing::warn!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
error = %e,
|
||||||
|
"worker GenerateStepWithImages: forward failed"
|
||||||
|
);
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "forward_failed".into(),
|
||||||
|
message: format!("TP image forward: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
tracing::debug!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
"worker GenerateStepWithImages: forward done"
|
||||||
|
);
|
||||||
|
WorkerResponse::GenerateStepOk
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_generate_step_with_images(
|
||||||
|
&mut self,
|
||||||
|
_model_id: &str,
|
||||||
|
_tokens: Vec<u32>,
|
||||||
|
_offset: usize,
|
||||||
|
_image_token_id: u32,
|
||||||
|
_image_data_uris: Vec<String>,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "GenerateStepWithImages requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse {
|
fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse {
|
||||||
let Some(model) = self.models.get_mut(model_id) else {
|
let Some(model) = self.models.get_mut(model_id) else {
|
||||||
|
|||||||
Reference in New Issue
Block a user