feat(neuron): TP-vision Stage 3 — wire TP chat + stream vision prefill
Some checks failed
CI / Format (push) Successful in 30s
CI / Clippy (push) Successful in 2m51s
CI / Test (push) Successful in 5m52s
CI / CUDA type-check (push) Failing after 50s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped

End-to-end TP-vision: an image request to a TP-loaded Qwen3.6-27B now
conditions on the image across both ranks.

- TpLoadedModel carries has_vision / image_token_id / lm_tokens_per_image,
  populated at load via the shared VisionMeta::from_config_path (same
  config.json the shards loaded from; Stage 1 materialises the replicated
  tower on every rank).
- LoadedHandle::capabilities() now advertises "vision" for TP loads with
  a tower (cortex-gateway already unions this into /v1/models via C3).
- The TP rejection guards (chat_completion_tp + inference_tp_stream) are
  now conditional on !has_vision — text-only TP models still 400 cleanly,
  vision-capable ones fall through.
- chat_completion_tp_inner and the streaming orchestration task detect
  images (request_has_images), expand <|image_pad|> to the per-image
  patch count, and run a single-shot generate_step_with_images prefill
  (every rank encodes + splices its replicated tower) before the
  unchanged decode loop. Text requests keep chunked_prefill_tp.
- extract_image_data_uris ships the source data URIs to every rank for
  identical per-rank preprocessing.

prompt_tokens now reflects the patch expansion, so usage accounting and
KV offsets match the single-GPU baseline.

TP entry points are cuda-gated (validated by CI's CUDA type-check);
capabilities() + extract_image_data_uris + VisionMeta reuse compile on
the non-cuda build. Full workspace test green.

Refs TP-vision plan Stage 3. Implements #12.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-04 15:14:44 +03:00
parent 4994b94c84
commit ed2d09864e

View File

@@ -106,18 +106,18 @@ impl LoadedHandle {
} }
} }
/// Modalities the loaded model supports. Stage B7. TP models are /// Modalities the loaded model supports. Stage B7 (single-GPU) +
/// always text-only today — TP-vision is tracked under issue #12. /// TP-vision (#12) — both single-GPU and TP loads advertise
/// `"vision"` when a replicated vision tower materialised.
pub fn capabilities(&self) -> Vec<String> { pub fn capabilities(&self) -> Vec<String> {
let mut caps = vec!["text".to_string()]; let mut caps = vec!["text".to_string()];
match self { let has_vision = match self {
LoadedHandle::Single(m) => { LoadedHandle::Single(m) => m.has_vision,
if m.has_vision {
caps.push("vision".to_string());
}
}
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
LoadedHandle::Tp(_) => {} LoadedHandle::Tp(m) => m.has_vision,
};
if has_vision {
caps.push("vision".to_string());
} }
caps caps
} }
@@ -281,6 +281,16 @@ pub struct TpLoadedModel {
pub tool_call_tokens: Option<ToolCallTokenPair>, pub tool_call_tokens: Option<ToolCallTokenPair>,
/// Same shape as [`LoadedModel::chat_template`]. /// Same shape as [`LoadedModel::chat_template`].
pub chat_template: Option<String>, pub chat_template: Option<String>,
/// Vision capability flag (TP-vision). `true` iff every rank
/// materialised a replicated vision tower. Mirrors
/// [`LoadedModel::has_vision`]; drives capability advertising and
/// the TP vision dispatch.
pub has_vision: bool,
/// `<|image_pad|>` token id — same as [`LoadedModel::image_token_id`].
pub image_token_id: Option<u32>,
/// LM-side tokens per image at the fixed 448×448 resolution — same
/// as [`LoadedModel::lm_tokens_per_image`].
pub lm_tokens_per_image: Option<usize>,
} }
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
@@ -2675,6 +2685,20 @@ impl CandleHarness {
); );
} }
// Vision metadata from the same config.json the shards loaded
// from. The TP model builder (Stage 1) materialises a replicated
// vision tower on every rank when `vision_config` is present, so
// `has_vision` here is consistent with what each rank loaded.
let vision_meta = VisionMeta::from_config_path(&config_path);
if vision_meta.has_vision {
tracing::info!(
model = %spec.model_id,
image_token_id = ?vision_meta.image_token_id,
lm_tokens_per_image = ?vision_meta.lm_tokens_per_image,
"TP load: vision tower present, advertising vision capability"
);
}
let tp_loaded = StdArc::new(TpLoadedModel { let tp_loaded = StdArc::new(TpLoadedModel {
model_id: spec.model_id.clone(), model_id: spec.model_id.clone(),
tokenizer, tokenizer,
@@ -2690,6 +2714,9 @@ impl CandleHarness {
reasoning_tokens, reasoning_tokens,
tool_call_tokens, tool_call_tokens,
chat_template, chat_template,
has_vision: vision_meta.has_vision,
image_token_id: vision_meta.image_token_id,
lm_tokens_per_image: vision_meta.lm_tokens_per_image,
}); });
let mut models = self.models.write().await; let mut models = self.models.write().await;
@@ -2739,15 +2766,15 @@ impl CandleHarness {
return Err(poisoned_error(&model_id)); return Err(poisoned_error(&model_id));
} }
// Stage 0 (TP-vision): the TP path has no vision tower yet, so // Reject image-bearing requests against a TP model with no
// an image-bearing request can't be honoured. Reject it cleanly // vision tower, cleanly (`vision_unsupported`) rather than
// with `vision_unsupported` instead of silently dropping the // silently dropping the image. Vision-capable TP loads fall
// image and answering from text alone (the issue-#3 confident- // through to the image-aware prefill in chat_completion_tp_inner.
// hallucination pattern). Made conditional on the TP model's if request_has_images(&request) && !tp.has_vision {
// `has_vision` once Stage 3 wires real TP-vision.
if request_has_images(&request) {
let _g = span.enter(); let _g = span.enter();
tracing::warn!("TP chat_completion: rejecting image request, TP vision unsupported"); tracing::warn!(
"TP chat_completion: rejecting image request, model has no vision tower"
);
return Err(InferenceError::VisionUnsupported { model_id }); return Err(InferenceError::VisionUnsupported { model_id });
} }
@@ -2828,14 +2855,12 @@ impl CandleHarness {
return Err(poisoned_error(&request.model)); return Err(poisoned_error(&request.model));
} }
// Stage 0 (TP-vision): reject image requests on the TP streaming // Reject image requests against a non-vision TP model before
// path before opening the SSE stream — the TP path has no vision // opening the SSE stream. Vision-capable TP loads fall through
// tower yet, so honouring the image is impossible and silently // to the image-aware prefill in the orchestration task below.
// dropping it would hallucinate. Returns a clean 400; made if request_has_images(&request) && !tp.has_vision {
// conditional on `has_vision` in Stage 3.
if request_has_images(&request) {
tracing::warn!( tracing::warn!(
"TP chat_completion (stream): rejecting image request, TP vision unsupported" "TP chat_completion (stream): rejecting image request, model has no vision tower"
); );
return Err(InferenceError::VisionUnsupported { return Err(InferenceError::VisionUnsupported {
model_id: request.model.clone(), model_id: request.model.clone(),
@@ -2847,7 +2872,44 @@ impl CandleHarness {
.tokenizer .tokenizer
.encode(prompt.as_str(), true) .encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?; .map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec(); let mut prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
// TP-vision (streaming): same detection + pad expansion as the
// non-streaming path. The resulting `vision_route` moves into
// the orchestration task, which runs a single-shot image prefill
// when present. Returning early here keeps a rejected request
// from opening the SSE stream.
let vision_route: Option<(Vec<String>, u32)> = if request_has_images(&request) {
if !tp.has_vision {
return Err(InferenceError::VisionUnsupported {
model_id: request.model.clone(),
});
}
let image_token_id =
tp.image_token_id
.ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(),
})?;
let patches_per_image =
tp.lm_tokens_per_image
.ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(),
})?;
let data_uris = extract_image_data_uris(&request);
if data_uris.is_empty() {
return Err(InferenceError::Other(anyhow::anyhow!(
"request has image content but extractor produced zero data URIs"
)));
}
let per_image_counts: Vec<usize> = vec![patches_per_image; data_uris.len()];
prompt_tokens =
expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
.map_err(InferenceError::Other)?;
Some((data_uris, image_token_id))
} else {
None
};
let prompt_len = prompt_tokens.len(); let prompt_len = prompt_tokens.len();
let temperature = request.temperature.unwrap_or(0.7); let temperature = request.temperature.unwrap_or(0.7);
@@ -2961,14 +3023,27 @@ impl CandleHarness {
// chunk fans out to every rank with a growing // chunk fans out to every rank with a growing
// offset; only the final chunk's logits are kept // offset; only the final chunk's logits are kept
// for the first sample. // for the first sample.
let logits_vec = match chunked_prefill_tp( // Vision requests do a single-shot image prefill;
&mut pool, // text requests chunk it. `vision_route` was moved
&model_id, // into this task from the synchronous setup above.
leader_handle, let prefill_result = match &vision_route {
&prompt_tokens, Some((data_uris, image_token_id)) => {
) pool.generate_step_with_images(
.await &model_id,
{ leader_handle,
prompt_tokens.clone(),
0,
*image_token_id,
data_uris.clone(),
)
.await
}
None => {
chunked_prefill_tp(&mut pool, &model_id, leader_handle, &prompt_tokens)
.await
}
};
let logits_vec = match prefill_result {
Ok(l) => l, Ok(l) => l,
Err(e) => { Err(e) => {
failure = Some(format!("prefill: {e:#}")); failure = Some(format!("prefill: {e:#}"));
@@ -3311,7 +3386,43 @@ async fn chat_completion_tp_inner(
.tokenizer .tokenizer
.encode(prompt.as_str(), true) .encode(prompt.as_str(), true)
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?; .map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec(); let mut prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
// TP-vision: when the request carries images (and the model has a
// replicated tower — enforced by the caller's guard), expand each
// `<|image_pad|>` sentinel to the per-image patch count and carry
// the source data URIs through to the single-shot image prefill.
// Mirrors the single-GPU `chat_completion` vision_route block.
let vision_route: Option<(Vec<String>, u32)> = if request_has_images(&request) {
if !tp.has_vision {
return Err(InferenceError::VisionUnsupported {
model_id: request.model.clone(),
});
}
let image_token_id =
tp.image_token_id
.ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(),
})?;
let patches_per_image =
tp.lm_tokens_per_image
.ok_or_else(|| InferenceError::VisionUnsupported {
model_id: request.model.clone(),
})?;
let data_uris = extract_image_data_uris(&request);
if data_uris.is_empty() {
return Err(InferenceError::Other(anyhow::anyhow!(
"request has image content but extractor produced zero data URIs"
)));
}
let per_image_counts: Vec<usize> = vec![patches_per_image; data_uris.len()];
prompt_tokens = expand_image_pad_tokens(&prompt_tokens, image_token_id, &per_image_counts)
.map_err(InferenceError::Other)?;
Some((data_uris, image_token_id))
} else {
None
};
let prompt_len = prompt_tokens.len(); let prompt_len = prompt_tokens.len();
let temperature = request.temperature.unwrap_or(0.7); let temperature = request.temperature.unwrap_or(0.7);
@@ -3381,9 +3492,24 @@ async fn chat_completion_tp_inner(
// spread across multiple `generate_step` calls with monotonically // spread across multiple `generate_step` calls with monotonically
// growing offsets. // growing offsets.
let prefill_start = std::time::Instant::now(); let prefill_start = std::time::Instant::now();
let logits_vec = chunked_prefill_tp(&mut pool, &model_id, leader_handle, &prompt_tokens) // Vision requests do a single-shot image prefill (every rank encodes
.await // + splices its replicated tower); text requests chunk the prefill.
.map_err(InferenceError::Other)?; let logits_vec = match &vision_route {
Some((data_uris, image_token_id)) => pool
.generate_step_with_images(
&model_id,
leader_handle,
prompt_tokens.clone(),
0,
*image_token_id,
data_uris.clone(),
)
.await
.map_err(InferenceError::Other)?,
None => chunked_prefill_tp(&mut pool, &model_id, leader_handle, &prompt_tokens)
.await
.map_err(InferenceError::Other)?,
};
let (post_prefill_vram_free_mb, _) = tp.query_vram().await; let (post_prefill_vram_free_mb, _) = tp.query_vram().await;
tracing::info!( tracing::info!(
model = %model_id, model = %model_id,
@@ -3841,6 +3967,37 @@ fn extract_images_from_request(
Ok(out) Ok(out)
} }
/// Collect the raw `image_url.url` strings (data URIs) from a chat
/// request, in prompt order. The TP vision path (Stage C / TP-vision)
/// ships these verbatim to every rank, which each preprocess + encode
/// identically — so unlike `extract_images_from_request` (which
/// preprocesses on the leader for the single-GPU worker job) this
/// keeps the source form for replicated per-rank encoding.
///
/// Cuda-gated: the only callers are the TP entry points, which compile
/// only under the `cuda` feature.
#[cfg(feature = "cuda")]
fn extract_image_data_uris(request: &ChatCompletionRequest) -> Vec<String> {
let mut out = Vec::new();
for msg in &request.messages {
if let MessageContent::Parts(parts) = &msg.content {
for part in parts {
if part.get("type").and_then(|v| v.as_str()) != Some("image_url") {
continue;
}
if let Some(url) = part
.get("image_url")
.and_then(|v| v.get("url"))
.and_then(|v| v.as_str())
{
out.push(url.to_string());
}
}
}
}
out
}
/// Expand each occurrence of `image_token_id` in `input_ids` into /// Expand each occurrence of `image_token_id` in `input_ids` into
/// `patches_per_image[i]` copies (one expansion per image, in order). /// `patches_per_image[i]` copies (one expansion per image, in order).
/// Stage B4 helper. /// Stage B4 helper.