All checks were successful
CI / Format (push) Successful in 30s
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Clippy (push) Successful in 2m14s
build-prerelease / Build neuron-blackwell (push) Successful in 3m44s
build-prerelease / Build cortex binary (push) Successful in 4m13s
CI / Test (push) Successful in 4m38s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 5m13s
build-prerelease / Build neuron-ada (push) Successful in 4m47s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m41s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
Two changes addressing operator visibility into TP inference + the
HTTP-cancellation poisoning chain:
1. `chat_completion_tp` now runs its body inside `tokio::spawn`. When
the HTTP client disconnects (curl --max-time, browser nav, etc.)
the future returned from `chat_completion_tp` gets dropped, but
the spawned task keeps running to completion — finishing every
`pool.generate_step` / `pool.clear_kv_cache` to drain the worker
pipes. The next inference request then finds a clean pool.
Previously: dropped future left workers still processing the
in-flight request, the next call's `ClearKvCache` recv would
read the stale `GenerateStepOk` from the abandoned step ("rank N
expected KvCacheCleared, got GenerateStepOk"). The drain-on-
leader-error fix from d1a4aad covered Rust-side leader failures
but not HTTP-layer cancellation, which is what we actually hit
on the user's Qwen3.6 test.
2. Tracing throughout the TP path so journalctl shows where an
inference spends its time without needing to surface harness
internals via the HTTP error body:
- `chat_completion_tp_inner` (now a free fn so it can run inside
spawn): `info` at request start (prompt_len, max_new, temp,
top_p, eos_id), `info` per major phase (prefill complete with
elapsed_ms, decode complete with elapsed_ms + token count),
`info` at completion (total_ms, finish_reason). `debug` for
pool-lock acquisition + kv-cache clear timing. `trace` per
decode step (next_token, step_ms).
- `WorkerPool::generate_step` (leader side): `debug` at fan-out,
`debug` after leader forward returns with elapsed_ms + ok flag,
`debug` after drain with errors count + total_ms.
- `WorkerPool::clear_kv_cache`: matching `debug` at fan-out + drain.
- `worker::handle_generate_step`: `debug` at forward start + done
with elapsed_ms, `warn` on forward failure with the full error.
The default log filter is already `info,neuron=debug` so the
operator gets every `info` and `debug` line by default; `trace`
needs RUST_LOG=trace for per-step decode timing.
Stage 7c-ii crash-detection is still future work; this is the
minimum that makes the "where did the 120s go" question answerable
from the logs.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
774 lines
31 KiB
Rust
774 lines
31 KiB
Rust
//! Tensor-parallel inference plumbing.
|
|
//!
|
|
//! The leader process (the neuron daemon proper) drives one
|
|
//! subprocess per non-zero NCCL rank — `tokio::process::Command` on
|
|
//! `/proc/self/exe --worker --rank N --tp-size N --cuda-device N` —
|
|
//! and talks to each over a newline-delimited JSON RPC channel on
|
|
//! the worker's stdin/stdout (see `rpc.rs`).
|
|
//!
|
|
//! Sub-staging:
|
|
//!
|
|
//! - **7a-i (this commit):** process lifecycle. `WorkerPool::spawn`
|
|
//! forks N workers; `ping` round-trips every worker to confirm
|
|
//! they're alive; `shutdown` cleanly drains and reaps. `Init` /
|
|
//! `NcclSanityCheck` are stubbed.
|
|
//! - **7a-ii:** real NCCL `Comm` setup via `Init`, sanity check via
|
|
//! `NcclSanityCheck`. CUDA-gated.
|
|
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
|
|
//! - **7c:** crash detection, streaming SSE, graceful unload.
|
|
|
|
pub mod all_reduce;
|
|
pub mod nccl_state;
|
|
pub mod rpc;
|
|
pub mod tp_linear;
|
|
pub mod tp_qwen3;
|
|
pub mod tp_qwen3_5;
|
|
pub mod worker;
|
|
|
|
use anyhow::{Context, Result};
|
|
use std::path::{Path, PathBuf};
|
|
use std::process::Stdio;
|
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines};
|
|
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
|
|
|
use rpc::{WorkerRequest, WorkerResponse};
|
|
|
|
/// Leader-side handle for any TP-loaded model. The pool's
|
|
/// `load_dense_shard` dispatches on `config.json#/model_type` to build
|
|
/// the right variant; downstream callers (the harness's
|
|
/// `chat_completion_tp` path, `generate_step`, `clear_kv_cache`,
|
|
/// `unload_model`) all hold this enum and let the variant dispatch
|
|
/// determine the concrete forward.
|
|
///
|
|
/// Variants gated on `cuda` because the underlying TP models hold
|
|
/// `Arc<cudarc::nccl::Comm>` references — irrelevant on CPU builds.
|
|
#[cfg(feature = "cuda")]
|
|
pub enum TpLeaderModel {
|
|
Qwen3(tp_qwen3::TpQwen3ForCausalLM),
|
|
Qwen3_5(tp_qwen3_5::TpQwen3_5ForCausalLM),
|
|
}
|
|
|
|
#[cfg(feature = "cuda")]
|
|
impl TpLeaderModel {
|
|
pub fn forward(
|
|
&mut self,
|
|
input: &candle_core::Tensor,
|
|
offset: usize,
|
|
) -> candle_core::Result<candle_core::Tensor> {
|
|
match self {
|
|
TpLeaderModel::Qwen3(m) => m.forward(input, offset),
|
|
TpLeaderModel::Qwen3_5(m) => m.forward(input, offset),
|
|
}
|
|
}
|
|
|
|
pub fn clear_kv_cache(&mut self) {
|
|
match self {
|
|
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
|
|
TpLeaderModel::Qwen3_5(m) => m.clear_kv_cache(),
|
|
}
|
|
}
|
|
|
|
pub fn device(&self) -> &candle_core::Device {
|
|
match self {
|
|
TpLeaderModel::Qwen3(m) => m.device(),
|
|
TpLeaderModel::Qwen3_5(m) => m.device(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// One worker subprocess plus its bidirectional stdio handles.
|
|
struct Worker {
|
|
rank: u32,
|
|
/// Captured so the leader can log "spawned rank N on device M" and
|
|
/// future stages can re-issue Init after a CUDA reset. Unused in
|
|
/// the Stage 7a-i RPC paths themselves.
|
|
#[allow(dead_code)]
|
|
cuda_device: u32,
|
|
child: Child,
|
|
stdin: ChildStdin,
|
|
stdout: Lines<BufReader<ChildStdout>>,
|
|
}
|
|
|
|
impl Worker {
|
|
/// Send a request and wait for the response. Used for sequenced
|
|
/// ops like `Ping` / `Shutdown` where the caller doesn't need to
|
|
/// overlap the worker's execution with the leader's.
|
|
async fn request(&mut self, req: &WorkerRequest) -> Result<WorkerResponse> {
|
|
self.send_only(req).await?;
|
|
self.recv_only().await
|
|
}
|
|
|
|
/// Write a request without awaiting its response. Pair with
|
|
/// `recv_only` from the caller when leader and worker need to do
|
|
/// work concurrently — e.g. during `Init`, where the leader
|
|
/// itself calls `Comm::from_rank` on rank 0 in parallel with the
|
|
/// workers, then collects `InitOk` after NCCL completes.
|
|
async fn send_only(&mut self, req: &WorkerRequest) -> Result<()> {
|
|
let mut line = serde_json::to_string(req).context("serialise WorkerRequest")?;
|
|
line.push('\n');
|
|
self.stdin
|
|
.write_all(line.as_bytes())
|
|
.await
|
|
.with_context(|| format!("write request to rank {}", self.rank))?;
|
|
self.stdin
|
|
.flush()
|
|
.await
|
|
.with_context(|| format!("flush stdin to rank {}", self.rank))?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn recv_only(&mut self) -> Result<WorkerResponse> {
|
|
let reply = self
|
|
.stdout
|
|
.next_line()
|
|
.await
|
|
.with_context(|| format!("read reply from rank {}", self.rank))?
|
|
.ok_or_else(|| anyhow::anyhow!("rank {} stdout closed before reply", self.rank))?;
|
|
serde_json::from_str(&reply)
|
|
.with_context(|| format!("parse reply from rank {}: {reply:?}", self.rank))
|
|
}
|
|
}
|
|
|
|
/// Drain one response from every worker, classifying each via the
|
|
/// supplied checker. Always reads from every worker — even if some
|
|
/// fail — so the next call's recv doesn't pick up stale responses
|
|
/// from this one (pipe-poisoning was the cause of the
|
|
/// "ClearKvCache: expected KvCacheCleared, got GenerateStepOk" class
|
|
/// of bugs).
|
|
///
|
|
/// Returns a vector of `rank N: detail` strings for any worker that
|
|
/// errored, expected-mismatched, or failed to respond. Caller decides
|
|
/// how to combine these with the leader's outcome.
|
|
async fn drain_workers(
|
|
workers: &mut [Worker],
|
|
mut check: impl FnMut(WorkerResponse) -> std::result::Result<(), String>,
|
|
) -> Vec<String> {
|
|
let mut errs = Vec::new();
|
|
for w in workers {
|
|
match w.recv_only().await {
|
|
Ok(resp) => {
|
|
if let Err(detail) = check(resp) {
|
|
errs.push(format!("rank {} {detail}", w.rank));
|
|
}
|
|
}
|
|
Err(e) => errs.push(format!("rank {} recv: {e:#}", w.rank)),
|
|
}
|
|
}
|
|
errs
|
|
}
|
|
|
|
/// Combine a leader's `Result<Result<T>>` (the typical
|
|
/// `spawn_blocking → JoinHandle<Result<T>>` shape) with the worker
|
|
/// drain results into a single `Result<T>`. Leader failures take
|
|
/// precedence in the error message but worker errors get appended so
|
|
/// the operator sees both halves.
|
|
#[cfg(feature = "cuda")]
|
|
fn combine_leader_workers<T>(
|
|
leader: Result<Result<T>>,
|
|
worker_errors: Vec<String>,
|
|
op: &str,
|
|
) -> Result<T> {
|
|
match leader {
|
|
Ok(Ok(value)) => {
|
|
if worker_errors.is_empty() {
|
|
Ok(value)
|
|
} else {
|
|
anyhow::bail!(
|
|
"{op}: leader succeeded but workers failed: {}",
|
|
worker_errors.join("; ")
|
|
)
|
|
}
|
|
}
|
|
Ok(Err(e)) => {
|
|
if worker_errors.is_empty() {
|
|
Err(e.context(format!("{op}: leader forward failed")))
|
|
} else {
|
|
Err(e.context(format!(
|
|
"{op}: leader forward failed and workers also failed: {}",
|
|
worker_errors.join("; ")
|
|
)))
|
|
}
|
|
}
|
|
Err(panic_err) => {
|
|
if worker_errors.is_empty() {
|
|
Err(panic_err)
|
|
} else {
|
|
Err(panic_err.context(format!(
|
|
"{op}: leader task panicked and workers failed: {}",
|
|
worker_errors.join("; ")
|
|
)))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A live pool of worker subprocesses. Owns the `Child` handles so
|
|
/// dropping the pool kills the children; explicit `shutdown()` is
|
|
/// the graceful path.
|
|
pub struct WorkerPool {
|
|
world_size: u32,
|
|
workers: Vec<Worker>,
|
|
/// Path to the neuron binary used to launch workers.
|
|
#[allow(dead_code)]
|
|
exe: PathBuf,
|
|
/// Leader's own NCCL rank-0 state. Defaults to empty; populated by
|
|
/// `init_nccl()`. Held here so the leader can participate in
|
|
/// collectives (rank 0) without spawning a fourth subprocess.
|
|
leader_nccl: nccl_state::NcclState,
|
|
}
|
|
|
|
impl WorkerPool {
|
|
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
|
|
/// leader (in-process) and is *not* spawned here — the leader
|
|
/// holds rank 0's NCCL Comm and shard in its own address space.
|
|
///
|
|
/// `binary` is the path to the neuron executable to run for each
|
|
/// worker (production passes `/proc/self/exe`; tests pass the
|
|
/// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`).
|
|
/// `cuda_devices` is one entry per rank including rank 0. Worker
|
|
/// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`.
|
|
pub async fn spawn(binary: &Path, world_size: u32, cuda_devices: &[u32]) -> Result<Self> {
|
|
if world_size < 2 {
|
|
anyhow::bail!(
|
|
"WorkerPool::spawn called with world_size={world_size}; \
|
|
use the single-process path for world_size < 2"
|
|
);
|
|
}
|
|
if cuda_devices.len() as u32 != world_size {
|
|
anyhow::bail!(
|
|
"expected {world_size} cuda_devices entries, got {}",
|
|
cuda_devices.len()
|
|
);
|
|
}
|
|
let exe = binary.to_path_buf();
|
|
|
|
let mut workers = Vec::with_capacity(world_size as usize - 1);
|
|
// Rank 0 stays in-process. Spawn ranks 1..world_size.
|
|
for rank in 1..world_size {
|
|
let cuda_device = cuda_devices[rank as usize];
|
|
let mut cmd = Command::new(&exe);
|
|
cmd.arg("--worker")
|
|
.arg("--rank")
|
|
.arg(rank.to_string())
|
|
.arg("--tp-size")
|
|
.arg(world_size.to_string())
|
|
.arg("--cuda-device")
|
|
.arg(cuda_device.to_string())
|
|
.stdin(Stdio::piped())
|
|
.stdout(Stdio::piped())
|
|
// Inherit stderr so worker tracing surfaces alongside
|
|
// the leader's journalctl stream.
|
|
.stderr(Stdio::inherit())
|
|
.kill_on_drop(true);
|
|
|
|
let mut child = cmd
|
|
.spawn()
|
|
.with_context(|| format!("spawn worker rank {rank}"))?;
|
|
let stdin = child
|
|
.stdin
|
|
.take()
|
|
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdin handle"))?;
|
|
let stdout = child
|
|
.stdout
|
|
.take()
|
|
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdout handle"))?;
|
|
let stdout = BufReader::new(stdout).lines();
|
|
|
|
workers.push(Worker {
|
|
rank,
|
|
cuda_device,
|
|
child,
|
|
stdin,
|
|
stdout,
|
|
});
|
|
tracing::info!(rank, cuda_device, "spawned tp worker");
|
|
}
|
|
|
|
Ok(Self {
|
|
world_size,
|
|
workers,
|
|
exe,
|
|
leader_nccl: nccl_state::NcclState::new(),
|
|
})
|
|
}
|
|
|
|
/// Establish the NCCL communicator across the leader (rank 0) and
|
|
/// every worker subprocess. Rendezvous is via a freshly-generated
|
|
/// `Id` broadcast over the RPC stream; the actual handshake blocks
|
|
/// inside `Comm::from_rank` until all `world_size` ranks check in.
|
|
///
|
|
/// `leader_cuda_device` is the CUDA device the leader binds rank 0
|
|
/// to — typically the first entry of the `cuda_devices` slice
|
|
/// originally passed to `spawn()`.
|
|
///
|
|
/// On the non-cuda build this immediately fails because the leader
|
|
/// can't generate an `Id` without libnccl. The same call works in
|
|
/// the worker path (returning a no-cuda error response) so the
|
|
/// failure surface is uniform.
|
|
pub async fn init_nccl(&mut self, leader_cuda_device: u32) -> Result<()> {
|
|
let comm_id = nccl_state::generate_comm_id_hex()
|
|
.map_err(|m| anyhow::anyhow!("generate NCCL id: {m}"))?;
|
|
|
|
// 1. Write Init to every worker's stdin without awaiting the
|
|
// response. Workers will parse and call Comm::from_rank
|
|
// concurrently with the leader below.
|
|
for w in &mut self.workers {
|
|
let req = WorkerRequest::Init {
|
|
comm_id: comm_id.clone(),
|
|
};
|
|
w.send_only(&req).await?;
|
|
}
|
|
|
|
// 2. Leader rank 0 calls Comm::from_rank on its own device.
|
|
// Runs on spawn_blocking because NCCL's init blocks until
|
|
// every rank has called in — that's exactly the workers
|
|
// above. The leader's NcclState is moved through the
|
|
// blocking task and returned to the pool.
|
|
let leader_cfg = worker::WorkerConfig {
|
|
rank: 0,
|
|
world_size: self.world_size,
|
|
cuda_device: leader_cuda_device,
|
|
};
|
|
let comm_id_for_leader = comm_id.clone();
|
|
// Swap out the leader's NcclState into a fresh empty one so we
|
|
// can move it into spawn_blocking; restore after the task
|
|
// returns. (NcclState isn't Clone — it owns a real NCCL Comm.)
|
|
let mut leader_state = std::mem::take(&mut self.leader_nccl);
|
|
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
|
|
let resp = leader_state.init(leader_cfg, &comm_id_for_leader);
|
|
(leader_state, resp)
|
|
})
|
|
.await
|
|
.context("leader NCCL init task panicked")?;
|
|
self.leader_nccl = returned_state;
|
|
match leader_resp {
|
|
rpc::WorkerResponse::InitOk => {}
|
|
rpc::WorkerResponse::Error { kind, message } => {
|
|
anyhow::bail!("leader rank 0 init failed [{kind}]: {message}");
|
|
}
|
|
other => anyhow::bail!("leader rank 0 init: unexpected {other:?}"),
|
|
}
|
|
|
|
// 3. Read InitOk from each worker. By now every worker has
|
|
// completed its Comm::from_rank call (NCCL released them
|
|
// when the leader joined the handshake) and is writing its
|
|
// response.
|
|
for w in &mut self.workers {
|
|
let resp = w.recv_only().await?;
|
|
match &resp {
|
|
rpc::WorkerResponse::InitOk => {}
|
|
rpc::WorkerResponse::Error { kind, message } => {
|
|
anyhow::bail!("worker rank {} init failed [{kind}]: {message}", w.rank);
|
|
}
|
|
other => anyhow::bail!(
|
|
"worker rank {} init: expected InitOk, got {other:?}",
|
|
w.rank
|
|
),
|
|
}
|
|
}
|
|
tracing::info!(
|
|
world_size = self.world_size,
|
|
"NCCL communicator established across all ranks"
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
/// Validate the NCCL communicator: every rank `all_reduce`s a
|
|
/// sentinel `1u32` with `ReduceOp::Sum`; the expected total is
|
|
/// `world_size`. Confirms the handshake is live, not just
|
|
/// configured.
|
|
///
|
|
/// Must be called after `init_nccl()`; before that the leader has
|
|
/// no Comm and the workers reply with `nccl_not_initialised`.
|
|
pub async fn nccl_sanity_check(&mut self) -> Result<()> {
|
|
// 1. Trigger the all_reduce on every worker (write-only).
|
|
for w in &mut self.workers {
|
|
w.send_only(&WorkerRequest::NcclSanityCheck).await?;
|
|
}
|
|
|
|
// 2. Leader's own all_reduce, in spawn_blocking. NCCL operations
|
|
// block until every rank participates.
|
|
let mut leader_state = std::mem::take(&mut self.leader_nccl);
|
|
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
|
|
let resp = leader_state.sanity_check();
|
|
(leader_state, resp)
|
|
})
|
|
.await
|
|
.context("leader NCCL sanity task panicked")?;
|
|
self.leader_nccl = returned_state;
|
|
|
|
let expected = self.world_size;
|
|
let leader_sum = match leader_resp {
|
|
rpc::WorkerResponse::NcclSanityResult { observed_sum } => observed_sum,
|
|
rpc::WorkerResponse::Error { kind, message } => {
|
|
anyhow::bail!("leader rank 0 sanity failed [{kind}]: {message}");
|
|
}
|
|
other => anyhow::bail!("leader rank 0 sanity: unexpected {other:?}"),
|
|
};
|
|
if leader_sum != expected {
|
|
anyhow::bail!("leader observed_sum={leader_sum}, expected {expected}");
|
|
}
|
|
|
|
// 3. Read sanity result from each worker. All must match
|
|
// world_size — anything else means the collective didn't
|
|
// complete consistently across ranks.
|
|
for w in &mut self.workers {
|
|
let resp = w.recv_only().await?;
|
|
match resp {
|
|
rpc::WorkerResponse::NcclSanityResult { observed_sum }
|
|
if observed_sum == expected => {}
|
|
rpc::WorkerResponse::NcclSanityResult { observed_sum } => {
|
|
anyhow::bail!(
|
|
"worker rank {} observed_sum={observed_sum}, expected {expected}",
|
|
w.rank
|
|
);
|
|
}
|
|
rpc::WorkerResponse::Error { kind, message } => {
|
|
anyhow::bail!("worker rank {} sanity failed [{kind}]: {message}", w.rank);
|
|
}
|
|
other => anyhow::bail!("worker rank {} sanity: unexpected {other:?}", w.rank),
|
|
}
|
|
}
|
|
tracing::info!(
|
|
world_size = expected,
|
|
"NCCL sanity check OK across all ranks"
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
/// Ping every worker and return their Pong payloads in rank order.
|
|
/// Useful right after `spawn` to confirm the lifecycle plumbing is
|
|
/// intact before kicking off any heavier work.
|
|
pub async fn ping_all(&mut self) -> Result<Vec<WorkerResponse>> {
|
|
let mut out = Vec::with_capacity(self.workers.len());
|
|
for w in &mut self.workers {
|
|
let resp = w.request(&WorkerRequest::Ping).await?;
|
|
match &resp {
|
|
WorkerResponse::Pong { rank, .. } if *rank == w.rank => {}
|
|
WorkerResponse::Pong { rank, .. } => {
|
|
anyhow::bail!("rank mismatch: expected {}, got {rank}", w.rank);
|
|
}
|
|
other => anyhow::bail!("expected Pong from rank {}, got {other:?}", w.rank),
|
|
}
|
|
out.push(resp);
|
|
}
|
|
Ok(out)
|
|
}
|
|
|
|
/// Load this rank's shard of a dense Qwen3 model on every rank.
|
|
///
|
|
/// The leader builds rank 0's `TpQwen3ForCausalLM` directly into
|
|
/// the returned `Arc<Mutex<_>>` — workers build their rank-local
|
|
/// shards in their own address spaces and confirm via
|
|
/// `LoadDenseShardOk`. All ranks see the same `safetensors_paths`;
|
|
/// `ShardedVarBuilder` slices each tensor by rank at materialisation
|
|
/// time, so the per-rank VRAM footprint is roughly `1/world_size`
|
|
/// of the full model (plus the replicated embedding/norm/lm_head).
|
|
///
|
|
/// `leader_device` is the candle `Device` the leader's shard lives
|
|
/// on — typically `Device::new_cuda(leader_cuda_device)` matching
|
|
/// the same index passed to `init_nccl`. `dtype` is the on-device
|
|
/// element type; bf16 is the canonical Qwen3 distribution dtype.
|
|
///
|
|
/// `init_nccl` must have completed first. Bails if the leader's
|
|
/// NCCL comm isn't set up yet.
|
|
#[cfg(feature = "cuda")]
|
|
pub async fn load_dense_shard(
|
|
&mut self,
|
|
model_id: &str,
|
|
config_json: &str,
|
|
safetensors_paths: &[std::path::PathBuf],
|
|
leader_device: &candle_core::Device,
|
|
dtype: candle_core::DType,
|
|
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> {
|
|
use candle_nn::var_builder::ShardedSafeTensors;
|
|
use std::sync::Arc;
|
|
use tokio::sync::Mutex;
|
|
|
|
// Wrap the comm in SendComm immediately so it stays Send across
|
|
// the await points in this method — bare Arc<Comm> would
|
|
// poison the async fn's Send bound (Comm's raw NCCL pointer is
|
|
// !Send). The wrapper's safety contract is satisfied by the
|
|
// pool's outer Mutex serialising callers + the spawn_blocking
|
|
// thread being the only place ops are issued.
|
|
let leader_comm =
|
|
nccl_state::SendComm(self.leader_nccl.comm().ok_or_else(|| {
|
|
anyhow::anyhow!("leader NCCL not initialised; call init_nccl first")
|
|
})?);
|
|
let world_size = self.world_size;
|
|
let safetensors_str: Vec<String> = safetensors_paths
|
|
.iter()
|
|
.map(|p| p.to_string_lossy().into_owned())
|
|
.collect();
|
|
|
|
// 1. Fan out the LoadDenseShard request to every worker without
|
|
// awaiting their replies — they'll build their shards in
|
|
// parallel with the leader below.
|
|
for w in &mut self.workers {
|
|
w.send_only(&WorkerRequest::LoadDenseShard {
|
|
model_id: model_id.to_string(),
|
|
config_json: config_json.to_string(),
|
|
safetensors_paths: safetensors_str.clone(),
|
|
})
|
|
.await?;
|
|
}
|
|
|
|
// 2. Build rank 0's shard on the leader. Dispatch on model_type
|
|
// — for `qwen3` we build a `TpQwen3ForCausalLM`, for
|
|
// `qwen3_5` (Qwen3-Next, Qwen3.6's architecture) we build
|
|
// `TpQwen3_5ForCausalLM`. Both end up wrapped in the
|
|
// `TpLeaderModel` enum so downstream callers don't care.
|
|
let model_type = serde_json::from_str::<serde_json::Value>(config_json)
|
|
.ok()
|
|
.as_ref()
|
|
.and_then(|v| v.get("model_type"))
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("")
|
|
.to_string();
|
|
let paths_for_leader: Vec<std::path::PathBuf> = safetensors_paths.to_vec();
|
|
let device_for_leader = leader_device.clone();
|
|
let comm_for_leader = leader_comm;
|
|
let model_id_for_log = model_id.to_string();
|
|
let config_json_for_leader = config_json.to_string();
|
|
|
|
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
|
|
// SAFETY: same invariant as the single-GPU dense path —
|
|
// the HF cache files are treated as immutable while the
|
|
// mmap is held.
|
|
let vb = unsafe {
|
|
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
|
|
.context("build ShardedVarBuilder over safetensors")?
|
|
};
|
|
let comm = comm_for_leader.into_inner();
|
|
let loaded = match model_type.as_str() {
|
|
"qwen3" => {
|
|
let cfg: super::tp::tp_qwen3::Config = serde_json::from_str(&config_json_for_leader)
|
|
.context("parse Qwen3 Config JSON for leader load")?;
|
|
TpLeaderModel::Qwen3(super::tp::tp_qwen3::TpQwen3ForCausalLM::load(
|
|
&cfg, &vb, 0, world_size, comm,
|
|
)?)
|
|
}
|
|
"qwen3_5" => {
|
|
let cfg: super::tp::tp_qwen3_5::Config =
|
|
serde_json::from_str(&config_json_for_leader)
|
|
.context("parse Qwen3-Next Config JSON for leader load")?;
|
|
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
|
|
cfg, &vb, 0, world_size, comm,
|
|
)?)
|
|
}
|
|
other => anyhow::bail!(
|
|
"TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)"
|
|
),
|
|
};
|
|
tracing::info!(rank = 0, model = %model_id_for_log, model_type = %model_type, "loaded TP shard (leader)");
|
|
Ok(loaded)
|
|
})
|
|
.await
|
|
.context("leader load task panicked")??;
|
|
|
|
// 3. Collect worker confirmations. Anything other than
|
|
// LoadDenseShardOk aborts the whole load — the leader's
|
|
// already-loaded shard drops when this fn returns Err.
|
|
for w in &mut self.workers {
|
|
let resp = w.recv_only().await?;
|
|
match resp {
|
|
WorkerResponse::LoadDenseShardOk => {}
|
|
WorkerResponse::Error { kind, message } => {
|
|
anyhow::bail!("worker rank {} LoadDenseShard [{kind}]: {message}", w.rank)
|
|
}
|
|
other => anyhow::bail!(
|
|
"worker rank {} LoadDenseShard: expected LoadDenseShardOk, got {other:?}",
|
|
w.rank
|
|
),
|
|
}
|
|
}
|
|
|
|
Ok(Arc::new(Mutex::new(leader_model)))
|
|
}
|
|
|
|
/// Run one forward step across every rank. The leader's forward
|
|
/// returns the last-position logits as a candle Tensor on the
|
|
/// leader's device; the caller does sampling out-of-band. Workers
|
|
/// run their own forwards (the AllReduce inside row-parallel layers
|
|
/// is what lets the leader's collective complete) and reply with
|
|
/// `GenerateStepOk` — they do not ship logits over the wire.
|
|
///
|
|
/// `tokens` is the input for this step (prompt for prefill, the
|
|
/// previously-sampled token for decode). `offset` is the KV-cache
|
|
/// position before this step.
|
|
#[cfg(feature = "cuda")]
|
|
pub async fn generate_step(
|
|
&mut self,
|
|
model_id: &str,
|
|
leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
|
|
tokens: Vec<u32>,
|
|
offset: usize,
|
|
) -> Result<candle_core::Tensor> {
|
|
let step_start = std::time::Instant::now();
|
|
let tokens_len = tokens.len();
|
|
tracing::debug!(
|
|
model = %model_id,
|
|
tokens = tokens_len,
|
|
offset,
|
|
"WorkerPool::generate_step: fan-out"
|
|
);
|
|
// 1. Fan-out to workers.
|
|
for w in &mut self.workers {
|
|
w.send_only(&WorkerRequest::GenerateStep {
|
|
model_id: model_id.to_string(),
|
|
tokens: tokens.clone(),
|
|
offset,
|
|
})
|
|
.await?;
|
|
}
|
|
|
|
// 2. Leader's forward in spawn_blocking. The AllReduce CustomOps
|
|
// inside the row-parallel layers block until every worker's
|
|
// forward issues the matching collective.
|
|
let leader_start = std::time::Instant::now();
|
|
let leader_result = tokio::task::spawn_blocking(move || -> Result<candle_core::Tensor> {
|
|
let mut model = leader_model.blocking_lock();
|
|
let device = model.device().clone();
|
|
let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
|
// ForCausalLM::forward returns [B, 1, V] — squeeze both
|
|
// leading dims to the rank-1 vocab logits the sampler wants.
|
|
let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?;
|
|
Ok(logits)
|
|
})
|
|
.await
|
|
.context("leader forward task panicked");
|
|
let leader_ok = matches!(leader_result, Ok(Ok(_)));
|
|
tracing::debug!(
|
|
model = %model_id,
|
|
tokens = tokens_len,
|
|
leader_ms = leader_start.elapsed().as_millis(),
|
|
leader_ok,
|
|
"WorkerPool::generate_step: leader forward returned"
|
|
);
|
|
|
|
// 3. ALWAYS drain worker responses, regardless of whether the
|
|
// leader succeeded. Skipping this on the leader's error
|
|
// path leaves stale GenerateStepOk replies in the worker
|
|
// pipes that poison the NEXT request's recv (was seeing
|
|
// "ClearKvCache: expected KvCacheCleared, got
|
|
// GenerateStepOk" the call after any forward-time failure).
|
|
let drain_start = std::time::Instant::now();
|
|
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
|
WorkerResponse::GenerateStepOk => Ok(()),
|
|
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
|
other => Err(format!("expected GenerateStepOk, got {other:?}")),
|
|
})
|
|
.await;
|
|
tracing::debug!(
|
|
model = %model_id,
|
|
drain_ms = drain_start.elapsed().as_millis(),
|
|
errors = worker_errors.len(),
|
|
total_ms = step_start.elapsed().as_millis(),
|
|
"WorkerPool::generate_step: workers drained"
|
|
);
|
|
|
|
combine_leader_workers(leader_result, worker_errors, "GenerateStep")
|
|
}
|
|
|
|
/// Reset the KV cache for `model_id` on every rank. Called at the
|
|
/// start of every inference so a fresh request doesn't attend over
|
|
/// the previous one's tokens.
|
|
pub async fn clear_kv_cache(
|
|
&mut self,
|
|
model_id: &str,
|
|
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
|
|
) -> Result<()> {
|
|
let start = std::time::Instant::now();
|
|
tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out");
|
|
for w in &mut self.workers {
|
|
w.send_only(&WorkerRequest::ClearKvCache {
|
|
model_id: model_id.to_string(),
|
|
})
|
|
.await?;
|
|
}
|
|
#[cfg(feature = "cuda")]
|
|
{
|
|
let mut m = leader_model.lock().await;
|
|
m.clear_kv_cache();
|
|
}
|
|
// Drain workers — same rationale as `generate_step`. The
|
|
// leader's clear_kv_cache is in-process and infallible, but we
|
|
// still always drain so an error on one worker doesn't leave
|
|
// pending responses for the others.
|
|
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
|
WorkerResponse::KvCacheCleared => Ok(()),
|
|
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
|
other => Err(format!("expected KvCacheCleared, got {other:?}")),
|
|
})
|
|
.await;
|
|
tracing::debug!(
|
|
model = %model_id,
|
|
elapsed_ms = start.elapsed().as_millis(),
|
|
errors = worker_errors.len(),
|
|
"WorkerPool::clear_kv_cache: workers drained"
|
|
);
|
|
if !worker_errors.is_empty() {
|
|
anyhow::bail!("ClearKvCache: {}", worker_errors.join("; "));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Drop this model's shards on every rank. The leader's shard is
|
|
/// expected to have been dropped by the caller (its `Arc` was held
|
|
/// in the TpLoadedModel and goes away when that's removed).
|
|
pub async fn unload_model(&mut self, model_id: &str) -> Result<()> {
|
|
for w in &mut self.workers {
|
|
w.send_only(&WorkerRequest::UnloadModel {
|
|
model_id: model_id.to_string(),
|
|
})
|
|
.await?;
|
|
}
|
|
for w in &mut self.workers {
|
|
let resp = w.recv_only().await?;
|
|
match resp {
|
|
WorkerResponse::Unloaded => {}
|
|
WorkerResponse::Error { kind, message } => {
|
|
anyhow::bail!("worker rank {} UnloadModel [{kind}]: {message}", w.rank)
|
|
}
|
|
other => anyhow::bail!(
|
|
"worker rank {} UnloadModel: expected Unloaded, got {other:?}",
|
|
w.rank
|
|
),
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
|
|
/// children. Best-effort — individual worker failures are logged
|
|
/// but don't abort the rest of the sweep.
|
|
pub async fn shutdown(mut self) -> Result<()> {
|
|
for w in &mut self.workers {
|
|
match w.request(&WorkerRequest::Shutdown).await {
|
|
Ok(WorkerResponse::Bye) => {}
|
|
Ok(other) => tracing::warn!(
|
|
rank = w.rank,
|
|
response = ?other,
|
|
"expected Bye on shutdown"
|
|
),
|
|
Err(e) => tracing::warn!(rank = w.rank, error = %e, "shutdown request failed"),
|
|
}
|
|
}
|
|
for w in &mut self.workers {
|
|
match w.child.wait().await {
|
|
Ok(status) => tracing::info!(rank = w.rank, %status, "worker exited"),
|
|
Err(e) => tracing::warn!(rank = w.rank, error = %e, "wait on worker failed"),
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub fn world_size(&self) -> u32 {
|
|
self.world_size
|
|
}
|
|
|
|
pub fn binary_path(&self) -> &PathBuf {
|
|
&self.exe
|
|
}
|
|
}
|