Stage 7a-i: TP worker lifecycle scaffolding
All checks were successful
CI / Format (push) Successful in 36s
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Clippy (push) Successful in 2m12s
CI / Test (push) Successful in 4m25s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m49s
build-prerelease / Build cortex binary (push) Successful in 4m22s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 5m9s
build-prerelease / Build neuron-ada (push) Successful in 4m59s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m59s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m38s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m8s

Leader → worker process plumbing for tensor parallelism. The neuron
binary picks up two modes: default (the existing daemon, axum + HTTP)
and `--worker` (a bare RPC loop driven over stdin/stdout). The leader
spawns one worker per non-zero NCCL rank via tokio::process::Command
on the same binary path (production: /proc/self/exe; tests:
env!("CARGO_BIN_EXE_neuron")) and talks to each over newline-
delimited JSON.

Protocol (harness/tp/rpc.rs) is serde-tagged from the start —
WorkerRequest::{Ping, Init, NcclSanityCheck, Shutdown} and
WorkerResponse::{Pong, InitOk, NcclSanityResult, Bye, Error}, both
`#[serde(tag = "op", rename_all = "snake_case")]`. Adding ops in 7b/7c
is purely additive; unknown ops on the wire fail to parse (verified
in unit tests).

7a-i scope:
- WorkerPool::spawn(binary, world_size, devices) forks ranks 1..N as
  subprocesses, captures stdin/stdout, kills on drop.
- ping_all() round-trips a Ping to every worker and validates the
  returned rank.
- shutdown() sends Shutdown to each worker, awaits Bye, reaps.
- Worker mode: parse Ping/Shutdown, return Pong/Bye; Init and
  NcclSanityCheck return Error{kind="not_implemented_7a_i"} so a 7a-ii
  binary speaking the same wire is a drop-in replacement (the kind
  field signals "real NCCL lands in the next commit").
- CandleHarness::load_model refuses tensor_parallel > 1 with a clear
  message until 7b is in.

Three integration tests in tests/tp_worker_lifecycle.rs cover spawn/
ping/shutdown for 2- and 3-worker pools, plus the
not_implemented_7a_i contract test for Init. Seven rpc serde unit
tests assert the wire shape (op tags, field names, unknown-op
rejection). All pass on the dev host; no CUDA required.

Stage 7a-ii (next): the real NCCL Comm::from_rank wiring behind the
existing Init/NcclSanityCheck op surface, CUDA-gated. Verifiable on
beast's 2×5090.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-19 15:53:00 +03:00
parent 18ae3c30ee
commit 2a7ede0232
7 changed files with 687 additions and 3 deletions

View File

@@ -0,0 +1,186 @@
//! Wire protocol between the neuron leader process and its
//! `--worker` subprocesses.
//!
//! Every frame is one newline-delimited JSON object on the worker's
//! stdin (request) or stdout (response). Both directions are tagged
//! sum types from the start so new ops in Stage 7b/7c slot in without
//! breaking compatibility — no "14 message types and a version field"
//! drift later. Adding a new variant is the canonical way to evolve
//! the protocol; existing peers that don't recognise an op return
//! `WorkerResponse::Error { kind: "unknown_op", .. }`.
//!
//! The serialised shape uses `tag = "op"` so a request looks like:
//! {"op":"ping"}
//! {"op":"init","comm_id":"a1b2..."}
//! and a response:
//! {"op":"pong","rank":0,"world_size":2,"cuda_device":0}
//! {"op":"error","kind":"nccl_init_failed","message":"..."}
use serde::{Deserialize, Serialize};
/// Leader → worker. Worker handles one at a time; replies with exactly
/// one `WorkerResponse` per request.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "op", rename_all = "snake_case")]
pub enum WorkerRequest {
/// Liveness probe. Worker replies with `Pong` containing its own
/// identity. Used by the leader to confirm the subprocess is up
/// and ready before kicking off any heavier work.
Ping,
/// One-shot NCCL communicator setup. The leader generates the
/// `comm_id` once (rank 0 of NCCL), broadcasts it to every worker
/// via this message, then every rank (leader included) calls
/// `Comm::from_rank` with the same id — NCCL blocks until all
/// `world_size` ranks check in. The hex-encoded bytes are the
/// canonical `cudarc::nccl::Id::internal()` content.
Init {
/// Hex-encoded NCCL id bytes (128 bytes → 256 hex chars).
comm_id: String,
},
/// Sanity check: after Init, every rank runs an `all_reduce` over
/// a sentinel value (`1u32`). The expected sum is `world_size`.
/// Worker replies with the observed value so the leader can verify
/// the NCCL handshake is genuinely live, not just configured.
NcclSanityCheck,
/// Worker should release resources and exit. Worker replies `Bye`
/// and then closes stdout / exits zero. The leader reaps the
/// child via the `tokio::process::Child` it kept.
Shutdown,
}
/// Worker → leader. Always exactly one of these per `WorkerRequest`.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "op", rename_all = "snake_case")]
pub enum WorkerResponse {
/// Reply to `Ping`. Carries enough identity for the leader to log
/// what it actually got back.
Pong {
rank: u32,
world_size: u32,
cuda_device: u32,
},
/// Reply to `Init`. Empty payload — success is the absence of
/// `Error`. NCCL's internal blocking handshake means by the time
/// this comes back, every other rank has also reached
/// `Comm::from_rank`.
InitOk,
/// Reply to `NcclSanityCheck`. The observed sum after a single
/// `all_reduce(SUM, 1u32)` across all ranks. The leader checks
/// this matches `world_size`.
NcclSanityResult { observed_sum: u32 },
/// Reply to `Shutdown`. Worker exits immediately after writing this.
Bye,
/// Any request can produce this instead of its dedicated success
/// variant. `kind` is a machine-readable category so the leader
/// can branch on failure mode without string-matching `message`.
Error {
/// Short tag — `nccl_init_failed`, `unknown_op`, etc.
kind: String,
/// Human-readable detail for logs.
message: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
fn roundtrip<T>(value: &T) -> T
where
T: Serialize + for<'de> Deserialize<'de>,
{
serde_json::from_str(&serde_json::to_string(value).expect("serialise"))
.expect("deserialise")
}
#[test]
fn request_ping_round_trip() {
let req = WorkerRequest::Ping;
let wire = serde_json::to_string(&req).unwrap();
assert_eq!(wire, r#"{"op":"ping"}"#);
match roundtrip(&req) {
WorkerRequest::Ping => {}
other => panic!("expected Ping, got {other:?}"),
}
}
#[test]
fn request_init_carries_hex_id() {
let req = WorkerRequest::Init {
comm_id: "deadbeef".into(),
};
let wire = serde_json::to_string(&req).unwrap();
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
}
#[test]
fn request_shutdown_round_trip() {
assert_eq!(
serde_json::to_string(&WorkerRequest::Shutdown).unwrap(),
r#"{"op":"shutdown"}"#
);
}
#[test]
fn response_pong_round_trip() {
let resp = WorkerResponse::Pong {
rank: 1,
world_size: 4,
cuda_device: 1,
};
let wire = serde_json::to_string(&resp).unwrap();
assert!(wire.contains(r#""op":"pong""#));
assert!(wire.contains(r#""rank":1"#));
assert!(wire.contains(r#""world_size":4"#));
match roundtrip(&resp) {
WorkerResponse::Pong {
rank,
world_size,
cuda_device,
} => {
assert_eq!(rank, 1);
assert_eq!(world_size, 4);
assert_eq!(cuda_device, 1);
}
other => panic!("expected Pong, got {other:?}"),
}
}
#[test]
fn response_error_carries_kind_and_message() {
let resp = WorkerResponse::Error {
kind: "nccl_init_failed".into(),
message: "could not bind device".into(),
};
let wire = serde_json::to_string(&resp).unwrap();
assert!(wire.contains(r#""op":"error""#));
assert!(wire.contains(r#""kind":"nccl_init_failed""#));
}
#[test]
fn response_sanity_result_round_trip() {
let resp = WorkerResponse::NcclSanityResult { observed_sum: 4 };
match roundtrip(&resp) {
WorkerResponse::NcclSanityResult { observed_sum } => {
assert_eq!(observed_sum, 4);
}
other => panic!("expected NcclSanityResult, got {other:?}"),
}
}
/// Unknown ops on the wire deserialise to an error rather than
/// silently mis-matching — confirms our `serde(tag = "op")`
/// configuration rejects unknowns instead of doing fuzzy matching.
#[test]
fn unknown_op_fails_to_parse() {
let result: Result<WorkerRequest, _> = serde_json::from_str(r#"{"op":"explode"}"#);
assert!(result.is_err(), "should reject unknown op, got {result:?}");
}
}