Files
helexa/crates/neuron/tests/tp_worker_lifecycle_cuda.rs
rob thijssen 76ab24d98c
Some checks failed
CI / Format (push) Successful in 29s
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Test (push) Failing after 58s
CI / Clippy (push) Successful in 2m31s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m13s
build-prerelease / Build neuron-blackwell (push) Successful in 4m1s
build-prerelease / Package cortex RPM (push) Successful in 1m30s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
refactor(neuron): phase 3 — TP forward + NCCL state move onto device worker
Third slice of the per-device CUDA context-ownership refactor planned at
~/.claude/plans/plan-the-per-device-worker-abstract-micali.md. The
leader's `NcclState`, every `Comm::all_reduce` issued by the TP layers,
the leader-side KV cache reset, and the TP forward step itself now all
run on the per-device worker thread — the same OS thread that bound
the leader's `CudaContext` at startup.

What this phase changes:

- `Job` gains `NcclInit`, `NcclSanity`, `CloneLeaderComm` (Phase 3
  bridge — Phase 4 removes), `TransferInTp`, `DropTp`, `TpClearKv`,
  `TpForwardLogits`. Plus a new `TpHandle(u64)` opaque key.
- `DeviceWorkerState` gains `nccl: NcclState` and
  `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>` (+ counter).
- `WorkerPool` loses its `leader_nccl` field; gains a
  `leader_worker: Arc<DeviceWorkerHandle>` passed at construction.
  `init_nccl`, `nccl_sanity_check`, `load_dense_shard`,
  `generate_step`, `clear_kv_cache` all route their leader-side ops
  through `Job::Nccl*` / `Job::Tp*` instead of spawn_blocking against
  a Mutex-wrapped state. `generate_step` returns `Vec<f32>` instead
  of a device-resident `Tensor` — the worker copies logits to CPU
  before reply so the async caller can sample on a CPU candle
  tensor with zero device-context touch.
- `TpLoadedModel.leader_model: Arc<Mutex<TpLeaderModel>>` → opaque
  `leader_handle: TpHandle`. The boxed `TpLeaderModel` lives in the
  worker thread's slab; both the model's CUDA tensors and the
  embedded `Arc<Comm>` clones release on the same thread that
  allocated them (the Drop semantics constraint cudarc forces).
- `Job::CloneLeaderComm` is a Phase 3 bridge: the TP shard load still
  runs in spawn_blocking and needs the leader's `Arc<Comm>` to build
  the row-parallel layers' AllReduce ops. The Job clones the Comm
  out of the worker's NcclState and ships it back as `SendComm`.
  Phase 4 deletes this bridge when the load itself moves onto the
  worker.
- `Job::NcclInit` and `Job::NcclSanity` are ungated by `cuda` so the
  no-cuda `NcclState` stubs (which reply with `cuda_feature_not_enabled`)
  still flow through the same channel uniformly; the cuda-only
  TP variants (CloneLeaderComm, Transfer/Drop/Clear/Forward Tp)
  remain gated.

What this phase doesn't touch (yet):

- TP shard load itself — still spawn_blocking, bridged via
  `CloneLeaderComm`. Phase 4 moves it to `Job::TpLoadShard` and
  reads `state.nccl.comm()` directly inside the worker.
- Single-GPU model loads — still spawn_blocking, transferred via
  `Job::TransferIn`. Phase 4 moves them.
- `device_vram_mb` / `cuda_mem_mb` / `log_construction_complete`
  helpers — still present, used inside spawn_blocking load closures.
  Phase 4 cleanup folds them into `dispatch.rs`.

`tp/mod.rs::WorkerPool::spawn` gained a required
`leader_worker: Arc<DeviceWorkerHandle>` argument. Three external
callers were updated: `CandleHarness::load_tp` (passes the cached
device worker), `main.rs::tp_smoke` (spawns a fresh worker), and
the two `tp_worker_lifecycle*.rs` integration tests.

Public API unchanged. fmt + clippy clean; 37 lib tests + all
integration tests pass. CUDA-only TP integration smoke deferred to
the next deploy on beast.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-27 10:16:02 +03:00

46 lines
1.5 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Stage 7a-ii: real NCCL handshake across the worker pool.
//!
//! Gated behind the `cuda-integration` feature because it requires
//! libnccl AND multiple CUDA devices on the running host. Run on
//! beast (2× RTX 5090) via:
//!
//! cargo test -p neuron --features cuda-integration \
//! --test tp_worker_lifecycle_cuda
//!
//! Steps: spawn N-1 workers, call `init_nccl`, run `nccl_sanity_check`
//! (every rank `all_reduce`s `1u32` with Sum; expected total =
//! world_size), shut down cleanly.
#![cfg(feature = "cuda-integration")]
use neuron::harness::tp::WorkerPool;
const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
#[tokio::test]
async fn test_init_and_sanity_check_two_ranks() {
let _ = tracing_subscriber::fmt()
.with_test_writer()
.with_env_filter("info,neuron=debug")
.try_init();
// 2 ranks: leader = rank 0 on device 0, worker = rank 1 on device 1.
let leader_worker = neuron::harness::device_worker::DeviceWorkerHandle::spawn(0)
.expect("spawn leader device worker");
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1], leader_worker)
.await
.expect("spawn worker pool");
pool.ping_all().await.expect("pong all workers");
pool.init_nccl(0)
.await
.expect("init_nccl: NCCL handshake across all ranks");
pool.nccl_sanity_check()
.await
.expect("nccl_sanity_check: observed_sum == world_size on all ranks");
pool.shutdown().await.expect("clean shutdown");
}