Stage 7a-i: TP worker lifecycle scaffolding
All checks were successful
CI / Format (push) Successful in 36s
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Clippy (push) Successful in 2m12s
CI / Test (push) Successful in 4m25s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m49s
build-prerelease / Build cortex binary (push) Successful in 4m22s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 5m9s
build-prerelease / Build neuron-ada (push) Successful in 4m59s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m59s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m38s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m8s
All checks were successful
CI / Format (push) Successful in 36s
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Clippy (push) Successful in 2m12s
CI / Test (push) Successful in 4m25s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m49s
build-prerelease / Build cortex binary (push) Successful in 4m22s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 5m9s
build-prerelease / Build neuron-ada (push) Successful in 4m59s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m59s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m38s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m8s
Leader → worker process plumbing for tensor parallelism. The neuron
binary picks up two modes: default (the existing daemon, axum + HTTP)
and `--worker` (a bare RPC loop driven over stdin/stdout). The leader
spawns one worker per non-zero NCCL rank via tokio::process::Command
on the same binary path (production: /proc/self/exe; tests:
env!("CARGO_BIN_EXE_neuron")) and talks to each over newline-
delimited JSON.
Protocol (harness/tp/rpc.rs) is serde-tagged from the start —
WorkerRequest::{Ping, Init, NcclSanityCheck, Shutdown} and
WorkerResponse::{Pong, InitOk, NcclSanityResult, Bye, Error}, both
`#[serde(tag = "op", rename_all = "snake_case")]`. Adding ops in 7b/7c
is purely additive; unknown ops on the wire fail to parse (verified
in unit tests).
7a-i scope:
- WorkerPool::spawn(binary, world_size, devices) forks ranks 1..N as
subprocesses, captures stdin/stdout, kills on drop.
- ping_all() round-trips a Ping to every worker and validates the
returned rank.
- shutdown() sends Shutdown to each worker, awaits Bye, reaps.
- Worker mode: parse Ping/Shutdown, return Pong/Bye; Init and
NcclSanityCheck return Error{kind="not_implemented_7a_i"} so a 7a-ii
binary speaking the same wire is a drop-in replacement (the kind
field signals "real NCCL lands in the next commit").
- CandleHarness::load_model refuses tensor_parallel > 1 with a clear
message until 7b is in.
Three integration tests in tests/tp_worker_lifecycle.rs cover spawn/
ping/shutdown for 2- and 3-worker pools, plus the
not_implemented_7a_i contract test for Init. Seven rpc serde unit
tests assert the wire shape (op tags, field names, unknown-op
rejection). All pass on the dev host; no CUDA required.
Stage 7a-ii (next): the real NCCL Comm::from_rank wiring behind the
existing Init/NcclSanityCheck op surface, CUDA-gated. Verifiable on
beast's 2×5090.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -404,6 +404,19 @@ impl Harness for CandleHarness {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stage 7a-i scaffolds tensor-parallel worker subprocesses but
|
||||||
|
// does not yet route inference through them. Refuse TP loads
|
||||||
|
// for now with a clear marker so the request surface is honest.
|
||||||
|
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
||||||
|
if tp_size > 1 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"tensor_parallel={tp_size} requested for '{}': TP worker \
|
||||||
|
lifecycle is in place (Stage 7a-i) but TP-aware Qwen3 \
|
||||||
|
inference lands in Stage 7b; single-GPU loads only for now",
|
||||||
|
spec.model_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
|
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
|
||||||
let device = Self::pick_device(&devices)?;
|
let device = Self::pick_device(&devices)?;
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
//! Harness registry — maps harness names to trait implementations.
|
//! Harness registry — maps harness names to trait implementations.
|
||||||
|
|
||||||
pub mod candle;
|
pub mod candle;
|
||||||
|
pub mod tp;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
|
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
|
||||||
|
|||||||
204
crates/neuron/src/harness/tp/mod.rs
Normal file
204
crates/neuron/src/harness/tp/mod.rs
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
//! Tensor-parallel inference plumbing.
|
||||||
|
//!
|
||||||
|
//! The leader process (the neuron daemon proper) drives one
|
||||||
|
//! subprocess per non-zero NCCL rank — `tokio::process::Command` on
|
||||||
|
//! `/proc/self/exe --worker --rank N --tp-size N --cuda-device N` —
|
||||||
|
//! and talks to each over a newline-delimited JSON RPC channel on
|
||||||
|
//! the worker's stdin/stdout (see `rpc.rs`).
|
||||||
|
//!
|
||||||
|
//! Sub-staging:
|
||||||
|
//!
|
||||||
|
//! - **7a-i (this commit):** process lifecycle. `WorkerPool::spawn`
|
||||||
|
//! forks N workers; `ping` round-trips every worker to confirm
|
||||||
|
//! they're alive; `shutdown` cleanly drains and reaps. `Init` /
|
||||||
|
//! `NcclSanityCheck` are stubbed.
|
||||||
|
//! - **7a-ii:** real NCCL `Comm` setup via `Init`, sanity check via
|
||||||
|
//! `NcclSanityCheck`. CUDA-gated.
|
||||||
|
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
|
||||||
|
//! - **7c:** crash detection, streaming SSE, graceful unload.
|
||||||
|
|
||||||
|
pub mod rpc;
|
||||||
|
pub mod worker;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::process::Stdio;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines};
|
||||||
|
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||||
|
|
||||||
|
use rpc::{WorkerRequest, WorkerResponse};
|
||||||
|
|
||||||
|
/// One worker subprocess plus its bidirectional stdio handles.
|
||||||
|
struct Worker {
|
||||||
|
rank: u32,
|
||||||
|
/// Captured so the leader can log "spawned rank N on device M" and
|
||||||
|
/// future stages can re-issue Init after a CUDA reset. Unused in
|
||||||
|
/// the Stage 7a-i RPC paths themselves.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
cuda_device: u32,
|
||||||
|
child: Child,
|
||||||
|
stdin: ChildStdin,
|
||||||
|
stdout: Lines<BufReader<ChildStdout>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Worker {
|
||||||
|
async fn request(&mut self, req: &WorkerRequest) -> Result<WorkerResponse> {
|
||||||
|
let mut line = serde_json::to_string(req).context("serialise WorkerRequest")?;
|
||||||
|
line.push('\n');
|
||||||
|
self.stdin
|
||||||
|
.write_all(line.as_bytes())
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("write request to rank {}", self.rank))?;
|
||||||
|
self.stdin
|
||||||
|
.flush()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("flush stdin to rank {}", self.rank))?;
|
||||||
|
|
||||||
|
let reply = self
|
||||||
|
.stdout
|
||||||
|
.next_line()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("read reply from rank {}", self.rank))?
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("rank {} stdout closed before reply", self.rank))?;
|
||||||
|
serde_json::from_str(&reply)
|
||||||
|
.with_context(|| format!("parse reply from rank {}: {reply:?}", self.rank))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A live pool of worker subprocesses. Owns the `Child` handles so
|
||||||
|
/// dropping the pool kills the children; explicit `shutdown()` is
|
||||||
|
/// the graceful path.
|
||||||
|
pub struct WorkerPool {
|
||||||
|
world_size: u32,
|
||||||
|
workers: Vec<Worker>,
|
||||||
|
/// Path to the neuron binary used to launch workers — captured at
|
||||||
|
/// `spawn()` time via `/proc/self/exe` so the workers run the same
|
||||||
|
/// binary the leader is running.
|
||||||
|
exe: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WorkerPool {
|
||||||
|
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
|
||||||
|
/// leader (in-process) and is *not* spawned here — the leader
|
||||||
|
/// holds rank 0's NCCL Comm and shard in its own address space.
|
||||||
|
///
|
||||||
|
/// `binary` is the path to the neuron executable to run for each
|
||||||
|
/// worker (production passes `/proc/self/exe`; tests pass the
|
||||||
|
/// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`).
|
||||||
|
/// `cuda_devices` is one entry per rank including rank 0. Worker
|
||||||
|
/// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`.
|
||||||
|
pub async fn spawn(binary: &Path, world_size: u32, cuda_devices: &[u32]) -> Result<Self> {
|
||||||
|
if world_size < 2 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"WorkerPool::spawn called with world_size={world_size}; \
|
||||||
|
use the single-process path for world_size < 2"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if cuda_devices.len() as u32 != world_size {
|
||||||
|
anyhow::bail!(
|
||||||
|
"expected {world_size} cuda_devices entries, got {}",
|
||||||
|
cuda_devices.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let exe = binary.to_path_buf();
|
||||||
|
|
||||||
|
let mut workers = Vec::with_capacity(world_size as usize - 1);
|
||||||
|
// Rank 0 stays in-process. Spawn ranks 1..world_size.
|
||||||
|
for rank in 1..world_size {
|
||||||
|
let cuda_device = cuda_devices[rank as usize];
|
||||||
|
let mut cmd = Command::new(&exe);
|
||||||
|
cmd.arg("--worker")
|
||||||
|
.arg("--rank")
|
||||||
|
.arg(rank.to_string())
|
||||||
|
.arg("--tp-size")
|
||||||
|
.arg(world_size.to_string())
|
||||||
|
.arg("--cuda-device")
|
||||||
|
.arg(cuda_device.to_string())
|
||||||
|
.stdin(Stdio::piped())
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
// Inherit stderr so worker tracing surfaces alongside
|
||||||
|
// the leader's journalctl stream.
|
||||||
|
.stderr(Stdio::inherit())
|
||||||
|
.kill_on_drop(true);
|
||||||
|
|
||||||
|
let mut child = cmd
|
||||||
|
.spawn()
|
||||||
|
.with_context(|| format!("spawn worker rank {rank}"))?;
|
||||||
|
let stdin = child
|
||||||
|
.stdin
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdin handle"))?;
|
||||||
|
let stdout = child
|
||||||
|
.stdout
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdout handle"))?;
|
||||||
|
let stdout = BufReader::new(stdout).lines();
|
||||||
|
|
||||||
|
workers.push(Worker {
|
||||||
|
rank,
|
||||||
|
cuda_device,
|
||||||
|
child,
|
||||||
|
stdin,
|
||||||
|
stdout,
|
||||||
|
});
|
||||||
|
tracing::info!(rank, cuda_device, "spawned tp worker");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
world_size,
|
||||||
|
workers,
|
||||||
|
exe,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ping every worker and return their Pong payloads in rank order.
|
||||||
|
/// Useful right after `spawn` to confirm the lifecycle plumbing is
|
||||||
|
/// intact before kicking off any heavier work.
|
||||||
|
pub async fn ping_all(&mut self) -> Result<Vec<WorkerResponse>> {
|
||||||
|
let mut out = Vec::with_capacity(self.workers.len());
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.request(&WorkerRequest::Ping).await?;
|
||||||
|
match &resp {
|
||||||
|
WorkerResponse::Pong { rank, .. } if *rank == w.rank => {}
|
||||||
|
WorkerResponse::Pong { rank, .. } => {
|
||||||
|
anyhow::bail!("rank mismatch: expected {}, got {rank}", w.rank);
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("expected Pong from rank {}, got {other:?}", w.rank),
|
||||||
|
}
|
||||||
|
out.push(resp);
|
||||||
|
}
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
|
||||||
|
/// children. Best-effort — individual worker failures are logged
|
||||||
|
/// but don't abort the rest of the sweep.
|
||||||
|
pub async fn shutdown(mut self) -> Result<()> {
|
||||||
|
for w in &mut self.workers {
|
||||||
|
match w.request(&WorkerRequest::Shutdown).await {
|
||||||
|
Ok(WorkerResponse::Bye) => {}
|
||||||
|
Ok(other) => tracing::warn!(
|
||||||
|
rank = w.rank,
|
||||||
|
response = ?other,
|
||||||
|
"expected Bye on shutdown"
|
||||||
|
),
|
||||||
|
Err(e) => tracing::warn!(rank = w.rank, error = %e, "shutdown request failed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for w in &mut self.workers {
|
||||||
|
match w.child.wait().await {
|
||||||
|
Ok(status) => tracing::info!(rank = w.rank, %status, "worker exited"),
|
||||||
|
Err(e) => tracing::warn!(rank = w.rank, error = %e, "wait on worker failed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn world_size(&self) -> u32 {
|
||||||
|
self.world_size
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn binary_path(&self) -> &PathBuf {
|
||||||
|
&self.exe
|
||||||
|
}
|
||||||
|
}
|
||||||
186
crates/neuron/src/harness/tp/rpc.rs
Normal file
186
crates/neuron/src/harness/tp/rpc.rs
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
//! Wire protocol between the neuron leader process and its
|
||||||
|
//! `--worker` subprocesses.
|
||||||
|
//!
|
||||||
|
//! Every frame is one newline-delimited JSON object on the worker's
|
||||||
|
//! stdin (request) or stdout (response). Both directions are tagged
|
||||||
|
//! sum types from the start so new ops in Stage 7b/7c slot in without
|
||||||
|
//! breaking compatibility — no "14 message types and a version field"
|
||||||
|
//! drift later. Adding a new variant is the canonical way to evolve
|
||||||
|
//! the protocol; existing peers that don't recognise an op return
|
||||||
|
//! `WorkerResponse::Error { kind: "unknown_op", .. }`.
|
||||||
|
//!
|
||||||
|
//! The serialised shape uses `tag = "op"` so a request looks like:
|
||||||
|
//! {"op":"ping"}
|
||||||
|
//! {"op":"init","comm_id":"a1b2..."}
|
||||||
|
//! and a response:
|
||||||
|
//! {"op":"pong","rank":0,"world_size":2,"cuda_device":0}
|
||||||
|
//! {"op":"error","kind":"nccl_init_failed","message":"..."}
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Leader → worker. Worker handles one at a time; replies with exactly
|
||||||
|
/// one `WorkerResponse` per request.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "op", rename_all = "snake_case")]
|
||||||
|
pub enum WorkerRequest {
|
||||||
|
/// Liveness probe. Worker replies with `Pong` containing its own
|
||||||
|
/// identity. Used by the leader to confirm the subprocess is up
|
||||||
|
/// and ready before kicking off any heavier work.
|
||||||
|
Ping,
|
||||||
|
|
||||||
|
/// One-shot NCCL communicator setup. The leader generates the
|
||||||
|
/// `comm_id` once (rank 0 of NCCL), broadcasts it to every worker
|
||||||
|
/// via this message, then every rank (leader included) calls
|
||||||
|
/// `Comm::from_rank` with the same id — NCCL blocks until all
|
||||||
|
/// `world_size` ranks check in. The hex-encoded bytes are the
|
||||||
|
/// canonical `cudarc::nccl::Id::internal()` content.
|
||||||
|
Init {
|
||||||
|
/// Hex-encoded NCCL id bytes (128 bytes → 256 hex chars).
|
||||||
|
comm_id: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Sanity check: after Init, every rank runs an `all_reduce` over
|
||||||
|
/// a sentinel value (`1u32`). The expected sum is `world_size`.
|
||||||
|
/// Worker replies with the observed value so the leader can verify
|
||||||
|
/// the NCCL handshake is genuinely live, not just configured.
|
||||||
|
NcclSanityCheck,
|
||||||
|
|
||||||
|
/// Worker should release resources and exit. Worker replies `Bye`
|
||||||
|
/// and then closes stdout / exits zero. The leader reaps the
|
||||||
|
/// child via the `tokio::process::Child` it kept.
|
||||||
|
Shutdown,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Worker → leader. Always exactly one of these per `WorkerRequest`.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "op", rename_all = "snake_case")]
|
||||||
|
pub enum WorkerResponse {
|
||||||
|
/// Reply to `Ping`. Carries enough identity for the leader to log
|
||||||
|
/// what it actually got back.
|
||||||
|
Pong {
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
cuda_device: u32,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Reply to `Init`. Empty payload — success is the absence of
|
||||||
|
/// `Error`. NCCL's internal blocking handshake means by the time
|
||||||
|
/// this comes back, every other rank has also reached
|
||||||
|
/// `Comm::from_rank`.
|
||||||
|
InitOk,
|
||||||
|
|
||||||
|
/// Reply to `NcclSanityCheck`. The observed sum after a single
|
||||||
|
/// `all_reduce(SUM, 1u32)` across all ranks. The leader checks
|
||||||
|
/// this matches `world_size`.
|
||||||
|
NcclSanityResult { observed_sum: u32 },
|
||||||
|
|
||||||
|
/// Reply to `Shutdown`. Worker exits immediately after writing this.
|
||||||
|
Bye,
|
||||||
|
|
||||||
|
/// Any request can produce this instead of its dedicated success
|
||||||
|
/// variant. `kind` is a machine-readable category so the leader
|
||||||
|
/// can branch on failure mode without string-matching `message`.
|
||||||
|
Error {
|
||||||
|
/// Short tag — `nccl_init_failed`, `unknown_op`, etc.
|
||||||
|
kind: String,
|
||||||
|
/// Human-readable detail for logs.
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn roundtrip<T>(value: &T) -> T
|
||||||
|
where
|
||||||
|
T: Serialize + for<'de> Deserialize<'de>,
|
||||||
|
{
|
||||||
|
serde_json::from_str(&serde_json::to_string(value).expect("serialise"))
|
||||||
|
.expect("deserialise")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_ping_round_trip() {
|
||||||
|
let req = WorkerRequest::Ping;
|
||||||
|
let wire = serde_json::to_string(&req).unwrap();
|
||||||
|
assert_eq!(wire, r#"{"op":"ping"}"#);
|
||||||
|
match roundtrip(&req) {
|
||||||
|
WorkerRequest::Ping => {}
|
||||||
|
other => panic!("expected Ping, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_init_carries_hex_id() {
|
||||||
|
let req = WorkerRequest::Init {
|
||||||
|
comm_id: "deadbeef".into(),
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&req).unwrap();
|
||||||
|
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_shutdown_round_trip() {
|
||||||
|
assert_eq!(
|
||||||
|
serde_json::to_string(&WorkerRequest::Shutdown).unwrap(),
|
||||||
|
r#"{"op":"shutdown"}"#
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_pong_round_trip() {
|
||||||
|
let resp = WorkerResponse::Pong {
|
||||||
|
rank: 1,
|
||||||
|
world_size: 4,
|
||||||
|
cuda_device: 1,
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&resp).unwrap();
|
||||||
|
assert!(wire.contains(r#""op":"pong""#));
|
||||||
|
assert!(wire.contains(r#""rank":1"#));
|
||||||
|
assert!(wire.contains(r#""world_size":4"#));
|
||||||
|
match roundtrip(&resp) {
|
||||||
|
WorkerResponse::Pong {
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
cuda_device,
|
||||||
|
} => {
|
||||||
|
assert_eq!(rank, 1);
|
||||||
|
assert_eq!(world_size, 4);
|
||||||
|
assert_eq!(cuda_device, 1);
|
||||||
|
}
|
||||||
|
other => panic!("expected Pong, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_error_carries_kind_and_message() {
|
||||||
|
let resp = WorkerResponse::Error {
|
||||||
|
kind: "nccl_init_failed".into(),
|
||||||
|
message: "could not bind device".into(),
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&resp).unwrap();
|
||||||
|
assert!(wire.contains(r#""op":"error""#));
|
||||||
|
assert!(wire.contains(r#""kind":"nccl_init_failed""#));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_sanity_result_round_trip() {
|
||||||
|
let resp = WorkerResponse::NcclSanityResult { observed_sum: 4 };
|
||||||
|
match roundtrip(&resp) {
|
||||||
|
WorkerResponse::NcclSanityResult { observed_sum } => {
|
||||||
|
assert_eq!(observed_sum, 4);
|
||||||
|
}
|
||||||
|
other => panic!("expected NcclSanityResult, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unknown ops on the wire deserialise to an error rather than
|
||||||
|
/// silently mis-matching — confirms our `serde(tag = "op")`
|
||||||
|
/// configuration rejects unknowns instead of doing fuzzy matching.
|
||||||
|
#[test]
|
||||||
|
fn unknown_op_fails_to_parse() {
|
||||||
|
let result: Result<WorkerRequest, _> = serde_json::from_str(r#"{"op":"explode"}"#);
|
||||||
|
assert!(result.is_err(), "should reject unknown op, got {result:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
105
crates/neuron/src/harness/tp/worker.rs
Normal file
105
crates/neuron/src/harness/tp/worker.rs
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
//! Entry point for `neuron --worker`.
|
||||||
|
//!
|
||||||
|
//! Stage 7a-i: bare RPC loop — `Ping` and `Shutdown` work, `Init` and
|
||||||
|
//! `NcclSanityCheck` return `Error{kind = "not_implemented_7a_i"}`.
|
||||||
|
//! Stage 7a-ii will replace the latter with real `cudarc::nccl` calls
|
||||||
|
//! behind the `cuda` feature.
|
||||||
|
//!
|
||||||
|
//! The worker reads one newline-delimited JSON `WorkerRequest` from
|
||||||
|
//! stdin per loop iteration, dispatches synchronously, and writes
|
||||||
|
//! exactly one `WorkerResponse` JSON line to stdout. tracing goes to
|
||||||
|
//! stderr so it doesn't collide with the RPC stream.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
|
|
||||||
|
use super::rpc::{WorkerRequest, WorkerResponse};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct WorkerConfig {
|
||||||
|
pub rank: u32,
|
||||||
|
pub world_size: u32,
|
||||||
|
pub cuda_device: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drive the worker RPC loop until `Shutdown` or EOF on stdin.
|
||||||
|
pub async fn run(config: WorkerConfig) -> Result<()> {
|
||||||
|
tracing::info!(
|
||||||
|
rank = config.rank,
|
||||||
|
world_size = config.world_size,
|
||||||
|
cuda_device = config.cuda_device,
|
||||||
|
"tp worker starting"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut state = WorkerState::new(config);
|
||||||
|
let stdin = tokio::io::stdin();
|
||||||
|
let mut reader = BufReader::new(stdin).lines();
|
||||||
|
let mut stdout = tokio::io::stdout();
|
||||||
|
|
||||||
|
while let Some(line) = reader.next_line().await? {
|
||||||
|
if line.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let req: WorkerRequest = match serde_json::from_str(&line) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
let resp = WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("parse {line:?}: {e}"),
|
||||||
|
};
|
||||||
|
write_response(&mut stdout, &resp).await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp = state.handle(req).await;
|
||||||
|
let is_bye = matches!(resp, WorkerResponse::Bye);
|
||||||
|
write_response(&mut stdout, &resp).await?;
|
||||||
|
if is_bye {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(rank = config.rank, "tp worker exiting");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn write_response(stdout: &mut tokio::io::Stdout, resp: &WorkerResponse) -> Result<()> {
|
||||||
|
let mut line = serde_json::to_string(resp)?;
|
||||||
|
line.push('\n');
|
||||||
|
stdout.write_all(line.as_bytes()).await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-worker state. In Stage 7a-i this only carries the static
|
||||||
|
/// config; 7a-ii adds an `Option<cudarc::nccl::safe::Comm>` populated
|
||||||
|
/// by `Init`.
|
||||||
|
struct WorkerState {
|
||||||
|
config: WorkerConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WorkerState {
|
||||||
|
fn new(config: WorkerConfig) -> Self {
|
||||||
|
Self { config }
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(&mut self, req: WorkerRequest) -> WorkerResponse {
|
||||||
|
match req {
|
||||||
|
WorkerRequest::Ping => WorkerResponse::Pong {
|
||||||
|
rank: self.config.rank,
|
||||||
|
world_size: self.config.world_size,
|
||||||
|
cuda_device: self.config.cuda_device,
|
||||||
|
},
|
||||||
|
WorkerRequest::Init { comm_id: _ } => WorkerResponse::Error {
|
||||||
|
kind: "not_implemented_7a_i".into(),
|
||||||
|
message: "NCCL init lands in Stage 7a-ii (CUDA-gated)".into(),
|
||||||
|
},
|
||||||
|
WorkerRequest::NcclSanityCheck => WorkerResponse::Error {
|
||||||
|
kind: "not_implemented_7a_i".into(),
|
||||||
|
message: "NCCL sanity check lands in Stage 7a-ii (CUDA-gated)".into(),
|
||||||
|
},
|
||||||
|
WorkerRequest::Shutdown => WorkerResponse::Bye,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,21 +1,52 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use neuron::{api, config::NeuronConfig, discovery, harness::HarnessRegistry, health, startup};
|
use neuron::{
|
||||||
|
api,
|
||||||
|
config::NeuronConfig,
|
||||||
|
discovery,
|
||||||
|
harness::{HarnessRegistry, tp},
|
||||||
|
health, startup,
|
||||||
|
};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
/// Top-level CLI. The same binary runs as either the public neuron
|
||||||
|
/// daemon (default) or a tensor-parallel worker subprocess (when
|
||||||
|
/// `--worker` is set) spawned by the leader on the same host.
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "neuron")]
|
#[command(name = "neuron")]
|
||||||
#[command(about = "Per-node daemon for cortex inference clusters")]
|
#[command(about = "Per-node daemon for cortex inference clusters")]
|
||||||
#[command(version)]
|
#[command(version)]
|
||||||
struct Args {
|
struct Args {
|
||||||
/// Port to listen on (overrides config file).
|
/// Run in tensor-parallel worker mode. The leader process spawns
|
||||||
|
/// one of these per non-zero NCCL rank and drives it over
|
||||||
|
/// newline-delimited JSON on stdin/stdout. Worker mode skips
|
||||||
|
/// discovery, the HTTP listener, and the health poller — it's a
|
||||||
|
/// pure RPC loop.
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
worker: bool,
|
||||||
|
|
||||||
|
/// NCCL rank for worker mode. Ignored when `--worker` is not set.
|
||||||
|
#[arg(long, default_value_t = 0)]
|
||||||
|
rank: u32,
|
||||||
|
|
||||||
|
/// Total NCCL world size for worker mode. Ignored when `--worker`
|
||||||
|
/// is not set.
|
||||||
|
#[arg(long, default_value_t = 1)]
|
||||||
|
tp_size: u32,
|
||||||
|
|
||||||
|
/// CUDA device index for worker mode. Ignored when `--worker` is
|
||||||
|
/// not set.
|
||||||
|
#[arg(long, default_value_t = 0)]
|
||||||
|
cuda_device: u32,
|
||||||
|
|
||||||
|
/// Port to listen on (overrides config file). Daemon mode only.
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
port: Option<u16>,
|
port: Option<u16>,
|
||||||
|
|
||||||
/// Path to the neuron config file.
|
/// Path to the neuron config file. Daemon mode only.
|
||||||
#[arg(short, long, default_value = "neuron.toml")]
|
#[arg(short, long, default_value = "neuron.toml")]
|
||||||
config: String,
|
config: String,
|
||||||
}
|
}
|
||||||
@@ -23,6 +54,7 @@ struct Args {
|
|||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
.with_env_filter(
|
.with_env_filter(
|
||||||
EnvFilter::try_from_default_env()
|
EnvFilter::try_from_default_env()
|
||||||
.unwrap_or_else(|_| EnvFilter::new("info,neuron=debug")),
|
.unwrap_or_else(|_| EnvFilter::new("info,neuron=debug")),
|
||||||
@@ -31,6 +63,19 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
|
if args.worker {
|
||||||
|
return tp::worker::run(tp::worker::WorkerConfig {
|
||||||
|
rank: args.rank,
|
||||||
|
world_size: args.tp_size,
|
||||||
|
cuda_device: args.cuda_device,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
daemon(args).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn daemon(args: Args) -> Result<()> {
|
||||||
let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| {
|
let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| {
|
||||||
tracing::warn!(path = %args.config, error = %e, "config not found, using defaults");
|
tracing::warn!(path = %args.config, error = %e, "config not found, using defaults");
|
||||||
NeuronConfig::default()
|
NeuronConfig::default()
|
||||||
|
|||||||
130
crates/neuron/tests/tp_worker_lifecycle.rs
Normal file
130
crates/neuron/tests/tp_worker_lifecycle.rs
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
//! Stage 7a-i: confirm the TP worker subprocess lifecycle round-trips.
|
||||||
|
//!
|
||||||
|
//! Spawns two worker subprocesses via the leader→worker stdio RPC,
|
||||||
|
//! pings each, and cleanly shuts them down. No CUDA required —
|
||||||
|
//! `Init` and `NcclSanityCheck` are stubbed in 7a-i, so this test
|
||||||
|
//! runs on any host the workspace builds on.
|
||||||
|
|
||||||
|
use neuron::harness::tp::{WorkerPool, rpc::WorkerResponse};
|
||||||
|
|
||||||
|
/// Path to the neuron binary built by cargo for this test process.
|
||||||
|
/// cargo populates `CARGO_BIN_EXE_neuron` at compile time for sibling-
|
||||||
|
/// binary tests; production paths in main.rs use `/proc/self/exe`.
|
||||||
|
const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
|
||||||
|
|
||||||
|
/// Two workers (so we spawn one subprocess: rank 0 is in-process,
|
||||||
|
/// rank 1 is the child). Verify the spawned worker responds to Ping
|
||||||
|
/// with its own identity, then shut it down cleanly.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_spawn_ping_shutdown() {
|
||||||
|
// cuda_devices: rank 0 → device 0 (leader, unused here),
|
||||||
|
// rank 1 → device 1 (worker; not actually opened in 7a-i).
|
||||||
|
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1])
|
||||||
|
.await
|
||||||
|
.expect("spawn worker pool");
|
||||||
|
|
||||||
|
let pongs = pool.ping_all().await.expect("ping all workers");
|
||||||
|
assert_eq!(pongs.len(), 1, "expected one Pong (rank 1 only)");
|
||||||
|
match &pongs[0] {
|
||||||
|
WorkerResponse::Pong {
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
cuda_device,
|
||||||
|
} => {
|
||||||
|
assert_eq!(*rank, 1);
|
||||||
|
assert_eq!(*world_size, 2);
|
||||||
|
assert_eq!(*cuda_device, 1);
|
||||||
|
}
|
||||||
|
other => panic!("expected Pong, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.shutdown().await.expect("clean shutdown");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Three workers — exercise the loop in `ping_all` / `shutdown`.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_spawn_three_workers() {
|
||||||
|
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 3, &[0, 1, 2])
|
||||||
|
.await
|
||||||
|
.expect("spawn worker pool");
|
||||||
|
|
||||||
|
let pongs = pool.ping_all().await.expect("ping all workers");
|
||||||
|
assert_eq!(pongs.len(), 2, "expected two Pongs (ranks 1 and 2)");
|
||||||
|
for (i, resp) in pongs.iter().enumerate() {
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::Pong {
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
cuda_device,
|
||||||
|
} => {
|
||||||
|
let expected_rank = (i + 1) as u32;
|
||||||
|
assert_eq!(*rank, expected_rank);
|
||||||
|
assert_eq!(*world_size, 3);
|
||||||
|
assert_eq!(*cuda_device, expected_rank);
|
||||||
|
}
|
||||||
|
other => panic!("expected Pong, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.shutdown().await.expect("clean shutdown");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 7a-i's Init/NcclSanityCheck handlers return an error rather than
|
||||||
|
/// silently no-op, so the leader can tell the difference between
|
||||||
|
/// "haven't implemented yet" and "succeeded vacuously". Confirm the
|
||||||
|
/// shape so 7a-ii's replacement is a drop-in (same wire op names).
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_init_returns_not_implemented_in_7a_i() {
|
||||||
|
use neuron::harness::tp::rpc::WorkerRequest;
|
||||||
|
use std::process::Stdio;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
|
use tokio::process::Command;
|
||||||
|
|
||||||
|
// Spawn a single worker by hand to send Init directly (the pool's
|
||||||
|
// public API doesn't expose Init yet — that lands in 7a-ii).
|
||||||
|
let mut child = Command::new(NEURON_BIN)
|
||||||
|
.arg("--worker")
|
||||||
|
.arg("--rank")
|
||||||
|
.arg("1")
|
||||||
|
.arg("--tp-size")
|
||||||
|
.arg("2")
|
||||||
|
.arg("--cuda-device")
|
||||||
|
.arg("1")
|
||||||
|
.stdin(Stdio::piped())
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
.stderr(Stdio::null())
|
||||||
|
.kill_on_drop(true)
|
||||||
|
.spawn()
|
||||||
|
.expect("spawn worker");
|
||||||
|
|
||||||
|
let mut stdin = child.stdin.take().expect("stdin");
|
||||||
|
let stdout = child.stdout.take().expect("stdout");
|
||||||
|
let mut lines = BufReader::new(stdout).lines();
|
||||||
|
|
||||||
|
let req = WorkerRequest::Init {
|
||||||
|
comm_id: "ff".repeat(128),
|
||||||
|
};
|
||||||
|
let mut payload = serde_json::to_string(&req).unwrap();
|
||||||
|
payload.push('\n');
|
||||||
|
stdin.write_all(payload.as_bytes()).await.unwrap();
|
||||||
|
stdin.flush().await.unwrap();
|
||||||
|
|
||||||
|
let reply = lines
|
||||||
|
.next_line()
|
||||||
|
.await
|
||||||
|
.expect("read line")
|
||||||
|
.expect("got line");
|
||||||
|
let resp: WorkerResponse = serde_json::from_str(&reply).expect("parse reply");
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::Error { kind, .. } => {
|
||||||
|
assert_eq!(kind, "not_implemented_7a_i");
|
||||||
|
}
|
||||||
|
other => panic!("expected Error{{kind=not_implemented_7a_i}}, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean shutdown.
|
||||||
|
stdin.write_all(b"{\"op\":\"shutdown\"}\n").await.unwrap();
|
||||||
|
stdin.flush().await.unwrap();
|
||||||
|
let _ = lines.next_line().await; // Bye
|
||||||
|
let _ = child.wait().await;
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user